Skip to content

Commit

Permalink
ReLU derivative does not affect scaling factor.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Jan 20, 2025
1 parent 5c8ebd1 commit 8394649
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 1 deletion.
1 change: 1 addition & 0 deletions tf_shell/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tf_shell.python.shell_tensor import reduce_sum
from tf_shell.python.shell_tensor import fast_reduce_sum
from tf_shell.python.shell_tensor import reduce_sum_with_mod
from tf_shell.python.shell_tensor import mask_with_pt
from tf_shell.python.shell_tensor import matmul
from tf_shell.python.shell_tensor import expand_dims
from tf_shell.python.shell_tensor import reshape
Expand Down
41 changes: 41 additions & 0 deletions tf_shell/python/shell_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,47 @@ def to_tensorflow(s_tensor, key=None):
)


def mask_with_pt(x, mask):
if not isinstance(mask, tf.Tensor):
raise ValueError(f"Mask must be a TensorFlow Tensor. Got {type(mask)}.")

if isinstance(x, ShellTensor64):
# Do not encode the mask with a scaling factor, it is 0 or 1.
shell_mask = to_shell_plaintext(mask, x._context, override_scaling_factor=1)

matched_x, matched_mask = _match_moduli(x, shell_mask)

if x._is_enc:
raw_result = shell_ops.mul_ct_pt64(
matched_x._context._get_context_at_level(matched_x._level),
matched_x._raw_tensor,
matched_mask._raw_tensor,
)
else:
raw_result = shell_ops.mul_pt_pt64(
matched_x._context._get_context_at_level(matched_x._level),
matched_x._raw_tensor,
matched_mask._raw_tensor,
)

return ShellTensor64(
_raw_tensor=raw_result,
_context=matched_x._context,
_level=matched_x._level,
_num_mod_reductions=matched_x._num_mod_reductions,
_underlying_dtype=matched_x._underlying_dtype,
_scaling_factor=matched_x._scaling_factor, # Same sacling factor
_is_enc=matched_x._is_enc,
_is_fast_rotated=matched_x._is_fast_rotated,
)

elif isinstance(x, tf.Tensor):
return x * mask

else:
raise ValueError(f"Unsupported type for mask_with_pt. Got {type(x)}.")


def roll(x, shift, rotation_key=None):
if isinstance(x, ShellTensor64):
if not isinstance(rotation_key, ShellRotationKey64):
Expand Down
5 changes: 4 additions & 1 deletion tf_shell_ml/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ def relu_deriv(y, dy):
tf.constant(1, dtype=tf.keras.backend.floatx()),
tf.constant(0, dtype=tf.keras.backend.floatx()),
)
return dy * mask

# Use a special tf_shell function to multiply by a binary mask, without
# affecting the scaling factor in the case when dy is a ShellTensor.
return tf_shell.mask_with_pt(dy, mask)


def sigmoid(x):
Expand Down

0 comments on commit 8394649

Please sign in to comment.