Skip to content

Commit

Permalink
Fix for tensorflow datasets len() bug (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
evsmithx authored Feb 18, 2021
1 parent cb28e0f commit d3a70be
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 18 deletions.
4 changes: 2 additions & 2 deletions colearn_keras/keras_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ def split_to_folders(
data_split = [1 / n_learners] * n_learners

# Load CIFAR10 from tfds
train_dataset = tfds.load('cifar10', split='train+test', as_supervised=True)
n_datapoints = len(train_dataset)
train_dataset, info = tfds.load('cifar10', split='train+test', as_supervised=True, with_info=True)
n_datapoints = info.splits['train+test'].num_examples
train_dataset = train_dataset.map(normalize_img).batch(n_datapoints)

# there is only one batch in the iterator, and this contains all the data
Expand Down
4 changes: 2 additions & 2 deletions colearn_keras/keras_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ def split_to_folders(
data_split = [1 / n_learners] * n_learners

# Load MNIST from tfds
train_dataset = tfds.load('mnist', split='train+test', as_supervised=True)
n_datapoints = len(train_dataset)
train_dataset, info = tfds.load('mnist', split='train+test', as_supervised=True, with_info=True)
n_datapoints = info.splits['train+test'].num_examples
train_dataset = train_dataset.map(normalize_img).batch(n_datapoints)

# there is only one batch in the iterator, and this contains all the data
Expand Down
6 changes: 3 additions & 3 deletions docs/python_src/mnist_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@
batch_size = 64

# Load the data
train_dataset = tfds.load('mnist', split='train', as_supervised=True)
train_dataset, info = tfds.load('mnist', split='train', as_supervised=True, with_info=True)
n_train = info.splits['train'].num_examples
test_dataset = tfds.load('mnist', split='test', as_supervised=True)

train_dataset = train_dataset.map(normalize_img,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.shuffle(len(train_dataset))
train_dataset = train_dataset.shuffle(n_train)
train_dataset = train_dataset.batch(batch_size)

test_dataset = test_dataset.map(normalize_img,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
test_dataset = test_dataset.shuffle(len(test_dataset))
test_dataset = test_dataset.batch(batch_size)

# Define the model
Expand Down
9 changes: 5 additions & 4 deletions examples/keras_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,10 @@
loss = "sparse_categorical_crossentropy"
vote_batches = 2

train_datasets = tfds.load('cifar10',
split=tfds.even_splits('train', n=n_learners),
as_supervised=True)
train_datasets, info = tfds.load('cifar10',
split=tfds.even_splits('train', n=n_learners),
as_supervised=True, with_info=True)
n_datapoints = info.splits['train'].num_examples

test_datasets = tfds.load('cifar10',
split=tfds.even_splits('test', n=n_learners),
Expand All @@ -67,7 +68,7 @@
ds_train = train_datasets[i].map(
normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(len(ds_train))
ds_train = ds_train.shuffle(n_datapoints // n_learners)
ds_train = ds_train.batch(batch_size)
train_datasets[i] = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

Expand Down
7 changes: 4 additions & 3 deletions examples/keras_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@
batch_size = 64

# Load data for each learner
train_dataset = tfds.load('mnist', split='train', as_supervised=True)
train_dataset, info = tfds.load('mnist', split='train', as_supervised=True, with_info=True)
n_datapoints = info.splits['train'].num_examples

train_datasets = [train_dataset.shard(num_shards=n_learners, index=i) for i in range(n_learners)]

test_dataset = tfds.load('mnist', split='test', as_supervised=True)
Expand All @@ -62,12 +64,11 @@
for i in range(n_learners):
train_datasets[i] = train_datasets[i].map(
normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_datasets[i] = train_datasets[i].shuffle(len(train_datasets[i]))
train_datasets[i] = train_datasets[i].shuffle(n_datapoints // n_learners)
train_datasets[i] = train_datasets[i].batch(batch_size)

test_datasets[i] = test_datasets[i].map(
normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
test_datasets[i] = test_datasets[i].shuffle(len(test_datasets[i]))
test_datasets[i] = test_datasets[i].batch(batch_size)


Expand Down
9 changes: 5 additions & 4 deletions examples/keras_mnist_diffpriv.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@
noise_multiplier = 1.3 # more noise -> more privacy, less utility
num_microbatches = 64 # how many batches to split a batch into

train_datasets = tfds.load('mnist',
split=tfds.even_splits('train', n=n_learners),
as_supervised=True)
train_datasets, info = tfds.load('mnist',
split=tfds.even_splits('train', n=n_learners),
as_supervised=True, with_info=True)
n_datapoints = info.splits['train'].num_examples

test_datasets = tfds.load('mnist',
split=tfds.even_splits('test', n=n_learners),
Expand All @@ -58,7 +59,7 @@
ds_train = train_datasets[i].map(
normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(len(ds_train))
ds_train = ds_train.shuffle(n_datapoints // n_learners)
ds_train = ds_train.batch(batch_size)
train_datasets[i] = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

Expand Down

0 comments on commit d3a70be

Please sign in to comment.