Skip to content

Commit

Permalink
Fix mask scaling factor offset.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Jan 14, 2025
1 parent db4434b commit a787ce7
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 54 deletions.
15 changes: 9 additions & 6 deletions tf_shell/python/shell_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,9 +525,7 @@ def randomized_rounding(tensor):
Performs randomized rounding on a tensorflow tensor with dynamic/unknown shape.
"""
with tf.name_scope("randomized_rounding"):
with tf.name_scope("input_shape"):
dynamic_shape = tf.shape(tensor)

dynamic_shape = tf.shape(tensor)
floor = tf.floor(tensor)
decimal_part = tensor - floor
random_mask = (
Expand Down Expand Up @@ -574,7 +572,7 @@ def _decode_scaling(scaled_tensor, output_dtype, scaling_factor):
raise ValueError(f"Unsupported dtype {output_dtype}")


def to_shell_plaintext(tensor, context):
def to_shell_plaintext(tensor, context, override_scaling_factor=None):
"""Converts a Tensorflow tensor to a ShellTensor which holds a plaintext.
Under the hood, this means encoding the Tensorflow tensor to a BGV style
polynomial representation with the sign and scaling factor encoded as
Expand All @@ -591,8 +589,13 @@ def to_shell_plaintext(tensor, context):
return tensor # Do nothing, already a shell plaintext.

elif isinstance(tensor, tf.Tensor):
if override_scaling_factor is not None:
scaling_factor = override_scaling_factor
else:
scaling_factor = context.scaling_factor

# Shell tensor represents floats as integers * scaling_factor.
scaled_tensor = _encode_scaling(tensor, context.scaling_factor)
scaled_tensor = _encode_scaling(tensor, scaling_factor)

# Pad the tensor to the correct number of slots.
with tf.name_scope("pad_to_slots"):
Expand All @@ -617,7 +620,7 @@ def to_shell_plaintext(tensor, context):
_level=context.level,
_num_mod_reductions=0,
_underlying_dtype=tensor.dtype,
_scaling_factor=context.scaling_factor,
_scaling_factor=scaling_factor,
_is_enc=False,
)
else:
Expand Down
117 changes: 76 additions & 41 deletions tf_shell_ml/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def train_step(self, features, labels):
Returns:
metrics (dict): The metrics resulting from the training step.
"""
# When called from outside this class, the caches for encryption keys
# and contexts are not guaranteed to be populated, so set
# read_key_from_cache to False.
metrics, batch_size_should_be = self.shell_train_step(
features, labels, read_key_from_cache=False
)
Expand All @@ -118,8 +121,6 @@ def train_step_tf_func(self, features, labels, read_key_from_cache):
"""
tf.function wrapper around shell train step.
"""
# When running the training loop, the caches for encryption keys and
# contexts have already been populated.
return self.shell_train_step(features, labels, read_key_from_cache)

def compute_grads(self, features, enc_labels):
Expand Down Expand Up @@ -334,6 +335,9 @@ def fit(
zip(features_dataset, labels_dataset)
):
callback_list.on_train_batch_begin(step, logs)
# The caches for encryption keys and contexts have already been
# populated during the dataset preparation step. Set
# read_key_from_cache to True.
logs, batch_size_should_be = self.train_step_tf_func(
batch_x, batch_y, read_key_from_cache=True
)
Expand Down Expand Up @@ -530,6 +534,53 @@ def unflatten_grad_list(self, flat_grads, metadata_list):

return grads

def spoof_int_gradients(self, grads):
"""
Spoofs the gradients to have dtype int64 and a scaling factor of 1.
Args:
grads (list of tf_shell.ShellTensor): The list of gradient tensors
to be spoofed.
Returns:
list of tf_shell.ShellTensor64: The list of gradient tensors with
dtype int64 and scaling factor 1.
"""
int64_grads = [
tf_shell.ShellTensor64(
_raw_tensor=g._raw_tensor,
_context=g._context,
_level=g._level,
_num_mod_reductions=g._num_mod_reductions,
_underlying_dtype=tf.int64, # Spoof the dtype.
_scaling_factor=1, # Spoof the scaling factor.
_is_enc=g._is_enc,
_is_fast_rotated=g._is_fast_rotated,
)
for g in grads
]
return int64_grads

def unspoof_int_gradients(self, int64_grads, orig_scaling_factors):
"""
Recover the floating point gradients from spoofed int64 gradients using
the original scaling factors.
Args:
int64_grads (list of tf.Tensor): List of gradients that have been
spoofed as int64 with scaling factor 1.
orig_scaling_factors (list of float): List of original scaling
factors used to spoof the gradients.
Returns:
list of tf.Tensor: List of recovered floating point gradients.
"""
grads = [
tf.cast(g, tf.keras.backend.floatx()) / s
for g, s in zip(int64_grads, orig_scaling_factors)
]
return grads

def mask_gradients(self, context, grads):
"""
Adds random masks to the encrypted gradients within the range [-t/2,
Expand All @@ -547,9 +598,6 @@ def mask_gradients(self, context, grads):
- mask_scaling_factors (list): The original scaling factors of the
gradients.
"""
if self.disable_masking or self.disable_encryption:
return grads, None

t = tf.cast(tf.identity(context.plaintext_modulus), dtype=tf.int64)
t_half = t // 2

Expand All @@ -563,29 +611,18 @@ def mask_gradients(self, context, grads):
for g in grads
]

# Spoof the gradient's dtype to int64 and scaling factor to 1. This is
# necessary so when adding the masks to the gradients, tf_shell does not
# attempt to do any casting to floats matching scaling factors.
# Conversion back to floating point and handling the scaling factors is
# done in unmask_gradients.
int64_grads = [
tf_shell.ShellTensor64(
_raw_tensor=g._raw_tensor,
_context=g._context,
_level=g._level,
_num_mod_reductions=g._num_mod_reductions,
_underlying_dtype=tf.int64, # Spoof the dtype.
_scaling_factor=1, # Spoof the scaling factor.
_is_enc=g._is_enc,
_is_fast_rotated=g._is_fast_rotated,
)
for g in grads
# Convert to shell plaintexts manually to avoid using the scaling factor
# in the 'grads[i]' contexts when the addition happens in the next line
# below.
shell_masks = [
tf_shell.to_shell_plaintext(m, context, override_scaling_factor=1)
for m in masks
]

# Add the masks.
int64_grads = [(g + m) for g, m in zip(int64_grads, masks)]
masked_grads = [(g + m) for g, m in zip(grads, shell_masks)]

return int64_grads, masks
return masked_grads, masks

def unmask_gradients(self, context, grads, masks, mask_scaling_factors):
"""
Expand Down Expand Up @@ -616,12 +653,6 @@ def unmask_gradients(self, context, grads, masks, mask_scaling_factors):
for mg, s in zip(masks_and_grads, mask_scaling_factors)
]

# Recover the floating point values using the scaling factors.
unmasked_grads = [
tf.cast(g, tf.keras.backend.floatx()) / s
for g, s in zip(unmasked_grads, mask_scaling_factors)
]

return unmasked_grads

def compute_noise_factors(
Expand Down Expand Up @@ -748,10 +779,6 @@ def warn_on_overflow(self, grads, scaling_factors, plaintext_modulus, message):
corresponding to each gradient tensor.
plaintext_modulus (float): The plaintext modulus value.
message (str): The message to print if an overflow is detected.
Returns:
tf.Tensor: A boolean tensor indicating whether an overflow was
detected.
"""
# t = tf.cast(plaintext_modulus, grads[0].dtype)
t_half = plaintext_modulus / 2
Expand All @@ -776,8 +803,6 @@ def warn_on_overflow(self, grads, scaling_factors, plaintext_modulus, message):
lambda: tf.identity(overflowed),
)

return overflowed

def shell_train_step(self, features, labels, read_key_from_cache):
"""
The main training loop for the PostScale protocol.
Expand Down Expand Up @@ -816,12 +841,11 @@ def shell_train_step(self, features, labels, read_key_from_cache):
grads, max_two_norm, predictions = self.compute_grads(features, enc_labels)

with tf.device(self.features_party_dev):
# Store the scaling factors of the gradients before masking, as
# masking sets them to 1.
if not self.disable_encryption and not self.disable_masking:
backprop_scaling_factors = [g._scaling_factor for g in grads]
else:
# Store the scaling factors of the gradients.
if self.disable_encryption:
backprop_scaling_factors = [1] * len(grads)
else:
backprop_scaling_factors = [g._scaling_factor for g in grads]

# Check if the backproped gradients overflowed.
if not self.disable_encryption and self.check_overflow_INSECURE:
Expand All @@ -837,6 +861,12 @@ def shell_train_step(self, features, labels, read_key_from_cache):
"WARNING: Backprop gradient may have overflowed.",
)

# Spoof encrypted gradients as int64 with scaling factor 1 for
# subsequent masking and noising, so the intermediate decryption
# is not affected by the scaling factors.
if not self.disable_encryption:
grads = self.spoof_int_gradients(grads)

# Mask the encrypted gradients.
if not self.disable_masking and not self.disable_encryption:
grads, masks = self.mask_gradients(backprop_context, grads)
Expand Down Expand Up @@ -942,6 +972,11 @@ def shell_train_step(self, features, labels, read_key_from_cache):
backprop_context, grads, masks, backprop_scaling_factors
)

# Recover the original scaling factor of the gradients if they were
# originally encrypted.
if not self.disable_encryption:
grads = self.unspoof_int_gradients(grads, backprop_scaling_factors)

if not self.disable_noise:
if self.check_overflow_INSECURE:
self.warn_on_overflow(
Expand Down
11 changes: 6 additions & 5 deletions tf_shell_ml/postscale_sequential_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,19 @@ def compute_grads(self, features, enc_labels):
max_two_norms_list = []

with tf.device(self.features_party_dev):
split_features, remainder = self.split_with_padding(
split_features, end_pad = self.split_with_padding(
features, len(self.jacobian_devices)
)

for i, d in enumerate(self.jacobian_devices):
with tf.device(d):
f = tf.identity(split_features[i]) # copy to GPU if needed
prediction, jacobians = self.predict_and_jacobian(f)
if i == len(self.jacobian_devices) - 1 and remainder > 0:
# The last device may have a remainder that was padded.
prediction = prediction[: features.shape[0] - remainder]
jacobians = jacobians[: features.shape[0] - remainder]
if i == len(self.jacobian_devices) - 1 and end_pad > 0:
# The last device's features may have been padded for even
# split jacobian computation across multiple devices.
prediction = prediction[:-end_pad]
jacobians = jacobians[:-end_pad]
predictions_list.append(prediction)
jacobians_list.append(jacobians)
max_two_norms_list.append(self.jacobian_max_two_norm(jacobians))
Expand Down
1 change: 1 addition & 0 deletions tf_shell_ml/test/dpsgd_conv_model_local_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def _test_model(self, disable_encryption, disable_masking, disable_noise, cache)
validation_data=val_dataset,
)

print(history.history["val_categorical_accuracy"][-1])
self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.13)

def test_model(self):
Expand Down
4 changes: 2 additions & 2 deletions tf_shell_ml/test/postscale_model_distrib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_model(self):

# Clip dataset images to limit memory usage. The model accuracy will be
# bad but this test only measures functionality.
x_train, x_test = x_train[:, :250], x_test[:, :250]
x_train, x_test = x_train[:, :350], x_test[:, :350]

# Set a seed for shuffling both features and labels the same way.
seed = 42
Expand Down Expand Up @@ -130,7 +130,7 @@ def test_model(self):

cache_dir.cleanup()

self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.3)
self.assertGreater(history.history["val_categorical_accuracy"][-1], 0.25)


if __name__ == "__main__":
Expand Down

0 comments on commit a787ce7

Please sign in to comment.