Skip to content

Commit

Permalink
Rotation keys support copy semantics.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Sep 24, 2024
1 parent 79f9fff commit a8bdb36
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 97 deletions.
6 changes: 3 additions & 3 deletions tf_shell/cc/kernels/mul_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -574,8 +574,8 @@ class MatMulPtCtOp : public OpKernel {
OP_REQUIRES_VALUE(rotation_key_var, op_ctx,
GetVariant<RotationKeyVariant<T>>(op_ctx, 3));
}
std::vector<RotationKey> empty_rot_keys{};
std::vector<RotationKey> const& rot_keys =
std::vector<std::shared_ptr<RotationKey>> empty_rot_keys{};
std::vector<std::shared_ptr<RotationKey>> const& rot_keys =
use_fast_rotations ? empty_rot_keys : rotation_key_var->keys;

// b is a vector of Polynomials so first dimension is the number of
Expand Down Expand Up @@ -732,7 +732,7 @@ class MatMulPtCtOp : public OpKernel {
shift - 1 <
static_cast<int>(rot_keys.size()), // Skip key 0.
InvalidArgument("No key for shift of '", shift, "'"));
RotationKey const* k = &rot_keys[shift - 1]; // Skip key 0.
RotationKey const* k = rot_keys[shift].get();

// Rotate by the shift.
OP_REQUIRES_VALUE(auto ct_sub, op_ctx,
Expand Down
91 changes: 32 additions & 59 deletions tf_shell/cc/kernels/rotation_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,57 +71,36 @@ class RotationKeyGenOp : public OpKernel {
// Allocate the output tensor which is a scalar containing the rotation key.
Tensor* out;
OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, TensorShape{}, &out));
Gadget* gadget = nullptr;
{ // Add a scope to ensure the gadget created on the stack cannot be used
// after it is std::moved into the output tensor.
// Create the gadget.
int level = shell_ctx->NumMainPrimeModuli() - 1;
OP_REQUIRES_VALUE(auto q_hats, op_ctx,
shell_ctx->MainPrimeModulusComplements(level));
OP_REQUIRES_VALUE(auto q_hat_invs, op_ctx,
shell_ctx->MainPrimeModulusCrtFactors(level));
std::vector<size_t> log_bs(shell_ctx->NumMainPrimeModuli(),
kLogGadgetBase);
OP_REQUIRES_VALUE(
Gadget raw_gadget, op_ctx,
Gadget::Create(shell_ctx->LogN(), log_bs, q_hats, q_hat_invs,
shell_ctx->MainPrimeModuli()));

// Store the gadget in a variant. Once it has been std::moved into it's
// final memory location in the output tensor, it can be used to create
// the rotation keys.
RotationKeyVariant<T> v_out(std::move(raw_gadget));
out->scalar<Variant>()() = std::move(v_out);
}
RotationKeyVariant<T>* key_variant =
out->scalar<Variant>()(0).get<RotationKeyVariant<T>>();
OP_REQUIRES(op_ctx, key_variant != nullptr,
InvalidArgument(
"RotationKeyVariant did not unwrap successfully. Saw: '",
out->scalar<Variant>()().DebugString(), "'"));
gadget = &key_variant->gadget;
// Create the output variant
RotationKeyVariant<T> v_out;

// Create the gadget.
int level = shell_ctx->NumMainPrimeModuli() - 1;
OP_REQUIRES_VALUE(auto q_hats, op_ctx,
shell_ctx->MainPrimeModulusComplements(level));
OP_REQUIRES_VALUE(auto q_hat_invs, op_ctx,
shell_ctx->MainPrimeModulusCrtFactors(level));
std::vector<size_t> log_bs(shell_ctx->NumMainPrimeModuli(), kLogGadgetBase);
OP_REQUIRES_VALUE(Gadget raw_gadget, op_ctx,
Gadget::Create(shell_ctx->LogN(), log_bs, q_hats,
q_hat_invs, shell_ctx->MainPrimeModuli()));

auto gadget_ptr = std::make_shared<Gadget>(std::move(raw_gadget));
v_out.gadget = gadget_ptr;

// This method of rotation only allows us to rotate within half of the
// polynomial slots. E.g. for n slots, slot 0 can be rotated to at most
// n/2-1 and n/2 to n-1. This has implications for how batching is done if
// performing back propagation under encryption.
int num_rotation_keys = 1 << (shell_ctx->LogN() - 1);
int two_n = 1 << (shell_ctx->LogN() + 1);

// Create a temp Tensor to hold the individual rotation keys as they are
// generated. After they are all finished, they are moved into the scalar
// output tensor so they are all together and easier to keep track of.
Tensor individual_keys;
OP_REQUIRES_OK(
op_ctx, op_ctx->allocate_temp(tensorflow::DT_VARIANT,
{num_rotation_keys}, &individual_keys));
auto flat_key_buffer = individual_keys.flat<Variant>();
v_out.keys.resize(num_rotation_keys);

auto variance = secret_key->Variance();
auto t = shell_ctx->PlaintextModulus();

auto generate_keys_in_range = [&, secret_key, gadget](int start, int end) {
// Skip rotation key at zero.
auto generate_keys_in_range = [&](int start, int end) {
// Skip rotation key at zero, it does not rotate.
if (start == 0) ++start;

uint sub_power = base_power;
Expand All @@ -133,10 +112,9 @@ class RotationKeyGenOp : public OpKernel {
for (int i = start; i < end; ++i) {
OP_REQUIRES_VALUE(
RotationKey k, op_ctx,
RotationKey::CreateForBgv(*secret_key, sub_power, variance, gadget,
t, kPrngType));
SingleRotationKeyVariant<T> k_out(std::move(k));
flat_key_buffer(i) = std::move(k_out);
RotationKey::CreateForBgv(*secret_key, sub_power, variance,
gadget_ptr.get(), t, kPrngType));
v_out.keys[i] = std::move(std::make_shared<RotationKey>(k));
sub_power *= base_power;
sub_power %= two_n;
}
Expand All @@ -148,12 +126,7 @@ class RotationKeyGenOp : public OpKernel {
thread_pool->ParallelFor(num_rotation_keys, cost_per_key,
generate_keys_in_range);

// Move the keys out of the buffer and into the output tensor.
key_variant->keys.reserve(num_rotation_keys);
for (int i = 1; i < num_rotation_keys; ++i) {
key_variant->keys.push_back(std::move(
flat_key_buffer(i).get<SingleRotationKeyVariant<T>>()->key));
}
out->scalar<Variant>()() = std::move(v_out);
}
};

Expand All @@ -174,7 +147,8 @@ class RollOp : public OpKernel {
OP_REQUIRES_VALUE(RotationKeyVariant<T> const* rotation_key_var, op_ctx,
GetVariant<RotationKeyVariant<T>>(op_ctx, 1));

std::vector<RotationKey> const& keys = rotation_key_var->keys;
std::vector<std::shared_ptr<RotationKey>> const& keys =
rotation_key_var->keys;

Tensor const& value = op_ctx->input(2);

Expand Down Expand Up @@ -216,11 +190,9 @@ class RollOp : public OpKernel {

RotationKey const* key;
if (shift != 0) {
OP_REQUIRES(
op_ctx,
shift - 1 < static_cast<int64>(keys.size()), // Skip key at zero.
InvalidArgument("No key for shift of '", shift, "'"));
key = &keys[shift - 1]; // Skip key at zero.
OP_REQUIRES(op_ctx, shift < static_cast<int64>(keys.size()),
InvalidArgument("No key for shift of '", shift, "'"));
key = keys[shift].get();
}

auto roll_in_range = [&](int start, int end) {
Expand Down Expand Up @@ -286,7 +258,8 @@ class ReduceSumByRotationOp : public OpKernel {
OP_REQUIRES_VALUE(RotationKeyVariant<T> const* rotation_key_var, op_ctx,
GetVariant<RotationKeyVariant<T>>(op_ctx, 1));

std::vector<RotationKey> const& keys = rotation_key_var->keys;
std::vector<std::shared_ptr<RotationKey>> const& keys =
rotation_key_var->keys;
Tensor const& value = op_ctx->input(2);
OP_REQUIRES(op_ctx, value.dim_size(0) > 0,
InvalidArgument("Cannot reduce_sum an empty ciphertext."));
Expand Down Expand Up @@ -337,9 +310,9 @@ class ReduceSumByRotationOp : public OpKernel {
for (int shift = 1; shift < num_slots / 2; shift <<= 1) {
OP_REQUIRES(
op_ctx,
shift - 1 < static_cast<int64>(keys.size()), // Skip key at zero.
shift < static_cast<int64>(keys.size()),
InvalidArgument("No key for shift of '", shift, "'"));
RotationKey const* key = &keys[shift - 1]; // Skip key at zero.
auto key = keys[shift];

// Rotate by the shift.
OP_REQUIRES_VALUE(auto ct_sub, op_ctx,
Expand Down
35 changes: 4 additions & 31 deletions tf_shell/cc/kernels/rotation_variants.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@ class RotationKeyVariant {
public:
RotationKeyVariant() {}

// Create with gadget first, then create and add keys.
RotationKeyVariant(Gadget gadget) : gadget(gadget) {}

static inline char const kTypeName[] = "ShellRotationKeyVariant";

std::string TypeName() const { return kTypeName; }
Expand All @@ -48,34 +45,10 @@ class RotationKeyVariant {

std::string DebugString() const { return "ShellRotationKeyVariant"; }

Gadget gadget;
std::vector<RotationKey> keys;
};

template <typename T>
class SingleRotationKeyVariant {
using ModularInt = rlwe::MontgomeryInt<T>;
using RotationKey = rlwe::RnsGaloisKey<ModularInt>;

public:
SingleRotationKeyVariant() {}

// Create with gadget first, then create and add keys.
SingleRotationKeyVariant(RotationKey key) : key(key) {}

static inline char const kTypeName[] = "SingleRotationKeyVariant";

std::string TypeName() const { return kTypeName; }

// Individual keys are never sent over the network.
void Encode(VariantTensorData* data) const {};

// Individual keys are never sent over the network.
bool Decode(VariantTensorData const& data) { return false; };

std::string DebugString() const { return "SingleRotationKeyVariant"; }

RotationKey key;
// Each key holds a raw pointer to gadget. Use a smart pointer to the gadget
// to help with copy semantics.
std::shared_ptr<Gadget> gadget;
std::vector<std::shared_ptr<RotationKey>> keys;
};

template <typename T>
Expand Down
7 changes: 3 additions & 4 deletions tf_shell/cc/kernels/segment_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ template <typename Device, typename T, typename Index, typename InitialValueF,
struct UnsortedSegmentFunctor {
void operator()(
OpKernelContext* ctx, ContextVariant<T> const* shell_ctx_var,
std::vector<rlwe::RnsGaloisKey<rlwe::MontgomeryInt<T>>> const& keys,
std::vector<std::shared_ptr<rlwe::RnsGaloisKey<rlwe::MontgomeryInt<T>>>> const& keys,
TensorShape const& segment_ids_shape,
typename TTypes<Index, 2>::ConstTensor segment_ids,
typename TTypes<Variant, 2>::ConstTensor data,
Expand Down Expand Up @@ -117,7 +117,7 @@ struct UnsortedSegmentFunctor<CPUDevice, T, Index, InitialValueF, ReductionF> {
using RotationKey = rlwe::RnsGaloisKey<ModularInt>;

void operator()(OpKernelContext* ctx, ContextVariant<T> const* shell_ctx_var,
std::vector<RotationKey> const& keys,
std::vector<std::shared_ptr<RotationKey>> const& keys,
TensorShape const& segment_ids_shape,
typename TTypes<Index, 2>::ConstTensor segment_ids,
typename TTypes<Variant, 2>::ConstTensor data,
Expand Down Expand Up @@ -308,10 +308,9 @@ struct UnsortedSegmentFunctor<CPUDevice, T, Index, InitialValueF, ReductionF> {
RotationKey const* key;
int64_t key_slot = slot;
if (key_slot > num_slots / 2) key_slot = slot - num_slots / 2;
--key_slot; // -1 to skip key at zero.
OP_REQUIRES(ctx, key_slot < static_cast<int64_t>(keys.size()),
InvalidArgument("No key for slot '", key_slot, "'"));
key = &keys[key_slot];
key = keys[key_slot].get();

SymmetricCt const& ct =
unreduced_output(j, chip).get<SymmetricCtVariant<T>>()->ct;
Expand Down

0 comments on commit a8bdb36

Please sign in to comment.