Skip to content

Commit

Permalink
Upper bound on grad l2 norm takes quantization error into account.
Browse files Browse the repository at this point in the history
Note: This is at the expense of high memory requirements, especially for the PostScaleSequential model.
  • Loading branch information
james-choncholas committed Jan 18, 2025
1 parent 00a3450 commit 092ff42
Show file tree
Hide file tree
Showing 10 changed files with 262 additions and 150 deletions.
20 changes: 18 additions & 2 deletions tf_shell_ml/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,22 @@ def call(self, inputs, training=False, split_forward_mode=False):

return outputs

def backward(self, dy, rotation_key=None):
def backward(self, dy, rotation_key=None, sensitivity_analysis_factor=None):
"""Compute the gradient."""
x = tf.concat([tf.identity(x) for x in self._layer_input], axis=0)
z = tf.concat([tf.identity(z) for z in self._layer_intermediate], axis=0)
kernel = tf.identity(self.weights[0])
grad_weights = []
batch_size = tf.shape(x)[0] // 2

# To perform sensitivity analysis, add a small perturbation to the
# intermediate state, dictated by the sensitivity_analysis_factor.
# The perturbation has the same sign as the underlying value.
if sensitivity_analysis_factor is not None:
x = x + (tf.sign(x) / sensitivity_analysis_factor)
z = z + (tf.sign(z) / sensitivity_analysis_factor)
kernel = kernel + (tf.sign(kernel) / sensitivity_analysis_factor)

# On the forward pass, x may be batched differently than the
# ciphertext scheme. Pad them to match the ciphertext scheme.
if isinstance(dy, tf_shell.ShellTensor64):
Expand Down Expand Up @@ -204,7 +212,15 @@ def backward(self, dy, rotation_key=None):

grad_weights.append(d_w)

return grad_weights, d_x
# Both dw an dx have a mul depth of 1 in this function. Thus, the
# sensitivity factor is squared.
new_sensitivity_analysis_factor = (
sensitivity_analysis_factor**2
if sensitivity_analysis_factor is not None
else None
)

return grad_weights, d_x, new_sensitivity_analysis_factor

@staticmethod
def unpack(plaintext_packed_dx):
Expand Down
21 changes: 18 additions & 3 deletions tf_shell_ml/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,21 @@ def call(self, inputs, training=False, split_forward_mode=False):

return outputs

def backward(self, dy, rotation_key=None):
def backward(self, dy, rotation_key=None, sensitivity_analysis_factor=None):
"""dense backward"""
x = tf.concat([tf.identity(x) for x in self._layer_input], axis=0)
z = tf.concat([tf.identity(z) for z in self._layer_intermediate], axis=0)
self._layer_intermediate = []
kernel = tf.identity(self.weights[0])
d_ws = []

# To perform sensitivity analysis, add a small perturbation to the
# intermediate state, dictated by the sensitivity_analysis_factor.
# The perturbation has the same sign as the underlying value.
if sensitivity_analysis_factor is not None:
x = x + (tf.sign(x) / sensitivity_analysis_factor)
z = z + (tf.sign(z) / sensitivity_analysis_factor)
kernel = kernel + (tf.sign(kernel) / sensitivity_analysis_factor)

# On the forward pass, inputs may be batched differently than the
# ciphertext scheme when not in eager mode. Pad them to match the
# ciphertext scheme.
Expand Down Expand Up @@ -158,7 +165,15 @@ def backward(self, dy, rotation_key=None):

d_ws.append(d_bias)

return d_ws, d_x
# Both dw an dx have a mul depth of 1 in this function. Thus, the
# sensitivity factor is squared.
new_sensitivity_analysis_factor = (
sensitivity_analysis_factor**2
if sensitivity_analysis_factor is not None
else None
)

return d_ws, d_x, new_sensitivity_analysis_factor

@staticmethod
def unpack(plaintext_packed_dx):
Expand Down
77 changes: 62 additions & 15 deletions tf_shell_ml/dpsgd_sequential_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,27 @@ def call(self, features, training=False, with_softmax=True):
# purposes.
return tf.nn.softmax(predictions)

def _backward(self, dJ_dz, sensitivity_analysis_factor=None):
# Backward pass. dJ_dz is the derivative of the loss with respect to the
# last layer pre-activation.
dJ_dw = [] # Derivatives of the loss with respect to the weights.
dJ_dx = [dJ_dz] # Derivatives of the loss with respect to the inputs.
sa_factor = sensitivity_analysis_factor
for l in reversed(self.layers):
dw, dx, new_sa_factor = l.backward(
dJ_dx[-1], sensitivity_analysis_factor=sa_factor
)
dJ_dw.extend(dw)
dJ_dx.append(dx)
sa_factor = new_sa_factor

return [g for g in reversed(dJ_dw)]

def compute_grads(self, features, enc_labels):
scaling_factor = (
enc_labels.scaling_factor if hasattr(enc_labels, "scaling_factor") else 1
)

# Reset layers for forward pass over multiple devices.
for l in self.layers:
l.reset_split_forward_mode()
Expand All @@ -67,29 +87,56 @@ def compute_grads(self, features, enc_labels):
for i, d in enumerate(self.jacobian_devices):
with tf.device(d):
f = tf.identity(split_features[i]) # copy to GPU if needed
prediction, jacobians = self.predict_and_jacobian(f)

# First compute the real prediction.
prediction = self.call(f, training=True, with_softmax=True)

# prediction, jacobians = self.TEST_predict_and_jacobian(f) # TEST

if i == len(self.jacobian_devices) - 1 and end_pad > 0:
# The last device's features may have been padded for even
# split jacobian computation across multiple devices.
# The last device's features may have been padded.
prediction = prediction[:-end_pad]
jacobians = jacobians[:-end_pad]
predictions_list.append(prediction)
max_two_norms_list.append(self.jacobian_max_two_norm(jacobians))

# Next perform the sensitivity analysis. Straightforward
# backpropagation has mul/add depth proportional to the number
# of layers and the encoding error accumulates through each op.
# It is difficult to tightly bound and easier to simply compute.
#
# Perform backpropagation (in plaintext) for every possible
# label using the worst casse quantization of weights.
sensitivity = tf.constant(0.0, dtype=tf.keras.backend.floatx())
worst_case_prediction = prediction + (
tf.sign(prediction) / scaling_factor
)

for possible_label_i in range(self.out_classes):
possible_label = tf.one_hot(
possible_label_i,
self.out_classes,
dtype=tf.keras.backend.floatx(),
)
dJ_dz = (
worst_case_prediction - possible_label
) # Broadcasts to batch subset.
possible_grads = self._backward(
dJ_dz, sensitivity_analysis_factor=scaling_factor
)

max_norm = self.max_per_example_global_norm(possible_grads)
sensitivity = tf.maximum(sensitivity, max_norm)

max_two_norms_list.append(sensitivity)

with tf.device(self.features_party_dev):
predictions = tf.concat(predictions_list, axis=0)
max_two_norm = tf.reduce_max(max_two_norms_list)

dJ_dz = enc_labels.__rsub__(
predictions
) # Derivative of CCE loss and softmax.

# Backward pass.
dx = enc_labels.__rsub__(predictions) # Derivative of CCE loss and softmax.
dJ_dw = [] # Derivatives of the loss with respect to the weights.
dJ_dx = [dx] # Derivatives of the loss with respect to the inputs.
for l in reversed(self.layers):
dw, dx = l.backward(dJ_dx[-1])
dJ_dw.extend(dw)
dJ_dx.append(dx)

# Reverse the order to match the order of the layers.
grads = [g for g in reversed(dJ_dw)]
grads = self._backward(dJ_dz)

return grads, max_two_norm, predictions
13 changes: 11 additions & 2 deletions tf_shell_ml/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,18 @@ def call(self, inputs, training=False, split_forward_mode=False):
output = inputs * dropout_mask
return output

def backward(self, dy, rotation_key=None):
def backward(self, dy, rotation_key=None, sensitivity_analysis_factor=None):
dropout_mask = tf.concat(
[tf.identity(z) for z in self._layer_intermediate], axis=0
)

# Both dx has a mul depth of 1 in this function. Thus, the sensitivity
# factor is squared.
new_sensitivity_analysis_factor = (
sensitivity_analysis_factor**2
if sensitivity_analysis_factor is not None
else None
)

d_x = dy * dropout_mask
return [], d_x
return [], d_x, new_sensitivity_analysis_factor
5 changes: 4 additions & 1 deletion tf_shell_ml/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,10 @@ def backward(self, dy, rotation_key=None):
reduction=self.grad_reduction,
)

return [summedvalues], tf.zeros(0)
# The output of this function has the same scaling factor as the input.
new_sensitivity_analysis_factor = sensitivity_analysis_factor

return [summedvalues], tf.zeros(0), new_sensitivity_analysis_factor

@staticmethod
def unpack(plaintext_packed_dx):
Expand Down
8 changes: 6 additions & 2 deletions tf_shell_ml/flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def call(self, inputs, training=False, split_forward_mode=False):
self.batch_size = tf.shape(inputs)[0]
return tf.reshape(inputs, [self.batch_size] + self.flat_shape)

def backward(self, dy, rotation_key=None):
def backward(self, dy, rotation_key=None, sensitivity_analysis_factor=None):
new_shape = list(self.input_shape)
# On the forward pass, inputs may be batched differently than the
# ciphertext scheme when not in eager mode. Pad them to match the
Expand All @@ -46,4 +46,8 @@ def backward(self, dy, rotation_key=None):
new_shape[0] = dy.shape[0]
dw = []
dx = tf_shell.reshape(dy, new_shape)
return dw, dx

# The output of this function has the same scaling factor as the input.
new_sensitivity_analysis_factor = sensitivity_analysis_factor

return dw, dx, new_sensitivity_analysis_factor
8 changes: 6 additions & 2 deletions tf_shell_ml/globalaveragepool1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,17 @@ def call(self, inputs, training=False, split_forward_mode=False):

return outputs

def backward(self, dy, rotation_key=None):
def backward(self, dy, rotation_key=None, sensitivity_analysis_factor=None):
avg_dim = tf.identity(self._layer_intermediate)
dx = tf_shell.expand_dims(dy, axis=1)
dx = tf_shell.broadcast_to(
dx, (tf_shell.shape(dx)[0], avg_dim, tf_shell.shape(dx)[2])
)
return [], dx

# The output of this function has the same scaling factor as the input.
new_sensitivity_analysis_factor = sensitivity_analysis_factor

return [], dx, new_sensitivity_analysis_factor

@staticmethod
def unpack(packed_dw):
Expand Down
7 changes: 5 additions & 2 deletions tf_shell_ml/max_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def call(self, inputs, training=False, split_forward_mode=False):

return outputs

def backward(self, dy, rotation_key=None):
def backward(self, dy, rotation_key=None, sensitivity_analysis_factor=None):
"""Compute the gradient."""
indices = tf.concat([tf.identity(z) for z in self._layer_intermediate], axis=0)
grad_weights = []
Expand All @@ -114,7 +114,10 @@ def backward(self, dy, rotation_key=None):
output_shape=self._layer_input_shape,
)

return grad_weights, d_x
# The output of this function has the same scaling factor as the input.
new_sensitivity_analysis_factor = sensitivity_analysis_factor

return grad_weights, d_x, new_sensitivity_analysis_factor

def unpacking_funcs(self):
return []
Loading

0 comments on commit 092ff42

Please sign in to comment.