Skip to content

Commit

Permalink
Do not apply gradients on first run.
Browse files Browse the repository at this point in the history
The fast_prep_dataset call works now, as caches are populated on first run in `fit` (vs. in slow prep_dataset).
  • Loading branch information
james-choncholas committed Jan 14, 2025
1 parent a787ce7 commit 960cc51
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 26 deletions.
7 changes: 6 additions & 1 deletion tf_shell/python/shell_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,12 @@ def optimize_shell_graph(
# Grappler determines fetch ops from collection 'train_op'.
meta_graph_def.collection_def[ops.GraphKeys.TRAIN_OP].CopyFrom(fetch_collection)

grappler_session_config = config_pb2.ConfigProto()
# For a clean slate, create a new grappler session config as below
# # grappler_session_config = config_pb2.ConfigProto()
# But to retain other settings, like soft device placement, copy from the
# existing config.
grappler_session_config = context.context().config

grappler_session_config.graph_options.rewrite_options.CopyFrom(rewriter_config)
optimized_graph_def = tf_optimizer.OptimizeGraph(
grappler_session_config, meta_graph_def, graph_id=b"tf_graph"
Expand Down
65 changes: 40 additions & 25 deletions tf_shell_ml/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,16 +112,20 @@ def train_step(self, features, labels):
# and contexts are not guaranteed to be populated, so set
# read_key_from_cache to False.
metrics, batch_size_should_be = self.shell_train_step(
features, labels, read_key_from_cache=False
features, labels, read_key_from_cache=False, apply_gradients=True
)
return metrics

@tf.function
def train_step_tf_func(self, features, labels, read_key_from_cache):
def train_step_tf_func(
self, features, labels, read_key_from_cache, apply_gradients
):
"""
tf.function wrapper around shell train step.
"""
return self.shell_train_step(features, labels, read_key_from_cache)
return self.shell_train_step(
features, labels, read_key_from_cache, apply_gradients
)

def compute_grads(self, features, enc_labels):
raise NotImplementedError() # Should be overloaded by the subclass.
Expand All @@ -132,8 +136,8 @@ def prep_dataset_for_model(self, train_features, train_labels):
Sets the batch size to the same value as the encryption ring degree. Run
the training loop once on dummy inputs to figure out the batch size.
Note the batch size is not always known during graph construction, when
using autocontext.
Note the batch size is not always known during graph construction when
using autocontext, only after graph optimization.
Args:
train_features (tf.data.Dataset): The training features dataset.
Expand All @@ -156,6 +160,7 @@ def prep_dataset_for_model(self, train_features, train_labels):
next(iter(train_features)),
next(iter(train_labels)),
read_key_from_cache=False,
apply_gradients=False,
)

self.batch_size = batch_size_should_be.numpy()
Expand All @@ -174,13 +179,11 @@ def fast_prep_dataset_for_model(self, train_features, train_labels):
"""
Prepare the dataset for training with encryption.
Sets the batch size to the same value as the encryption ring degree.
This method is faster than `prep_dataset_for_model` because it does not
execute the graph. Instead, it traces and optimizes the graph,
extracting the required parameters without actually executing it.
Note, since the graph is not executed, caches for keys and the shell
context are not written.
Sets the batch size of the training datasets to the same value as the
encryption ring degree. This method is faster than
`prep_dataset_for_model` because it does not execute the graph. Instead,
it traces and optimizes the graph, extracting the required parameters
without actually executing it.
Args:
train_features (tf.data.Dataset): The training features dataset.
Expand All @@ -201,6 +204,7 @@ def fast_prep_dataset_for_model(self, train_features, train_labels):
next(iter(train_features)),
next(iter(train_labels)),
read_key_from_cache=False,
apply_gradients=False,
)

# Optimize the graph using tf_shells HE-specific optimizers.
Expand Down Expand Up @@ -229,7 +233,6 @@ def get_tensor_by_name(g, name):

log_n = get_tensor_by_name(optimized_graph, context_node.input[0]).tolist()
self.batch_size = 2**log_n
# TODO: Create the context and keys from the graph parameters.

with tf.device(self.features_party_dev):
train_features = train_features.rebatch(2**log_n, drop_remainder=True)
Expand Down Expand Up @@ -294,7 +297,7 @@ def fit(
tf_shell.enable_randomized_rounding()

if not self.dataset_prepped:
features_dataset, labels_dataset = self.prep_dataset_for_model(
features_dataset, labels_dataset = self.fast_prep_dataset_for_model(
features_dataset, labels_dataset
)
tf.keras.backend.clear_session()
Expand Down Expand Up @@ -324,6 +327,7 @@ def fit(
# Begin training.
callback_list.on_train_begin()
logs = {}
subsequent_run = False

for epoch in range(epochs):
callback_list.on_epoch_begin(epoch, logs)
Expand All @@ -339,8 +343,12 @@ def fit(
# populated during the dataset preparation step. Set
# read_key_from_cache to True.
logs, batch_size_should_be = self.train_step_tf_func(
batch_x, batch_y, read_key_from_cache=True
batch_x,
batch_y,
read_key_from_cache=subsequent_run,
apply_gradients=True,
)
subsequent_run = True
callback_list.on_train_batch_end(step, logs)
gc.collect()
if steps_per_epoch is not None and step + 1 >= steps_per_epoch:
Expand Down Expand Up @@ -803,20 +811,15 @@ def warn_on_overflow(self, grads, scaling_factors, plaintext_modulus, message):
lambda: tf.identity(overflowed),
)

def shell_train_step(self, features, labels, read_key_from_cache):
def shell_train_step(self, features, labels, read_key_from_cache, apply_gradients):
"""
The main training loop for the PostScale protocol.
Args:
backprop_context: ShellContext for backpropagation, containing
encryption parameters.
backprop_secret_key: Secret key for decrypting backpropagated
gradients.
features (tf.Tensor): Training features.
labels (tf.Tensor): Training labels.
read_key_from_cache: Boolean whether to read keys from cache.
predictions: Model predictions, plaintext.
grads: Gradients to be masked and noised.
max_two_norm: Maximum L2 norm to determine noise scale.
labels: Real labels corresponding to the predictions.
apply_gradients: Boolean whether to apply gradients.
Returns:
result: Dictionary of metric results.
Expand Down Expand Up @@ -994,7 +997,19 @@ def shell_train_step(self, features, labels, read_key_from_cache):
grads = [g / tf.cast(tf.shape(g)[0], g.dtype) for g in grads]

# Apply the gradients to the model.
self.optimizer.apply_gradients(zip(grads, self.weights))
if apply_gradients:
self.optimizer.apply_gradients(zip(grads, self.weights))
else:
# If the gradients should not be applied, add zeros instead so
# the optimizer internal variables are created. To ensure the
# `grads` are not removed by the grappler optimizer, add them to
# the UPDATE_OPS collection.
[
tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.UPDATE_OPS, g)
for g in grads
]
zeros = [tf.zeros_like(g) for g in grads]
self.optimizer.apply_gradients(zip(zeros, self.weights))

for metric in self.metrics:
if metric.name == "loss":
Expand Down

0 comments on commit 960cc51

Please sign in to comment.