You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
134 lines
4.6 KiB
134 lines
4.6 KiB
# Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
# ============================================================================== |
|
"""Contains utilities for downloading and converting datasets.""" |
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import os |
|
import sys |
|
import tarfile |
|
|
|
from six.moves import urllib |
|
import tensorflow as tf |
|
|
|
LABELS_FILENAME = 'labels.txt' |
|
|
|
|
|
def int64_feature(value): |
|
"""Wrapper for inserting int64 features into Example proto. |
|
""" |
|
if not isinstance(value, list): |
|
value = [value] |
|
return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) |
|
|
|
|
|
def float_feature(value): |
|
"""Wrapper for inserting float features into Example proto. |
|
""" |
|
if not isinstance(value, list): |
|
value = [value] |
|
return tf.train.Feature(float_list=tf.train.FloatList(value=value)) |
|
|
|
|
|
def bytes_feature(value): |
|
"""Wrapper for inserting bytes features into Example proto. |
|
""" |
|
if not isinstance(value, list): |
|
value = [value] |
|
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) |
|
|
|
|
|
def image_to_tfexample(image_data, image_format, height, width, class_id): |
|
return tf.train.Example(features=tf.train.Features(feature={ |
|
'image/encoded': bytes_feature(image_data), |
|
'image/format': bytes_feature(image_format), |
|
'image/class/label': int64_feature(class_id), |
|
'image/height': int64_feature(height), |
|
'image/width': int64_feature(width), |
|
})) |
|
|
|
|
|
def download_and_uncompress_tarball(tarball_url, dataset_dir): |
|
"""Downloads the `tarball_url` and uncompresses it locally. |
|
|
|
Args: |
|
tarball_url: The URL of a tarball file. |
|
dataset_dir: The directory where the temporary files are stored. |
|
""" |
|
filename = tarball_url.split('/')[-1] |
|
filepath = os.path.join(dataset_dir, filename) |
|
|
|
def _progress(count, block_size, total_size): |
|
sys.stdout.write('\r>> Downloading %s %.1f%%' % ( |
|
filename, float(count * block_size) / float(total_size) * 100.0)) |
|
sys.stdout.flush() |
|
filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress) |
|
print() |
|
statinfo = os.stat(filepath) |
|
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') |
|
tarfile.open(filepath, 'r:gz').extractall(dataset_dir) |
|
|
|
|
|
def write_label_file(labels_to_class_names, dataset_dir, |
|
filename=LABELS_FILENAME): |
|
"""Writes a file with the list of class names. |
|
|
|
Args: |
|
labels_to_class_names: A map of (integer) labels to class names. |
|
dataset_dir: The directory in which the labels file should be written. |
|
filename: The filename where the class names are written. |
|
""" |
|
labels_filename = os.path.join(dataset_dir, filename) |
|
with tf.gfile.Open(labels_filename, 'w') as f: |
|
for label in labels_to_class_names: |
|
class_name = labels_to_class_names[label] |
|
f.write('%d:%s\n' % (label, class_name)) |
|
|
|
|
|
def has_labels(dataset_dir, filename=LABELS_FILENAME): |
|
"""Specifies whether or not the dataset directory contains a label map file. |
|
|
|
Args: |
|
dataset_dir: The directory in which the labels file is found. |
|
filename: The filename where the class names are written. |
|
|
|
Returns: |
|
`True` if the labels file exists and `False` otherwise. |
|
""" |
|
return tf.gfile.Exists(os.path.join(dataset_dir, filename)) |
|
|
|
|
|
def read_label_file(dataset_dir, filename=LABELS_FILENAME): |
|
"""Reads the labels file and returns a mapping from ID to class name. |
|
|
|
Args: |
|
dataset_dir: The directory in which the labels file is found. |
|
filename: The filename where the class names are written. |
|
|
|
Returns: |
|
A map from a label (integer) to class name. |
|
""" |
|
labels_filename = os.path.join(dataset_dir, filename) |
|
with tf.gfile.Open(labels_filename, 'rb') as f: |
|
lines = f.read() |
|
lines = lines.split(b'\n') |
|
lines = filter(None, lines) |
|
|
|
labels_to_class_names = {} |
|
for line in lines: |
|
index = line.index(b':') |
|
labels_to_class_names[int(line[:index])] = line[index+1:] |
|
return labels_to_class_names
|
|
|