diff --git a/tf_shell/python/shell_tensor.py b/tf_shell/python/shell_tensor.py index 209f18b..781a78a 100644 --- a/tf_shell/python/shell_tensor.py +++ b/tf_shell/python/shell_tensor.py @@ -525,9 +525,7 @@ def randomized_rounding(tensor): Performs randomized rounding on a tensorflow tensor with dynamic/unknown shape. """ with tf.name_scope("randomized_rounding"): - with tf.name_scope("input_shape"): - dynamic_shape = tf.shape(tensor) - + dynamic_shape = tf.shape(tensor) floor = tf.floor(tensor) decimal_part = tensor - floor random_mask = ( @@ -574,7 +572,7 @@ def _decode_scaling(scaled_tensor, output_dtype, scaling_factor): raise ValueError(f"Unsupported dtype {output_dtype}") -def to_shell_plaintext(tensor, context): +def to_shell_plaintext(tensor, context, override_scaling_factor=None): """Converts a Tensorflow tensor to a ShellTensor which holds a plaintext. Under the hood, this means encoding the Tensorflow tensor to a BGV style polynomial representation with the sign and scaling factor encoded as @@ -591,8 +589,13 @@ def to_shell_plaintext(tensor, context): return tensor # Do nothing, already a shell plaintext. elif isinstance(tensor, tf.Tensor): + if override_scaling_factor is not None: + scaling_factor = override_scaling_factor + else: + scaling_factor = context.scaling_factor + # Shell tensor represents floats as integers * scaling_factor. - scaled_tensor = _encode_scaling(tensor, context.scaling_factor) + scaled_tensor = _encode_scaling(tensor, scaling_factor) # Pad the tensor to the correct number of slots. with tf.name_scope("pad_to_slots"): @@ -617,7 +620,7 @@ def to_shell_plaintext(tensor, context): _level=context.level, _num_mod_reductions=0, _underlying_dtype=tensor.dtype, - _scaling_factor=context.scaling_factor, + _scaling_factor=scaling_factor, _is_enc=False, ) else: diff --git a/tf_shell_ml/model_base.py b/tf_shell_ml/model_base.py index 85bf04f..b507b3b 100644 --- a/tf_shell_ml/model_base.py +++ b/tf_shell_ml/model_base.py @@ -108,6 +108,9 @@ def train_step(self, features, labels): Returns: metrics (dict): The metrics resulting from the training step. """ + # When called from outside this class, the caches for encryption keys + # and contexts are not guaranteed to be populated, so set + # read_key_from_cache to False. metrics, batch_size_should_be = self.shell_train_step( features, labels, read_key_from_cache=False ) @@ -118,8 +121,6 @@ def train_step_tf_func(self, features, labels, read_key_from_cache): """ tf.function wrapper around shell train step. """ - # When running the training loop, the caches for encryption keys and - # contexts have already been populated. return self.shell_train_step(features, labels, read_key_from_cache) def compute_grads(self, features, enc_labels): @@ -334,6 +335,9 @@ def fit( zip(features_dataset, labels_dataset) ): callback_list.on_train_batch_begin(step, logs) + # The caches for encryption keys and contexts have already been + # populated during the dataset preparation step. Set + # read_key_from_cache to True. logs, batch_size_should_be = self.train_step_tf_func( batch_x, batch_y, read_key_from_cache=True ) @@ -530,6 +534,53 @@ def unflatten_grad_list(self, flat_grads, metadata_list): return grads + def spoof_int_gradients(self, grads): + """ + Spoofs the gradients to have dtype int64 and a scaling factor of 1. + + Args: + grads (list of tf_shell.ShellTensor): The list of gradient tensors + to be spoofed. + + Returns: + list of tf_shell.ShellTensor64: The list of gradient tensors with + dtype int64 and scaling factor 1. + """ + int64_grads = [ + tf_shell.ShellTensor64( + _raw_tensor=g._raw_tensor, + _context=g._context, + _level=g._level, + _num_mod_reductions=g._num_mod_reductions, + _underlying_dtype=tf.int64, # Spoof the dtype. + _scaling_factor=1, # Spoof the scaling factor. + _is_enc=g._is_enc, + _is_fast_rotated=g._is_fast_rotated, + ) + for g in grads + ] + return int64_grads + + def unspoof_int_gradients(self, int64_grads, orig_scaling_factors): + """ + Recover the floating point gradients from spoofed int64 gradients using + the original scaling factors. + + Args: + int64_grads (list of tf.Tensor): List of gradients that have been + spoofed as int64 with scaling factor 1. + orig_scaling_factors (list of float): List of original scaling + factors used to spoof the gradients. + + Returns: + list of tf.Tensor: List of recovered floating point gradients. + """ + grads = [ + tf.cast(g, tf.keras.backend.floatx()) / s + for g, s in zip(int64_grads, orig_scaling_factors) + ] + return grads + def mask_gradients(self, context, grads): """ Adds random masks to the encrypted gradients within the range [-t/2, @@ -547,9 +598,6 @@ def mask_gradients(self, context, grads): - mask_scaling_factors (list): The original scaling factors of the gradients. """ - if self.disable_masking or self.disable_encryption: - return grads, None - t = tf.cast(tf.identity(context.plaintext_modulus), dtype=tf.int64) t_half = t // 2 @@ -563,29 +611,18 @@ def mask_gradients(self, context, grads): for g in grads ] - # Spoof the gradient's dtype to int64 and scaling factor to 1. This is - # necessary so when adding the masks to the gradients, tf_shell does not - # attempt to do any casting to floats matching scaling factors. - # Conversion back to floating point and handling the scaling factors is - # done in unmask_gradients. - int64_grads = [ - tf_shell.ShellTensor64( - _raw_tensor=g._raw_tensor, - _context=g._context, - _level=g._level, - _num_mod_reductions=g._num_mod_reductions, - _underlying_dtype=tf.int64, # Spoof the dtype. - _scaling_factor=1, # Spoof the scaling factor. - _is_enc=g._is_enc, - _is_fast_rotated=g._is_fast_rotated, - ) - for g in grads + # Convert to shell plaintexts manually to avoid using the scaling factor + # in the 'grads[i]' contexts when the addition happens in the next line + # below. + shell_masks = [ + tf_shell.to_shell_plaintext(m, context, override_scaling_factor=1) + for m in masks ] # Add the masks. - int64_grads = [(g + m) for g, m in zip(int64_grads, masks)] + masked_grads = [(g + m) for g, m in zip(grads, shell_masks)] - return int64_grads, masks + return masked_grads, masks def unmask_gradients(self, context, grads, masks, mask_scaling_factors): """ @@ -616,12 +653,6 @@ def unmask_gradients(self, context, grads, masks, mask_scaling_factors): for mg, s in zip(masks_and_grads, mask_scaling_factors) ] - # Recover the floating point values using the scaling factors. - unmasked_grads = [ - tf.cast(g, tf.keras.backend.floatx()) / s - for g, s in zip(unmasked_grads, mask_scaling_factors) - ] - return unmasked_grads def compute_noise_factors( @@ -748,10 +779,6 @@ def warn_on_overflow(self, grads, scaling_factors, plaintext_modulus, message): corresponding to each gradient tensor. plaintext_modulus (float): The plaintext modulus value. message (str): The message to print if an overflow is detected. - - Returns: - tf.Tensor: A boolean tensor indicating whether an overflow was - detected. """ # t = tf.cast(plaintext_modulus, grads[0].dtype) t_half = plaintext_modulus / 2 @@ -776,8 +803,6 @@ def warn_on_overflow(self, grads, scaling_factors, plaintext_modulus, message): lambda: tf.identity(overflowed), ) - return overflowed - def shell_train_step(self, features, labels, read_key_from_cache): """ The main training loop for the PostScale protocol. @@ -816,12 +841,11 @@ def shell_train_step(self, features, labels, read_key_from_cache): grads, max_two_norm, predictions = self.compute_grads(features, enc_labels) with tf.device(self.features_party_dev): - # Store the scaling factors of the gradients before masking, as - # masking sets them to 1. - if not self.disable_encryption and not self.disable_masking: - backprop_scaling_factors = [g._scaling_factor for g in grads] - else: + # Store the scaling factors of the gradients. + if self.disable_encryption: backprop_scaling_factors = [1] * len(grads) + else: + backprop_scaling_factors = [g._scaling_factor for g in grads] # Check if the backproped gradients overflowed. if not self.disable_encryption and self.check_overflow_INSECURE: @@ -837,6 +861,12 @@ def shell_train_step(self, features, labels, read_key_from_cache): "WARNING: Backprop gradient may have overflowed.", ) + # Spoof encrypted gradients as int64 with scaling factor 1 for + # subsequent masking and noising, so the intermediate decryption + # is not affected by the scaling factors. + if not self.disable_encryption: + grads = self.spoof_int_gradients(grads) + # Mask the encrypted gradients. if not self.disable_masking and not self.disable_encryption: grads, masks = self.mask_gradients(backprop_context, grads) @@ -942,6 +972,11 @@ def shell_train_step(self, features, labels, read_key_from_cache): backprop_context, grads, masks, backprop_scaling_factors ) + # Recover the original scaling factor of the gradients if they were + # originally encrypted. + if not self.disable_encryption: + grads = self.unspoof_int_gradients(grads, backprop_scaling_factors) + if not self.disable_noise: if self.check_overflow_INSECURE: self.warn_on_overflow( diff --git a/tf_shell_ml/postscale_sequential_model.py b/tf_shell_ml/postscale_sequential_model.py index 287fd87..2079d3f 100644 --- a/tf_shell_ml/postscale_sequential_model.py +++ b/tf_shell_ml/postscale_sequential_model.py @@ -48,7 +48,7 @@ def compute_grads(self, features, enc_labels): max_two_norms_list = [] with tf.device(self.features_party_dev): - split_features, remainder = self.split_with_padding( + split_features, end_pad = self.split_with_padding( features, len(self.jacobian_devices) ) @@ -56,10 +56,11 @@ def compute_grads(self, features, enc_labels): with tf.device(d): f = tf.identity(split_features[i]) # copy to GPU if needed prediction, jacobians = self.predict_and_jacobian(f) - if i == len(self.jacobian_devices) - 1 and remainder > 0: - # The last device may have a remainder that was padded. - prediction = prediction[: features.shape[0] - remainder] - jacobians = jacobians[: features.shape[0] - remainder] + if i == len(self.jacobian_devices) - 1 and end_pad > 0: + # The last device's features may have been padded for even + # split jacobian computation across multiple devices. + prediction = prediction[:-end_pad] + jacobians = jacobians[:-end_pad] predictions_list.append(prediction) jacobians_list.append(jacobians) max_two_norms_list.append(self.jacobian_max_two_norm(jacobians)) diff --git a/tf_shell_ml/test/dpsgd_conv_model_local_test.py b/tf_shell_ml/test/dpsgd_conv_model_local_test.py index 0e79963..06405b8 100644 --- a/tf_shell_ml/test/dpsgd_conv_model_local_test.py +++ b/tf_shell_ml/test/dpsgd_conv_model_local_test.py @@ -124,6 +124,7 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache) validation_data=val_dataset, ) + print(history.history["val_categorical_accuracy"][-1]) self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.13) def test_model(self): diff --git a/tf_shell_ml/test/postscale_model_distrib_test.py b/tf_shell_ml/test/postscale_model_distrib_test.py index b35f9df..34b014c 100644 --- a/tf_shell_ml/test/postscale_model_distrib_test.py +++ b/tf_shell_ml/test/postscale_model_distrib_test.py @@ -70,7 +70,7 @@ def test_model(self): # Clip dataset images to limit memory usage. The model accuracy will be # bad but this test only measures functionality. - x_train, x_test = x_train[:, :250], x_test[:, :250] + x_train, x_test = x_train[:, :350], x_test[:, :350] # Set a seed for shuffling both features and labels the same way. seed = 42 @@ -130,7 +130,7 @@ def test_model(self): cache_dir.cleanup() - self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.3) + self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.25) if __name__ == "__main__":