Skip to content

Commit

Permalink
Glint everything (tensorflow#3654)
Browse files Browse the repository at this point in the history
* Glint everything

* Adding rcfile and pylinting

* Extra newline

* Few last lints
  • Loading branch information
karmel authored Mar 20, 2018
1 parent adfd5a3 commit 7cfb6bb
Show file tree
Hide file tree
Showing 27 changed files with 382 additions and 162 deletions.
14 changes: 0 additions & 14 deletions official/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
8 changes: 5 additions & 3 deletions official/mnist/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from __future__ import division
from __future__ import print_function

import gzip
import os
import shutil
import gzip

import numpy as np
from six.moves import urllib
Expand All @@ -36,7 +36,7 @@ def check_image_file_header(filename):
"""Validate that filename corresponds to images for the MNIST dataset."""
with tf.gfile.Open(filename, 'rb') as f:
magic = read32(f)
num_images = read32(f)
read32(f) # num_images, unused
rows = read32(f)
cols = read32(f)
if magic != 2051:
Expand All @@ -52,7 +52,7 @@ def check_labels_file_header(filename):
"""Validate that filename corresponds to labels for the MNIST dataset."""
with tf.gfile.Open(filename, 'rb') as f:
magic = read32(f)
num_items = read32(f)
read32(f) # num_items, unused
if magic != 2049:
raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
f.name))
Expand All @@ -77,6 +77,8 @@ def download(directory, filename):


def dataset(directory, images_file, labels_file):
"""Download and parse MNIST dataset."""

images_file = download(directory, images_file)
labels_file = download(directory, labels_file)

Expand Down
33 changes: 21 additions & 12 deletions official/mnist/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
import argparse
import sys

import tensorflow as tf
import tensorflow as tf # pylint: disable=g-bad-import-order

from official.mnist import dataset
from official.utils.arg_parsers import parsers
from official.utils.logging import hooks_helper

LEARNING_RATE = 1e-4


class Model(tf.keras.Model):
"""Model to recognize digits in the MNIST dataset.
Expand Down Expand Up @@ -145,31 +146,36 @@ def model_fn(features, labels, mode, params):


def validate_batch_size_for_multi_gpu(batch_size):
"""For multi-gpu, batch-size must be a multiple of the number of
available GPUs.
"""For multi-gpu, batch-size must be a multiple of the number of GPUs.
Note that this should eventually be handled by replicate_model_fn
directly. Multi-GPU support is currently experimental, however,
so doing the work here until that feature is in place.
Args:
batch_size: the number of examples processed in each training batch.
Raises:
ValueError: if no GPUs are found, or selected batch_size is invalid.
"""
from tensorflow.python.client import device_lib
from tensorflow.python.client import device_lib # pylint: disable=g-import-not-at-top

local_device_protos = device_lib.list_local_devices()
num_gpus = sum([1 for d in local_device_protos if d.device_type == 'GPU'])
if not num_gpus:
raise ValueError('Multi-GPU mode was specified, but no GPUs '
'were found. To use CPU, run without --multi_gpu.')
'were found. To use CPU, run without --multi_gpu.')

remainder = batch_size % num_gpus
if remainder:
err = ('When running with multiple GPUs, batch size '
'must be a multiple of the number of available GPUs. '
'Found {} GPUs with a batch size of {}; try --batch_size={} instead.'
).format(num_gpus, batch_size, batch_size - remainder)
'must be a multiple of the number of available GPUs. '
'Found {} GPUs with a batch size of {}; try --batch_size={} instead.'
).format(num_gpus, batch_size, batch_size - remainder)
raise ValueError(err)


def main(unused_argv):
def main(_):
model_function = model_fn

if FLAGS.multi_gpu:
Expand All @@ -195,6 +201,8 @@ def main(unused_argv):

# Set up training and evaluation input functions.
def train_input_fn():
"""Prepare data for training."""

# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes use less memory. MNIST is a small
# enough dataset that we can easily shuffle the full epoch.
Expand All @@ -215,7 +223,7 @@ def eval_input_fn():
FLAGS.hooks, batch_size=FLAGS.batch_size)

# Train and evaluate model.
for n in range(FLAGS.train_epochs // FLAGS.epochs_between_evals):
for _ in range(FLAGS.train_epochs // FLAGS.epochs_between_evals):
mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
print('\nEvaluation results:\n\t%s\n' % eval_results)
Expand All @@ -231,10 +239,11 @@ def eval_input_fn():

class MNISTArgParser(argparse.ArgumentParser):
"""Argument parser for running MNIST model."""

def __init__(self):
super(MNISTArgParser, self).__init__(parents=[
parsers.BaseParser(),
parsers.ImageModelParser()])
parsers.BaseParser(),
parsers.ImageModelParser()])

self.add_argument(
'--export_dir',
Expand Down
19 changes: 10 additions & 9 deletions official/mnist/mnist_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@
import sys
import time

import tensorflow as tf
import tensorflow.contrib.eager as tfe
import tensorflow as tf # pylint: disable=g-bad-import-order
import tensorflow.contrib.eager as tfe # pylint: disable=g-bad-import-order

from official.mnist import dataset as mnist_dataset
from official.mnist import mnist
from official.mnist import dataset
from official.utils.arg_parsers import parsers

FLAGS = None
Expand Down Expand Up @@ -110,9 +110,9 @@ def main(_):
print('Using device %s, and data format %s.' % (device, data_format))

# Load the datasets
train_ds = dataset.train(FLAGS.data_dir).shuffle(60000).batch(
train_ds = mnist_dataset.train(FLAGS.data_dir).shuffle(60000).batch(
FLAGS.batch_size)
test_ds = dataset.test(FLAGS.data_dir).batch(FLAGS.batch_size)
test_ds = mnist_dataset.test(FLAGS.data_dir).batch(FLAGS.batch_size)

# Create the model and optimizer
model = mnist.Model(data_format)
Expand Down Expand Up @@ -159,12 +159,13 @@ def main(_):


class MNISTEagerArgParser(argparse.ArgumentParser):
"""Argument parser for running MNIST model with eager trainng loop."""
"""Argument parser for running MNIST model with eager training loop."""

def __init__(self):
super(MNISTEagerArgParser, self).__init__(parents=[
parsers.BaseParser(epochs_between_evals=False, multi_gpu=False,
hooks=False),
parsers.ImageModelParser()])
parsers.BaseParser(
epochs_between_evals=False, multi_gpu=False, hooks=False),
parsers.ImageModelParser()])

self.add_argument(
'--log_interval', '-li',
Expand Down
5 changes: 3 additions & 2 deletions official/mnist/mnist_eager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import tensorflow.contrib.eager as tfe
import tensorflow as tf # pylint: disable=g-bad-import-order
import tensorflow.contrib.eager as tfe # pylint: disable=g-bad-import-order

from official.mnist import mnist
from official.mnist import mnist_eager
Expand Down Expand Up @@ -60,6 +60,7 @@ def evaluate(defun=False):


class MNISTTest(tf.test.TestCase):
"""Run tests for MNIST eager loop."""

def test_train(self):
train(defun=False)
Expand Down
7 changes: 5 additions & 2 deletions official/mnist/mnist_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import time

import tensorflow as tf # pylint: disable=g-bad-import-order

from official.mnist import mnist

BATCH_SIZE = 100
Expand All @@ -42,6 +43,7 @@ def make_estimator():


class Tests(tf.test.TestCase):
"""Run tests for MNIST model."""

def test_mnist(self):
classifier = make_estimator()
Expand All @@ -57,7 +59,7 @@ def test_mnist(self):

input_fn = lambda: tf.random_uniform([3, 784])
predictions_generator = classifier.predict(input_fn)
for i in range(3):
for _ in range(3):
predictions = next(predictions_generator)
self.assertEqual(predictions['probabilities'].shape, (10,))
self.assertEqual(predictions['classes'].shape, ())
Expand Down Expand Up @@ -103,6 +105,7 @@ def test_mnist_model_fn_predict_mode(self):


class Benchmarks(tf.test.Benchmark):
"""Simple speed benchmarking for MNIST."""

def benchmark_train_step_time(self):
classifier = make_estimator()
Expand Down
5 changes: 3 additions & 2 deletions official/mnist/mnist_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import tensorflow as tf # pylint: disable=g-bad-import-order

from official.mnist import dataset
from official.mnist import mnist

Expand Down Expand Up @@ -132,7 +133,7 @@ def main(argv):
tf.logging.set_verbosity(tf.logging.INFO)

tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

run_config = tf.contrib.tpu.RunConfig(
cluster=tpu_cluster_resolver,
Expand Down
2 changes: 1 addition & 1 deletion official/resnet/cifar10_download_and_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
help='Directory to download data and extract the tarball')


def main(unused_argv):
def main(_):
"""Download and extract the tarball from Alex's website."""
if not os.path.exists(FLAGS.data_dir):
os.makedirs(FLAGS.data_dir)
Expand Down
20 changes: 13 additions & 7 deletions official/resnet/cifar10_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import os
import sys

import tensorflow as tf
import tensorflow as tf # pylint: disable=g-bad-import-order

from official.resnet import resnet_model
from official.resnet import resnet_run_loop
Expand Down Expand Up @@ -127,22 +127,25 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,

num_images = is_training and _NUM_IMAGES['train'] or _NUM_IMAGES['validation']

return resnet_run_loop.process_record_dataset(dataset, is_training, batch_size,
_NUM_IMAGES['train'], parse_record, num_epochs, num_parallel_calls,
examples_per_epoch=num_images, multi_gpu=multi_gpu)
return resnet_run_loop.process_record_dataset(
dataset, is_training, batch_size, _NUM_IMAGES['train'],
parse_record, num_epochs, num_parallel_calls,
examples_per_epoch=num_images, multi_gpu=multi_gpu)


def get_synth_input_fn():
return resnet_run_loop.get_synth_input_fn(_HEIGHT, _WIDTH, _NUM_CHANNELS, _NUM_CLASSES)
return resnet_run_loop.get_synth_input_fn(
_HEIGHT, _WIDTH, _NUM_CHANNELS, _NUM_CLASSES)


###############################################################################
# Running the model
###############################################################################
class Cifar10Model(resnet_model.Model):
"""Model class with appropriate defaults for CIFAR-10 data."""

def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
version=resnet_model.DEFAULT_VERSION):
version=resnet_model.DEFAULT_VERSION):
"""These are the parameters that work for CIFAR-10 data.
Args:
Expand All @@ -153,6 +156,9 @@ def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
enables users to extend the same model to their own datasets.
version: Integer representing which version of the ResNet network to use.
See README for details. Valid values: [1, 2]
Raises:
ValueError: if invalid resnet_size is chosen
"""
if resnet_size % 6 != 2:
raise ValueError('resnet_size must be 6n + 2:', resnet_size)
Expand Down Expand Up @@ -195,7 +201,7 @@ def cifar10_model_fn(features, labels, mode, params):
# for the CIFAR-10 dataset, perhaps because the regularization prevents
# overfitting on the small data set. We therefore include all vars when
# regularizing and computing loss during training.
def loss_filter_fn(name):
def loss_filter_fn(_):
return True

return resnet_run_loop.resnet_model_fn(features, labels, mode, Cifar10Model,
Expand Down
14 changes: 9 additions & 5 deletions official/resnet/cifar10_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tempfile import mkstemp

import numpy as np
import tensorflow as tf
import tensorflow as tf # pylint: disable=g-bad-import-order

from official.resnet import cifar10_main
from official.utils.testing import integration
Expand All @@ -34,6 +34,8 @@


class BaseTest(tf.test.TestCase):
"""Tests for the Cifar10 version of Resnet.
"""

def tearDown(self):
super(BaseTest, self).tearDown()
Expand All @@ -52,7 +54,7 @@ def test_dataset_input_fn(self):
data_file.close()

fake_dataset = tf.data.FixedLengthRecordDataset(
filename, cifar10_main._RECORD_BYTES)
filename, cifar10_main._RECORD_BYTES) # pylint: disable=protected-access
fake_dataset = fake_dataset.map(
lambda val: cifar10_main.parse_record(val, False))
image, label = fake_dataset.make_one_shot_iterator().get_next()
Expand Down Expand Up @@ -133,9 +135,11 @@ def test_cifar10model_shape(self):
num_classes = 246

for version in (1, 2):
model = cifar10_main.Cifar10Model(32, data_format='channels_last',
num_classes=num_classes, version=version)
fake_input = tf.random_uniform([batch_size, _HEIGHT, _WIDTH, _NUM_CHANNELS])
model = cifar10_main.Cifar10Model(
32, data_format='channels_last', num_classes=num_classes,
version=version)
fake_input = tf.random_uniform(
[batch_size, _HEIGHT, _WIDTH, _NUM_CHANNELS])
output = model(fake_input, training=True)

self.assertAllEqual(output.shape, (batch_size, num_classes))
Expand Down
Loading

0 comments on commit 7cfb6bb

Please sign in to comment.