Skip to content

Commit

Permalink
Add the TensorFlow official models directory (tensorflow#2384)
Browse files Browse the repository at this point in the history
  • Loading branch information
nealwu authored Sep 21, 2017
1 parent c96ef83 commit 2c5c3f3
Show file tree
Hide file tree
Showing 20 changed files with 2,487 additions and 0 deletions.
3 changes: 3 additions & 0 deletions official/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
cnn/data
MNIST-data
labels.txt
13 changes: 13 additions & 0 deletions official/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# TensorFlow Official Models

The TensorFlow official models are a collection of example models that use TensorFlow's high-level APIs. They are intended to be well-maintained, tested, and kept up to date with the latest stable TensorFlow API. They should also be reasonably optimized for fast performance while still being easy to read.

Below is the list of models contained in the garden:

[mnist](mnist): A basic model to classify digits from the MNIST dataset.

[resnet](resnet): A deep residual network that can be used to classify both CIFAR-10 and ImageNet's dataset of 1000 classes.

More models to come!

If you would like to make any fixes or improvements to the models, please [submit a pull request](https://github.com/tensorflow/models/compare).
29 changes: 29 additions & 0 deletions official/mnist/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# MNIST in TensorFlow

This directory builds a convolutional neural net to classify the [MNIST
dataset](http://yann.lecun.com/exdb/mnist/) using the
[tf.contrib.data](https://www.tensorflow.org/api_docs/python/tf/contrib/data),
[tf.estimator.Estimator](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator),
and
[tf.layers](https://www.tensorflow.org/api_docs/python/tf/layers)
APIs.


## Setup

To begin, you'll simply need the latest version of TensorFlow installed.

First convert the MNIST data to TFRecord file format by running the following:

```
python convert_to_records.py
```

Then to train the model, run the following:

```
python mnist.py
```

The model will begin training and will automatically evaluate itself on the
validation data.
Empty file added official/mnist/__init__.py
Empty file.
92 changes: 92 additions & 0 deletions official/mnist/convert_to_records.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Converts MNIST data to TFRecords file format with Example protos."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os
import sys

import tensorflow as tf

from tensorflow.contrib.learn.python.learn.datasets import mnist

parser = argparse.ArgumentParser()

parser.add_argument('--directory', type=str, default='/tmp/mnist_data',
help='Directory to download data files and write the '
'converted result.')

parser.add_argument('--validation_size', type=int, default=0,
help='Number of examples to separate from the training '
'data for the validation set.')


def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def convert_to(data_set, name):
"""Converts a dataset to TFRecords."""
images = data_set.images
labels = data_set.labels
num_examples = data_set.num_examples

if images.shape[0] != num_examples:
raise ValueError('Images size %d does not match label size %d.' %
(images.shape[0], num_examples))
rows = images.shape[1]
cols = images.shape[2]
depth = images.shape[3]

filename = os.path.join(FLAGS.directory, name + '.tfrecords')
print('Writing', filename)
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
image_raw = images[index].tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(rows),
'width': _int64_feature(cols),
'depth': _int64_feature(depth),
'label': _int64_feature(int(labels[index])),
'image_raw': _bytes_feature(image_raw)}))
writer.write(example.SerializeToString())
writer.close()


def main(unused_argv):
# Get the data.
data_sets = mnist.read_data_sets(FLAGS.directory,
dtype=tf.uint8,
reshape=False,
validation_size=FLAGS.validation_size)

# Convert to Examples and write the result to TFRecords.
convert_to(data_sets.train, 'train')
convert_to(data_sets.validation, 'validation')
convert_to(data_sets.test, 'test')


if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
FLAGS = parser.parse_args()
tf.app.run()
226 changes: 226 additions & 0 deletions official/mnist/mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convolutional Neural Network Estimator for MNIST, built with tf.layers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os

import numpy as np
import tensorflow as tf

parser = argparse.ArgumentParser()

# Basic model parameters.
parser.add_argument('--batch_size', type=int, default=100,
help='Number of images to process in a batch')

parser.add_argument('--data_dir', type=str, default='/tmp/mnist_data',
help='Path to the MNIST data directory.')

parser.add_argument('--model_dir', type=str, default='/tmp/mnist_model',
help='The directory where the model will be stored.')

parser.add_argument('--steps', type=int, default=20000,
help='Number of steps to train.')


def input_fn(mode, batch_size=1):
"""A simple input_fn using the contrib.data input pipeline."""

def parser(serialized_example):
"""Parses a single tf.Example into image and label tensors."""
features = tf.parse_single_example(
serialized_example,
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
})
image = tf.decode_raw(features['image_raw'], tf.uint8)
image.set_shape([28 * 28])

# Normalize the values of the image from the range [0, 255] to [-0.5, 0.5]
image = tf.cast(image, tf.float32) / 255 - 0.5
label = tf.cast(features['label'], tf.int32)
return image, tf.one_hot(label, 10)

if mode == tf.estimator.ModeKeys.TRAIN:
tfrecords_file = os.path.join(FLAGS.data_dir, 'train.tfrecords')
else:
assert mode == tf.estimator.ModeKeys.EVAL, 'invalid mode'
tfrecords_file = os.path.join(FLAGS.data_dir, 'test.tfrecords')

assert os.path.exists(tfrecords_file), ('Run convert_to_records.py first to '
'convert the MNIST data to TFRecord file format.')

dataset = tf.contrib.data.TFRecordDataset([tfrecords_file])

# For training, repeat the dataset forever
if mode == tf.estimator.ModeKeys.TRAIN:
dataset = dataset.repeat()

# Map the parser over dataset, and batch results by up to batch_size
dataset = dataset.map(parser, num_threads=1, output_buffer_size=batch_size)
dataset = dataset.batch(batch_size)
images, labels = dataset.make_one_shot_iterator().get_next()

return images, labels


def mnist_model(inputs, mode):
"""Takes the MNIST inputs and mode and outputs a tensor of logits."""
# Input Layer
# Reshape X to 4-D tensor: [batch_size, width, height, channels]
# MNIST images are 28x28 pixels, and have one color channel
inputs = tf.reshape(inputs, [-1, 28, 28, 1])
data_format = 'channels_last'

if tf.test.is_built_with_cuda():
# When running on GPU, transpose the data from channels_last (NHWC) to
# channels_first (NCHW) to improve performance.
data_format = 'channels_first'
inputs = tf.transpose(inputs, [0, 3, 1, 2])

# Convolutional Layer #1
# Computes 32 features using a 5x5 filter with ReLU activation.
# Padding is added to preserve width and height.
# Input Tensor Shape: [batch_size, 28, 28, 1]
# Output Tensor Shape: [batch_size, 28, 28, 32]
conv1 = tf.layers.conv2d(
inputs=inputs,
filters=32,
kernel_size=[5, 5],
padding='same',
activation=tf.nn.relu,
data_format=data_format)

# Pooling Layer #1
# First max pooling layer with a 2x2 filter and stride of 2
# Input Tensor Shape: [batch_size, 28, 28, 32]
# Output Tensor Shape: [batch_size, 14, 14, 32]
pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2,
data_format=data_format)

# Convolutional Layer #2
# Computes 64 features using a 5x5 filter.
# Padding is added to preserve width and height.
# Input Tensor Shape: [batch_size, 14, 14, 32]
# Output Tensor Shape: [batch_size, 14, 14, 64]
conv2 = tf.layers.conv2d(
inputs=pool1,
filters=64,
kernel_size=[5, 5],
padding='same',
activation=tf.nn.relu,
data_format=data_format)

# Pooling Layer #2
# Second max pooling layer with a 2x2 filter and stride of 2
# Input Tensor Shape: [batch_size, 14, 14, 64]
# Output Tensor Shape: [batch_size, 7, 7, 64]
pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2,
data_format=data_format)

# Flatten tensor into a batch of vectors
# Input Tensor Shape: [batch_size, 7, 7, 64]
# Output Tensor Shape: [batch_size, 7 * 7 * 64]
pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64])

# Dense Layer
# Densely connected layer with 1024 neurons
# Input Tensor Shape: [batch_size, 7 * 7 * 64]
# Output Tensor Shape: [batch_size, 1024]
dense = tf.layers.dense(inputs=pool2_flat, units=1024,
activation=tf.nn.relu)

# Add dropout operation; 0.6 probability that element will be kept
dropout = tf.layers.dropout(
inputs=dense, rate=0.4, training=(mode == tf.estimator.ModeKeys.TRAIN))

# Logits layer
# Input Tensor Shape: [batch_size, 1024]
# Output Tensor Shape: [batch_size, 10]
logits = tf.layers.dense(inputs=dropout, units=10)
return logits


def mnist_model_fn(features, labels, mode):
"""Model function for MNIST."""
logits = mnist_model(features, mode)

predictions = {
'classes': tf.argmax(input=logits, axis=1),
'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
}

if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)

# Configure the training op
if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
train_op = optimizer.minimize(loss, tf.train.get_or_create_global_step())
else:
train_op = None

accuracy = tf.metrics.accuracy(
tf.argmax(labels, axis=1), predictions['classes'])
metrics = {'accuracy': accuracy}

# Create a tensor named train_accuracy for logging purposes
tf.identity(accuracy[1], name='train_accuracy')
tf.summary.scalar('train_accuracy', accuracy[1])

return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op,
eval_metric_ops=metrics)


def main(unused_argv):
# Create the Estimator
mnist_classifier = tf.estimator.Estimator(
model_fn=mnist_model_fn, model_dir=FLAGS.model_dir)

# Train the model
tensors_to_log = {
'train_accuracy': 'train_accuracy'
}

logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100)

mnist_classifier.train(
input_fn=lambda: input_fn(tf.estimator.ModeKeys.TRAIN, FLAGS.batch_size),
steps=FLAGS.steps,
hooks=[logging_hook])

# Evaluate the model and print results
eval_results = mnist_classifier.evaluate(
input_fn=lambda: input_fn(tf.estimator.ModeKeys.EVAL))
print()
print('Evaluation results:\n %s' % eval_results)


if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
FLAGS = parser.parse_args()
tf.app.run()
Loading

0 comments on commit 2c5c3f3

Please sign in to comment.