You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

95 lines
3.6 KiB

# Copyright 2017 Paul Balanca. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TF Extended: additional tensors operations.
"""
import tensorflow as tf
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.metrics.python.ops import set_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
def get_shape(x, rank=None):
"""Returns the dimensions of a Tensor as list of integers or scale tensors.
Args:
x: N-d Tensor;
rank: Rank of the Tensor. If None, will try to guess it.
Returns:
A list of `[d1, d2, ..., dN]` corresponding to the dimensions of the
input tensor. Dimensions that are statically known are python integers,
otherwise they are integer scalar tensors.
"""
if x.get_shape().is_fully_defined():
return x.get_shape().as_list()
else:
static_shape = x.get_shape()
if rank is None:
static_shape = static_shape.as_list()
rank = len(static_shape)
else:
static_shape = x.get_shape().with_rank(rank).as_list()
dynamic_shape = tf.unstack(tf.shape(x), rank)
return [s if s is not None else d
for s, d in zip(static_shape, dynamic_shape)]
def pad_axis(x, offset, size, axis=0, name=None):
"""Pad a tensor on an axis, with a given offset and output size.
The tensor is padded with zero (i.e. CONSTANT mode). Note that the if the
`size` is smaller than existing size + `offset`, the output tensor
was the latter dimension.
Args:
x: Tensor to pad;
offset: Offset to add on the dimension chosen;
size: Final size of the dimension.
Return:
Padded tensor whose dimension on `axis` is `size`, or greater if
the input vector was larger.
"""
with tf.name_scope(name, 'pad_axis'):
shape = get_shape(x)
rank = len(shape)
# Padding description.
new_size = tf.maximum(size-offset-shape[axis], 0)
pad1 = tf.stack([0]*axis + [offset] + [0]*(rank-axis-1))
pad2 = tf.stack([0]*axis + [new_size] + [0]*(rank-axis-1))
paddings = tf.stack([pad1, pad2], axis=1)
x = tf.pad(x, paddings, mode='CONSTANT')
# Reshape, to get fully defined shape if possible.
# TODO: fix with tf.slice
shape[axis] = size
x = tf.reshape(x, tf.stack(shape))
return x
# def select_at_index(idx, val, t):
# """Return a tensor.
# """
# idx = tf.expand_dims(tf.expand_dims(idx, 0), 0)
# val = tf.expand_dims(val, 0)
# t = t + tf.scatter_nd(idx, val, tf.shape(t))
# return t