You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
258 lines
9.6 KiB
258 lines
9.6 KiB
# Copyright 2016 Paul Balanca. All Rights Reserved. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
# ============================================================================== |
|
"""Diverse TensorFlow utils, for training, evaluation and so on! |
|
""" |
|
import os |
|
from pprint import pprint |
|
|
|
import tensorflow as tf |
|
from tensorflow.contrib.slim.python.slim.data import parallel_reader |
|
|
|
slim = tf.contrib.slim |
|
|
|
|
|
# =========================================================================== # |
|
# General tools. |
|
# =========================================================================== # |
|
def reshape_list(l, shape=None): |
|
"""Reshape list of (list): 1D to 2D or the other way around. |
|
|
|
Args: |
|
l: List or List of list. |
|
shape: 1D or 2D shape. |
|
Return |
|
Reshaped list. |
|
""" |
|
r = [] |
|
if shape is None: |
|
# Flatten everything. |
|
for a in l: |
|
if isinstance(a, (list, tuple)): |
|
r = r + list(a) |
|
else: |
|
r.append(a) |
|
else: |
|
# Reshape to list of list. |
|
i = 0 |
|
for s in shape: |
|
if s == 1: |
|
r.append(l[i]) |
|
else: |
|
r.append(l[i:i+s]) |
|
i += s |
|
return r |
|
|
|
|
|
# =========================================================================== # |
|
# Training utils. |
|
# =========================================================================== # |
|
def print_configuration(flags, ssd_params, data_sources, save_dir=None): |
|
"""Print the training configuration. |
|
""" |
|
def print_config(stream=None): |
|
print('\n# =========================================================================== #', file=stream) |
|
print('# Training | Evaluation flags:', file=stream) |
|
print('# =========================================================================== #', file=stream) |
|
pprint(flags, stream=stream) |
|
|
|
print('\n# =========================================================================== #', file=stream) |
|
print('# SSD net parameters:', file=stream) |
|
print('# =========================================================================== #', file=stream) |
|
pprint(dict(ssd_params._asdict()), stream=stream) |
|
|
|
print('\n# =========================================================================== #', file=stream) |
|
print('# Training | Evaluation dataset files:', file=stream) |
|
print('# =========================================================================== #', file=stream) |
|
data_files = parallel_reader.get_data_files(data_sources) |
|
pprint(sorted(data_files), stream=stream) |
|
print('', file=stream) |
|
|
|
print_config(None) |
|
# Save to a text file as well. |
|
if save_dir is not None: |
|
if not os.path.exists(save_dir): |
|
os.makedirs(save_dir) |
|
path = os.path.join(save_dir, 'training_config.txt') |
|
with open(path, "w") as out: |
|
print_config(out) |
|
|
|
|
|
def configure_learning_rate(flags, num_samples_per_epoch, global_step): |
|
"""Configures the learning rate. |
|
|
|
Args: |
|
num_samples_per_epoch: The number of samples in each epoch of training. |
|
global_step: The global_step tensor. |
|
Returns: |
|
A `Tensor` representing the learning rate. |
|
""" |
|
decay_steps = int(num_samples_per_epoch / flags.batch_size * |
|
flags.num_epochs_per_decay) |
|
|
|
if flags.learning_rate_decay_type == 'exponential': |
|
return tf.train.exponential_decay(flags.learning_rate, |
|
global_step, |
|
decay_steps, |
|
flags.learning_rate_decay_factor, |
|
staircase=True, |
|
name='exponential_decay_learning_rate') |
|
elif flags.learning_rate_decay_type == 'fixed': |
|
return tf.constant(flags.learning_rate, name='fixed_learning_rate') |
|
elif flags.learning_rate_decay_type == 'polynomial': |
|
return tf.train.polynomial_decay(flags.learning_rate, |
|
global_step, |
|
decay_steps, |
|
flags.end_learning_rate, |
|
power=1.0, |
|
cycle=False, |
|
name='polynomial_decay_learning_rate') |
|
else: |
|
raise ValueError('learning_rate_decay_type [%s] was not recognized', |
|
flags.learning_rate_decay_type) |
|
|
|
|
|
def configure_optimizer(flags, learning_rate): |
|
"""Configures the optimizer used for training. |
|
|
|
Args: |
|
learning_rate: A scalar or `Tensor` learning rate. |
|
Returns: |
|
An instance of an optimizer. |
|
""" |
|
if flags.optimizer == 'adadelta': |
|
optimizer = tf.train.AdadeltaOptimizer( |
|
learning_rate, |
|
rho=flags.adadelta_rho, |
|
epsilon=flags.opt_epsilon) |
|
elif flags.optimizer == 'adagrad': |
|
optimizer = tf.train.AdagradOptimizer( |
|
learning_rate, |
|
initial_accumulator_value=flags.adagrad_initial_accumulator_value) |
|
elif flags.optimizer == 'adam': |
|
optimizer = tf.train.AdamOptimizer( |
|
learning_rate, |
|
beta1=flags.adam_beta1, |
|
beta2=flags.adam_beta2, |
|
epsilon=flags.opt_epsilon) |
|
elif flags.optimizer == 'ftrl': |
|
optimizer = tf.train.FtrlOptimizer( |
|
learning_rate, |
|
learning_rate_power=flags.ftrl_learning_rate_power, |
|
initial_accumulator_value=flags.ftrl_initial_accumulator_value, |
|
l1_regularization_strength=flags.ftrl_l1, |
|
l2_regularization_strength=flags.ftrl_l2) |
|
elif flags.optimizer == 'momentum': |
|
optimizer = tf.train.MomentumOptimizer( |
|
learning_rate, |
|
momentum=flags.momentum, |
|
name='Momentum') |
|
elif flags.optimizer == 'rmsprop': |
|
optimizer = tf.train.RMSPropOptimizer( |
|
learning_rate, |
|
decay=flags.rmsprop_decay, |
|
momentum=flags.rmsprop_momentum, |
|
epsilon=flags.opt_epsilon) |
|
elif flags.optimizer == 'sgd': |
|
optimizer = tf.train.GradientDescentOptimizer(learning_rate) |
|
else: |
|
raise ValueError('Optimizer [%s] was not recognized', flags.optimizer) |
|
return optimizer |
|
|
|
|
|
def add_variables_summaries(learning_rate): |
|
summaries = [] |
|
for variable in slim.get_model_variables(): |
|
summaries.append(tf.summary.histogram(variable.op.name, variable)) |
|
summaries.append(tf.summary.scalar('training/Learning Rate', learning_rate)) |
|
return summaries |
|
|
|
|
|
def update_model_scope(var, ckpt_scope, new_scope): |
|
return var.op.name.replace(new_scope,'vgg_16') |
|
|
|
|
|
def get_init_fn(flags): |
|
"""Returns a function run by the chief worker to warm-start the training. |
|
Note that the init_fn is only run when initializing the model during the very |
|
first global step. |
|
|
|
Returns: |
|
An init function run by the supervisor. |
|
""" |
|
if flags.checkpoint_path is None: |
|
return None |
|
# Warn the user if a checkpoint exists in the train_dir. Then ignore. |
|
if tf.train.latest_checkpoint(flags.train_dir): |
|
tf.logging.info( |
|
'Ignoring --checkpoint_path because a checkpoint already exists in %s' |
|
% flags.train_dir) |
|
return None |
|
|
|
exclusions = [] |
|
if flags.checkpoint_exclude_scopes: |
|
exclusions = [scope.strip() |
|
for scope in flags.checkpoint_exclude_scopes.split(',')] |
|
|
|
# TODO(sguada) variables.filter_variables() |
|
variables_to_restore = [] |
|
for var in slim.get_model_variables(): |
|
excluded = False |
|
for exclusion in exclusions: |
|
if var.op.name.startswith(exclusion): |
|
excluded = True |
|
break |
|
if not excluded: |
|
variables_to_restore.append(var) |
|
# Change model scope if necessary. |
|
if flags.checkpoint_model_scope is not None: |
|
variables_to_restore = \ |
|
{var.op.name.replace(flags.model_name, |
|
flags.checkpoint_model_scope): var |
|
for var in variables_to_restore} |
|
|
|
|
|
if tf.gfile.IsDirectory(flags.checkpoint_path): |
|
checkpoint_path = tf.train.latest_checkpoint(flags.checkpoint_path) |
|
else: |
|
checkpoint_path = flags.checkpoint_path |
|
tf.logging.info('Fine-tuning from %s. Ignoring missing vars: %s' % (checkpoint_path, flags.ignore_missing_vars)) |
|
|
|
return slim.assign_from_checkpoint_fn( |
|
checkpoint_path, |
|
variables_to_restore, |
|
ignore_missing_vars=flags.ignore_missing_vars) |
|
|
|
|
|
def get_variables_to_train(flags): |
|
"""Returns a list of variables to train. |
|
|
|
Returns: |
|
A list of variables to train by the optimizer. |
|
""" |
|
if flags.trainable_scopes is None: |
|
return tf.trainable_variables() |
|
else: |
|
scopes = [scope.strip() for scope in flags.trainable_scopes.split(',')] |
|
|
|
variables_to_train = [] |
|
for scope in scopes: |
|
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope) |
|
variables_to_train.extend(variables) |
|
return variables_to_train |
|
|
|
|
|
# =========================================================================== # |
|
# Evaluation utils. |
|
# =========================================================================== #
|
|
|