Skip to content

Commit

Permalink
Autocontext tracks noise bits across split ops.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Jan 20, 2025
1 parent 8394649 commit 368f3b1
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 49 deletions.
23 changes: 17 additions & 6 deletions tf_shell/cc/optimizers/moduli_autotune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -761,10 +761,8 @@ Status EstimateNodeNoise(
}
}

// Shape Ops.
else if (IsExpandDimsVariant(*node_def)) {
*this_noise = node_noise[node_view->GetRegularFanin(0).node_index()];
} else if (IsConcatCt(*node_def)) {
// tf-shell shape ops.
else if (IsConcatCt(*node_def)) {
// Fanins from 1 to n - 1 are the input tensors to be concatenated.
// The first fanin is the axis. The noise is the maximum of the input
// tensors.
Expand All @@ -776,7 +774,11 @@ Status EstimateNodeNoise(
}

// Tensorflow Ops.
else if (IsReshape(*node_def)) {
else if (IsExpandDimsVariant(*node_def)) {
*this_noise = node_noise[node_view->GetRegularFanin(0).node_index()];
} else if (IsReshape(*node_def)) {
*this_noise = node_noise[node_view->GetRegularFanin(0).node_index()];
} else if (IsSplitV(*node_def)) {
*this_noise = node_noise[node_view->GetRegularFanin(0).node_index()];
}

Expand Down Expand Up @@ -844,6 +846,15 @@ Status EstimateNoiseGrowth(utils::MutableGraphView& graph_view,
TF_ASSIGN_OR_RETURN(bool is_same_autocontext,
DecryptUsesSameContext(node_view, autocontext));
if (is_same_autocontext) {
// If this is a decrypt node, and the context is the same, the noise
// should never be zero. If it is, this means there is a problem with
// the noise estimation.
if (node_noise[i] == 0) {
return errors::FailedPrecondition(
"Noise budget of decrypt node is zero. Could not estimate noise "
"growth. Check the noise estimation logic.");
}

// Update the maximum noise budget.
log_max_noise = std::max(log_max_noise, node_noise[i]);
}
Expand Down Expand Up @@ -1019,7 +1030,7 @@ Status OptimizeAutocontext(utils::MutableGraphView& graph_view,
BitWidth(params.t) + log_max_noise + auto_params.noise_offset_bits;

// Adjust the noise budget to account for the encryption noise.
if (total_ct_bits > log_q) {
if (total_ct_bits > static_cast<uint64_t>(log_q)) {
if constexpr (debug_moduli) {
std::cout << "Noise budget exceeded "
<< "(plaintext bits: " << BitWidth(params.t)
Expand Down
7 changes: 5 additions & 2 deletions tf_shell/cc/optimizers/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,15 @@ bool IsMaxUnpool2d(NodeDef const& node) {
return node.op() == kMaxUnpool2dCt64;
}

// tf-shell shape ops
bool IsConcatCt(NodeDef const& node) { return node.op() == kConcatCt; }

// TensorFlow ops.
bool IsExpandDimsVariant(NodeDef const& node) {
return node.op() == kExpandDimsVariant;
}
bool IsConcatCt(NodeDef const& node) { return node.op() == kConcatCt; }
bool IsBroadcastToShape(NodeDef const& node) {
return node.op() == kBroadcastToShape;
}
bool IsReshape(NodeDef const& node) { return node.op() == kReshape; }
bool IsReshape(NodeDef const& node) { return node.op() == kReshape; }
bool IsSplitV(NodeDef const& node) { return node.op() == kSplitVOpName; }
13 changes: 9 additions & 4 deletions tf_shell/cc/optimizers/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,16 @@ constexpr char kConv2dTransposeWithChanCtCt64[] =

constexpr char kMaxUnpool2dCt64[] = "MaxUnpool2dCt64";

// TensorFlow names
constexpr char kExpandDimsVariant[] = "ExpandDimsVariant";
// tf-shell shape ops
constexpr char kConcatCt[] = "ConcatCt64";
constexpr char kConcatPt[] = "ConcatPt64";

// TensorFlow names
constexpr char kExpandDimsVariant[] = "ExpandDimsVariant";
constexpr char kBroadcastToShape[] = "BroadcastToShape"; // TODO check name
constexpr char kReshape[] = "Reshape"; // TODO check name
constexpr char kConstOpName[] = "Const";
constexpr char kSplitVOpName[] = "SplitV";

bool IsShellContext(NodeDef const& node);
bool IsShellAutoContext(NodeDef const& node);
Expand Down Expand Up @@ -131,7 +134,9 @@ bool IsConv2d(NodeDef const& node);

bool IsMaxUnpool2d(NodeDef const& node);

bool IsExpandDimsVariant(NodeDef const& node);
bool IsConcatCt(NodeDef const& node);

bool IsExpandDimsVariant(NodeDef const& node);
bool IsBroadcastToShape(NodeDef const& node);
bool IsReshape(NodeDef const& node);
bool IsReshape(NodeDef const& node);
bool IsSplitV(NodeDef const& node);
79 changes: 42 additions & 37 deletions tf_shell_ml/large_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,14 @@ def calculate_tf_shell_split_sizes(context, total_elements):
Returns:
List of split sizes that sum to total_elements
"""
num_slots = tf.cast(context.num_slots, dtype=tf.int64)
num_main_moduli = tf.size(context.main_moduli, out_type=tf.int64)

# Each element in the shell tensor is a tuple of polynomials, one for
# each component of the ciphertext, which have `ring degree` elements.
# In tf_shell, these are represented with uint64_t values. Serialziation
# also includes the power_of_s (int) and the error (double).
bytes_per_element = (
tf.cast(context.num_slots, dtype=tf.int64)
* 2
* (tf.size(context.main_moduli, out_type=tf.int64) * 8 + 4 + 8)
)
bytes_per_element = num_slots * 2 * (num_main_moduli * 8 + 4 + 8)
max_elements = tf.cast(
tf.constant(int(UINT32_MAX * SAFETY_FACTOR), dtype=tf.int64)
/ bytes_per_element,
Expand Down Expand Up @@ -124,38 +122,45 @@ def split_tensor(tensor):
- List of tensor chunks
- Metadata dictionary with original shape and other info needed for reassembly
"""
if isinstance(tensor, tf_shell.ShellTensor64):
shape = tf_shell.shape(tensor)
total_elements = tensor._raw_tensor.shape.num_elements()
with tf.name_scope("large_tensor_split"):
if isinstance(tensor, tf_shell.ShellTensor64):
shape = tf_shell.shape(tensor)
total_elements = tensor._raw_tensor.shape.num_elements()

# Calculate split sizes
split_sizes = calculate_tf_shell_split_sizes(tensor._context, total_elements)
# Calculate split sizes
split_sizes = calculate_tf_shell_split_sizes(
tensor._context, total_elements
)

# Reshape tensor to 1D for splitting, ignoring the batch dimension
flat_tensor = tf_shell.reshape(tensor, [tensor._context.num_slots, -1])
# Reshape tensor to 1D for splitting, ignoring the batch dimension
flat_tensor = tf_shell.reshape(tensor, [tensor._context.num_slots, -1])

# Split into chunks of calculated sizes
chunks = tf_shell.split(flat_tensor, split_sizes, axis=1, num=MAX_NUM_SPLITS)
# Split into chunks of calculated sizes
chunks = tf_shell.split(
flat_tensor, split_sizes, axis=1, num=MAX_NUM_SPLITS
)

else:
shape = tf.shape(tensor)
total_elements = tf.reduce_prod(tf.cast(shape, tf.int64))
else:
shape = tf.shape(tensor)
total_elements = tf.reduce_prod(tf.cast(shape, tf.int64))

# Calculate split sizes
split_sizes = calculate_split_sizes(total_elements, tensor.dtype)
# Calculate split sizes
split_sizes = calculate_split_sizes(total_elements, tensor.dtype)

# Reshape tensor to 1D for splitting.
flat_tensor = tf.reshape(tensor, [-1])
# Reshape tensor to 1D for splitting.
flat_tensor = tf.reshape(tensor, [-1])

# Split into chunks of calculated sizes
chunks = tf_shell.split(flat_tensor, split_sizes, axis=0, num=MAX_NUM_SPLITS)
# Split into chunks of calculated sizes
chunks = tf_shell.split(
flat_tensor, split_sizes, axis=0, num=MAX_NUM_SPLITS
)

metadata = {
"original_shape": shape,
"split_sizes": split_sizes,
}
metadata = {
"original_shape": shape,
"split_sizes": split_sizes,
}

return chunks, metadata
return chunks, metadata


def split_tensor_list(tensors):
Expand Down Expand Up @@ -192,17 +197,17 @@ def reassemble_tensor(chunks, metadata):
Returns:
Reassembled tensor with original shape
"""
# Concatenate chunks
if isinstance(chunks[0], tf_shell.ShellTensor64):
flat_tensor = tf_shell.concat(chunks, axis=1)
else:
flat_tensor = tf.concat(chunks, axis=0)
with tf.name_scope("large_tensor_reassemble"):
# Concatenate chunks
if isinstance(chunks[0], tf_shell.ShellTensor64):
flat_tensor = tf_shell.concat(chunks, axis=1)
else:
flat_tensor = tf.concat(chunks, axis=0)

# Reshape to original shape
original_shape = metadata["original_shape"]
reassembled_tensor = tf_shell.reshape(flat_tensor, original_shape)
# Reshape to original shape
reassembled_tensor = tf_shell.reshape(flat_tensor, metadata["original_shape"])

return reassembled_tensor
return reassembled_tensor


def reassemble_tensor_list(all_chunks, all_metadata):
Expand Down

0 comments on commit 368f3b1

Please sign in to comment.