Skip to content

Commit

Permalink
Models support rebatching inputs with backprop encryption disabled.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Jan 14, 2025
1 parent 960cc51 commit 5d4d5b5
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 25 deletions.
6 changes: 0 additions & 6 deletions tf_shell_ml/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,12 +900,6 @@ def shell_train_step(self, features, labels, read_key_from_cache, apply_gradient
noise_context.num_slots,
message="Backprop and noise contexts must have the same number of slots.",
)
else:
tf.assert_equal(
tf.identity(noise_context.num_slots),
tf.shape(features, out_type=tf.int64)[0],
message="Noise context must have the same number of slots as the batch size.",
)

# The noise scaling factor must always be 1. Encryptions already
# have the scaling factor applied when the noise is applied,
Expand Down
13 changes: 6 additions & 7 deletions tf_shell_ml/test/dpsgd_conv_model_local_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,19 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache)
x_test = x_test[:, clip_by : (28 - clip_by), clip_by : (28 - clip_by), :]

labels_dataset = tf.data.Dataset.from_tensor_slices(y_train)
labels_dataset = labels_dataset.batch(2**12)
labels_dataset = labels_dataset.batch(2**10)

features_dataset = tf.data.Dataset.from_tensor_slices(x_train)
features_dataset = features_dataset.batch(2**12)
features_dataset = features_dataset.batch(2**10)

val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
val_dataset = val_dataset.batch(32)

m = tf_shell_ml.DpSgdSequential(
[
# Model from tensorflow-privacy tutorial. The first layer may
# be skipped and the model still has ~95% accuracy (plaintext,
# no input clipping).
# Model from tensorflow-privacy tutorial. The first conv and
# maxpool layer may be skipped and the model still has ~95%
# accuracy (plaintext, no input image clipping).
# tf_shell_ml.Conv2D(
# filters=16,
# kernel_size=8,
Expand Down Expand Up @@ -124,8 +124,7 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache)
validation_data=val_dataset,
)

print(history.history["val_categorical_accuracy"][-1])
self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.13)
self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.20)

def test_model(self):
with tempfile.TemporaryDirectory() as cache_dir:
Expand Down
4 changes: 2 additions & 2 deletions tf_shell_ml/test/dpsgd_model_distrib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ def test_model(self):

with tf.device(labels_party_dev):
labels_dataset = tf.data.Dataset.from_tensor_slices(y_train)
labels_dataset = labels_dataset.batch(2**12)
labels_dataset = labels_dataset.batch(2**10)

with tf.device(features_party_dev):
features_dataset = tf.data.Dataset.from_tensor_slices(x_train)
features_dataset = features_dataset.batch(2**12)
features_dataset = features_dataset.batch(2**10)

val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
val_dataset = val_dataset.batch(32)
Expand Down
4 changes: 2 additions & 2 deletions tf_shell_ml/test/dpsgd_model_float64_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache)
x_train, x_test = x_train[:, :350], x_test[:, :350]

labels_dataset = tf.data.Dataset.from_tensor_slices(y_train)
labels_dataset = labels_dataset.batch(2**12)
labels_dataset = labels_dataset.batch(2**10)

features_dataset = tf.data.Dataset.from_tensor_slices(x_train)
features_dataset = features_dataset.batch(2**12)
features_dataset = features_dataset.batch(2**10)

val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
val_dataset = val_dataset.batch(32)
Expand Down
4 changes: 2 additions & 2 deletions tf_shell_ml/test/dpsgd_model_local_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache)
x_train, x_test = x_train[:, :350], x_test[:, :350]

labels_dataset = tf.data.Dataset.from_tensor_slices(y_train)
labels_dataset = labels_dataset.batch(2**12)
labels_dataset = labels_dataset.batch(2**10)

features_dataset = tf.data.Dataset.from_tensor_slices(x_train)
features_dataset = features_dataset.batch(2**12)
features_dataset = features_dataset.batch(2**10)

val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
val_dataset = val_dataset.batch(32)
Expand Down
4 changes: 2 additions & 2 deletions tf_shell_ml/test/postscale_model_distrib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ def test_model(self):

with tf.device(labels_party_dev):
labels_dataset = tf.data.Dataset.from_tensor_slices(y_train)
labels_dataset = labels_dataset.batch(2**12)
labels_dataset = labels_dataset.batch(2**10)

with tf.device(features_party_dev):
features_dataset = tf.data.Dataset.from_tensor_slices(x_train)
features_dataset = features_dataset.batch(2**12)
features_dataset = features_dataset.batch(2**10)

val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
val_dataset = val_dataset.batch(32)
Expand Down
4 changes: 2 additions & 2 deletions tf_shell_ml/test/postscale_model_float64_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache)
x_train, x_test = x_train[:, :350], x_test[:, :350]

labels_dataset = tf.data.Dataset.from_tensor_slices(y_train)
labels_dataset = labels_dataset.batch(2**12)
labels_dataset = labels_dataset.batch(2**10)

features_dataset = tf.data.Dataset.from_tensor_slices(x_train)
features_dataset = features_dataset.batch(2**12)
features_dataset = features_dataset.batch(2**10)

val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
val_dataset = val_dataset.batch(32)
Expand Down
4 changes: 2 additions & 2 deletions tf_shell_ml/test/postscale_model_local_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache)
x_train, x_test = x_train[:, :350], x_test[:, :350]

labels_dataset = tf.data.Dataset.from_tensor_slices(y_train)
labels_dataset = labels_dataset.batch(2**12)
labels_dataset = labels_dataset.batch(2**10)

features_dataset = tf.data.Dataset.from_tensor_slices(x_train)
features_dataset = features_dataset.batch(2**12)
features_dataset = features_dataset.batch(2**10)

val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
val_dataset = val_dataset.batch(32)
Expand Down

0 comments on commit 5d4d5b5

Please sign in to comment.