You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
259 lines
9.6 KiB
259 lines
9.6 KiB
5 years ago
|
# Copyright 2016 Paul Balanca. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Diverse TensorFlow utils, for training, evaluation and so on!
|
||
|
"""
|
||
|
import os
|
||
|
from pprint import pprint
|
||
|
|
||
|
import tensorflow as tf
|
||
|
from tensorflow.contrib.slim.python.slim.data import parallel_reader
|
||
|
|
||
|
slim = tf.contrib.slim
|
||
|
|
||
|
|
||
|
# =========================================================================== #
|
||
|
# General tools.
|
||
|
# =========================================================================== #
|
||
|
def reshape_list(l, shape=None):
|
||
|
"""Reshape list of (list): 1D to 2D or the other way around.
|
||
|
|
||
|
Args:
|
||
|
l: List or List of list.
|
||
|
shape: 1D or 2D shape.
|
||
|
Return
|
||
|
Reshaped list.
|
||
|
"""
|
||
|
r = []
|
||
|
if shape is None:
|
||
|
# Flatten everything.
|
||
|
for a in l:
|
||
|
if isinstance(a, (list, tuple)):
|
||
|
r = r + list(a)
|
||
|
else:
|
||
|
r.append(a)
|
||
|
else:
|
||
|
# Reshape to list of list.
|
||
|
i = 0
|
||
|
for s in shape:
|
||
|
if s == 1:
|
||
|
r.append(l[i])
|
||
|
else:
|
||
|
r.append(l[i:i+s])
|
||
|
i += s
|
||
|
return r
|
||
|
|
||
|
|
||
|
# =========================================================================== #
|
||
|
# Training utils.
|
||
|
# =========================================================================== #
|
||
|
def print_configuration(flags, ssd_params, data_sources, save_dir=None):
|
||
|
"""Print the training configuration.
|
||
|
"""
|
||
|
def print_config(stream=None):
|
||
|
print('\n# =========================================================================== #', file=stream)
|
||
|
print('# Training | Evaluation flags:', file=stream)
|
||
|
print('# =========================================================================== #', file=stream)
|
||
|
pprint(flags, stream=stream)
|
||
|
|
||
|
print('\n# =========================================================================== #', file=stream)
|
||
|
print('# SSD net parameters:', file=stream)
|
||
|
print('# =========================================================================== #', file=stream)
|
||
|
pprint(dict(ssd_params._asdict()), stream=stream)
|
||
|
|
||
|
print('\n# =========================================================================== #', file=stream)
|
||
|
print('# Training | Evaluation dataset files:', file=stream)
|
||
|
print('# =========================================================================== #', file=stream)
|
||
|
data_files = parallel_reader.get_data_files(data_sources)
|
||
|
pprint(sorted(data_files), stream=stream)
|
||
|
print('', file=stream)
|
||
|
|
||
|
print_config(None)
|
||
|
# Save to a text file as well.
|
||
|
if save_dir is not None:
|
||
|
if not os.path.exists(save_dir):
|
||
|
os.makedirs(save_dir)
|
||
|
path = os.path.join(save_dir, 'training_config.txt')
|
||
|
with open(path, "w") as out:
|
||
|
print_config(out)
|
||
|
|
||
|
|
||
|
def configure_learning_rate(flags, num_samples_per_epoch, global_step):
|
||
|
"""Configures the learning rate.
|
||
|
|
||
|
Args:
|
||
|
num_samples_per_epoch: The number of samples in each epoch of training.
|
||
|
global_step: The global_step tensor.
|
||
|
Returns:
|
||
|
A `Tensor` representing the learning rate.
|
||
|
"""
|
||
|
decay_steps = int(num_samples_per_epoch / flags.batch_size *
|
||
|
flags.num_epochs_per_decay)
|
||
|
|
||
|
if flags.learning_rate_decay_type == 'exponential':
|
||
|
return tf.train.exponential_decay(flags.learning_rate,
|
||
|
global_step,
|
||
|
decay_steps,
|
||
|
flags.learning_rate_decay_factor,
|
||
|
staircase=True,
|
||
|
name='exponential_decay_learning_rate')
|
||
|
elif flags.learning_rate_decay_type == 'fixed':
|
||
|
return tf.constant(flags.learning_rate, name='fixed_learning_rate')
|
||
|
elif flags.learning_rate_decay_type == 'polynomial':
|
||
|
return tf.train.polynomial_decay(flags.learning_rate,
|
||
|
global_step,
|
||
|
decay_steps,
|
||
|
flags.end_learning_rate,
|
||
|
power=1.0,
|
||
|
cycle=False,
|
||
|
name='polynomial_decay_learning_rate')
|
||
|
else:
|
||
|
raise ValueError('learning_rate_decay_type [%s] was not recognized',
|
||
|
flags.learning_rate_decay_type)
|
||
|
|
||
|
|
||
|
def configure_optimizer(flags, learning_rate):
|
||
|
"""Configures the optimizer used for training.
|
||
|
|
||
|
Args:
|
||
|
learning_rate: A scalar or `Tensor` learning rate.
|
||
|
Returns:
|
||
|
An instance of an optimizer.
|
||
|
"""
|
||
|
if flags.optimizer == 'adadelta':
|
||
|
optimizer = tf.train.AdadeltaOptimizer(
|
||
|
learning_rate,
|
||
|
rho=flags.adadelta_rho,
|
||
|
epsilon=flags.opt_epsilon)
|
||
|
elif flags.optimizer == 'adagrad':
|
||
|
optimizer = tf.train.AdagradOptimizer(
|
||
|
learning_rate,
|
||
|
initial_accumulator_value=flags.adagrad_initial_accumulator_value)
|
||
|
elif flags.optimizer == 'adam':
|
||
|
optimizer = tf.train.AdamOptimizer(
|
||
|
learning_rate,
|
||
|
beta1=flags.adam_beta1,
|
||
|
beta2=flags.adam_beta2,
|
||
|
epsilon=flags.opt_epsilon)
|
||
|
elif flags.optimizer == 'ftrl':
|
||
|
optimizer = tf.train.FtrlOptimizer(
|
||
|
learning_rate,
|
||
|
learning_rate_power=flags.ftrl_learning_rate_power,
|
||
|
initial_accumulator_value=flags.ftrl_initial_accumulator_value,
|
||
|
l1_regularization_strength=flags.ftrl_l1,
|
||
|
l2_regularization_strength=flags.ftrl_l2)
|
||
|
elif flags.optimizer == 'momentum':
|
||
|
optimizer = tf.train.MomentumOptimizer(
|
||
|
learning_rate,
|
||
|
momentum=flags.momentum,
|
||
|
name='Momentum')
|
||
|
elif flags.optimizer == 'rmsprop':
|
||
|
optimizer = tf.train.RMSPropOptimizer(
|
||
|
learning_rate,
|
||
|
decay=flags.rmsprop_decay,
|
||
|
momentum=flags.rmsprop_momentum,
|
||
|
epsilon=flags.opt_epsilon)
|
||
|
elif flags.optimizer == 'sgd':
|
||
|
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
|
||
|
else:
|
||
|
raise ValueError('Optimizer [%s] was not recognized', flags.optimizer)
|
||
|
return optimizer
|
||
|
|
||
|
|
||
|
def add_variables_summaries(learning_rate):
|
||
|
summaries = []
|
||
|
for variable in slim.get_model_variables():
|
||
|
summaries.append(tf.summary.histogram(variable.op.name, variable))
|
||
|
summaries.append(tf.summary.scalar('training/Learning Rate', learning_rate))
|
||
|
return summaries
|
||
|
|
||
|
|
||
|
def update_model_scope(var, ckpt_scope, new_scope):
|
||
|
return var.op.name.replace(new_scope,'vgg_16')
|
||
|
|
||
|
|
||
|
def get_init_fn(flags):
|
||
|
"""Returns a function run by the chief worker to warm-start the training.
|
||
|
Note that the init_fn is only run when initializing the model during the very
|
||
|
first global step.
|
||
|
|
||
|
Returns:
|
||
|
An init function run by the supervisor.
|
||
|
"""
|
||
|
if flags.checkpoint_path is None:
|
||
|
return None
|
||
|
# Warn the user if a checkpoint exists in the train_dir. Then ignore.
|
||
|
if tf.train.latest_checkpoint(flags.train_dir):
|
||
|
tf.logging.info(
|
||
|
'Ignoring --checkpoint_path because a checkpoint already exists in %s'
|
||
|
% flags.train_dir)
|
||
|
return None
|
||
|
|
||
|
exclusions = []
|
||
|
if flags.checkpoint_exclude_scopes:
|
||
|
exclusions = [scope.strip()
|
||
|
for scope in flags.checkpoint_exclude_scopes.split(',')]
|
||
|
|
||
|
# TODO(sguada) variables.filter_variables()
|
||
|
variables_to_restore = []
|
||
|
for var in slim.get_model_variables():
|
||
|
excluded = False
|
||
|
for exclusion in exclusions:
|
||
|
if var.op.name.startswith(exclusion):
|
||
|
excluded = True
|
||
|
break
|
||
|
if not excluded:
|
||
|
variables_to_restore.append(var)
|
||
|
# Change model scope if necessary.
|
||
|
if flags.checkpoint_model_scope is not None:
|
||
|
variables_to_restore = \
|
||
|
{var.op.name.replace(flags.model_name,
|
||
|
flags.checkpoint_model_scope): var
|
||
|
for var in variables_to_restore}
|
||
|
|
||
|
|
||
|
if tf.gfile.IsDirectory(flags.checkpoint_path):
|
||
|
checkpoint_path = tf.train.latest_checkpoint(flags.checkpoint_path)
|
||
|
else:
|
||
|
checkpoint_path = flags.checkpoint_path
|
||
|
tf.logging.info('Fine-tuning from %s. Ignoring missing vars: %s' % (checkpoint_path, flags.ignore_missing_vars))
|
||
|
|
||
|
return slim.assign_from_checkpoint_fn(
|
||
|
checkpoint_path,
|
||
|
variables_to_restore,
|
||
|
ignore_missing_vars=flags.ignore_missing_vars)
|
||
|
|
||
|
|
||
|
def get_variables_to_train(flags):
|
||
|
"""Returns a list of variables to train.
|
||
|
|
||
|
Returns:
|
||
|
A list of variables to train by the optimizer.
|
||
|
"""
|
||
|
if flags.trainable_scopes is None:
|
||
|
return tf.trainable_variables()
|
||
|
else:
|
||
|
scopes = [scope.strip() for scope in flags.trainable_scopes.split(',')]
|
||
|
|
||
|
variables_to_train = []
|
||
|
for scope in scopes:
|
||
|
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
|
||
|
variables_to_train.extend(variables)
|
||
|
return variables_to_train
|
||
|
|
||
|
|
||
|
# =========================================================================== #
|
||
|
# Evaluation utils.
|
||
|
# =========================================================================== #
|