diff --git a/tf_shell/python/shell_optimizers.py b/tf_shell/python/shell_optimizers.py index b174a27..770efbe 100644 --- a/tf_shell/python/shell_optimizers.py +++ b/tf_shell/python/shell_optimizers.py @@ -78,7 +78,12 @@ def optimize_shell_graph( # Grappler determines fetch ops from collection 'train_op'. meta_graph_def.collection_def[ops.GraphKeys.TRAIN_OP].CopyFrom(fetch_collection) - grappler_session_config = config_pb2.ConfigProto() + # For a clean slate, create a new grappler session config as below + # # grappler_session_config = config_pb2.ConfigProto() + # But to retain other settings, like soft device placement, copy from the + # existing config. + grappler_session_config = context.context().config + grappler_session_config.graph_options.rewrite_options.CopyFrom(rewriter_config) optimized_graph_def = tf_optimizer.OptimizeGraph( grappler_session_config, meta_graph_def, graph_id=b"tf_graph" diff --git a/tf_shell_ml/model_base.py b/tf_shell_ml/model_base.py index b507b3b..4ede47d 100644 --- a/tf_shell_ml/model_base.py +++ b/tf_shell_ml/model_base.py @@ -112,16 +112,20 @@ def train_step(self, features, labels): # and contexts are not guaranteed to be populated, so set # read_key_from_cache to False. metrics, batch_size_should_be = self.shell_train_step( - features, labels, read_key_from_cache=False + features, labels, read_key_from_cache=False, apply_gradients=True ) return metrics @tf.function - def train_step_tf_func(self, features, labels, read_key_from_cache): + def train_step_tf_func( + self, features, labels, read_key_from_cache, apply_gradients + ): """ tf.function wrapper around shell train step. """ - return self.shell_train_step(features, labels, read_key_from_cache) + return self.shell_train_step( + features, labels, read_key_from_cache, apply_gradients + ) def compute_grads(self, features, enc_labels): raise NotImplementedError() # Should be overloaded by the subclass. @@ -132,8 +136,8 @@ def prep_dataset_for_model(self, train_features, train_labels): Sets the batch size to the same value as the encryption ring degree. Run the training loop once on dummy inputs to figure out the batch size. - Note the batch size is not always known during graph construction, when - using autocontext. + Note the batch size is not always known during graph construction when + using autocontext, only after graph optimization. Args: train_features (tf.data.Dataset): The training features dataset. @@ -156,6 +160,7 @@ def prep_dataset_for_model(self, train_features, train_labels): next(iter(train_features)), next(iter(train_labels)), read_key_from_cache=False, + apply_gradients=False, ) self.batch_size = batch_size_should_be.numpy() @@ -174,13 +179,11 @@ def fast_prep_dataset_for_model(self, train_features, train_labels): """ Prepare the dataset for training with encryption. - Sets the batch size to the same value as the encryption ring degree. - This method is faster than `prep_dataset_for_model` because it does not - execute the graph. Instead, it traces and optimizes the graph, - extracting the required parameters without actually executing it. - - Note, since the graph is not executed, caches for keys and the shell - context are not written. + Sets the batch size of the training datasets to the same value as the + encryption ring degree. This method is faster than + `prep_dataset_for_model` because it does not execute the graph. Instead, + it traces and optimizes the graph, extracting the required parameters + without actually executing it. Args: train_features (tf.data.Dataset): The training features dataset. @@ -201,6 +204,7 @@ def fast_prep_dataset_for_model(self, train_features, train_labels): next(iter(train_features)), next(iter(train_labels)), read_key_from_cache=False, + apply_gradients=False, ) # Optimize the graph using tf_shells HE-specific optimizers. @@ -229,7 +233,6 @@ def get_tensor_by_name(g, name): log_n = get_tensor_by_name(optimized_graph, context_node.input[0]).tolist() self.batch_size = 2**log_n - # TODO: Create the context and keys from the graph parameters. with tf.device(self.features_party_dev): train_features = train_features.rebatch(2**log_n, drop_remainder=True) @@ -294,7 +297,7 @@ def fit( tf_shell.enable_randomized_rounding() if not self.dataset_prepped: - features_dataset, labels_dataset = self.prep_dataset_for_model( + features_dataset, labels_dataset = self.fast_prep_dataset_for_model( features_dataset, labels_dataset ) tf.keras.backend.clear_session() @@ -324,6 +327,7 @@ def fit( # Begin training. callback_list.on_train_begin() logs = {} + subsequent_run = False for epoch in range(epochs): callback_list.on_epoch_begin(epoch, logs) @@ -339,8 +343,12 @@ def fit( # populated during the dataset preparation step. Set # read_key_from_cache to True. logs, batch_size_should_be = self.train_step_tf_func( - batch_x, batch_y, read_key_from_cache=True + batch_x, + batch_y, + read_key_from_cache=subsequent_run, + apply_gradients=True, ) + subsequent_run = True callback_list.on_train_batch_end(step, logs) gc.collect() if steps_per_epoch is not None and step + 1 >= steps_per_epoch: @@ -803,20 +811,15 @@ def warn_on_overflow(self, grads, scaling_factors, plaintext_modulus, message): lambda: tf.identity(overflowed), ) - def shell_train_step(self, features, labels, read_key_from_cache): + def shell_train_step(self, features, labels, read_key_from_cache, apply_gradients): """ The main training loop for the PostScale protocol. Args: - backprop_context: ShellContext for backpropagation, containing - encryption parameters. - backprop_secret_key: Secret key for decrypting backpropagated - gradients. + features (tf.Tensor): Training features. + labels (tf.Tensor): Training labels. read_key_from_cache: Boolean whether to read keys from cache. - predictions: Model predictions, plaintext. - grads: Gradients to be masked and noised. - max_two_norm: Maximum L2 norm to determine noise scale. - labels: Real labels corresponding to the predictions. + apply_gradients: Boolean whether to apply gradients. Returns: result: Dictionary of metric results. @@ -994,7 +997,19 @@ def shell_train_step(self, features, labels, read_key_from_cache): grads = [g / tf.cast(tf.shape(g)[0], g.dtype) for g in grads] # Apply the gradients to the model. - self.optimizer.apply_gradients(zip(grads, self.weights)) + if apply_gradients: + self.optimizer.apply_gradients(zip(grads, self.weights)) + else: + # If the gradients should not be applied, add zeros instead so + # the optimizer internal variables are created. To ensure the + # `grads` are not removed by the grappler optimizer, add them to + # the UPDATE_OPS collection. + [ + tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.UPDATE_OPS, g) + for g in grads + ] + zeros = [tf.zeros_like(g) for g in grads] + self.optimizer.apply_gradients(zip(zeros, self.weights)) for metric in self.metrics: if metric.name == "loss":