Skip to content

Commit

Permalink
Fix sensitivity analysis for DP-SGD.
Browse files Browse the repository at this point in the history
To perform pt*ct operations, tf-shell encodes the pt to the ct's original scaling factor. Original code assumed it was to the ct's current scaling factor.
  • Loading branch information
james-choncholas committed Jan 20, 2025
1 parent 137f9cd commit b97a62d
Show file tree
Hide file tree
Showing 14 changed files with 49 additions and 76 deletions.
27 changes: 10 additions & 17 deletions tf_shell_ml/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,25 +155,26 @@ def call(self, inputs, training=False, split_forward_mode=False):

def backward(self, dy, rotation_key=None, sensitivity_analysis_factor=None):
"""Compute the gradient."""
kernel = tf.identity(self.weights[0])

if sensitivity_analysis_factor is not None:
# When performing sensitivity analysis, use the most recent
# intermediate state.
x = self._layer_input[-1]
z = self._layer_intermediate[-1]

# To perform sensitivity analysis, assume the worst case rounding for
# the intermediate state, dictated by the sensitivity_analysis_factor.
x = tf_shell.worst_case_rounding(x, sensitivity_analysis_factor)
z = tf_shell.worst_case_rounding(z, sensitivity_analysis_factor)
kernel = tf_shell.worst_case_rounding(kernel, sensitivity_analysis_factor)
else:
x = tf.concat([tf.identity(x) for x in self._layer_input], axis=0)
z = tf.concat([tf.identity(z) for z in self._layer_intermediate], axis=0)
kernel = tf.identity(self.weights[0])

grad_weights = []
batch_size = tf.shape(x)[0]

# To perform sensitivity analysis, assume the worst case rounding for
# the intermediate state, dictated by the sensitivity_analysis_factor.
if sensitivity_analysis_factor is not None:
x = tf_shell.worst_case_rounding(x, sensitivity_analysis_factor)
z = tf_shell.worst_case_rounding(z, sensitivity_analysis_factor)
kernel = tf_shell.worst_case_rounding(kernel, sensitivity_analysis_factor)

# On the forward pass, x may be batched differently than the
# ciphertext scheme. Pad them to match the ciphertext scheme.
if isinstance(dy, tf_shell.ShellTensor64):
Expand Down Expand Up @@ -224,15 +225,7 @@ def backward(self, dy, rotation_key=None, sensitivity_analysis_factor=None):

grad_weights.append(d_w)

# Both dw an dx have a mul depth of 1 in this function. Thus, the
# sensitivity factor is squared.
new_sensitivity_analysis_factor = (
sensitivity_analysis_factor**2
if sensitivity_analysis_factor is not None
else None
)

return grad_weights, d_x, new_sensitivity_analysis_factor
return grad_weights, d_x

@staticmethod
def unpack(plaintext_packed_dx):
Expand Down
28 changes: 11 additions & 17 deletions tf_shell_ml/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,23 +119,25 @@ def call(self, inputs, training=False, split_forward_mode=False):

def backward(self, dy, rotation_key=None, sensitivity_analysis_factor=None):
"""dense backward"""
kernel = tf.identity(self.weights[0])

if sensitivity_analysis_factor is not None:
# When performing sensitivity analysis, use the most recent
# intermediate state.
x = self._layer_input[-1]
z = self._layer_intermediate[-1]

# To perform sensitivity analysis, assume the worst case rounding
# for the intermediate state, dictated by the
# sensitivity_analysis_factor.
x = tf_shell.worst_case_rounding(x, sensitivity_analysis_factor)
z = tf_shell.worst_case_rounding(z, sensitivity_analysis_factor)
kernel = tf_shell.worst_case_rounding(kernel, sensitivity_analysis_factor)
else:
x = tf.concat([tf.identity(x) for x in self._layer_input], axis=0)
z = tf.concat([tf.identity(z) for z in self._layer_intermediate], axis=0)
kernel = tf.identity(self.weights[0])
d_ws = []

# To perform sensitivity analysis, assume the worst case rounding for
# the intermediate state, dictated by the sensitivity_analysis_factor.
if sensitivity_analysis_factor is not None:
x = tf_shell.worst_case_rounding(x, sensitivity_analysis_factor)
x = tf_shell.worst_case_rounding(x, sensitivity_analysis_factor)
kernel = tf_shell.worst_case_rounding(kernel, sensitivity_analysis_factor)
d_ws = []

# On the forward pass, inputs may be batched differently than the
# ciphertext scheme when not in eager mode. Pad them to match the
Expand Down Expand Up @@ -176,15 +178,7 @@ def backward(self, dy, rotation_key=None, sensitivity_analysis_factor=None):

d_ws.append(d_bias)

# Both dw an dx have a mul depth of 1 in this function. Thus, the
# sensitivity factor is squared.
new_sensitivity_analysis_factor = (
sensitivity_analysis_factor**2
if sensitivity_analysis_factor is not None
else None
)

return d_ws, d_x, new_sensitivity_analysis_factor
return d_ws, d_x

@staticmethod
def unpack(plaintext_packed_dx):
Expand Down
6 changes: 2 additions & 4 deletions tf_shell_ml/dpsgd_sequential_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,12 @@ def _backward(self, dJ_dz, sensitivity_analysis_factor=None):
# last layer pre-activation.
dJ_dw = [] # Derivatives of the loss with respect to the weights.
dJ_dx = [dJ_dz] # Derivatives of the loss with respect to the inputs.
sa_factor = sensitivity_analysis_factor
for l in reversed(self.layers):
dw, dx, new_sa_factor = l.backward(
dJ_dx[-1], sensitivity_analysis_factor=sa_factor
dw, dx = l.backward(
dJ_dx[-1], sensitivity_analysis_factor=sensitivity_analysis_factor
)
dJ_dw.extend(dw)
dJ_dx.append(dx)
sa_factor = new_sa_factor

return [g for g in reversed(dJ_dw)]

Expand Down
18 changes: 9 additions & 9 deletions tf_shell_ml/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,18 +87,18 @@ def backward(self, dy, rotation_key=None, sensitivity_analysis_factor=None):
# When performing sensitivity analysis, use the most recent
# intermediate state.
dropout_mask = self._layer_intermediate[-1]

# To perform sensitivity analysis, assume the worst case rounding
# for the intermediate state, dictated by the
# sensitivity_analysis_factor.
# Note: The dropout mask is not necessarily 0 or 1.
dropout_mask = tf_shell.worst_case_rounding(
dropout_mask, sensitivity_analysis_factor
)
else:
dropout_mask = tf.concat(
[tf.identity(z) for z in self._layer_intermediate], axis=0
)

# Both dx has a mul depth of 1 in this function. Thus, the sensitivity
# factor is squared.
new_sensitivity_analysis_factor = (
sensitivity_analysis_factor**2
if sensitivity_analysis_factor is not None
else None
)

d_x = dy * dropout_mask
return [], d_x, new_sensitivity_analysis_factor
return [], d_x
5 changes: 1 addition & 4 deletions tf_shell_ml/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,7 @@ def backward(self, dy, rotation_key=None, sensitivity_analysis_factor=None):
reduction=self.grad_reduction,
)

# The output of this function has the same scaling factor as the input.
new_sensitivity_analysis_factor = sensitivity_analysis_factor

return [summedvalues], tf.zeros(0), new_sensitivity_analysis_factor
return [summedvalues], None

@staticmethod
def unpack(plaintext_packed_dx):
Expand Down
5 changes: 1 addition & 4 deletions tf_shell_ml/flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,4 @@ def backward(self, dy, rotation_key=None, sensitivity_analysis_factor=None):
dw = []
dx = tf_shell.reshape(dy, new_shape)

# The output of this function has the same scaling factor as the input.
new_sensitivity_analysis_factor = sensitivity_analysis_factor

return dw, dx, new_sensitivity_analysis_factor
return dw, dx
5 changes: 1 addition & 4 deletions tf_shell_ml/globalaveragepool1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,7 @@ def backward(self, dy, rotation_key=None, sensitivity_analysis_factor=None):
dx, (tf_shell.shape(dx)[0], avg_dim, tf_shell.shape(dx)[2])
)

# The output of this function has the same scaling factor as the input.
new_sensitivity_analysis_factor = sensitivity_analysis_factor

return [], dx, new_sensitivity_analysis_factor
return [], dx

@staticmethod
def unpack(packed_dw):
Expand Down
5 changes: 1 addition & 4 deletions tf_shell_ml/max_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,7 @@ def backward(self, dy, rotation_key=None, sensitivity_analysis_factor=None):
output_shape=input_shape,
)

# The output of this function has the same scaling factor as the input.
new_sensitivity_analysis_factor = sensitivity_analysis_factor

return grad_weights, d_x, new_sensitivity_analysis_factor
return grad_weights, d_x

def unpacking_funcs(self):
return []
6 changes: 3 additions & 3 deletions tf_shell_ml/test/conv2d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _test_conv2d_plaintext_forward_backward_correct(
self.assertAllClose(y, tf_y)

# Next check backward pass.
dws, dx, _ = conv_layer.backward(tf.ones_like(y), rotation_key)
dws, dx = conv_layer.backward(tf.ones_like(y), rotation_key)
with tf.GradientTape(persistent=True) as tape:
tape.watch(im)
y = tf_conv_layer(im)
Expand Down Expand Up @@ -130,13 +130,13 @@ def forward_backward(x):
enc_dy = tf_shell.to_encrypted(dy, key, context)

# Encrypted backward pass.
enc_dw, enc_dx, _ = conv_layer.backward(enc_dy, rotation_key)
enc_dw, enc_dx = conv_layer.backward(enc_dy, rotation_key)
dw = tf_shell.to_tensorflow(enc_dw[0], key)
# dw = conv_layer.unpack(dw) # for layer reduction 'fast' or 'galois'
dx = tf_shell.to_tensorflow(enc_dx, key)

# Plaintext backward pass.
pt_dws, pt_dx, _ = conv_layer.backward(dy, None)
pt_dws, pt_dx = conv_layer.backward(dy, None)
pt_dw = pt_dws[0] # No unpack required for pt.

return dw, dx, dw.shape, dx.shape, pt_dw, pt_dx
Expand Down
4 changes: 2 additions & 2 deletions tf_shell_ml/test/dropout_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ def _test_dropout_back(self, per_batch):
notrain_y = dropout_layer.call(x, training=True)
dy = tf.ones_like(notrain_y)

dw, dx, _ = dropout_layer.backward(dy, None)
dw, dx = dropout_layer.backward(dy)

enc_dy = tf_shell.to_encrypted(dy, key, context)
enc_dw, enc_dx, _ = dropout_layer.backward(enc_dy, None)
enc_dw, enc_dx = dropout_layer.backward(enc_dy)
dec_dx = tf_shell.to_tensorflow(enc_dx, key)

self.assertEmpty(dw)
Expand Down
2 changes: 1 addition & 1 deletion tf_shell_ml/test/embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def forward_backward(x):
dy = tf.ones_like(y)
enc_dy = tf_shell.to_encrypted(dy, key, context)

enc_dw, _, _ = embedding_layer.backward(enc_dy, rotation_key)
enc_dw, _ = embedding_layer.backward(enc_dy, rotation_key)
dw = tf_shell.to_tensorflow(enc_dw[0], key)
if reduction == "none":
dw = tf.reduce_sum(dw, axis=0)
Expand Down
6 changes: 3 additions & 3 deletions tf_shell_ml/test/max_pool2d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _test_max_pool2d_plaintext_forward_backward_correct(
self.assertAllClose(y, tf_y)

# Next check backward pass.
_, dx, _ = layer.backward(tf.ones_like(y), rotation_key)
_, dx = layer.backward(tf.ones_like(y), rotation_key)
with tf.GradientTape(persistent=True) as tape:
tape.watch(im)
y = tf_layer(im)
Expand Down Expand Up @@ -107,11 +107,11 @@ def forward_backward(x):
enc_dy = tf_shell.to_encrypted(dy, key, context)

# Encrypted backward pass.
_, enc_dx, _ = layer.backward(enc_dy, rotation_key)
_, enc_dx = layer.backward(enc_dy, rotation_key)
dx = tf_shell.to_tensorflow(enc_dx, key)

# Plaintext backward pass.
_, pt_dx, _ = layer.backward(dy, None)
_, pt_dx = layer.backward(dy)

return dx, dx.shape, pt_dx

Expand Down
4 changes: 2 additions & 2 deletions tf_shell_ml/test/mnist_enc_backprop_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ def train_step(x, y, hidden_layer, output_layer):
# Backward pass.
dJ_dy_pred = y.__rsub__(y_pred) # Derivative of CCE loss and softmax.

dJ_dw1, dJ_dx1, _ = output_layer.backward(dJ_dy_pred, rotation_key)
dJ_dw1, dJ_dx1 = output_layer.backward(dJ_dy_pred, rotation_key)

# Mod reduce will reduce noise but increase the plaintext error.
# if isinstance(dJ_dx1, tf_shell.ShellTensor64):
# dJ_dx1.get_mod_reduced()

dJ_dw0, dJ_dx0_unused, _ = hidden_layer.backward(dJ_dx1, rotation_key)
dJ_dw0, dJ_dx0_unused = hidden_layer.backward(dJ_dx1, rotation_key)

# Only return the weight gradients at [0], not the bias gradients at [1].
return dJ_dw1[0], dJ_dw0[0]
Expand Down
4 changes: 2 additions & 2 deletions tf_shell_ml/test/mnist_noenc_backprop_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def train_step(x, y):
# Backward pass.
dJ_dy_pred = y.__rsub__(y_pred) # Derivative of CCE loss and softmax.

dJ_dw1, dJ_dx1, _ = output_layer.backward(dJ_dy_pred, None)
dJ_dw1, dJ_dx1 = output_layer.backward(dJ_dy_pred)

dJ_dw0, dJ_dx0_unused, _ = hidden_layer.backward(dJ_dx1, None)
dJ_dw0, dJ_dx0_unused = hidden_layer.backward(dJ_dx1)

return dJ_dw1, dJ_dw0

Expand Down

0 comments on commit b97a62d

Please sign in to comment.