Skip to content

Commit

Permalink
TF cifar10 cnn tutorial: use tensorflow-datasets to load the data. (t…
Browse files Browse the repository at this point in the history
…ensorflow#5906)

* TF cifar10 cnn tutorial: use tensorflow-datasets to load the data.

* Load cifar10 in memory.

* Fix imports

* More import fixes
  • Loading branch information
pierrot0 authored and tfboyd committed Apr 22, 2019
1 parent d11aa33 commit 56b5d03
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 256 deletions.
50 changes: 3 additions & 47 deletions tutorials/image/cifar10/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,8 @@
from __future__ import division
from __future__ import print_function

import os
import re
import sys
import tarfile

from six.moves import urllib
import tensorflow as tf

import cifar10_input
Expand All @@ -50,8 +46,6 @@
# Basic model parameters.
tf.app.flags.DEFINE_integer('batch_size', 128,
"""Number of images to process in a batch.""")
tf.app.flags.DEFINE_string('data_dir', '/tmp/cifar10_data',
"""Path to the CIFAR-10 data directory.""")
tf.app.flags.DEFINE_boolean('use_fp16', False,
"""Train the model using fp16.""")

Expand All @@ -73,8 +67,6 @@
# names of the summaries when visualizing a model.
TOWER_NAME = 'tower'

DATA_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'


def _activation_summary(x):
"""Helper to create summaries for activations.
Expand All @@ -91,8 +83,7 @@ def _activation_summary(x):
# session. This helps the clarity of presentation on tensorboard.
tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name)
tf.summary.histogram(tensor_name + '/activations', x)
tf.summary.scalar(tensor_name + '/sparsity',
tf.nn.zero_fraction(x))
tf.summary.scalar(tensor_name + '/sparsity', tf.nn.zero_fraction(x))


def _variable_on_cpu(name, shape, initializer):
Expand Down Expand Up @@ -145,15 +136,8 @@ def distorted_inputs():
Returns:
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: Labels. 1D tensor of [batch_size] size.
Raises:
ValueError: If no data_dir
"""
if not FLAGS.data_dir:
raise ValueError('Please supply a data_dir')
data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
images, labels = cifar10_input.distorted_inputs(data_dir=data_dir,
batch_size=FLAGS.batch_size)
images, labels = cifar10_input.distorted_inputs(batch_size=FLAGS.batch_size)
if FLAGS.use_fp16:
images = tf.cast(images, tf.float16)
labels = tf.cast(labels, tf.float16)
Expand All @@ -169,15 +153,8 @@ def inputs(eval_data):
Returns:
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: Labels. 1D tensor of [batch_size] size.
Raises:
ValueError: If no data_dir
"""
if not FLAGS.data_dir:
raise ValueError('Please supply a data_dir')
data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
images, labels = cifar10_input.inputs(eval_data=eval_data,
data_dir=data_dir,
batch_size=FLAGS.batch_size)
if FLAGS.use_fp16:
images = tf.cast(images, tf.float16)
Expand Down Expand Up @@ -240,7 +217,7 @@ def inference(images):
# local3
with tf.variable_scope('local3') as scope:
# Move everything into depth so we can perform a single matrix multiply.
reshape = tf.reshape(pool2, [images.get_shape().as_list()[0], -1])
reshape = tf.keras.layers.Flatten()(pool2)
dim = reshape.get_shape()[1].value
weights = _variable_with_weight_decay('weights', shape=[dim, 384],
stddev=0.04, wd=0.004)
Expand Down Expand Up @@ -374,24 +351,3 @@ def train(total_loss, global_step):
variables_averages_op = variable_averages.apply(tf.trainable_variables())

return variables_averages_op


def maybe_download_and_extract():
"""Download and extract the tarball from Alex's website."""
dest_directory = FLAGS.data_dir
if not os.path.exists(dest_directory):
os.makedirs(dest_directory)
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(dest_directory, filename)
if not os.path.exists(filepath):
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
print()
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
extracted_dir_path = os.path.join(dest_directory, 'cifar-10-batches-bin')
if not os.path.exists(extracted_dir_path):
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
9 changes: 4 additions & 5 deletions tutorials/image/cifar10/cifar10_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@
"""Either 'test' or 'train_eval'.""")
tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/cifar10_train',
"""Directory where to read model checkpoints.""")
tf.app.flags.DEFINE_integer('eval_interval_secs', 60 * 5,
tf.app.flags.DEFINE_integer('eval_interval_secs', 5,
"""How often to run the eval.""")
tf.app.flags.DEFINE_integer('num_examples', 10000,
tf.app.flags.DEFINE_integer('num_examples', 1000,
"""Number of examples to run.""")
tf.app.flags.DEFINE_boolean('run_once', False,
"""Whether to run eval only once.""")
"""Whether to run eval only once.""")


def eval_once(saver, summary_writer, top_k_op, summary_op):
Expand Down Expand Up @@ -89,7 +89,7 @@ def eval_once(saver, summary_writer, top_k_op, summary_op):
threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
start=True))

num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
num_iter = int(math.ceil(float(FLAGS.num_examples) / FLAGS.batch_size))
true_count = 0 # Counts the number of correct predictions.
total_sample_count = num_iter * FLAGS.batch_size
step = 0
Expand Down Expand Up @@ -146,7 +146,6 @@ def evaluate():


def main(argv=None): # pylint: disable=unused-argument
cifar10.maybe_download_and_extract()
if tf.gfile.Exists(FLAGS.eval_dir):
tf.gfile.DeleteRecursively(FLAGS.eval_dir)
tf.gfile.MakeDirs(FLAGS.eval_dir)
Expand Down
Loading

0 comments on commit 56b5d03

Please sign in to comment.