# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Contains utilities for downloading and converting datasets.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import sys import tarfile from six.moves import urllib import tensorflow as tf LABELS_FILENAME = 'labels.txt' def int64_feature(value): """Wrapper for inserting int64 features into Example proto. """ if not isinstance(value, list): value = [value] return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) def float_feature(value): """Wrapper for inserting float features into Example proto. """ if not isinstance(value, list): value = [value] return tf.train.Feature(float_list=tf.train.FloatList(value=value)) def bytes_feature(value): """Wrapper for inserting bytes features into Example proto. """ if not isinstance(value, list): value = [value] return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) def image_to_tfexample(image_data, image_format, height, width, class_id): return tf.train.Example(features=tf.train.Features(feature={ 'image/encoded': bytes_feature(image_data), 'image/format': bytes_feature(image_format), 'image/class/label': int64_feature(class_id), 'image/height': int64_feature(height), 'image/width': int64_feature(width), })) def download_and_uncompress_tarball(tarball_url, dataset_dir): """Downloads the `tarball_url` and uncompresses it locally. Args: tarball_url: The URL of a tarball file. dataset_dir: The directory where the temporary files are stored. """ filename = tarball_url.split('/')[-1] filepath = os.path.join(dataset_dir, filename) def _progress(count, block_size, total_size): sys.stdout.write('\r>> Downloading %s %.1f%%' % ( filename, float(count * block_size) / float(total_size) * 100.0)) sys.stdout.flush() filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress) print() statinfo = os.stat(filepath) print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') tarfile.open(filepath, 'r:gz').extractall(dataset_dir) def write_label_file(labels_to_class_names, dataset_dir, filename=LABELS_FILENAME): """Writes a file with the list of class names. Args: labels_to_class_names: A map of (integer) labels to class names. dataset_dir: The directory in which the labels file should be written. filename: The filename where the class names are written. """ labels_filename = os.path.join(dataset_dir, filename) with tf.gfile.Open(labels_filename, 'w') as f: for label in labels_to_class_names: class_name = labels_to_class_names[label] f.write('%d:%s\n' % (label, class_name)) def has_labels(dataset_dir, filename=LABELS_FILENAME): """Specifies whether or not the dataset directory contains a label map file. Args: dataset_dir: The directory in which the labels file is found. filename: The filename where the class names are written. Returns: `True` if the labels file exists and `False` otherwise. """ return tf.gfile.Exists(os.path.join(dataset_dir, filename)) def read_label_file(dataset_dir, filename=LABELS_FILENAME): """Reads the labels file and returns a mapping from ID to class name. Args: dataset_dir: The directory in which the labels file is found. filename: The filename where the class names are written. Returns: A map from a label (integer) to class name. """ labels_filename = os.path.join(dataset_dir, filename) with tf.gfile.Open(labels_filename, 'rb') as f: lines = f.read() lines = lines.split(b'\n') lines = filter(None, lines) labels_to_class_names = {} for line in lines: index = line.index(b':') labels_to_class_names[int(line[:index])] = line[index+1:] return labels_to_class_names