You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
193 lines
7.3 KiB
193 lines
7.3 KiB
# Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
# ============================================================================== |
|
"""Provides data for the ImageNet ILSVRC 2012 Dataset plus some bounding boxes. |
|
|
|
Some images have one or more bounding boxes associated with the label of the |
|
image. See details here: http://image-net.org/download-bboxes |
|
|
|
ImageNet is based upon WordNet 3.0. To uniquely identify a synset, we use |
|
"WordNet ID" (wnid), which is a concatenation of POS ( i.e. part of speech ) |
|
and SYNSET OFFSET of WordNet. For more information, please refer to the |
|
WordNet documentation[http://wordnet.princeton.edu/wordnet/documentation/]. |
|
|
|
"There are bounding boxes for over 3000 popular synsets available. |
|
For each synset, there are on average 150 images with bounding boxes." |
|
|
|
WARNING: Don't use for object detection, in this case all the bounding boxes |
|
of the image belong to just one class. |
|
""" |
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import os |
|
from six.moves import urllib |
|
import tensorflow as tf |
|
|
|
from datasets import dataset_utils |
|
|
|
slim = tf.contrib.slim |
|
|
|
# TODO(nsilberman): Add tfrecord file type once the script is updated. |
|
_FILE_PATTERN = '%s-*' |
|
|
|
_SPLITS_TO_SIZES = { |
|
'train': 1281167, |
|
'validation': 50000, |
|
} |
|
|
|
_ITEMS_TO_DESCRIPTIONS = { |
|
'image': 'A color image of varying height and width.', |
|
'label': 'The label id of the image, integer between 0 and 999', |
|
'label_text': 'The text of the label.', |
|
'object/bbox': 'A list of bounding boxes.', |
|
'object/label': 'A list of labels, one per each object.', |
|
} |
|
|
|
_NUM_CLASSES = 1001 |
|
|
|
|
|
def create_readable_names_for_imagenet_labels(): |
|
"""Create a dict mapping label id to human readable string. |
|
|
|
Returns: |
|
labels_to_names: dictionary where keys are integers from to 1000 |
|
and values are human-readable names. |
|
|
|
We retrieve a synset file, which contains a list of valid synset labels used |
|
by ILSVRC competition. There is one synset one per line, eg. |
|
# n01440764 |
|
# n01443537 |
|
We also retrieve a synset_to_human_file, which contains a mapping from synsets |
|
to human-readable names for every synset in Imagenet. These are stored in a |
|
tsv format, as follows: |
|
# n02119247 black fox |
|
# n02119359 silver fox |
|
We assign each synset (in alphabetical order) an integer, starting from 1 |
|
(since 0 is reserved for the background class). |
|
|
|
Code is based on |
|
https://github.com/tensorflow/models/blob/master/inception/inception/data/build_imagenet_data.py#L463 |
|
""" |
|
|
|
# pylint: disable=g-line-too-long |
|
base_url = 'https://raw.githubusercontent.com/tensorflow/models/master/inception/inception/data/' |
|
synset_url = '{}/imagenet_lsvrc_2015_synsets.txt'.format(base_url) |
|
synset_to_human_url = '{}/imagenet_metadata.txt'.format(base_url) |
|
|
|
filename, _ = urllib.request.urlretrieve(synset_url) |
|
synset_list = [s.strip() for s in open(filename).readlines()] |
|
num_synsets_in_ilsvrc = len(synset_list) |
|
assert num_synsets_in_ilsvrc == 1000 |
|
|
|
filename, _ = urllib.request.urlretrieve(synset_to_human_url) |
|
synset_to_human_list = open(filename).readlines() |
|
num_synsets_in_all_imagenet = len(synset_to_human_list) |
|
assert num_synsets_in_all_imagenet == 21842 |
|
|
|
synset_to_human = {} |
|
for s in synset_to_human_list: |
|
parts = s.strip().split('\t') |
|
assert len(parts) == 2 |
|
synset = parts[0] |
|
human = parts[1] |
|
synset_to_human[synset] = human |
|
|
|
label_index = 1 |
|
labels_to_names = {0: 'background'} |
|
for synset in synset_list: |
|
name = synset_to_human[synset] |
|
labels_to_names[label_index] = name |
|
label_index += 1 |
|
|
|
return labels_to_names |
|
|
|
|
|
def get_split(split_name, dataset_dir, file_pattern=None, reader=None): |
|
"""Gets a dataset tuple with instructions for reading ImageNet. |
|
|
|
Args: |
|
split_name: A train/test split name. |
|
dataset_dir: The base directory of the dataset sources. |
|
file_pattern: The file pattern to use when matching the dataset sources. |
|
It is assumed that the pattern contains a '%s' string so that the split |
|
name can be inserted. |
|
reader: The TensorFlow reader type. |
|
|
|
Returns: |
|
A `Dataset` namedtuple. |
|
|
|
Raises: |
|
ValueError: if `split_name` is not a valid train/test split. |
|
""" |
|
if split_name not in _SPLITS_TO_SIZES: |
|
raise ValueError('split name %s was not recognized.' % split_name) |
|
|
|
if not file_pattern: |
|
file_pattern = _FILE_PATTERN |
|
file_pattern = os.path.join(dataset_dir, file_pattern % split_name) |
|
|
|
# Allowing None in the signature so that dataset_factory can use the default. |
|
if reader is None: |
|
reader = tf.TFRecordReader |
|
|
|
keys_to_features = { |
|
'image/encoded': tf.FixedLenFeature( |
|
(), tf.string, default_value=''), |
|
'image/format': tf.FixedLenFeature( |
|
(), tf.string, default_value='jpeg'), |
|
'image/class/label': tf.FixedLenFeature( |
|
[], dtype=tf.int64, default_value=-1), |
|
'image/class/text': tf.FixedLenFeature( |
|
[], dtype=tf.string, default_value=''), |
|
'image/object/bbox/xmin': tf.VarLenFeature( |
|
dtype=tf.float32), |
|
'image/object/bbox/ymin': tf.VarLenFeature( |
|
dtype=tf.float32), |
|
'image/object/bbox/xmax': tf.VarLenFeature( |
|
dtype=tf.float32), |
|
'image/object/bbox/ymax': tf.VarLenFeature( |
|
dtype=tf.float32), |
|
'image/object/class/label': tf.VarLenFeature( |
|
dtype=tf.int64), |
|
} |
|
|
|
items_to_handlers = { |
|
'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'), |
|
'label': slim.tfexample_decoder.Tensor('image/class/label'), |
|
'label_text': slim.tfexample_decoder.Tensor('image/class/text'), |
|
'object/bbox': slim.tfexample_decoder.BoundingBox( |
|
['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'), |
|
'object/label': slim.tfexample_decoder.Tensor('image/object/class/label'), |
|
} |
|
|
|
decoder = slim.tfexample_decoder.TFExampleDecoder( |
|
keys_to_features, items_to_handlers) |
|
|
|
labels_to_names = None |
|
if dataset_utils.has_labels(dataset_dir): |
|
labels_to_names = dataset_utils.read_label_file(dataset_dir) |
|
else: |
|
labels_to_names = create_readable_names_for_imagenet_labels() |
|
dataset_utils.write_label_file(labels_to_names, dataset_dir) |
|
|
|
return slim.dataset.Dataset( |
|
data_sources=file_pattern, |
|
reader=reader, |
|
decoder=decoder, |
|
num_samples=_SPLITS_TO_SIZES[split_name], |
|
items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, |
|
num_classes=_NUM_CLASSES, |
|
labels_to_names=labels_to_names)
|
|
|