Skip to content

Commit

Permalink
Masking is performed in the integer domain to avoid leakages associat…
Browse files Browse the repository at this point in the history
…ed with floating point.
  • Loading branch information
james-choncholas committed Jan 4, 2025
1 parent 1f5a798 commit 94451d1
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 159 deletions.
123 changes: 45 additions & 78 deletions tf_shell_ml/dpsgd_sequential_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
import tensorflow.keras as keras
import tf_shell
import tf_shell_ml
from tf_shell_ml.model_base import SequentialBase


Expand Down Expand Up @@ -106,8 +104,8 @@ def shell_train_step(self, features, labels):

# Check if the backproped gradients overflowed.
if not self.disable_encryption and self.check_overflow_INSECURE:
# Note, checking the backprop gradients requires decryption
# on the features party which breaks security of the protocol.
# Note, checking the gradients requires decryption on the
# features party which breaks security of the protocol.
bp_scaling_factors = [g._scaling_factor for g in dJ_dw]
dec_dJ_dw = [
tf_shell.to_tensorflow(g, backprop_secret_key) for g in dJ_dw
Expand All @@ -119,39 +117,39 @@ def shell_train_step(self, features, labels):
"WARNING: Backprop gradient may have overflowed.",
)

# Mask the encrypted grads to prepare for decryption. The masks may
# overflow during the reduce_sum over the batch. When the masks are
# operated on, they are multiplied by the scaling factor, so it is
# not necessary to mask the full range -t/2 to t/2. (Though it is
# possible, it unnecessarily introduces noise into the ciphertext.)
if self.disable_masking or self.disable_encryption:
grads = [g for g in reversed(dJ_dw)]
else:
t = tf.cast(
backprop_context.plaintext_modulus, tf.keras.backend.floatx()
)
t_half = t // 2
mask_scaling_factors = [g._scaling_factor for g in reversed(dJ_dw)]
masks = [
tf.random.uniform(
tf_shell.shape(g),
dtype=tf.keras.backend.floatx(),
minval=-t_half / s,
maxval=t_half / s,
)
for g, s in zip(reversed(dJ_dw), mask_scaling_factors)
]
# Reverse the order to match the order of the layers.
grads = [g for g in reversed(dJ_dw)]

# Mask the encrypted gradients and reverse the order to match
# the order of the layers.
grads = [(g + m) for g, m in zip(reversed(dJ_dw), masks)]
# Mask the encrypted gradients.
if not self.disable_masking and not self.disable_encryption:
grads, masks, mask_scaling_factors = self.mask_gradients(
backprop_context, grads
)

if not self.disable_noise:
# Set up the features party size of the distributed noise
# Set up the features party side of the distributed noise
# sampling sub-protocol.
noise_context = self.noise_context_fn()
noise_secret_key = tf_shell.create_key64(noise_context, self.cache_path)

# The noise context must have the same number of slots
# (encryption ring degree) as used in backpropagation.
if not self.disable_encryption:
tf.assert_equal(
backprop_context.num_slots,
noise_context.num_slots,
message="Backprop and noise contexts must have the same number of slots.",
)

# The noise scaling factor must always be 1. Encryptions already
# have the scaling factor applied when the noise is applied,
# and an additional noise scaling factor is not needed.
tf.assert_equal(
noise_context.scaling_factor,
1,
message="Noise scaling factor must be 1.",
)

# Compute the noise factors for the distributed noise sampling
# sub-protocol.
enc_a, enc_b = self.compute_noise_factors(
Expand All @@ -164,8 +162,8 @@ def shell_train_step(self, features, labels):
grads = [tf_shell.to_tensorflow(g, backprop_secret_key) for g in grads]

# Sum the masked gradients over the batch.
if self.disable_encryption or self.disable_masking:
# No mask, only sum required.
if self.disable_masking or self.disable_encryption:
# No mask has been added so a only a normal sum is required.
grads = [tf.reduce_sum(g, axis=0) for g in grads]
else:
grads = [
Expand All @@ -174,23 +172,13 @@ def shell_train_step(self, features, labels):
]

if not self.disable_noise:
if not self.disable_encryption:
tf.assert_equal(
backprop_context.num_slots,
noise_context.num_slots,
message="Backprop and noise contexts must have the same number of slots.",
)

# Efficiently pack the masked gradients to prepare for adding
# the encrypted noise. This is special because the masked
# gradients are no longer batched, so the packing must be done
# manually.
(
flat_grads,
grad_shapes,
flattened_grad_shapes,
total_grad_size,
) = self.flatten_and_pad_grad_list(grads, noise_context.num_slots)
(flat_grads, grad_shapes, flattened_grad_shapes, total_grad_size) = (
self.flatten_and_pad_grad_list(grads, noise_context.num_slots)
)

# Add the encrypted noise to the masked gradients.
grads = self.noise_gradients(noise_context, flat_grads, enc_a, enc_b)
Expand All @@ -201,15 +189,6 @@ def shell_train_step(self, features, labels):
# secret key.
flat_grads = tf_shell.to_tensorflow(grads, noise_secret_key)

if self.check_overflow_INSECURE:
nosie_scaling_factors = grads._scaling_factor
self.warn_on_overflow(
[flat_grads],
[nosie_scaling_factors],
noise_context.plaintext_modulus,
"WARNING: Noised gradient may have overflowed.",
)

# Unpack the noised grads after decryption.
grads = self.unflatten_and_unpad_grad(
flat_grads,
Expand All @@ -218,32 +197,20 @@ def shell_train_step(self, features, labels):
total_grad_size,
)

# Unmask the gradients.
if not self.disable_masking and not self.disable_encryption:
# Sum the masks over the batch.
sum_masks = [
tf_shell.reduce_sum_with_mod(m, 0, backprop_context, s)
for m, s in zip(masks, mask_scaling_factors)
]

# Unmask the batch gradient.
grads = [mg - m for mg, m in zip(grads, sum_masks)]

# SHELL represents floats as integers between [0, t) where t is
# the plaintext modulus. To mimic SHELL's modulo operations in
# TensorFlow, numbers which exceed the range [-t/2, t/2] are
# shifted back into the range. Note, this could be done as a
# custom op with tf-shell, but this is simpler for now.
epsilon = tf.constant(1e-6, dtype=tf.keras.backend.floatx())

def rebalance(x, s):
r_bound = t_half / s + epsilon
l_bound = -t_half / s - epsilon
t_over_s = t / s
x = tf.where(x > r_bound, x - t_over_s, x)
x = tf.where(x < l_bound, x + t_over_s, x)
return x
grads = self.unmask_gradients(
backprop_context, grads, masks, mask_scaling_factors
)

grads = [rebalance(g, s) for g, s in zip(grads, mask_scaling_factors)]
if not self.disable_noise:
if self.check_overflow_INSECURE:
self.warn_on_overflow(
[flat_grads],
[1],
noise_context.plaintext_modulus,
"WARNING: Noised gradient may have overflowed.",
)

# Apply the gradients to the model.
self.optimizer.apply_gradients(zip(grads, self.weights))
Expand Down
76 changes: 71 additions & 5 deletions tf_shell_ml/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,73 @@ def unflatten_and_unpad_grad(
grads_list = [tf.reshape(g, s) for g, s in zip(grads_list, grad_shapes)]
return grads_list

def mask_gradients(self, context, grads):
"""Adds a random masks to the encrypted gradients between [-t/2, t/2]."""
if self.disable_masking or self.disable_encryption:
return grads, None

t = tf.cast(context.plaintext_modulus, dtype=tf.int64)
t_half = t // 2
mask_scaling_factors = [g._scaling_factor for g in grads]

masks = [
tf.random.uniform(
tf_shell.shape(g),
dtype=tf.int64,
minval=-t_half,
maxval=t_half,
)
for g in grads
]

# Spoof the gradient's dtype to int64 and scaling factor to 1.
# This is necessary so when adding the masks to the gradients, tf_shell
# does not attempt to do any casting to floats matching scaling factors.
# Conversion back to floating point and handling the scaling factors is
# done in unmask_gradients.
int64_grads = [
tf_shell.ShellTensor64(
_raw_tensor=g._raw_tensor,
_context=g._context,
_level=g._level,
_num_mod_reductions=g._num_mod_reductions,
_underlying_dtype=tf.int64, # Spoof the dtype.
_scaling_factor=1, # Spoof the scaling factor.
_is_enc=g._is_enc,
_is_fast_rotated=g._is_fast_rotated,
)
for g in grads
]

# Add the masks.
int64_grads = [(g + m) for g, m in zip(int64_grads, masks)]

return int64_grads, masks, mask_scaling_factors

def unmask_gradients(self, context, grads, masks, mask_scaling_factors):
"""Subtracts the masks from the gradients."""

# Sum the masks over the batch.
sum_masks = [
tf_shell.reduce_sum_with_mod(m, 0, context, s)
for m, s in zip(masks, mask_scaling_factors)
]

# Unmask the batch gradient.
masks_and_grads = [tf.stack([-m, g]) for m, g in zip(sum_masks, grads)]
unmasked_grads = [
tf_shell.reduce_sum_with_mod(mg, 0, context, s)
for mg, s in zip(masks_and_grads, mask_scaling_factors)
]

# Recover the floating point values using the scaling factors.
unmasked_grads = [
tf.cast(g, tf.keras.backend.floatx()) / s
for g, s in zip(unmasked_grads, mask_scaling_factors)
]

return unmasked_grads

def compute_noise_factors(self, context, secret_key, sensitivity):
bounded_sensitivity = tf.cast(sensitivity, dtype=tf.float32) * tf.sqrt(2.0)

Expand All @@ -437,7 +504,6 @@ def compute_noise_factors(self, context, secret_key, sensitivity):
a, b = tf_shell.sample_centered_gaussian_f(bounded_sensitivity, self.dg_params)

def _prep_noise_factor(x):
x = tf.cast(x, tf.keras.backend.floatx())
x = tf.expand_dims(x, 0)
x = tf.expand_dims(x, 0)
x = tf.repeat(x, context.num_slots, axis=0)
Expand Down Expand Up @@ -476,11 +542,11 @@ def warn_on_overflow(self, grads, scaling_factors, plaintext_modulus, message):
may have overflowed. This also must take the scaling factor into account
so the range is divided by the scaling factor.
"""
t = tf.cast(plaintext_modulus, grads[0].dtype)
t_half = t / 2
# t = tf.cast(plaintext_modulus, grads[0].dtype)
t_half = plaintext_modulus / 2

over_by = [
tf.reduce_max(tf.abs(g) - t_half / 2 / s)
tf.reduce_max(tf.abs(g) - tf.cast(t_half / 2 / s, g.dtype))
for g, s in zip(grads, scaling_factors)
]
max_over_by = tf.reduce_max(over_by)
Expand All @@ -494,7 +560,7 @@ def warn_on_overflow(self, grads, scaling_factors, plaintext_modulus, message):
over_by,
"(positive number indicates overflow amount).",
"Values should be less than",
[t_half / 2 / s for s in scaling_factors],
[tf.cast(t_half / 2 / s, grads[0].dtype) for s in scaling_factors],
),
lambda: tf.identity(overflowed),
)
Expand Down
Loading

0 comments on commit 94451d1

Please sign in to comment.