From 94451d143d432bb168b0799c3e36c782d12b039c Mon Sep 17 00:00:00 2001 From: james-choncholas Date: Sat, 4 Jan 2025 06:15:12 +0000 Subject: [PATCH] Masking is performed in the integer domain to avoid leakages associated with floating point. --- tf_shell_ml/dpsgd_sequential_model.py | 123 +++++++------------ tf_shell_ml/model_base.py | 76 +++++++++++- tf_shell_ml/postscale_sequential_model.py | 105 ++++++---------- tf_shell_ml/test/dpsgd_model_distrib_test.py | 10 +- tf_shell_ml/test/dpsgd_model_float64_test.py | 6 +- tf_shell_ml/test/dpsgd_model_local_test.py | 6 +- 6 files changed, 167 insertions(+), 159 deletions(-) diff --git a/tf_shell_ml/dpsgd_sequential_model.py b/tf_shell_ml/dpsgd_sequential_model.py index 2074695..2e02c5b 100644 --- a/tf_shell_ml/dpsgd_sequential_model.py +++ b/tf_shell_ml/dpsgd_sequential_model.py @@ -14,9 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf -import tensorflow.keras as keras import tf_shell -import tf_shell_ml from tf_shell_ml.model_base import SequentialBase @@ -106,8 +104,8 @@ def shell_train_step(self, features, labels): # Check if the backproped gradients overflowed. if not self.disable_encryption and self.check_overflow_INSECURE: - # Note, checking the backprop gradients requires decryption - # on the features party which breaks security of the protocol. + # Note, checking the gradients requires decryption on the + # features party which breaks security of the protocol. bp_scaling_factors = [g._scaling_factor for g in dJ_dw] dec_dJ_dw = [ tf_shell.to_tensorflow(g, backprop_secret_key) for g in dJ_dw @@ -119,39 +117,39 @@ def shell_train_step(self, features, labels): "WARNING: Backprop gradient may have overflowed.", ) - # Mask the encrypted grads to prepare for decryption. The masks may - # overflow during the reduce_sum over the batch. When the masks are - # operated on, they are multiplied by the scaling factor, so it is - # not necessary to mask the full range -t/2 to t/2. (Though it is - # possible, it unnecessarily introduces noise into the ciphertext.) - if self.disable_masking or self.disable_encryption: - grads = [g for g in reversed(dJ_dw)] - else: - t = tf.cast( - backprop_context.plaintext_modulus, tf.keras.backend.floatx() - ) - t_half = t // 2 - mask_scaling_factors = [g._scaling_factor for g in reversed(dJ_dw)] - masks = [ - tf.random.uniform( - tf_shell.shape(g), - dtype=tf.keras.backend.floatx(), - minval=-t_half / s, - maxval=t_half / s, - ) - for g, s in zip(reversed(dJ_dw), mask_scaling_factors) - ] + # Reverse the order to match the order of the layers. + grads = [g for g in reversed(dJ_dw)] - # Mask the encrypted gradients and reverse the order to match - # the order of the layers. - grads = [(g + m) for g, m in zip(reversed(dJ_dw), masks)] + # Mask the encrypted gradients. + if not self.disable_masking and not self.disable_encryption: + grads, masks, mask_scaling_factors = self.mask_gradients( + backprop_context, grads + ) if not self.disable_noise: - # Set up the features party size of the distributed noise + # Set up the features party side of the distributed noise # sampling sub-protocol. noise_context = self.noise_context_fn() noise_secret_key = tf_shell.create_key64(noise_context, self.cache_path) + # The noise context must have the same number of slots + # (encryption ring degree) as used in backpropagation. + if not self.disable_encryption: + tf.assert_equal( + backprop_context.num_slots, + noise_context.num_slots, + message="Backprop and noise contexts must have the same number of slots.", + ) + + # The noise scaling factor must always be 1. Encryptions already + # have the scaling factor applied when the noise is applied, + # and an additional noise scaling factor is not needed. + tf.assert_equal( + noise_context.scaling_factor, + 1, + message="Noise scaling factor must be 1.", + ) + # Compute the noise factors for the distributed noise sampling # sub-protocol. enc_a, enc_b = self.compute_noise_factors( @@ -164,8 +162,8 @@ def shell_train_step(self, features, labels): grads = [tf_shell.to_tensorflow(g, backprop_secret_key) for g in grads] # Sum the masked gradients over the batch. - if self.disable_encryption or self.disable_masking: - # No mask, only sum required. + if self.disable_masking or self.disable_encryption: + # No mask has been added so a only a normal sum is required. grads = [tf.reduce_sum(g, axis=0) for g in grads] else: grads = [ @@ -174,23 +172,13 @@ def shell_train_step(self, features, labels): ] if not self.disable_noise: - if not self.disable_encryption: - tf.assert_equal( - backprop_context.num_slots, - noise_context.num_slots, - message="Backprop and noise contexts must have the same number of slots.", - ) - # Efficiently pack the masked gradients to prepare for adding # the encrypted noise. This is special because the masked # gradients are no longer batched, so the packing must be done # manually. - ( - flat_grads, - grad_shapes, - flattened_grad_shapes, - total_grad_size, - ) = self.flatten_and_pad_grad_list(grads, noise_context.num_slots) + (flat_grads, grad_shapes, flattened_grad_shapes, total_grad_size) = ( + self.flatten_and_pad_grad_list(grads, noise_context.num_slots) + ) # Add the encrypted noise to the masked gradients. grads = self.noise_gradients(noise_context, flat_grads, enc_a, enc_b) @@ -201,15 +189,6 @@ def shell_train_step(self, features, labels): # secret key. flat_grads = tf_shell.to_tensorflow(grads, noise_secret_key) - if self.check_overflow_INSECURE: - nosie_scaling_factors = grads._scaling_factor - self.warn_on_overflow( - [flat_grads], - [nosie_scaling_factors], - noise_context.plaintext_modulus, - "WARNING: Noised gradient may have overflowed.", - ) - # Unpack the noised grads after decryption. grads = self.unflatten_and_unpad_grad( flat_grads, @@ -218,32 +197,20 @@ def shell_train_step(self, features, labels): total_grad_size, ) + # Unmask the gradients. if not self.disable_masking and not self.disable_encryption: - # Sum the masks over the batch. - sum_masks = [ - tf_shell.reduce_sum_with_mod(m, 0, backprop_context, s) - for m, s in zip(masks, mask_scaling_factors) - ] - - # Unmask the batch gradient. - grads = [mg - m for mg, m in zip(grads, sum_masks)] - - # SHELL represents floats as integers between [0, t) where t is - # the plaintext modulus. To mimic SHELL's modulo operations in - # TensorFlow, numbers which exceed the range [-t/2, t/2] are - # shifted back into the range. Note, this could be done as a - # custom op with tf-shell, but this is simpler for now. - epsilon = tf.constant(1e-6, dtype=tf.keras.backend.floatx()) - - def rebalance(x, s): - r_bound = t_half / s + epsilon - l_bound = -t_half / s - epsilon - t_over_s = t / s - x = tf.where(x > r_bound, x - t_over_s, x) - x = tf.where(x < l_bound, x + t_over_s, x) - return x + grads = self.unmask_gradients( + backprop_context, grads, masks, mask_scaling_factors + ) - grads = [rebalance(g, s) for g, s in zip(grads, mask_scaling_factors)] + if not self.disable_noise: + if self.check_overflow_INSECURE: + self.warn_on_overflow( + [flat_grads], + [1], + noise_context.plaintext_modulus, + "WARNING: Noised gradient may have overflowed.", + ) # Apply the gradients to the model. self.optimizer.apply_gradients(zip(grads, self.weights)) diff --git a/tf_shell_ml/model_base.py b/tf_shell_ml/model_base.py index 80f5937..e81d6c9 100644 --- a/tf_shell_ml/model_base.py +++ b/tf_shell_ml/model_base.py @@ -425,6 +425,73 @@ def unflatten_and_unpad_grad( grads_list = [tf.reshape(g, s) for g, s in zip(grads_list, grad_shapes)] return grads_list + def mask_gradients(self, context, grads): + """Adds a random masks to the encrypted gradients between [-t/2, t/2].""" + if self.disable_masking or self.disable_encryption: + return grads, None + + t = tf.cast(context.plaintext_modulus, dtype=tf.int64) + t_half = t // 2 + mask_scaling_factors = [g._scaling_factor for g in grads] + + masks = [ + tf.random.uniform( + tf_shell.shape(g), + dtype=tf.int64, + minval=-t_half, + maxval=t_half, + ) + for g in grads + ] + + # Spoof the gradient's dtype to int64 and scaling factor to 1. + # This is necessary so when adding the masks to the gradients, tf_shell + # does not attempt to do any casting to floats matching scaling factors. + # Conversion back to floating point and handling the scaling factors is + # done in unmask_gradients. + int64_grads = [ + tf_shell.ShellTensor64( + _raw_tensor=g._raw_tensor, + _context=g._context, + _level=g._level, + _num_mod_reductions=g._num_mod_reductions, + _underlying_dtype=tf.int64, # Spoof the dtype. + _scaling_factor=1, # Spoof the scaling factor. + _is_enc=g._is_enc, + _is_fast_rotated=g._is_fast_rotated, + ) + for g in grads + ] + + # Add the masks. + int64_grads = [(g + m) for g, m in zip(int64_grads, masks)] + + return int64_grads, masks, mask_scaling_factors + + def unmask_gradients(self, context, grads, masks, mask_scaling_factors): + """Subtracts the masks from the gradients.""" + + # Sum the masks over the batch. + sum_masks = [ + tf_shell.reduce_sum_with_mod(m, 0, context, s) + for m, s in zip(masks, mask_scaling_factors) + ] + + # Unmask the batch gradient. + masks_and_grads = [tf.stack([-m, g]) for m, g in zip(sum_masks, grads)] + unmasked_grads = [ + tf_shell.reduce_sum_with_mod(mg, 0, context, s) + for mg, s in zip(masks_and_grads, mask_scaling_factors) + ] + + # Recover the floating point values using the scaling factors. + unmasked_grads = [ + tf.cast(g, tf.keras.backend.floatx()) / s + for g, s in zip(unmasked_grads, mask_scaling_factors) + ] + + return unmasked_grads + def compute_noise_factors(self, context, secret_key, sensitivity): bounded_sensitivity = tf.cast(sensitivity, dtype=tf.float32) * tf.sqrt(2.0) @@ -437,7 +504,6 @@ def compute_noise_factors(self, context, secret_key, sensitivity): a, b = tf_shell.sample_centered_gaussian_f(bounded_sensitivity, self.dg_params) def _prep_noise_factor(x): - x = tf.cast(x, tf.keras.backend.floatx()) x = tf.expand_dims(x, 0) x = tf.expand_dims(x, 0) x = tf.repeat(x, context.num_slots, axis=0) @@ -476,11 +542,11 @@ def warn_on_overflow(self, grads, scaling_factors, plaintext_modulus, message): may have overflowed. This also must take the scaling factor into account so the range is divided by the scaling factor. """ - t = tf.cast(plaintext_modulus, grads[0].dtype) - t_half = t / 2 + # t = tf.cast(plaintext_modulus, grads[0].dtype) + t_half = plaintext_modulus / 2 over_by = [ - tf.reduce_max(tf.abs(g) - t_half / 2 / s) + tf.reduce_max(tf.abs(g) - tf.cast(t_half / 2 / s, g.dtype)) for g, s in zip(grads, scaling_factors) ] max_over_by = tf.reduce_max(over_by) @@ -494,7 +560,7 @@ def warn_on_overflow(self, grads, scaling_factors, plaintext_modulus, message): over_by, "(positive number indicates overflow amount).", "Values should be less than", - [t_half / 2 / s for s in scaling_factors], + [tf.cast(t_half / 2 / s, grads[0].dtype) for s in scaling_factors], ), lambda: tf.identity(overflowed), ) diff --git a/tf_shell_ml/postscale_sequential_model.py b/tf_shell_ml/postscale_sequential_model.py index a65b412..fe554b6 100644 --- a/tf_shell_ml/postscale_sequential_model.py +++ b/tf_shell_ml/postscale_sequential_model.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf -import tensorflow.keras as keras import tf_shell from tf_shell_ml.model_base import SequentialBase @@ -117,8 +116,8 @@ def shell_train_step(self, features, labels): # Check if the post-scaled gradients overflowed. if not self.disable_encryption and self.check_overflow_INSECURE: - # Note, checking the backprop gradients requires decryption - # on the features party which breaks security of the protocol. + # Note, checking the gradients requires decryption on the + # features party which breaks security of the protocol. bp_scaling_factors = [g._scaling_factor for g in grads] dec_grads = [tf_shell.to_tensorflow(g, secret_key) for g in grads] self.warn_on_overflow( @@ -128,29 +127,11 @@ def shell_train_step(self, features, labels): "WARNING: Backprop gradient may have overflowed.", ) - # Mask the encrypted grads to prepare for decryption. The masks may - # overflow during the reduce_sum over the batch. When the masks are - # operated on, they are multiplied by the scaling factor, so it is - # not necessary to mask the full range -t/2 to t/2. (Though it is - # possible, it unnecessarily introduces noise into the ciphertext.) + # Mask the encrypted gradients. if not self.disable_masking and not self.disable_encryption: - t = tf.cast( - backprop_context.plaintext_modulus, tf.keras.backend.floatx() + grads, masks, mask_scaling_factors = self.mask_gradients( + backprop_context, grads ) - t_half = t // 2 - mask_scaling_factors = [g._scaling_factor for g in grads] - masks = [ - tf.random.uniform( - tf_shell.shape(g), - dtype=tf.keras.backend.floatx(), - minval=-t_half / s, - maxval=t_half / s, - ) - for g, s in zip(grads, mask_scaling_factors) - ] - - # Mask the encrypted gradients to prepare for decryption. - grads = [g + m for g, m in zip(grads, masks)] if not self.disable_noise: # Features party encrypts the max two norm to send to the labels @@ -158,6 +139,24 @@ def shell_train_step(self, features, labels): noise_context = self.noise_context_fn() noise_secret_key = tf_shell.create_key64(noise_context, self.cache_path) + # The noise context must have the same number of slots + # (encryption ring degree) as used in backpropagation. + if not self.disable_encryption: + tf.assert_equal( + backprop_context.num_slots, + noise_context.num_slots, + message="Backprop and noise contexts must have the same number of slots.", + ) + + # The noise scaling factor must always be 1. Encryptions already + # have the scaling factor applied when the noise is applied, + # and an additional noise scaling factor is not needed. + tf.assert_equal( + noise_context.scaling_factor, + 1, + message="Noise scaling factor must be 1.", + ) + # Compute the noise factors for the distributed noise sampling # sub-protocol. enc_a, enc_b = self.compute_noise_factors( @@ -170,22 +169,16 @@ def shell_train_step(self, features, labels): grads = [tf_shell.to_tensorflow(g, secret_key) for g in grads] # Sum the masked gradients over the batch. - if not self.disable_masking and not self.disable_encryption: + if self.disable_masking or self.disable_encryption: + # No mask has been added so a only a normal sum is required. + grads = [tf.reduce_sum(g, 0) for g in grads] + else: grads = [ tf_shell.reduce_sum_with_mod(g, 0, backprop_context, s) for g, s in zip(grads, mask_scaling_factors) ] - else: - grads = [tf.reduce_sum(g, 0) for g in grads] if not self.disable_noise: - if not self.disable_encryption: - tf.assert_equal( - backprop_context.num_slots, - noise_context.num_slots, - message="Backprop and noise contexts must have the same number of slots.", - ) - # Efficiently pack the masked gradients to prepare for # encryption. This is special because the masked gradients are # no longer batched so the packing must be done manually. @@ -202,14 +195,6 @@ def shell_train_step(self, features, labels): # secret key. flat_grads = tf_shell.to_tensorflow(grads, noise_secret_key) - if self.check_overflow_INSECURE: - nosie_scaling_factor = grads._scaling_factor - self.warn_on_overflow( - [flat_grads], - [nosie_scaling_factor], - noise_context.plaintext_modulus, - "WARNING: Noised gradient may have overflowed.", - ) # Unpack the noise after decryption. grads = self.unflatten_and_unpad_grad( flat_grads, @@ -218,31 +203,21 @@ def shell_train_step(self, features, labels): total_grad_size, ) + # Unmask the gradients. if not self.disable_masking and not self.disable_encryption: - # Sum the masks over the batch. - sum_masks = [ - tf_shell.reduce_sum_with_mod(m, 0, backprop_context, s) - for m, s in zip(masks, mask_scaling_factors) - ] - - # Unmask the batch gradient. - grads = [mg - m for mg, m in zip(grads, sum_masks)] - - # SHELL represents floats as integers between [0, t) where t is - # the plaintext modulus. To mimic SHELL's modulo operations in - # TensorFlow, numbers which exceed the range [-t/2, t/2] are - # shifted back into the range. - epsilon = tf.constant(1e-6, dtype=tf.keras.backend.floatx()) - - def rebalance(x, s): - r_bound = t_half / s + epsilon - l_bound = -t_half / s - epsilon - t_over_s = t / s - x = tf.where(x > r_bound, x - t_over_s, x) - x = tf.where(x < l_bound, x + t_over_s, x) - return x + grads = self.unmask_gradients( + backprop_context, grads, masks, mask_scaling_factors + ) - grads = [rebalance(g, s) for g, s in zip(grads, mask_scaling_factors)] + if not self.disable_noise: + if self.check_overflow_INSECURE: + nosie_scaling_factor = grads._scaling_factor + self.warn_on_overflow( + [flat_grads], + [nosie_scaling_factor], + noise_context.plaintext_modulus, + "WARNING: Noised gradient may have overflowed.", + ) # Apply the gradients to the model. self.optimizer.apply_gradients(zip(grads, self.weights)) diff --git a/tf_shell_ml/test/dpsgd_model_distrib_test.py b/tf_shell_ml/test/dpsgd_model_distrib_test.py index 6b0ac9b..c20c86d 100644 --- a/tf_shell_ml/test/dpsgd_model_distrib_test.py +++ b/tf_shell_ml/test/dpsgd_model_distrib_test.py @@ -101,13 +101,13 @@ def test_model(self): lambda: tf_shell.create_autocontext64( log2_cleartext_sz=23, scaling_factor=32, - noise_offset_log2=14, + noise_offset_log2=8, cache_path=cache, ), lambda: tf_shell.create_autocontext64( - log2_cleartext_sz=24, + log2_cleartext_sz=25, scaling_factor=1, - noise_offset_log2=0, + noise_offset_log2=2, cache_path=cache, ), labels_party_dev=labels_party_dev, @@ -117,14 +117,14 @@ def test_model(self): m.compile( loss=tf.keras.losses.CategoricalCrossentropy(), - optimizer=tf.keras.optimizers.Adam(0.10), + optimizer=tf.keras.optimizers.Adam(0.05), metrics=[tf.keras.metrics.CategoricalAccuracy()], ) history = m.fit( features_dataset, labels_dataset, - steps_per_epoch=2, + steps_per_epoch=4, epochs=1, verbose=2, validation_data=val_dataset, diff --git a/tf_shell_ml/test/dpsgd_model_float64_test.py b/tf_shell_ml/test/dpsgd_model_float64_test.py index ef6edb5..2513a8e 100644 --- a/tf_shell_ml/test/dpsgd_model_float64_test.py +++ b/tf_shell_ml/test/dpsgd_model_float64_test.py @@ -59,11 +59,11 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache) backprop_context_fn=lambda: tf_shell.create_autocontext64( log2_cleartext_sz=23, scaling_factor=32, - noise_offset_log2=14, + noise_offset_log2=8, cache_path=cache, ), noise_context_fn=lambda: tf_shell.create_autocontext64( - log2_cleartext_sz=24, + log2_cleartext_sz=25, scaling_factor=1, noise_offset_log2=0, cache_path=cache, @@ -87,7 +87,7 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache) history = m.fit( features_dataset, labels_dataset, - steps_per_epoch=2, + steps_per_epoch=4, epochs=1, verbose=2, validation_data=val_dataset, diff --git a/tf_shell_ml/test/dpsgd_model_local_test.py b/tf_shell_ml/test/dpsgd_model_local_test.py index 5ef25ef..64aadf3 100644 --- a/tf_shell_ml/test/dpsgd_model_local_test.py +++ b/tf_shell_ml/test/dpsgd_model_local_test.py @@ -59,11 +59,11 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache) backprop_context_fn=lambda: tf_shell.create_autocontext64( log2_cleartext_sz=23, scaling_factor=32, - noise_offset_log2=14, + noise_offset_log2=8, cache_path=cache, ), noise_context_fn=lambda: tf_shell.create_autocontext64( - log2_cleartext_sz=24, + log2_cleartext_sz=25, scaling_factor=1, noise_offset_log2=2, cache_path=cache, @@ -87,7 +87,7 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache) history = m.fit( features_dataset, labels_dataset, - steps_per_epoch=2, + steps_per_epoch=4, epochs=1, verbose=2, validation_data=val_dataset,