diff --git a/tf_shell_ml/conv2d.py b/tf_shell_ml/conv2d.py index 1a9bf5a..a8cb932 100644 --- a/tf_shell_ml/conv2d.py +++ b/tf_shell_ml/conv2d.py @@ -147,7 +147,7 @@ def call(self, inputs, training=False, split_forward_mode=False): return outputs - def backward(self, dy, rotation_key=None): + def backward(self, dy, rotation_key=None, sensitivity_analysis_factor=None): """Compute the gradient.""" x = tf.concat([tf.identity(x) for x in self._layer_input], axis=0) z = tf.concat([tf.identity(z) for z in self._layer_intermediate], axis=0) @@ -155,6 +155,14 @@ def backward(self, dy, rotation_key=None): grad_weights = [] batch_size = tf.shape(x)[0] // 2 + # To perform sensitivity analysis, add a small perturbation to the + # intermediate state, dictated by the sensitivity_analysis_factor. + # The perturbation has the same sign as the underlying value. + if sensitivity_analysis_factor is not None: + x = x + (tf.sign(x) / sensitivity_analysis_factor) + z = z + (tf.sign(z) / sensitivity_analysis_factor) + kernel = kernel + (tf.sign(kernel) / sensitivity_analysis_factor) + # On the forward pass, x may be batched differently than the # ciphertext scheme. Pad them to match the ciphertext scheme. if isinstance(dy, tf_shell.ShellTensor64): @@ -204,7 +212,15 @@ def backward(self, dy, rotation_key=None): grad_weights.append(d_w) - return grad_weights, d_x + # Both dw an dx have a mul depth of 1 in this function. Thus, the + # sensitivity factor is squared. + new_sensitivity_analysis_factor = ( + sensitivity_analysis_factor**2 + if sensitivity_analysis_factor is not None + else None + ) + + return grad_weights, d_x, new_sensitivity_analysis_factor @staticmethod def unpack(plaintext_packed_dx): diff --git a/tf_shell_ml/dense.py b/tf_shell_ml/dense.py index 3b817ed..4472a91 100644 --- a/tf_shell_ml/dense.py +++ b/tf_shell_ml/dense.py @@ -111,14 +111,21 @@ def call(self, inputs, training=False, split_forward_mode=False): return outputs - def backward(self, dy, rotation_key=None): + def backward(self, dy, rotation_key=None, sensitivity_analysis_factor=None): """dense backward""" x = tf.concat([tf.identity(x) for x in self._layer_input], axis=0) z = tf.concat([tf.identity(z) for z in self._layer_intermediate], axis=0) - self._layer_intermediate = [] kernel = tf.identity(self.weights[0]) d_ws = [] + # To perform sensitivity analysis, add a small perturbation to the + # intermediate state, dictated by the sensitivity_analysis_factor. + # The perturbation has the same sign as the underlying value. + if sensitivity_analysis_factor is not None: + x = x + (tf.sign(x) / sensitivity_analysis_factor) + z = z + (tf.sign(z) / sensitivity_analysis_factor) + kernel = kernel + (tf.sign(kernel) / sensitivity_analysis_factor) + # On the forward pass, inputs may be batched differently than the # ciphertext scheme when not in eager mode. Pad them to match the # ciphertext scheme. @@ -158,7 +165,15 @@ def backward(self, dy, rotation_key=None): d_ws.append(d_bias) - return d_ws, d_x + # Both dw an dx have a mul depth of 1 in this function. Thus, the + # sensitivity factor is squared. + new_sensitivity_analysis_factor = ( + sensitivity_analysis_factor**2 + if sensitivity_analysis_factor is not None + else None + ) + + return d_ws, d_x, new_sensitivity_analysis_factor @staticmethod def unpack(plaintext_packed_dx): diff --git a/tf_shell_ml/dpsgd_sequential_model.py b/tf_shell_ml/dpsgd_sequential_model.py index 873b508..1b7e655 100644 --- a/tf_shell_ml/dpsgd_sequential_model.py +++ b/tf_shell_ml/dpsgd_sequential_model.py @@ -51,7 +51,27 @@ def call(self, features, training=False, with_softmax=True): # purposes. return tf.nn.softmax(predictions) + def _backward(self, dJ_dz, sensitivity_analysis_factor=None): + # Backward pass. dJ_dz is the derivative of the loss with respect to the + # last layer pre-activation. + dJ_dw = [] # Derivatives of the loss with respect to the weights. + dJ_dx = [dJ_dz] # Derivatives of the loss with respect to the inputs. + sa_factor = sensitivity_analysis_factor + for l in reversed(self.layers): + dw, dx, new_sa_factor = l.backward( + dJ_dx[-1], sensitivity_analysis_factor=sa_factor + ) + dJ_dw.extend(dw) + dJ_dx.append(dx) + sa_factor = new_sa_factor + + return [g for g in reversed(dJ_dw)] + def compute_grads(self, features, enc_labels): + scaling_factor = ( + enc_labels.scaling_factor if hasattr(enc_labels, "scaling_factor") else 1 + ) + # Reset layers for forward pass over multiple devices. for l in self.layers: l.reset_split_forward_mode() @@ -67,29 +87,56 @@ def compute_grads(self, features, enc_labels): for i, d in enumerate(self.jacobian_devices): with tf.device(d): f = tf.identity(split_features[i]) # copy to GPU if needed - prediction, jacobians = self.predict_and_jacobian(f) + + # First compute the real prediction. + prediction = self.call(f, training=True, with_softmax=True) + + # prediction, jacobians = self.TEST_predict_and_jacobian(f) # TEST + if i == len(self.jacobian_devices) - 1 and end_pad > 0: - # The last device's features may have been padded for even - # split jacobian computation across multiple devices. + # The last device's features may have been padded. prediction = prediction[:-end_pad] - jacobians = jacobians[:-end_pad] predictions_list.append(prediction) - max_two_norms_list.append(self.jacobian_max_two_norm(jacobians)) + + # Next perform the sensitivity analysis. Straightforward + # backpropagation has mul/add depth proportional to the number + # of layers and the encoding error accumulates through each op. + # It is difficult to tightly bound and easier to simply compute. + # + # Perform backpropagation (in plaintext) for every possible + # label using the worst casse quantization of weights. + sensitivity = tf.constant(0.0, dtype=tf.keras.backend.floatx()) + worst_case_prediction = prediction + ( + tf.sign(prediction) / scaling_factor + ) + + for possible_label_i in range(self.out_classes): + possible_label = tf.one_hot( + possible_label_i, + self.out_classes, + dtype=tf.keras.backend.floatx(), + ) + dJ_dz = ( + worst_case_prediction - possible_label + ) # Broadcasts to batch subset. + possible_grads = self._backward( + dJ_dz, sensitivity_analysis_factor=scaling_factor + ) + + max_norm = self.max_per_example_global_norm(possible_grads) + sensitivity = tf.maximum(sensitivity, max_norm) + + max_two_norms_list.append(sensitivity) with tf.device(self.features_party_dev): predictions = tf.concat(predictions_list, axis=0) max_two_norm = tf.reduce_max(max_two_norms_list) + dJ_dz = enc_labels.__rsub__( + predictions + ) # Derivative of CCE loss and softmax. + # Backward pass. - dx = enc_labels.__rsub__(predictions) # Derivative of CCE loss and softmax. - dJ_dw = [] # Derivatives of the loss with respect to the weights. - dJ_dx = [dx] # Derivatives of the loss with respect to the inputs. - for l in reversed(self.layers): - dw, dx = l.backward(dJ_dx[-1]) - dJ_dw.extend(dw) - dJ_dx.append(dx) - - # Reverse the order to match the order of the layers. - grads = [g for g in reversed(dJ_dw)] + grads = self._backward(dJ_dz) return grads, max_two_norm, predictions diff --git a/tf_shell_ml/dropout.py b/tf_shell_ml/dropout.py index f51349e..ca72f51 100644 --- a/tf_shell_ml/dropout.py +++ b/tf_shell_ml/dropout.py @@ -82,9 +82,18 @@ def call(self, inputs, training=False, split_forward_mode=False): output = inputs * dropout_mask return output - def backward(self, dy, rotation_key=None): + def backward(self, dy, rotation_key=None, sensitivity_analysis_factor=None): dropout_mask = tf.concat( [tf.identity(z) for z in self._layer_intermediate], axis=0 ) + + # Both dx has a mul depth of 1 in this function. Thus, the sensitivity + # factor is squared. + new_sensitivity_analysis_factor = ( + sensitivity_analysis_factor**2 + if sensitivity_analysis_factor is not None + else None + ) + d_x = dy * dropout_mask - return [], d_x + return [], d_x, new_sensitivity_analysis_factor diff --git a/tf_shell_ml/embedding.py b/tf_shell_ml/embedding.py index c0d2bc4..e0b1fdc 100644 --- a/tf_shell_ml/embedding.py +++ b/tf_shell_ml/embedding.py @@ -134,7 +134,10 @@ def backward(self, dy, rotation_key=None): reduction=self.grad_reduction, ) - return [summedvalues], tf.zeros(0) + # The output of this function has the same scaling factor as the input. + new_sensitivity_analysis_factor = sensitivity_analysis_factor + + return [summedvalues], tf.zeros(0), new_sensitivity_analysis_factor @staticmethod def unpack(plaintext_packed_dx): diff --git a/tf_shell_ml/flatten.py b/tf_shell_ml/flatten.py index ac3c35f..f0debfd 100644 --- a/tf_shell_ml/flatten.py +++ b/tf_shell_ml/flatten.py @@ -35,7 +35,7 @@ def call(self, inputs, training=False, split_forward_mode=False): self.batch_size = tf.shape(inputs)[0] return tf.reshape(inputs, [self.batch_size] + self.flat_shape) - def backward(self, dy, rotation_key=None): + def backward(self, dy, rotation_key=None, sensitivity_analysis_factor=None): new_shape = list(self.input_shape) # On the forward pass, inputs may be batched differently than the # ciphertext scheme when not in eager mode. Pad them to match the @@ -46,4 +46,8 @@ def backward(self, dy, rotation_key=None): new_shape[0] = dy.shape[0] dw = [] dx = tf_shell.reshape(dy, new_shape) - return dw, dx + + # The output of this function has the same scaling factor as the input. + new_sensitivity_analysis_factor = sensitivity_analysis_factor + + return dw, dx, new_sensitivity_analysis_factor diff --git a/tf_shell_ml/globalaveragepool1d.py b/tf_shell_ml/globalaveragepool1d.py index adf995f..f205992 100644 --- a/tf_shell_ml/globalaveragepool1d.py +++ b/tf_shell_ml/globalaveragepool1d.py @@ -38,13 +38,17 @@ def call(self, inputs, training=False, split_forward_mode=False): return outputs - def backward(self, dy, rotation_key=None): + def backward(self, dy, rotation_key=None, sensitivity_analysis_factor=None): avg_dim = tf.identity(self._layer_intermediate) dx = tf_shell.expand_dims(dy, axis=1) dx = tf_shell.broadcast_to( dx, (tf_shell.shape(dx)[0], avg_dim, tf_shell.shape(dx)[2]) ) - return [], dx + + # The output of this function has the same scaling factor as the input. + new_sensitivity_analysis_factor = sensitivity_analysis_factor + + return [], dx, new_sensitivity_analysis_factor @staticmethod def unpack(packed_dw): diff --git a/tf_shell_ml/max_pool2d.py b/tf_shell_ml/max_pool2d.py index 32ee3e5..4826575 100644 --- a/tf_shell_ml/max_pool2d.py +++ b/tf_shell_ml/max_pool2d.py @@ -88,7 +88,7 @@ def call(self, inputs, training=False, split_forward_mode=False): return outputs - def backward(self, dy, rotation_key=None): + def backward(self, dy, rotation_key=None, sensitivity_analysis_factor=None): """Compute the gradient.""" indices = tf.concat([tf.identity(z) for z in self._layer_intermediate], axis=0) grad_weights = [] @@ -114,7 +114,10 @@ def backward(self, dy, rotation_key=None): output_shape=self._layer_input_shape, ) - return grad_weights, d_x + # The output of this function has the same scaling factor as the input. + new_sensitivity_analysis_factor = sensitivity_analysis_factor + + return grad_weights, d_x, new_sensitivity_analysis_factor def unpacking_funcs(self): return [] diff --git a/tf_shell_ml/model_base.py b/tf_shell_ml/model_base.py index db24de7..ddc1db1 100644 --- a/tf_shell_ml/model_base.py +++ b/tf_shell_ml/model_base.py @@ -19,7 +19,7 @@ import tf_shell import time import gc -from tf_shell_ml import large_tensor +import tf_shell_ml class SequentialBase(keras.Sequential): @@ -59,7 +59,7 @@ def __init__( ) self.disable_encryption = disable_encryption self.disable_masking = disable_masking - self.disable_noise = disable_noise + self.disable_noise = False if noise_multiplier == 0.0 else disable_noise self.check_overflow_INSECURE = check_overflow_INSECURE self.dataset_prepped = False @@ -72,14 +72,17 @@ def __init__( "WARNING: `jacobian_pfor` may be incompatible with `disable_encryption`." ) - if len(self.layers) > 0 and not ( - self.layers[-1].activation is tf.keras.activations.softmax - or self.layers[-1].activation is tf.nn.softmax - ): - raise ValueError( - "The model must have a softmax activation function on the final layer. Saw", - self.layers[-1].activation, - ) + if len(self.layers) > 0: + if not ( + self.layers[-1].activation is tf.keras.activations.softmax + or self.layers[-1].activation is tf.nn.softmax + ): + raise ValueError( + "The model must have a softmax activation function on the final layer. Saw", + self.layers[-1].activation, + ) + + self.out_classes = self.layers[-1].units # Unset the last layer's activation function. Derived classes handle # the softmax activation manually. @@ -434,84 +437,6 @@ def split_with_padding(self, tensor, num_splits, axis=0, padding_value=0): # Split the tensor return tf.split(tensor, num_splits, axis=axis), padded_by - def predict_and_jacobian(self, features): - """ - Predicts the output for the given features and optionally computes the - Jacobian. - - Args: - features (tf.Tensor): Input features for the model. - - Returns: - tuple: A tuple containing: - - predictions (tf.Tensor): The model output after applying - softmax. - - jacobians (list or tf.Tensor): The Jacobian of the last layer - preactivation with respect to the model weights. - """ - with tf.GradientTape( - persistent=tf.executing_eagerly() or self.jacobian_pfor - ) as tape: - predictions = self(features, training=True, with_softmax=False) - - jacobians = tape.jacobian( - predictions, - self.trainable_variables, - # unconnected_gradients=tf.UnconnectedGradients.ZERO, broken with pfor - parallel_iterations=self.jacobian_pfor_iterations, - experimental_use_pfor=self.jacobian_pfor, - ) - # ^ layers list x (batch size x num output classes x weights) matrix - # dy_pred_j/dW_sample_class - - # Compute the last layer's activation manually since we skipped it above. - predictions = tf.nn.softmax(predictions) - - return predictions, jacobians - - def jacobian_max_two_norm(self, jacobians): - """ - Computes the maximum two-norm of the Jacobian over all examples in - the batch and all output classes, layer-wise to limit memory usage. - - Args: - jacobians (list of tf.Tensor): A list of tensors representing the - Jacobians of the model last layer pre-activation with respect to - the weights. - - Returns: - tf.Tensor: A scalar tensor representing the maximum two-norm. - """ - if len(jacobians) == 0: - return tf.constant(0.0, dtype=tf.keras.backend.floatx()) - - batch_size = jacobians[0].shape[0] - num_output_classes = jacobians[0].shape[1] - - # Sum of squares accumulates the l2 norm squared of the per-example, - # per-output class gradients. - sum_of_squares = tf.zeros( - [batch_size, num_output_classes], dtype=tf.keras.backend.floatx() - ) - - for j in jacobians: - # Ignore the batch size and num output classes dimensions and - # recover just the number of dimensions in the weights. - num_weight_dims = len(j.shape) - 2 - reduce_sum_dims = range(2, 2 + num_weight_dims) - sum_of_squares += tf.reduce_sum(tf.square(j), axis=reduce_sum_dims) - - # Find the largest two l2 norms over the output classes. - top_2_l2s = tf.sqrt(tf.math.top_k(sum_of_squares, k=2).values) - - # Compute their sum, i.e. ||v|| + ||w|| where v and w are the two - # largest per-class gradients by l2 norm, for every per-example - # gradient. - max_two_norm = tf.reduce_sum(top_2_l2s, axis=1) - - # Return the maximum over the examples. - return tf.reduce_max(max_two_norm) - def flatten_and_pad_grad(self, grad, dim_0_sz): num_el = tf.cast(tf.size(grad), dtype=tf.int64) @@ -552,6 +477,33 @@ def unflatten_grad_list(self, flat_grads, metadata_list): return grads + def max_per_example_global_norm(self, grads): + """ + Computes the maximum per-example global norm of the gradients. + + Args: + grads (list of tf.Tensor): The list of gradient tensors. + + Returns: + tf.Tensor: Scalar of the maximum global norm of the gradients. + """ + if len(grads) == 0: + return tf.constant(0.0, dtype=tf.keras.backend.floatx()) + + # Sum of squares accumulates the l2 norm squared of the per-example + # gradients. + sum_of_squares = tf.zeros([grads[0].shape[0]], dtype=tf.keras.backend.floatx()) + for g in grads: + # Ignore the batch size and num output classes + # dimensions and recover just the number of dimensions + # in the weights. + num_weight_dims = len(g.shape) - 1 + weight_dims = range(1, 1 + num_weight_dims) + sum_of_squares += tf.reduce_sum(tf.square(g), axis=weight_dims) + + max_norm = tf.sqrt(tf.reduce_max(sum_of_squares)) + return max_norm + def spoof_int_gradients(self, grads): """ Spoofs the gradients to have dtype int64 and a scaling factor of 1. @@ -856,6 +808,7 @@ def shell_train_step(self, features, labels, read_key_from_cache, apply_gradient # Call the derived class to compute the gradients. grads, max_two_norm, predictions = self.compute_grads(features, enc_labels) + tf.print("---- max_two_norm: ", max_two_norm) with tf.device(self.features_party_dev): # Store the scaling factors of the gradients. @@ -902,8 +855,8 @@ def shell_train_step(self, features, labels, read_key_from_cache, apply_gradient # Note, running this on a single machine sometimes breaks # TensorFlow's optimizer, which then decides to replace the # entire training graph with a const op. - chunked_grads, chunked_grads_metadata = large_tensor.split_tensor_list( - grads + chunked_grads, chunked_grads_metadata = ( + tf_shell_ml.large_tensor.split_tensor_list(grads) ) if not self.disable_noise: @@ -944,7 +897,7 @@ def shell_train_step(self, features, labels, read_key_from_cache, apply_gradient with tf.device(self.labels_party_dev): if self.features_party_dev != self.labels_party_dev: # Reassemble the tensor list after sending it between machines. - grads = large_tensor.reassemble_tensor_list( + grads = tf_shell_ml.large_tensor.reassemble_tensor_list( chunked_grads, chunked_grads_metadata ) diff --git a/tf_shell_ml/postscale_sequential_model.py b/tf_shell_ml/postscale_sequential_model.py index 2079d3f..2b043a7 100644 --- a/tf_shell_ml/postscale_sequential_model.py +++ b/tf_shell_ml/postscale_sequential_model.py @@ -42,7 +42,77 @@ def call(self, inputs, training=False, with_softmax=True): # purposes. return tf.nn.softmax(prediction) + def _predict_and_jacobian(self, features): + """ + Predicts the output for the given features and optionally computes the + Jacobian. + + Args: + features (tf.Tensor): Input features for the model. + + Returns: + tuple: A tuple containing: + - predictions (tf.Tensor): The model output after applying + softmax. + - jacobians (list or tf.Tensor): The Jacobian of the last layer + preactivation with respect to the model weights. + """ + with tf.GradientTape( + persistent=tf.executing_eagerly() or self.jacobian_pfor + ) as tape: + predictions = self(features, training=True, with_softmax=False) + + jacobians = tape.jacobian( + predictions, + self.trainable_variables, + # unconnected_gradients=tf.UnconnectedGradients.ZERO, broken with pfor + parallel_iterations=self.jacobian_pfor_iterations, + experimental_use_pfor=self.jacobian_pfor, + ) + # ^ layers list x (batch size x num output classes x weights) matrix + # dy_pred_j/dW_sample_class + + # Compute the last layer's activation manually since we skipped it above. + predictions = tf.nn.softmax(predictions) + + return predictions, jacobians + + def _backward(self, dJ_dz, jacobians): + # Scale each gradient. Since 'scalars' may be a vector of + # ciphertexts, this requires multiplying plaintext gradient for the + # specific layer (2d) by the ciphertext (scalar). + grads = [] + for j in jacobians: + # Ignore the batch size and num output classes dimensions and + # recover just the number of dimensions in the weights. + num_weight_dims = len(j.shape) - 2 + + # Make the scalars the same shape as the gradients so the + # multiplication can be broadcasted. Doing this inside the loop + # is okay, TensorFlow will reuse the same expanded tensor if + # their dimensions match across iterations. + dJ_dz_exp = dJ_dz + for _ in range(num_weight_dims): + dJ_dz_exp = tf_shell.expand_dims(dJ_dz_exp, axis=-1) + + # Scale the jacobian. + scaled_grad = dJ_dz_exp * j + # ^ batch_size x num output classes x weights + + # Sum over the output classes. At this point, this is a gradient + # and no longer a jacobian. + scaled_grad = tf_shell.reduce_sum(scaled_grad, axis=1) + # ^ batch_size x weights + + grads.append(scaled_grad) + + return grads + def compute_grads(self, features, enc_labels): + scaling_factor = ( + enc_labels.scaling_factor if hasattr(enc_labels, "scaling_factor") else 1 + ) + predictions_list = [] jacobians_list = [] max_two_norms_list = [] @@ -55,7 +125,7 @@ def compute_grads(self, features, enc_labels): for i, d in enumerate(self.jacobian_devices): with tf.device(d): f = tf.identity(split_features[i]) # copy to GPU if needed - prediction, jacobians = self.predict_and_jacobian(f) + prediction, jacobians = self._predict_and_jacobian(f) if i == len(self.jacobian_devices) - 1 and end_pad > 0: # The last device's features may have been padded for even # split jacobian computation across multiple devices. @@ -63,7 +133,34 @@ def compute_grads(self, features, enc_labels): jacobians = jacobians[:-end_pad] predictions_list.append(prediction) jacobians_list.append(jacobians) - max_two_norms_list.append(self.jacobian_max_two_norm(jacobians)) + + # Perform PostScale (dJ/dz * dz/dw) in plaintext for every + # possible label using the worst casse quantization of the + # jacobian. + sensitivity = tf.constant(0.0, dtype=tf.keras.backend.floatx()) + worst_case_quantized_jacobians = [ + j + (tf.sign(j) / scaling_factor) for j in jacobians + ] + worst_case_prediction = prediction + ( + tf.sign(prediction) / scaling_factor + ) + + for possible_label_i in range(self.out_classes): + possible_label = tf.one_hot( + possible_label_i, + self.out_classes, + dtype=tf.keras.backend.floatx(), + ) + dJ_dz = ( + worst_case_prediction - possible_label + ) # Broadcasts to batch subset. + possible_grads = self._backward( + dJ_dz, worst_case_quantized_jacobians + ) + + max_norm = self.max_per_example_global_norm(possible_grads) + sensitivity = tf.maximum(sensitivity, max_norm) + max_two_norms_list.append(sensitivity) with tf.device(self.features_party_dev): predictions = tf.concat(predictions_list, axis=0) @@ -71,35 +168,9 @@ def compute_grads(self, features, enc_labels): jacobians = [tf.concat(j, axis=0) for j in zip(*jacobians_list)] # Compute prediction - labels (where labels may be encrypted). - scalars = enc_labels.__rsub__(predictions) # dJ/dprediction + enc_dJ_dz = enc_labels.__rsub__(predictions) # dJ/dprediction # ^ batch_size x num output classes. - # Scale each gradient. Since 'scalars' may be a vector of - # ciphertexts, this requires multiplying plaintext gradient for the - # specific layer (2d) by the ciphertext (scalar). - grads = [] - for j in jacobians: - # Ignore the batch size and num output classes dimensions and - # recover just the number of dimensions in the weights. - num_weight_dims = len(j.shape) - 2 - - # Make the scalars the same shape as the gradients so the - # multiplication can be broadcasted. Doing this inside the loop - # is okay, TensorFlow will reuse the same expanded tensor if - # their dimensions match across iterations. - scalars_exp = scalars - for _ in range(num_weight_dims): - scalars_exp = tf_shell.expand_dims(scalars_exp, axis=-1) - - # Scale the jacobian. - scaled_grad = scalars_exp * j - # ^ batch_size x num output classes x weights - - # Sum over the output classes. At this point, this is a gradient - # and no longer a jacobian. - scaled_grad = tf_shell.reduce_sum(scaled_grad, axis=1) - # ^ batch_size x weights - - grads.append(scaled_grad) + grads = self._backward(enc_dJ_dz, jacobians) return grads, max_two_norm, predictions