You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
95 lines
3.6 KiB
95 lines
3.6 KiB
# Copyright 2017 Paul Balanca. All Rights Reserved. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
# ============================================================================== |
|
"""TF Extended: additional tensors operations. |
|
""" |
|
import tensorflow as tf |
|
|
|
from tensorflow.contrib.framework.python.ops import variables as contrib_variables |
|
from tensorflow.contrib.metrics.python.ops import set_ops |
|
from tensorflow.python.framework import dtypes |
|
from tensorflow.python.framework import ops |
|
from tensorflow.python.framework import sparse_tensor |
|
from tensorflow.python.ops import array_ops |
|
from tensorflow.python.ops import check_ops |
|
from tensorflow.python.ops import control_flow_ops |
|
from tensorflow.python.ops import math_ops |
|
from tensorflow.python.ops import nn |
|
from tensorflow.python.ops import state_ops |
|
from tensorflow.python.ops import variable_scope |
|
from tensorflow.python.ops import variables |
|
|
|
|
|
def get_shape(x, rank=None): |
|
"""Returns the dimensions of a Tensor as list of integers or scale tensors. |
|
|
|
Args: |
|
x: N-d Tensor; |
|
rank: Rank of the Tensor. If None, will try to guess it. |
|
Returns: |
|
A list of `[d1, d2, ..., dN]` corresponding to the dimensions of the |
|
input tensor. Dimensions that are statically known are python integers, |
|
otherwise they are integer scalar tensors. |
|
""" |
|
if x.get_shape().is_fully_defined(): |
|
return x.get_shape().as_list() |
|
else: |
|
static_shape = x.get_shape() |
|
if rank is None: |
|
static_shape = static_shape.as_list() |
|
rank = len(static_shape) |
|
else: |
|
static_shape = x.get_shape().with_rank(rank).as_list() |
|
dynamic_shape = tf.unstack(tf.shape(x), rank) |
|
return [s if s is not None else d |
|
for s, d in zip(static_shape, dynamic_shape)] |
|
|
|
|
|
def pad_axis(x, offset, size, axis=0, name=None): |
|
"""Pad a tensor on an axis, with a given offset and output size. |
|
The tensor is padded with zero (i.e. CONSTANT mode). Note that the if the |
|
`size` is smaller than existing size + `offset`, the output tensor |
|
was the latter dimension. |
|
|
|
Args: |
|
x: Tensor to pad; |
|
offset: Offset to add on the dimension chosen; |
|
size: Final size of the dimension. |
|
Return: |
|
Padded tensor whose dimension on `axis` is `size`, or greater if |
|
the input vector was larger. |
|
""" |
|
with tf.name_scope(name, 'pad_axis'): |
|
shape = get_shape(x) |
|
rank = len(shape) |
|
# Padding description. |
|
new_size = tf.maximum(size-offset-shape[axis], 0) |
|
pad1 = tf.stack([0]*axis + [offset] + [0]*(rank-axis-1)) |
|
pad2 = tf.stack([0]*axis + [new_size] + [0]*(rank-axis-1)) |
|
paddings = tf.stack([pad1, pad2], axis=1) |
|
x = tf.pad(x, paddings, mode='CONSTANT') |
|
# Reshape, to get fully defined shape if possible. |
|
# TODO: fix with tf.slice |
|
shape[axis] = size |
|
x = tf.reshape(x, tf.stack(shape)) |
|
return x |
|
|
|
|
|
# def select_at_index(idx, val, t): |
|
# """Return a tensor. |
|
# """ |
|
# idx = tf.expand_dims(tf.expand_dims(idx, 0), 0) |
|
# val = tf.expand_dims(val, 0) |
|
# t = t + tf.scatter_nd(idx, val, tf.shape(t)) |
|
# return t
|
|
|