From 368f3b1b204c2b511b87c03e178579e6ac8f6341 Mon Sep 17 00:00:00 2001 From: james-choncholas Date: Mon, 20 Jan 2025 19:14:50 +0000 Subject: [PATCH] Autocontext tracks noise bits across split ops. --- tf_shell/cc/optimizers/moduli_autotune.cc | 23 +++++-- tf_shell/cc/optimizers/utils.cc | 7 +- tf_shell/cc/optimizers/utils.h | 13 ++-- tf_shell_ml/large_tensor.py | 79 ++++++++++++----------- 4 files changed, 73 insertions(+), 49 deletions(-) diff --git a/tf_shell/cc/optimizers/moduli_autotune.cc b/tf_shell/cc/optimizers/moduli_autotune.cc index 4d8866f..e35067d 100644 --- a/tf_shell/cc/optimizers/moduli_autotune.cc +++ b/tf_shell/cc/optimizers/moduli_autotune.cc @@ -761,10 +761,8 @@ Status EstimateNodeNoise( } } - // Shape Ops. - else if (IsExpandDimsVariant(*node_def)) { - *this_noise = node_noise[node_view->GetRegularFanin(0).node_index()]; - } else if (IsConcatCt(*node_def)) { + // tf-shell shape ops. + else if (IsConcatCt(*node_def)) { // Fanins from 1 to n - 1 are the input tensors to be concatenated. // The first fanin is the axis. The noise is the maximum of the input // tensors. @@ -776,7 +774,11 @@ Status EstimateNodeNoise( } // Tensorflow Ops. - else if (IsReshape(*node_def)) { + else if (IsExpandDimsVariant(*node_def)) { + *this_noise = node_noise[node_view->GetRegularFanin(0).node_index()]; + } else if (IsReshape(*node_def)) { + *this_noise = node_noise[node_view->GetRegularFanin(0).node_index()]; + } else if (IsSplitV(*node_def)) { *this_noise = node_noise[node_view->GetRegularFanin(0).node_index()]; } @@ -844,6 +846,15 @@ Status EstimateNoiseGrowth(utils::MutableGraphView& graph_view, TF_ASSIGN_OR_RETURN(bool is_same_autocontext, DecryptUsesSameContext(node_view, autocontext)); if (is_same_autocontext) { + // If this is a decrypt node, and the context is the same, the noise + // should never be zero. If it is, this means there is a problem with + // the noise estimation. + if (node_noise[i] == 0) { + return errors::FailedPrecondition( + "Noise budget of decrypt node is zero. Could not estimate noise " + "growth. Check the noise estimation logic."); + } + // Update the maximum noise budget. log_max_noise = std::max(log_max_noise, node_noise[i]); } @@ -1019,7 +1030,7 @@ Status OptimizeAutocontext(utils::MutableGraphView& graph_view, BitWidth(params.t) + log_max_noise + auto_params.noise_offset_bits; // Adjust the noise budget to account for the encryption noise. - if (total_ct_bits > log_q) { + if (total_ct_bits > static_cast(log_q)) { if constexpr (debug_moduli) { std::cout << "Noise budget exceeded " << "(plaintext bits: " << BitWidth(params.t) diff --git a/tf_shell/cc/optimizers/utils.cc b/tf_shell/cc/optimizers/utils.cc index 4484f56..6020707 100644 --- a/tf_shell/cc/optimizers/utils.cc +++ b/tf_shell/cc/optimizers/utils.cc @@ -111,12 +111,15 @@ bool IsMaxUnpool2d(NodeDef const& node) { return node.op() == kMaxUnpool2dCt64; } +// tf-shell shape ops +bool IsConcatCt(NodeDef const& node) { return node.op() == kConcatCt; } + // TensorFlow ops. bool IsExpandDimsVariant(NodeDef const& node) { return node.op() == kExpandDimsVariant; } -bool IsConcatCt(NodeDef const& node) { return node.op() == kConcatCt; } bool IsBroadcastToShape(NodeDef const& node) { return node.op() == kBroadcastToShape; } -bool IsReshape(NodeDef const& node) { return node.op() == kReshape; } \ No newline at end of file +bool IsReshape(NodeDef const& node) { return node.op() == kReshape; } +bool IsSplitV(NodeDef const& node) { return node.op() == kSplitVOpName; } \ No newline at end of file diff --git a/tf_shell/cc/optimizers/utils.h b/tf_shell/cc/optimizers/utils.h index 9c9a1d3..f08c477 100644 --- a/tf_shell/cc/optimizers/utils.h +++ b/tf_shell/cc/optimizers/utils.h @@ -69,13 +69,16 @@ constexpr char kConv2dTransposeWithChanCtCt64[] = constexpr char kMaxUnpool2dCt64[] = "MaxUnpool2dCt64"; -// TensorFlow names -constexpr char kExpandDimsVariant[] = "ExpandDimsVariant"; +// tf-shell shape ops constexpr char kConcatCt[] = "ConcatCt64"; constexpr char kConcatPt[] = "ConcatPt64"; + +// TensorFlow names +constexpr char kExpandDimsVariant[] = "ExpandDimsVariant"; constexpr char kBroadcastToShape[] = "BroadcastToShape"; // TODO check name constexpr char kReshape[] = "Reshape"; // TODO check name constexpr char kConstOpName[] = "Const"; +constexpr char kSplitVOpName[] = "SplitV"; bool IsShellContext(NodeDef const& node); bool IsShellAutoContext(NodeDef const& node); @@ -131,7 +134,9 @@ bool IsConv2d(NodeDef const& node); bool IsMaxUnpool2d(NodeDef const& node); -bool IsExpandDimsVariant(NodeDef const& node); bool IsConcatCt(NodeDef const& node); + +bool IsExpandDimsVariant(NodeDef const& node); bool IsBroadcastToShape(NodeDef const& node); -bool IsReshape(NodeDef const& node); \ No newline at end of file +bool IsReshape(NodeDef const& node); +bool IsSplitV(NodeDef const& node); \ No newline at end of file diff --git a/tf_shell_ml/large_tensor.py b/tf_shell_ml/large_tensor.py index 8a9a721..66030c1 100644 --- a/tf_shell_ml/large_tensor.py +++ b/tf_shell_ml/large_tensor.py @@ -26,16 +26,14 @@ def calculate_tf_shell_split_sizes(context, total_elements): Returns: List of split sizes that sum to total_elements """ + num_slots = tf.cast(context.num_slots, dtype=tf.int64) + num_main_moduli = tf.size(context.main_moduli, out_type=tf.int64) # Each element in the shell tensor is a tuple of polynomials, one for # each component of the ciphertext, which have `ring degree` elements. # In tf_shell, these are represented with uint64_t values. Serialziation # also includes the power_of_s (int) and the error (double). - bytes_per_element = ( - tf.cast(context.num_slots, dtype=tf.int64) - * 2 - * (tf.size(context.main_moduli, out_type=tf.int64) * 8 + 4 + 8) - ) + bytes_per_element = num_slots * 2 * (num_main_moduli * 8 + 4 + 8) max_elements = tf.cast( tf.constant(int(UINT32_MAX * SAFETY_FACTOR), dtype=tf.int64) / bytes_per_element, @@ -124,38 +122,45 @@ def split_tensor(tensor): - List of tensor chunks - Metadata dictionary with original shape and other info needed for reassembly """ - if isinstance(tensor, tf_shell.ShellTensor64): - shape = tf_shell.shape(tensor) - total_elements = tensor._raw_tensor.shape.num_elements() + with tf.name_scope("large_tensor_split"): + if isinstance(tensor, tf_shell.ShellTensor64): + shape = tf_shell.shape(tensor) + total_elements = tensor._raw_tensor.shape.num_elements() - # Calculate split sizes - split_sizes = calculate_tf_shell_split_sizes(tensor._context, total_elements) + # Calculate split sizes + split_sizes = calculate_tf_shell_split_sizes( + tensor._context, total_elements + ) - # Reshape tensor to 1D for splitting, ignoring the batch dimension - flat_tensor = tf_shell.reshape(tensor, [tensor._context.num_slots, -1]) + # Reshape tensor to 1D for splitting, ignoring the batch dimension + flat_tensor = tf_shell.reshape(tensor, [tensor._context.num_slots, -1]) - # Split into chunks of calculated sizes - chunks = tf_shell.split(flat_tensor, split_sizes, axis=1, num=MAX_NUM_SPLITS) + # Split into chunks of calculated sizes + chunks = tf_shell.split( + flat_tensor, split_sizes, axis=1, num=MAX_NUM_SPLITS + ) - else: - shape = tf.shape(tensor) - total_elements = tf.reduce_prod(tf.cast(shape, tf.int64)) + else: + shape = tf.shape(tensor) + total_elements = tf.reduce_prod(tf.cast(shape, tf.int64)) - # Calculate split sizes - split_sizes = calculate_split_sizes(total_elements, tensor.dtype) + # Calculate split sizes + split_sizes = calculate_split_sizes(total_elements, tensor.dtype) - # Reshape tensor to 1D for splitting. - flat_tensor = tf.reshape(tensor, [-1]) + # Reshape tensor to 1D for splitting. + flat_tensor = tf.reshape(tensor, [-1]) - # Split into chunks of calculated sizes - chunks = tf_shell.split(flat_tensor, split_sizes, axis=0, num=MAX_NUM_SPLITS) + # Split into chunks of calculated sizes + chunks = tf_shell.split( + flat_tensor, split_sizes, axis=0, num=MAX_NUM_SPLITS + ) - metadata = { - "original_shape": shape, - "split_sizes": split_sizes, - } + metadata = { + "original_shape": shape, + "split_sizes": split_sizes, + } - return chunks, metadata + return chunks, metadata def split_tensor_list(tensors): @@ -192,17 +197,17 @@ def reassemble_tensor(chunks, metadata): Returns: Reassembled tensor with original shape """ - # Concatenate chunks - if isinstance(chunks[0], tf_shell.ShellTensor64): - flat_tensor = tf_shell.concat(chunks, axis=1) - else: - flat_tensor = tf.concat(chunks, axis=0) + with tf.name_scope("large_tensor_reassemble"): + # Concatenate chunks + if isinstance(chunks[0], tf_shell.ShellTensor64): + flat_tensor = tf_shell.concat(chunks, axis=1) + else: + flat_tensor = tf.concat(chunks, axis=0) - # Reshape to original shape - original_shape = metadata["original_shape"] - reassembled_tensor = tf_shell.reshape(flat_tensor, original_shape) + # Reshape to original shape + reassembled_tensor = tf_shell.reshape(flat_tensor, metadata["original_shape"]) - return reassembled_tensor + return reassembled_tensor def reassemble_tensor_list(all_chunks, all_metadata):