diff --git a/tf_shell/cc/kernels/mul_kernels.cc b/tf_shell/cc/kernels/mul_kernels.cc index 6755b03..442a97b 100644 --- a/tf_shell/cc/kernels/mul_kernels.cc +++ b/tf_shell/cc/kernels/mul_kernels.cc @@ -574,8 +574,8 @@ class MatMulPtCtOp : public OpKernel { OP_REQUIRES_VALUE(rotation_key_var, op_ctx, GetVariant>(op_ctx, 3)); } - std::vector empty_rot_keys{}; - std::vector const& rot_keys = + std::vector> empty_rot_keys{}; + std::vector> const& rot_keys = use_fast_rotations ? empty_rot_keys : rotation_key_var->keys; // b is a vector of Polynomials so first dimension is the number of @@ -732,7 +732,7 @@ class MatMulPtCtOp : public OpKernel { shift - 1 < static_cast(rot_keys.size()), // Skip key 0. InvalidArgument("No key for shift of '", shift, "'")); - RotationKey const* k = &rot_keys[shift - 1]; // Skip key 0. + RotationKey const* k = rot_keys[shift].get(); // Rotate by the shift. OP_REQUIRES_VALUE(auto ct_sub, op_ctx, diff --git a/tf_shell/cc/kernels/rotation_kernels.cc b/tf_shell/cc/kernels/rotation_kernels.cc index 4db97c7..8de339c 100644 --- a/tf_shell/cc/kernels/rotation_kernels.cc +++ b/tf_shell/cc/kernels/rotation_kernels.cc @@ -71,35 +71,22 @@ class RotationKeyGenOp : public OpKernel { // Allocate the output tensor which is a scalar containing the rotation key. Tensor* out; OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, TensorShape{}, &out)); - Gadget* gadget = nullptr; - { // Add a scope to ensure the gadget created on the stack cannot be used - // after it is std::moved into the output tensor. - // Create the gadget. - int level = shell_ctx->NumMainPrimeModuli() - 1; - OP_REQUIRES_VALUE(auto q_hats, op_ctx, - shell_ctx->MainPrimeModulusComplements(level)); - OP_REQUIRES_VALUE(auto q_hat_invs, op_ctx, - shell_ctx->MainPrimeModulusCrtFactors(level)); - std::vector log_bs(shell_ctx->NumMainPrimeModuli(), - kLogGadgetBase); - OP_REQUIRES_VALUE( - Gadget raw_gadget, op_ctx, - Gadget::Create(shell_ctx->LogN(), log_bs, q_hats, q_hat_invs, - shell_ctx->MainPrimeModuli())); - - // Store the gadget in a variant. Once it has been std::moved into it's - // final memory location in the output tensor, it can be used to create - // the rotation keys. - RotationKeyVariant v_out(std::move(raw_gadget)); - out->scalar()() = std::move(v_out); - } - RotationKeyVariant* key_variant = - out->scalar()(0).get>(); - OP_REQUIRES(op_ctx, key_variant != nullptr, - InvalidArgument( - "RotationKeyVariant did not unwrap successfully. Saw: '", - out->scalar()().DebugString(), "'")); - gadget = &key_variant->gadget; + // Create the output variant + RotationKeyVariant v_out; + + // Create the gadget. + int level = shell_ctx->NumMainPrimeModuli() - 1; + OP_REQUIRES_VALUE(auto q_hats, op_ctx, + shell_ctx->MainPrimeModulusComplements(level)); + OP_REQUIRES_VALUE(auto q_hat_invs, op_ctx, + shell_ctx->MainPrimeModulusCrtFactors(level)); + std::vector log_bs(shell_ctx->NumMainPrimeModuli(), kLogGadgetBase); + OP_REQUIRES_VALUE(Gadget raw_gadget, op_ctx, + Gadget::Create(shell_ctx->LogN(), log_bs, q_hats, + q_hat_invs, shell_ctx->MainPrimeModuli())); + + auto gadget_ptr = std::make_shared(std::move(raw_gadget)); + v_out.gadget = gadget_ptr; // This method of rotation only allows us to rotate within half of the // polynomial slots. E.g. for n slots, slot 0 can be rotated to at most @@ -107,21 +94,13 @@ class RotationKeyGenOp : public OpKernel { // performing back propagation under encryption. int num_rotation_keys = 1 << (shell_ctx->LogN() - 1); int two_n = 1 << (shell_ctx->LogN() + 1); - - // Create a temp Tensor to hold the individual rotation keys as they are - // generated. After they are all finished, they are moved into the scalar - // output tensor so they are all together and easier to keep track of. - Tensor individual_keys; - OP_REQUIRES_OK( - op_ctx, op_ctx->allocate_temp(tensorflow::DT_VARIANT, - {num_rotation_keys}, &individual_keys)); - auto flat_key_buffer = individual_keys.flat(); + v_out.keys.resize(num_rotation_keys); auto variance = secret_key->Variance(); auto t = shell_ctx->PlaintextModulus(); - auto generate_keys_in_range = [&, secret_key, gadget](int start, int end) { - // Skip rotation key at zero. + auto generate_keys_in_range = [&](int start, int end) { + // Skip rotation key at zero, it does not rotate. if (start == 0) ++start; uint sub_power = base_power; @@ -133,10 +112,9 @@ class RotationKeyGenOp : public OpKernel { for (int i = start; i < end; ++i) { OP_REQUIRES_VALUE( RotationKey k, op_ctx, - RotationKey::CreateForBgv(*secret_key, sub_power, variance, gadget, - t, kPrngType)); - SingleRotationKeyVariant k_out(std::move(k)); - flat_key_buffer(i) = std::move(k_out); + RotationKey::CreateForBgv(*secret_key, sub_power, variance, + gadget_ptr.get(), t, kPrngType)); + v_out.keys[i] = std::move(std::make_shared(k)); sub_power *= base_power; sub_power %= two_n; } @@ -148,12 +126,7 @@ class RotationKeyGenOp : public OpKernel { thread_pool->ParallelFor(num_rotation_keys, cost_per_key, generate_keys_in_range); - // Move the keys out of the buffer and into the output tensor. - key_variant->keys.reserve(num_rotation_keys); - for (int i = 1; i < num_rotation_keys; ++i) { - key_variant->keys.push_back(std::move( - flat_key_buffer(i).get>()->key)); - } + out->scalar()() = std::move(v_out); } }; @@ -174,7 +147,8 @@ class RollOp : public OpKernel { OP_REQUIRES_VALUE(RotationKeyVariant const* rotation_key_var, op_ctx, GetVariant>(op_ctx, 1)); - std::vector const& keys = rotation_key_var->keys; + std::vector> const& keys = + rotation_key_var->keys; Tensor const& value = op_ctx->input(2); @@ -216,11 +190,9 @@ class RollOp : public OpKernel { RotationKey const* key; if (shift != 0) { - OP_REQUIRES( - op_ctx, - shift - 1 < static_cast(keys.size()), // Skip key at zero. - InvalidArgument("No key for shift of '", shift, "'")); - key = &keys[shift - 1]; // Skip key at zero. + OP_REQUIRES(op_ctx, shift < static_cast(keys.size()), + InvalidArgument("No key for shift of '", shift, "'")); + key = keys[shift].get(); } auto roll_in_range = [&](int start, int end) { @@ -286,7 +258,8 @@ class ReduceSumByRotationOp : public OpKernel { OP_REQUIRES_VALUE(RotationKeyVariant const* rotation_key_var, op_ctx, GetVariant>(op_ctx, 1)); - std::vector const& keys = rotation_key_var->keys; + std::vector> const& keys = + rotation_key_var->keys; Tensor const& value = op_ctx->input(2); OP_REQUIRES(op_ctx, value.dim_size(0) > 0, InvalidArgument("Cannot reduce_sum an empty ciphertext.")); @@ -337,9 +310,9 @@ class ReduceSumByRotationOp : public OpKernel { for (int shift = 1; shift < num_slots / 2; shift <<= 1) { OP_REQUIRES( op_ctx, - shift - 1 < static_cast(keys.size()), // Skip key at zero. + shift < static_cast(keys.size()), InvalidArgument("No key for shift of '", shift, "'")); - RotationKey const* key = &keys[shift - 1]; // Skip key at zero. + auto key = keys[shift]; // Rotate by the shift. OP_REQUIRES_VALUE(auto ct_sub, op_ctx, diff --git a/tf_shell/cc/kernels/rotation_variants.h b/tf_shell/cc/kernels/rotation_variants.h index b882408..269a13f 100644 --- a/tf_shell/cc/kernels/rotation_variants.h +++ b/tf_shell/cc/kernels/rotation_variants.h @@ -33,9 +33,6 @@ class RotationKeyVariant { public: RotationKeyVariant() {} - // Create with gadget first, then create and add keys. - RotationKeyVariant(Gadget gadget) : gadget(gadget) {} - static inline char const kTypeName[] = "ShellRotationKeyVariant"; std::string TypeName() const { return kTypeName; } @@ -48,34 +45,10 @@ class RotationKeyVariant { std::string DebugString() const { return "ShellRotationKeyVariant"; } - Gadget gadget; - std::vector keys; -}; - -template -class SingleRotationKeyVariant { - using ModularInt = rlwe::MontgomeryInt; - using RotationKey = rlwe::RnsGaloisKey; - - public: - SingleRotationKeyVariant() {} - - // Create with gadget first, then create and add keys. - SingleRotationKeyVariant(RotationKey key) : key(key) {} - - static inline char const kTypeName[] = "SingleRotationKeyVariant"; - - std::string TypeName() const { return kTypeName; } - - // Individual keys are never sent over the network. - void Encode(VariantTensorData* data) const {}; - - // Individual keys are never sent over the network. - bool Decode(VariantTensorData const& data) { return false; }; - - std::string DebugString() const { return "SingleRotationKeyVariant"; } - - RotationKey key; + // Each key holds a raw pointer to gadget. Use a smart pointer to the gadget + // to help with copy semantics. + std::shared_ptr gadget; + std::vector> keys; }; template diff --git a/tf_shell/cc/kernels/segment_kernels.cc b/tf_shell/cc/kernels/segment_kernels.cc index 9fa2c2c..0a0b19d 100644 --- a/tf_shell/cc/kernels/segment_kernels.cc +++ b/tf_shell/cc/kernels/segment_kernels.cc @@ -73,7 +73,7 @@ template const* shell_ctx_var, - std::vector>> const& keys, + std::vector>>> const& keys, TensorShape const& segment_ids_shape, typename TTypes::ConstTensor segment_ids, typename TTypes::ConstTensor data, @@ -117,7 +117,7 @@ struct UnsortedSegmentFunctor { using RotationKey = rlwe::RnsGaloisKey; void operator()(OpKernelContext* ctx, ContextVariant const* shell_ctx_var, - std::vector const& keys, + std::vector> const& keys, TensorShape const& segment_ids_shape, typename TTypes::ConstTensor segment_ids, typename TTypes::ConstTensor data, @@ -308,10 +308,9 @@ struct UnsortedSegmentFunctor { RotationKey const* key; int64_t key_slot = slot; if (key_slot > num_slots / 2) key_slot = slot - num_slots / 2; - --key_slot; // -1 to skip key at zero. OP_REQUIRES(ctx, key_slot < static_cast(keys.size()), InvalidArgument("No key for slot '", key_slot, "'")); - key = &keys[key_slot]; + key = keys[key_slot].get(); SymmetricCt const& ct = unreduced_output(j, chip).get>()->ct;