Skip to content

Commit

Permalink
More accurate upper bound on gradient l2 norm.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Jan 17, 2025
1 parent b41afbc commit 00a3450
Showing 1 changed file with 32 additions and 8 deletions.
40 changes: 32 additions & 8 deletions tf_shell_ml/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,9 @@ def jacobian_max_two_norm(self, jacobians):

batch_size = jacobians[0].shape[0]
num_output_classes = jacobians[0].shape[1]

# Sum of squares accumulates the l2 norm squared of the per-example,
# per-output class gradients.
sum_of_squares = tf.zeros(
[batch_size, num_output_classes], dtype=tf.keras.backend.floatx()
)
Expand All @@ -496,9 +499,18 @@ def jacobian_max_two_norm(self, jacobians):
# recover just the number of dimensions in the weights.
num_weight_dims = len(j.shape) - 2
reduce_sum_dims = range(2, 2 + num_weight_dims)
sum_of_squares += tf.reduce_sum(j * j, axis=reduce_sum_dims)
sum_of_squares += tf.reduce_sum(tf.square(j), axis=reduce_sum_dims)

# Find the largest two l2 norms over the output classes.
top_2_l2s = tf.sqrt(tf.math.top_k(sum_of_squares, k=2).values)

# Compute their sum, i.e. ||v|| + ||w|| where v and w are the two
# largest per-class gradients by l2 norm, for every per-example
# gradient.
max_two_norm = tf.reduce_sum(top_2_l2s, axis=1)

return tf.sqrt(tf.reduce_max(sum_of_squares))
# Return the maximum over the examples.
return tf.reduce_max(max_two_norm)

def flatten_and_pad_grad(self, grad, dim_0_sz):
num_el = tf.cast(tf.size(grad), dtype=tf.int64)
Expand Down Expand Up @@ -684,12 +696,7 @@ def compute_noise_factors(
tf.errors.InvalidArgumentError: If the bounded sensitivity exceeds
the maximum noise scale defined in `self.dg_params`.
"""

# The sensitivity of the PostScale protocol is the maximum L2 norm of
# the per-example gradients (computed with tf.jacobian) multiplied by
# sqrt(2).
sensitivity = tf.cast(max_two_norm, dtype=tf.float32) * tf.sqrt(2.0)

sensitivity = tf.cast(max_two_norm, dtype=tf.float32)
sensitivity *= self.noise_multiplier

# Ensure the sensitivity is at least as large as the base scale, since
Expand All @@ -708,6 +715,15 @@ def compute_noise_factors(
# same amount.
scaled_sensitivity = bounded_sensitivity * sf

# The noise protocol generates samples in the range [-6s, 6s], where
# s is the scale. Ensure the maximum noise is less than the
# plaintext modulus.
tf.assert_less(
scaled_sensitivity * 6,
tf.cast(context.plaintext_modulus, tf.float32),
message="Sensitivity is too large for the noise context's plaintext modulus and WILL overflow. Reduce the sensitivity by reducing the gradient norms (e.g. reducing the backprop context's scaling factor), or increase the noise context's plaintext modulus.",
)

tf.assert_less(
scaled_sensitivity,
self.dg_params.max_scale,
Expand Down Expand Up @@ -848,6 +864,14 @@ def shell_train_step(self, features, labels, read_key_from_cache, apply_gradient
else:
backprop_scaling_factors = [g._scaling_factor for g in grads]

# Simple correctness test to check if the gradients are smaller
# than the plaintext modulus.
tf.assert_less(
max_two_norm,
tf.cast(backprop_context.plaintext_modulus, tf.float32),
message="Gradient may be too large for the backprop context's plaintext modulus. Reduce the sensitivity by reducing the gradient norms (e.g. reducing the backprop scaling factor), or increase the backprop context's plaintext modulus.",
)

# Check if the backproped gradients overflowed.
if not self.disable_encryption and self.check_overflow_INSECURE:
# Note, checking the gradients requires decryption on the
Expand Down

0 comments on commit 00a3450

Please sign in to comment.