Skip to content

Commit

Permalink
Segment sum plaintext implementation to mimic tf-shell.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Jan 21, 2025
1 parent fd5cbed commit afe5601
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 45 deletions.
75 changes: 61 additions & 14 deletions tf_shell/python/shell_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,26 +1327,73 @@ def segment_sum(x, segments, num_segments, rotation_key=None, reduction="galois"
reduction_count,
)
elif isinstance(x, tf.Tensor):
# tf-shell segment functions differs from tensorflow in the following
# ways: First, the ciphertext dimension is included in the output, but
# only one dimension is valid. For the top half of the ciphertext, the
# first dimension is valid, and for the bottom half, the `num_slots //
# 2`th dimension is valid.
# Second, the reduction only happens across half of the batching
# tf-shell segment functions differs from tensorflow in the
# following ways:
# 1) The ciphertext dimension is included in the output, but only one
# dimension is valid. For the top half of the ciphertext, the first
# dimension is valid, and for the bottom half, the `num_slots // 2`th
# dimension is valid.
# 2) The reduction (if enabled) only happens across half of the batching
# dimension, due to how rotations in tf-shell work. Segment reduction
# happens on the top and bottom halves of the ciphertext independently.

def bincount(x):
return tf.math.bincount(
x, minlength=num_segments, maxlength=num_segments, dtype=segments.dtype
)

segments_nonnegative = tf.where(segments >= 0, segments, num_segments + 1)
pt_counts = tf.map_fn(
bincount, segments_nonnegative, fn_output_signature=segments.dtype
)

if reduction == "none":
raise ValueError("Plaintext segment_sum does not support `none` reduction.")
half_slots = x.shape[0] // 2
padding = tf.zeros_like(x[:half_slots])

x_top = tf.concat([x[:half_slots], padding], 0)
x_bottom = tf.concat([padding, x[half_slots:]], 0)
# TensorFlow's segment_sum does not work "per example" or over the
# first dimentsion where tf-shell does. Emulate this with map_fn.
def per_example_segment_sum(tupl):
x, segments = tupl

# Fake batch size.
x = tf.expand_dims(x, 0)
segments = tf.expand_dims(segments, 0)

return tf.math.unsorted_segment_sum(x, segments, num_segments)

output_signature = tf.TensorSpec.from_tensor(
per_example_segment_sum((x[0], segments[0]))
)
res = tf.map_fn(
per_example_segment_sum,
(x, segments),
fn_output_signature=output_signature,
name="shell_segment_sum_tf",
)
return res, pt_counts

else:
x_top, x_bottom = tf.split(x, num_or_size_splits=2, axis=0)

padding = tf.zeros_like(x_top)

x_top = tf.concat([x_top, padding], 0)
x_bottom = tf.concat([padding, x_bottom], 0)

res_top = tf.math.unsorted_segment_sum(x_top, segments, num_segments)
res_bottom = tf.math.unsorted_segment_sum(x_bottom, segments, num_segments)

res_top = tf.expand_dims(res_top, 0)
res_top = tf.repeat(res_top, repeats=x.shape[0] // 2, axis=0)
res_top = tf.concat([res_top, tf.zeros_like(res_top)], 0)

top_answer = tf.math.unsorted_segment_sum(x_top, segments, num_segments)
bottom_answer = tf.math.unsorted_segment_sum(x_bottom, segments, num_segments)
res_bottom = tf.expand_dims(res_bottom, 0)
res_bottom = tf.repeat(res_bottom, repeats=x.shape[0] // 2, axis=0)
res_bottom = tf.concat([tf.zeros_like(res_bottom), res_bottom], 0)

return top_answer, bottom_answer
res_top = tf.expand_dims(res_top, 1)
res_bottom = tf.expand_dims(res_bottom, 1)
res = tf.concat([res_top, res_bottom], 1)
return res, pt_counts

else:
raise ValueError("Unsupported type for segment_sum")
Expand Down
40 changes: 13 additions & 27 deletions tf_shell/test/segment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,24 +124,19 @@ def test_functor(ea, segments, num_segments, rot_key):
ess, counts = test_functor(
ea, segments, num_segments, test_context.rotation_key
)

ss = tf_shell.to_tensorflow(ess, test_context.key)

pt_ss_top, pt_ss_bottom = tf_shell.segment_sum(a, segments, num_segments)
pt_ss, pt_counts = tf_shell.segment_sum(a, segments, num_segments)

# Ensure the reduced data is correct.
self.assertAllClose(pt_ss_top, ss[0][0])
self.assertAllClose(ss[0][0], pt_ss[0][0])
self.assertAllClose(
pt_ss_bottom, ss[test_context.shell_context.num_slots // 2][1]
ss[test_context.shell_context.num_slots // 2][1],
pt_ss[test_context.shell_context.num_slots // 2][1],
)

# Ensure the counts are correct.
def bincount(x):
return tf.math.bincount(x, minlength=num_segments, maxlength=num_segments)

segments_nonnegative = tf.where(segments >= 0, segments, num_segments + 1)
pt_counts = tf.map_fn(bincount, segments_nonnegative)
self.assertAllEqual(pt_counts, counts)
self.assertAllEqual(counts, pt_counts)

# Ensure initial arguments are not modified.
self.assertAllClose(a, tf_shell.to_tensorflow(sa))
Expand Down Expand Up @@ -197,20 +192,16 @@ def test_functor(ea, segments, num_segments):
return ess, counts

ess, counts = test_functor(ea, segments, num_segments)

ss = tf_shell.to_tensorflow(ess, test_context.key)

pt_result = tf.math.unsorted_segment_sum(a, segments, num_segments)
pt_ss, pt_counts = tf_shell.segment_sum(
a, segments, num_segments, reduction="none"
)

# Ensure the data is correct.
self.assertAllClose(pt_result, tf.reduce_sum(ss, axis=0))
self.assertAllClose(ss, pt_ss)

# Ensure the counts are correct.
def bincount(x):
return tf.math.bincount(x, minlength=num_segments, maxlength=num_segments)

segments_nonnegative = tf.where(segments >= 0, segments, num_segments + 1)
pt_counts = tf.map_fn(bincount, segments_nonnegative)
self.assertAllEqual(pt_counts, counts)

# Ensure initial arguments are not modified.
Expand Down Expand Up @@ -276,23 +267,18 @@ def test_functor(ea, segments, num_segments, rot_key):
ess, counts = test_functor(
ea, segments, num_segments, test_context.rotation_key
)

ss = tf_shell.to_tensorflow(ess, test_context.key)

pt_ss_top, pt_ss_bottom = tf_shell.segment_sum(a, segments, num_segments)
pt_ss, pt_counts = tf_shell.segment_sum(a, segments, num_segments)

# Ensure the data is correctly reduced.
self.assertAllClose(pt_ss_top, ss[0][0])
self.assertAllClose(ss[0][0], pt_ss[0][0])
self.assertAllClose(
pt_ss_bottom, ss[test_context.shell_context.num_slots // 2][1]
ss[test_context.shell_context.num_slots // 2][1],
pt_ss[test_context.shell_context.num_slots // 2][1],
)

# Ensure the counts are correct.
def bincount(x):
return tf.math.bincount(x, minlength=num_segments, maxlength=num_segments)

segments_nonnegative = tf.where(segments >= 0, segments, num_segments + 1)
pt_counts = tf.map_fn(bincount, segments_nonnegative)
self.assertAllEqual(pt_counts, counts)

# Ensure initial arguments are not modified.
Expand Down
32 changes: 28 additions & 4 deletions tf_shell_ml/test/embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,20 +85,37 @@ def forward_backward(x):

enc_dw, _ = embedding_layer.backward(enc_dy, rotation_key)
dw = tf_shell.to_tensorflow(enc_dw[0], key)

# Plaintext backward pass.
pt_dws, _ = embedding_layer.backward(dy)
pt_dw = pt_dws[0]

if reduction == "none":
dw = tf.reduce_sum(dw, axis=0)
pt_dw = tf.reduce_sum(pt_dw, axis=0)
else:
dw = embedding_layer.unpack(dw)
return dw, enc_dw[0].shape
pt_dw = embedding_layer.unpack(pt_dw)

dw, shape_inf = forward_backward(x)
return dw, enc_dw[0].shape, pt_dw, pt_dws[0].shape

# dw, shape_inf = forward_backward(x)
dw, dw_shape_inf, pt_dw, pt_dw_shape_inf = forward_backward(x)

# Check the inferred shape of the gradient is correct.
if reduction == "none":
self.assertAllEqual(shape_inf, [context.num_slots, input_dim, output_dim])
self.assertAllEqual(
dw_shape_inf, [context.num_slots, input_dim, output_dim]
)
self.assertAllEqual(
pt_dw_shape_inf, [context.num_slots, input_dim, output_dim]
)
else:
self.assertAllEqual(
shape_inf, [context.num_slots, 2, input_dim, output_dim]
dw_shape_inf, [context.num_slots, 2, input_dim, output_dim]
)
self.assertAllEqual(
pt_dw_shape_inf, [context.num_slots, 2, input_dim, output_dim]
)

for i in range(0, input_dim):
Expand All @@ -110,9 +127,16 @@ def forward_backward(x):
context.num_slots * sentence_length, shape=(output_dim,)
),
)
self.assertAllEqual(
pt_dw[special_index, :],
tf.constant(
context.num_slots * sentence_length, shape=(output_dim,)
),
)
# Make sure the rest of the gradient elements are 0.
else:
self.assertAllEqual(dw[i, :], tf.constant(0, shape=(output_dim,)))
self.assertAllEqual(pt_dw[i, :], tf.constant(0, shape=(output_dim,)))

def test_embedding_eager(self):
tf.config.run_functions_eagerly(True)
Expand Down

0 comments on commit afe5601

Please sign in to comment.