From 941e81c2c427802ebbe201cb3e4a404629d0ef7a Mon Sep 17 00:00:00 2001 From: james-choncholas Date: Wed, 25 Sep 2024 07:12:40 +0000 Subject: [PATCH] Fix issues with memory semantics for distributed execution. SHELL objects like RnsBgvCiphertext hold raw pointers to moduli. These are all derived from a leader unique_ptr stored in an RnsContext object. This causes problems when a ciphertext is sent to another machine. This commit encodes SHELLs memory semantics with TensorFlows execution manager. --- examples/label_dp_sgd.ipynb | 55 ++--- tf_shell/__init__.py | 2 - tf_shell/cc/kernels/add_kernels.cc | 22 +- tf_shell/cc/kernels/context_kernels.cc | 20 +- tf_shell/cc/kernels/mod_switch_kernels.cc | 20 +- tf_shell/cc/kernels/mul_kernels.cc | 37 ++-- tf_shell/cc/kernels/rotation_kernels.cc | 60 ++++-- tf_shell/cc/kernels/rotation_kernels_fast.cc | 16 +- tf_shell/cc/kernels/segment_kernels.cc | 26 ++- tf_shell/cc/ops/shell_ops.cc | 6 + tf_shell/cc/optimizers/moduli_autotune.cc | 10 +- tf_shell/cc/optimizers/pt_pt.cc | 2 +- tf_shell/python/shell_context.py | 47 +++-- tf_shell/python/shell_key.py | 201 +++++++++++++------ tf_shell/python/shell_tensor.py | 92 ++++++--- tf_shell/test/auto_param_optimizer_test.py | 118 +++++++++-- tf_shell/test/distribution_test.py | 7 + tf_shell/test/pt_pt_optimizer_test.py | 12 +- tf_shell/test/rotation_test.py | 11 + tf_shell/test/rotation_test_fast.py | 14 +- tf_shell/test/test_utils.py | 5 +- tf_shell_ml/model.py | 42 ++-- 22 files changed, 557 insertions(+), 268 deletions(-) diff --git a/examples/label_dp_sgd.ipynb b/examples/label_dp_sgd.ipynb index cc2e824..00bcb6c 100644 --- a/examples/label_dp_sgd.ipynb +++ b/examples/label_dp_sgd.ipynb @@ -38,8 +38,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-09-13 00:05:53.757039: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", - "2024-09-13 00:05:53.780113: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "2024-09-23 06:17:56.606856: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", + "2024-09-23 06:17:56.633047: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" ] } @@ -90,27 +90,7 @@ "cell_type": "code", "execution_count": 3, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Model: \"tf_shell_sequential\"\n", - "_________________________________________________________________\n", - " Layer (type) Output Shape Param # \n", - "=================================================================\n", - " shell_dense (ShellDense) (4096, 64) 50176 \n", - " \n", - " shell_dense_1 (ShellDense) (4096, 10) 640 \n", - " \n", - "=================================================================\n", - "Total params: 50816 (198.50 KB)\n", - "Trainable params: 50816 (198.50 KB)\n", - "Non-trainable params: 0 (0.00 Byte)\n", - "_________________________________________________________________\n" - ] - } - ], + "outputs": [], "source": [ "# Turn on the shell optimizer to use autocontext.\n", "shell_optimizers.enable_tf_shell_optimizer()\n", @@ -135,8 +115,6 @@ " scaling_factor=3,\n", " noise_offset_log2=68,\n", " ),\n", - " None,\n", - " None,\n", " True,\n", ")\n", "\n", @@ -145,13 +123,11 @@ " optimizer=tf.keras.optimizers.Adam(0.1),\n", " loss=tf.keras.losses.CategoricalCrossentropy(),\n", " metrics=[tf.keras.metrics.CategoricalAccuracy()],\n", - " # metrics=[\"accuracy\"],\n", - " # metrics=[keras.metrics.SparseCategoricalAccuracy()],\n", ")\n", "\n", - "m.build([batch_size, 784])\n", + "# m.build([batch_size, 784]) # do not build if using autoparams\n", "# m(train_dataset)\n", - "m.summary()\n" + "# m.summary()\n" ] }, { @@ -163,9 +139,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-09-13 00:05:55.276666: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:104] Profiler session initializing.\n", - "2024-09-13 00:05:55.276687: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:119] Profiler session started.\n", - "2024-09-13 00:05:55.276871: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:131] Profiler session tear down.\n" + "2024-09-23 06:17:58.163788: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:104] Profiler session initializing.\n", + "2024-09-23 06:17:58.163815: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:119] Profiler session started.\n", + "2024-09-23 06:17:58.163882: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:131] Profiler session tear down.\n" ] }, { @@ -176,7 +152,11 @@ "log_n: 12\n", "t: 65537\n", "qs: 288230376151760897 288230376152137729 \n", - "14/14 [==============================] - 111s 8s/step - categorical_accuracy: 0.0000e+00 - val_categorical_accuracy: 0.6646\n" + "Final parameters:\n", + "log_n: 12\n", + "t: 65537\n", + "qs: 288230376151760897 288230376152137729 \n", + "15/15 [==============================] - 109s 7s/step - num_slots: 4096.0000 - val_categorical_accuracy: 0.0973\n" ] } ], @@ -192,6 +172,15 @@ "\n", "history = m.fit(train_dataset, epochs=1, validation_data=val_dataset, callbacks = [tboard_callback])" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "m.summary()" + ] } ], "metadata": { diff --git a/tf_shell/__init__.py b/tf_shell/__init__.py index 09c2d1b..927811d 100644 --- a/tf_shell/__init__.py +++ b/tf_shell/__init__.py @@ -37,8 +37,6 @@ from tf_shell.python.shell_key import ShellKey64 from tf_shell.python.shell_key import create_key64 -from tf_shell.python.shell_key import mod_reduce_key64 - from tf_shell.python.shell_key import ShellRotationKey64 from tf_shell.python.shell_key import create_rotation_key64 from tf_shell.python.shell_key import ShellFastRotationKey64 diff --git a/tf_shell/cc/kernels/add_kernels.cc b/tf_shell/cc/kernels/add_kernels.cc index 504deb6..405d3e6 100644 --- a/tf_shell/cc/kernels/add_kernels.cc +++ b/tf_shell/cc/kernels/add_kernels.cc @@ -144,8 +144,11 @@ class AddCtCtOp : public OpKernel { ShellAddSub add_or_sub; OP_REQUIRES_VALUE(SymmetricCt ct_c, op_ctx, add_or_sub(ct_a, ct_b)); - SymmetricCtVariant ct_c_var(std::move(ct_c), shell_ctx_var->ct_context_, - shell_ctx_var->error_params_); + // SHELL's addition preserves moduli pointers of the first input. + // Ensure the output holds smart pointers to the input's context to + // prevent premature deletion of the moduli. + SymmetricCtVariant ct_c_var(std::move(ct_c), ct_a_var->ct_context, + ct_a_var->error_params); flat_output(i) = std::move(ct_c_var); } } @@ -210,8 +213,12 @@ class AddCtPtOp : public OpKernel { ShellAddSub add_or_sub; OP_REQUIRES_VALUE(SymmetricCt ct_c, op_ctx, add_or_sub(ct_a, pt_b)); - SymmetricCtVariant ct_c_var(std::move(ct_c), shell_ctx_var->ct_context_, - shell_ctx_var->error_params_); + // The output ct will hold raw pointers to moduli stored in the input's + // context. Ensure the output ciphertext Variant wrapper holds smart + // pointers to the input's context to prevent premature deletion of the + // moduli + SymmetricCtVariant ct_c_var(std::move(ct_c), ct_a_var->ct_context, + ct_a_var->error_params); flat_output(i) = std::move(ct_c_var); } } @@ -324,9 +331,10 @@ class NegCtOp : public OpKernel { OP_REQUIRES_VALUE(auto ct_out, op_ctx, ct_a.Negate()); - SymmetricCtVariant ct_out_var(std::move(ct_out), - shell_ctx_var->ct_context_, - shell_ctx_var->error_params_); + // The output ct will hold smart pointers to the input's context + // to prevent premature deletion of the moduli. + SymmetricCtVariant ct_out_var(std::move(ct_out), ct_a_var->ct_context, + ct_a_var->error_params); flat_output(i) = std::move(ct_out_var); } } diff --git a/tf_shell/cc/kernels/context_kernels.cc b/tf_shell/cc/kernels/context_kernels.cc index ff46dde..fd73106 100644 --- a/tf_shell/cc/kernels/context_kernels.cc +++ b/tf_shell/cc/kernels/context_kernels.cc @@ -51,19 +51,35 @@ class ContextImportOp : public OpKernel { OP_REQUIRES_VALUE(tstring t_seed, op_ctx, GetScalar(op_ctx, 5)); std::string seed(t_seed.c_str()); - // Allocate the output. + // Allocate the outputs. Tensor* out0; OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, TensorShape{}, &out0)); Tensor* out1; OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(1, TensorShape{}, &out1)); + Tensor* out2; + OP_REQUIRES_OK(op_ctx, + op_ctx->allocate_output(2, TensorShape{qs.size()}, &out2)); + Tensor* out3; + OP_REQUIRES_OK(op_ctx, + op_ctx->allocate_output(3, TensorShape{ps.size()}, &out3)); + Tensor* out4; + OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(4, TensorShape{}, &out4)); // Initialize the context variant and store it in the output. ContextVariant ctx_variant{}; OP_REQUIRES_OK(op_ctx, ctx_variant.Initialize(log_n, qs, ps, pt_modulus, noise_variance, seed)); - out0->scalar()() = std::move(ctx_variant); + + // Output other parameters for usage with auto-context. out1->scalar()() = log_n; + for (size_t i = 0; i < qs.size(); ++i) { + out2->flat()(i) = qs[i]; + } + for (size_t i = 0; i < ps.size(); ++i) { + out3->flat()(i) = ps[i]; + } + out4->scalar()() = pt_modulus; } }; diff --git a/tf_shell/cc/kernels/mod_switch_kernels.cc b/tf_shell/cc/kernels/mod_switch_kernels.cc index 15db3d4..fa7154f 100644 --- a/tf_shell/cc/kernels/mod_switch_kernels.cc +++ b/tf_shell/cc/kernels/mod_switch_kernels.cc @@ -88,12 +88,15 @@ class ModulusReduceKeyOp : public OpKernel { OP_REQUIRES_VALUE(SymmetricKeyVariant const* secret_key_var, op_ctx, GetVariant>(op_ctx, 1)); OP_REQUIRES( - op_ctx, secret_key_var->key != nullptr, + op_ctx, secret_key_var != nullptr, InvalidArgument("SymmetricKeyVariant did not unwrap successfully.")); OP_REQUIRES_OK(op_ctx, const_cast*>(secret_key_var) ->MaybeLazyDecode(shell_ctx_var->ct_context_, shell_ctx_var->noise_variance_)); + OP_REQUIRES(op_ctx, secret_key_var->key != nullptr, + InvalidArgument( + "SymmetricKeyVariant key did not unwrap successfully.")); Key secret_key = *secret_key_var->key; // Deep copy. // Allocate a scalar output tensor to store the reduced key. @@ -102,9 +105,11 @@ class ModulusReduceKeyOp : public OpKernel { OP_REQUIRES_OK(op_ctx, secret_key.ModReduce()); - // Store the reduced key in the output tensor. + // Store the reduced key in the output tensor. Keep a reference to the + // original context (even though it has the un-reduced moduli) to ensure + // the moduli held internally by the key are not deleted prematurely. SymmetricKeyVariant reduced_key_variant(std::move(secret_key), - shell_ctx_var->ct_context_); + secret_key_var->ct_context); out->scalar()() = std::move(reduced_key_variant); } }; @@ -169,10 +174,11 @@ class ModulusReduceCtOp : public OpKernel { OP_REQUIRES_OK(op_ctx, result_ct.ModReduce(t, ql_inv)); - // Store in the output. - SymmetricCtVariant result_var(std::move(result_ct), - shell_ctx_var->ct_context_, - shell_ctx_var->error_params_); + // Store in the output. Keep a reference to the original context to + // ensure the moduli held internally by the ciphertext are not deleted + // prematurely. + SymmetricCtVariant result_var( + std::move(result_ct), ct_a_var->ct_context, ct_a_var->error_params); flat_output(i) = std::move(result_var); } }; diff --git a/tf_shell/cc/kernels/mul_kernels.cc b/tf_shell/cc/kernels/mul_kernels.cc index 442a97b..197d24d 100644 --- a/tf_shell/cc/kernels/mul_kernels.cc +++ b/tf_shell/cc/kernels/mul_kernels.cc @@ -102,8 +102,12 @@ class MulCtCtOp : public OpKernel { OP_REQUIRES_VALUE(SymmetricCt ct_c, op_ctx, ct_a * ct_b); - SymmetricCtVariant ct_c_var(std::move(ct_c), shell_ctx_var->ct_context_, - shell_ctx_var->error_params_); + // Wrap the result in a SymmetricCtVariant and store it in the output. + // SHELL's multiplication preserves moduli pointers of the first input. + // Ensure the output holds smart pointers to the input's context to + // prevent premature deletion of the moduli. + SymmetricCtVariant ct_c_var(std::move(ct_c), ct_a_var->ct_context, + ct_a_var->error_params); flat_output(i) = std::move(ct_c_var); } } @@ -182,8 +186,13 @@ class MulCtPtOp : public OpKernel { OP_REQUIRES_VALUE(SymmetricCt ct_c, op_ctx, ct_a * pt_b); // shell absorb operation - SymmetricCtVariant ct_c_var(std::move(ct_c), shell_ctx_var->ct_context_, - shell_ctx_var->error_params_); + // Wrap the result in a SymmetricCtVariant and store it in the output. + // The output ct will hold raw pointers to moduli stored in the input's + // context. Ensure the output ciphertext Variant wrapper holds smart + // pointers to the input's context to prevent premature deletion of the + // moduli + SymmetricCtVariant ct_c_var(std::move(ct_c), ct_a_var->ct_context, + ct_a_var->error_params); flat_output(i) = std::move(ct_c_var); } }; @@ -261,8 +270,8 @@ class MulShellTfScalarOp : public OpKernel { OP_REQUIRES_VALUE(RnsPolynomial result, op_ctx, poly.Mul(wrapped_b, shell_ctx->MainPrimeModuli())); - CtOrPolyVariant result_var(std::move(result), - shell_ctx_var->ct_context_); + PolynomialVariant result_var(std::move(result), + shell_ctx_var->ct_context_); flat_output(i) = std::move(result_var); } else if constexpr (std::is_same>::value) { @@ -275,9 +284,13 @@ class MulShellTfScalarOp : public OpKernel { OP_REQUIRES_VALUE(SymmetricCt result, op_ctx, ct * wrapped_b); // shell aborb operation + // The output ct will hold raw pointers to moduli stored in the input's + // context. Ensure the output ciphertext Variant wrapper holds smart + // pointers to the input's context to prevent premature deletion of the + // moduli SymmetricCtVariant result_var(std::move(result), - shell_ctx_var->ct_context_, - shell_ctx_var->error_params_); + ct_or_pt_var->ct_context, + ct_or_pt_var->error_params); flat_output(i) = std::move(result_var); } } @@ -493,8 +506,8 @@ class MatMulCtPtOp : public OpKernel { } SymmetricCtVariant ct_result_var(std::move(ct_result), - shell_ctx_var->ct_context_, - shell_ctx_var->error_params_); + ct_a_var->ct_context, + ct_a_var->error_params); flat_output(i) = std::move(ct_result_var); } }; @@ -749,8 +762,8 @@ class MatMulPtCtOp : public OpKernel { // the result of the reduce sum operation. Store in the output // tensor. SymmetricCtVariant ct_result_var(std::move(ct_result), - shell_ctx_var->ct_context_, - shell_ctx_var->error_params_); + ct_b_var->ct_context, + ct_b_var->error_params); flat_output(outer, i, ct_col) = std::move(ct_result_var); } } diff --git a/tf_shell/cc/kernels/rotation_kernels.cc b/tf_shell/cc/kernels/rotation_kernels.cc index 8de339c..7eb70e6 100644 --- a/tf_shell/cc/kernels/rotation_kernels.cc +++ b/tf_shell/cc/kernels/rotation_kernels.cc @@ -61,11 +61,18 @@ class RotationKeyGenOp : public OpKernel { // Get the input tensors. OP_REQUIRES_VALUE(ContextVariant const* shell_ctx_var, op_ctx, GetVariant>(op_ctx, 0)); + OP_REQUIRES(op_ctx, shell_ctx_var != nullptr, + InvalidArgument("ContextVariant did not unwrap successfully.")); Context const* shell_ctx = shell_ctx_var->ct_context_.get(); OP_REQUIRES_VALUE(SymmetricKeyVariant const* secret_key_var, op_ctx, GetVariant>(op_ctx, 1)); - + OP_REQUIRES(op_ctx, secret_key_var != nullptr, + InvalidArgument("SymmetricKeyVariant did not unwrap successfully.")); + OP_REQUIRES_OK(op_ctx, + const_cast*>(secret_key_var) + ->MaybeLazyDecode(shell_ctx_var->ct_context_, + shell_ctx_var->noise_variance_)); std::shared_ptr const secret_key = secret_key_var->key; // Allocate the output tensor which is a scalar containing the rotation key. @@ -144,9 +151,13 @@ class RollOp : public OpKernel { // Get the input tensors. OP_REQUIRES_VALUE(ContextVariant const* shell_ctx_var, op_ctx, GetVariant>(op_ctx, 0)); + OP_REQUIRES(op_ctx, shell_ctx_var != nullptr, + InvalidArgument("ContextVariant did not unwrap successfully.")); + OP_REQUIRES_VALUE(RotationKeyVariant const* rotation_key_var, op_ctx, GetVariant>(op_ctx, 1)); - + OP_REQUIRES(op_ctx, rotation_key_var != nullptr, + InvalidArgument("RotationKeyVariant did not unwrap successfully.")); std::vector> const& keys = rotation_key_var->keys; @@ -218,9 +229,13 @@ class RollOp : public OpKernel { ct.Substitute(key->SubstitutionPower())); OP_REQUIRES_VALUE(auto ct_rot, op_ctx, key->ApplyTo(ct_sub)); - SymmetricCtVariant ct_out_var(std::move(ct_rot), - shell_ctx_var->ct_context_, - shell_ctx_var->error_params_); + // Wrap the result in a SymmetricCtVariant and store it in the output. + // The output ct will hold raw pointers to moduli stored in the + // input's context. Ensure the output ciphertext Variant wrapper holds + // smart pointers to the input's context to prevent premature deletion + // of the moduli + SymmetricCtVariant ct_out_var(std::move(ct_rot), ct_var->ct_context, + ct_var->error_params); flat_output(i) = std::move(ct_out_var); } } @@ -254,16 +269,19 @@ class ReduceSumByRotationOp : public OpKernel { // Recover the inputs. OP_REQUIRES_VALUE(ContextVariant const* shell_ctx_var, op_ctx, GetVariant>(op_ctx, 0)); + OP_REQUIRES(op_ctx, shell_ctx_var != nullptr, + InvalidArgument("ContextVariant did not unwrap successfully.")); OP_REQUIRES_VALUE(RotationKeyVariant const* rotation_key_var, op_ctx, GetVariant>(op_ctx, 1)); - + OP_REQUIRES(op_ctx, rotation_key_var != nullptr, + InvalidArgument("RotationKeyVariant did not unwrap successfully.")); std::vector> const& keys = rotation_key_var->keys; + Tensor const& value = op_ctx->input(2); OP_REQUIRES(op_ctx, value.dim_size(0) > 0, InvalidArgument("Cannot reduce_sum an empty ciphertext.")); - auto flat_value = value.flat(); // Recover num_slots from first ciphertext. @@ -308,10 +326,8 @@ class ReduceSumByRotationOp : public OpKernel { // ciphertext separately. So the max rotation is by half the number // of slots. for (int shift = 1; shift < num_slots / 2; shift <<= 1) { - OP_REQUIRES( - op_ctx, - shift < static_cast(keys.size()), - InvalidArgument("No key for shift of '", shift, "'")); + OP_REQUIRES(op_ctx, shift < static_cast(keys.size()), + InvalidArgument("No key for shift of '", shift, "'")); auto key = keys[shift]; // Rotate by the shift. @@ -323,9 +339,13 @@ class ReduceSumByRotationOp : public OpKernel { OP_REQUIRES_OK(op_ctx, sum.AddInPlace(ct_rot)); } - SymmetricCtVariant ct_out_var(std::move(sum), - shell_ctx_var->ct_context_, - shell_ctx_var->error_params_); + // Wrap the result in a SymmetricCtVariant and store it in the output. + // The output ct will hold raw pointers to moduli stored in the input's + // context. Ensure the output ciphertext Variant wrapper holds smart + // pointers to the input's context to prevent premature deletion of the + // moduli + SymmetricCtVariant ct_out_var(std::move(sum), ct_var->ct_context, + ct_var->error_params); flat_output(i) = std::move(ct_out_var); } }; @@ -368,6 +388,8 @@ class ReduceSumOp : public OpKernel { // Recover the inputs. OP_REQUIRES_VALUE(ContextVariant const* shell_ctx_var, op_ctx, GetVariant>(op_ctx, 0)); + OP_REQUIRES(op_ctx, shell_ctx_var != nullptr, + InvalidArgument("ContextVariant did not unwrap successfully.")); Tensor const& value = op_ctx->input(1); OP_REQUIRES(op_ctx, value.dim_size(0) > 0, @@ -457,9 +479,13 @@ class ReduceSumOp : public OpKernel { OP_REQUIRES_OK(op_ctx, sum.AddInPlace(ct)); } - // Store in the output. - SymmetricCtVariant res_var(std::move(sum), shell_ctx_var->ct_context_, - shell_ctx_var->error_params_); + // Wrap the result in a SymmetricCtVariant and store it in the output. + // The output ct will hold raw pointers to moduli stored in the + // input's context. Ensure the output ciphertext Variant wrapper holds + // smart pointers to the input's context to prevent premature deletion + // of the moduli + SymmetricCtVariant res_var(std::move(sum), first_ct_var->ct_context, + first_ct_var->error_params); flat_output(i, j) = std::move(res_var); } diff --git a/tf_shell/cc/kernels/rotation_kernels_fast.cc b/tf_shell/cc/kernels/rotation_kernels_fast.cc index d855e01..fbe8e5f 100644 --- a/tf_shell/cc/kernels/rotation_kernels_fast.cc +++ b/tf_shell/cc/kernels/rotation_kernels_fast.cc @@ -84,11 +84,18 @@ class FastRotationKeyGenOp : public OpKernel { // Get the input tensors. OP_REQUIRES_VALUE(ContextVariant const* shell_ctx_var, op_ctx, GetVariant>(op_ctx, 0)); + OP_REQUIRES(op_ctx, shell_ctx_var != nullptr, + InvalidArgument("ContextVariant did not unwrap successfully.")); Context const* shell_ctx = shell_ctx_var->ct_context_.get(); OP_REQUIRES_VALUE(SymmetricKeyVariant const* secret_key_var, op_ctx, GetVariant>(op_ctx, 1)); - + OP_REQUIRES(op_ctx, secret_key_var != nullptr, + InvalidArgument("SymmetricKeyVariant did not unwrap successfully.")); + OP_REQUIRES_OK(op_ctx, + const_cast*>(secret_key_var) + ->MaybeLazyDecode(shell_ctx_var->ct_context_, + shell_ctx_var->noise_variance_)); std::shared_ptr const secret_key = secret_key_var->key; // Allocate the output tensor which is a scalar containing the fast rotation @@ -151,6 +158,8 @@ class FastReduceSumByRotationOp : public OpKernel { void Compute(OpKernelContext* op_ctx) override { OP_REQUIRES_VALUE(ContextVariant const* shell_ctx_var, op_ctx, GetVariant>(op_ctx, 0)); + OP_REQUIRES(op_ctx, shell_ctx_var != nullptr, + InvalidArgument("ContextVariant did not unwrap successfully.")); auto const& sub_powers = shell_ctx_var->substitution_powers_; // Recover the input tensor. @@ -229,9 +238,8 @@ class FastReduceSumByRotationOp : public OpKernel { SymmetricCt ct_out(std::move(components), moduli_vector, ct.PowerOfS(), ct.Error() * ct.LogN(), ct.ErrorParams()); - SymmetricCtVariant ct_out_var(std::move(ct_out), - shell_ctx_var->ct_context_, - shell_ctx_var->error_params_); + SymmetricCtVariant ct_out_var(std::move(ct_out), ct_var->ct_context, + ct_var->error_params); flat_output(i) = std::move(ct_out_var); } }; diff --git a/tf_shell/cc/kernels/segment_kernels.cc b/tf_shell/cc/kernels/segment_kernels.cc index 0a0b19d..260ce0f 100644 --- a/tf_shell/cc/kernels/segment_kernels.cc +++ b/tf_shell/cc/kernels/segment_kernels.cc @@ -71,14 +71,14 @@ namespace functor { template struct UnsortedSegmentFunctor { - void operator()( - OpKernelContext* ctx, ContextVariant const* shell_ctx_var, - std::vector>>> const& keys, - TensorShape const& segment_ids_shape, - typename TTypes::ConstTensor segment_ids, - typename TTypes::ConstTensor data, - typename TTypes::Tensor unreduced_output, - typename TTypes::Tensor output); + void operator()(OpKernelContext* ctx, ContextVariant const* shell_ctx_var, + std::vector>>> const& keys, + TensorShape const& segment_ids_shape, + typename TTypes::ConstTensor segment_ids, + typename TTypes::ConstTensor data, + typename TTypes::Tensor unreduced_output, + typename TTypes::Tensor output); }; template @@ -262,9 +262,13 @@ struct UnsortedSegmentFunctor { // No need to lazy decode the output_var, it was created in this op. if (output_var->ct.Len() == 0) { - // Output has not been set yet. - SymmetricCtVariant var(masked_data_ct, shell_ctx_var->ct_context_, - shell_ctx_var->error_params_); + // Output has not been set yet, wrap the result in a + // SymmetricCtVariant and store. The output ct will hold raw + // pointers to moduli stored in the input's context. Ensure the + // output ciphertext Variant wrapper holds smart pointers to the + // input's context to prevent premature deletion of the moduli. + SymmetricCtVariant var(masked_data_ct, data_var->ct_context, + data_var->error_params); unreduced_output((int64_t)j, chip) = std::move(var); } else { OP_REQUIRES_OK(ctx, reduction(masked_data_ct, output_var->ct)); diff --git a/tf_shell/cc/ops/shell_ops.cc b/tf_shell/cc/ops/shell_ops.cc index 0557a27..7283d8e 100644 --- a/tf_shell/cc/ops/shell_ops.cc +++ b/tf_shell/cc/ops/shell_ops.cc @@ -40,6 +40,9 @@ REGISTER_OP("ContextImport64") .Input("seed: string") .Output("shell_context: variant") .Output("new_log_n: uint64") + .Output("new_qs: uint64") + .Output("new_ps: uint64") + .Output("new_pt_modulus: uint64") .SetShapeFn(MultiScalarOut<2>); REGISTER_OP("AutoShellContext64") @@ -49,6 +52,9 @@ REGISTER_OP("AutoShellContext64") .Input("noise_variance: uint64") .Output("shell_context: variant") .Output("new_log_n: uint64") + .Output("new_qs: uint64") + .Output("new_ps: uint64") + .Output("new_pt_modulus: uint64") .SetShapeFn(MultiScalarOut<2>); REGISTER_OP("PolynomialImport64") diff --git a/tf_shell/cc/optimizers/moduli_autotune.cc b/tf_shell/cc/optimizers/moduli_autotune.cc index f31e642..6238b0e 100644 --- a/tf_shell/cc/optimizers/moduli_autotune.cc +++ b/tf_shell/cc/optimizers/moduli_autotune.cc @@ -464,12 +464,12 @@ Status EstimateNodeNoise( rot_noise += BitWidth(params.log_n); // There are log_n rotations. *this_noise = std::max(noise_a, rot_noise); } else if (IsReduceSum(*node_def)) { - auto const* axis_node_def = - node_view->GetRegularFanin(2).node_view()->node(); + // auto const* axis_node_def = + // node_view->GetRegularFanin(2).node_view()->node(); - uint64_t axis = 0; - TF_RETURN_IF_ERROR( - GetScalarConstValue(*axis_node_def, &axis)); + // uint64_t axis = 0; + // TF_RETURN_IF_ERROR( + // GetScalarConstValue(*axis_node_def, &axis)); *this_noise = noise_a; // TODO depends on axis attribute and shape. } diff --git a/tf_shell/cc/optimizers/pt_pt.cc b/tf_shell/cc/optimizers/pt_pt.cc index ef9790d..b575e64 100644 --- a/tf_shell/cc/optimizers/pt_pt.cc +++ b/tf_shell/cc/optimizers/pt_pt.cc @@ -15,7 +15,7 @@ namespace grappler { namespace { -constexpr bool const debug = false; +constexpr bool const debug = true; bool IsReplaceableOp(NodeDef const& node) { return IsAddPtPt(node) || IsSubPtPt(node) || IsMulPtPt(node) || IsNegPt(node); diff --git a/tf_shell/python/shell_context.py b/tf_shell/python/shell_context.py index 3459680..3ff895c 100644 --- a/tf_shell/python/shell_context.py +++ b/tf_shell/python/shell_context.py @@ -25,11 +25,10 @@ class ShellContext64(tf.experimental.ExtensionType): num_slots: tf.Tensor two_n: tf.Tensor main_moduli: tf.Tensor - level: int + level: tf.Tensor aux_moduli: tf.Tensor - plaintext_modulus: int + plaintext_modulus: tf.Tensor noise_variance: int - noise_bits: int scaling_factor: int seed: str @@ -47,23 +46,39 @@ def __init__( ): self._raw_context = _raw_context self.is_autocontext = is_autocontext - self.log_n = log_n + self.log_n = tf.convert_to_tensor(log_n, dtype=tf.uint64) self.num_slots = 2 ** tf.cast(log_n, dtype=tf.int64) self.two_n = self.num_slots * 2 if isinstance(main_moduli, list): - main_moduli = tf.convert_to_tensor(main_moduli) + main_moduli = tf.convert_to_tensor(main_moduli, dtype=tf.uint64) self.main_moduli = main_moduli - self.level = main_moduli.shape[0] + self.level = tf.shape(main_moduli)[0] + if isinstance(aux_moduli, list): + aux_moduli = tf.convert_to_tensor(aux_moduli, dtype=tf.uint64) self.aux_moduli = aux_moduli - self.plaintext_modulus = plaintext_modulus + self.plaintext_modulus = tf.convert_to_tensor( + plaintext_modulus, dtype=tf.uint64 + ) self.noise_variance = noise_variance - if self.noise_variance % 2 == 0: - self.noise_bits = self.noise_variance.bit_length() - else: - self.noise_bits = self.noise_variance.bit_length() + 1 self.scaling_factor = scaling_factor self.seed = seed + def _get_generic_context_spec(self): + return ShellContext64.Spec( + _raw_context=tf.TensorSpec([], dtype=tf.variant), + is_autocontext=self.is_autocontext, + log_n=tf.TensorSpec([], dtype=tf.uint64), + num_slots=tf.TensorSpec([], dtype=tf.int64), + two_n=tf.TensorSpec([], dtype=tf.int64), + main_moduli=tf.TensorSpec(None, dtype=tf.uint64), + level=tf.TensorSpec([], dtype=tf.int32), + aux_moduli=tf.TensorSpec(None, dtype=tf.uint64), + plaintext_modulus=tf.TensorSpec(None, dtype=tf.uint64), + noise_variance=self.noise_variance, + scaling_factor=self.scaling_factor, + seed=self.seed, + ) + def mod_reduce_context64(context): if not isinstance(context, ShellContext64): @@ -100,7 +115,7 @@ def create_context64( elif len(seed) < 64 and seed != "": seed = seed.ljust(64) - shell_context, _ = shell_ops.context_import64( + shell_context, _, _, _, _ = shell_ops.context_import64( log_n=log_n, main_moduli=main_moduli, aux_moduli=aux_moduli, @@ -133,7 +148,7 @@ def create_autocontext64( raise ValueError("Seed must be at most 64 characters long.") seed = seed.ljust(64) - shell_context, new_log_n = shell_ops.auto_shell_context64( + shell_context, new_log_n, new_qs, new_ps, new_t = shell_ops.auto_shell_context64( log2_cleartext_sz=log2_cleartext_sz, scaling_factor=scaling_factor, log2_noise_offset=noise_offset_log2, @@ -144,9 +159,9 @@ def create_autocontext64( _raw_context=shell_context, is_autocontext=True, log_n=new_log_n, - main_moduli=[], - aux_moduli=[], - plaintext_modulus=0, + main_moduli=new_qs, + aux_moduli=new_ps, + plaintext_modulus=new_t, noise_variance=noise_variance, scaling_factor=scaling_factor, seed=seed, diff --git a/tf_shell/python/shell_key.py b/tf_shell/python/shell_key.py index 11330e9..60935a4 100644 --- a/tf_shell/python/shell_key.py +++ b/tf_shell/python/shell_key.py @@ -21,53 +21,73 @@ class ShellKey64(tf.experimental.ExtensionType): - _raw_keys_at_level: typing.Mapping[int, tf.Tensor] + _raw_keys_at_level: tf.Tensor def _get_key_at_level(self, level): - if level not in self._raw_keys_at_level: - raise ValueError(f"No key at level {level}.") + level -= 1 # Keys tensor start at level 1. + tf.Assert(level >= 0, [f"level must be >= 0. Got {level}"]) + tf.Assert( + level < tf.shape(self._raw_keys_at_level)[0], + [f"level must be < {tf.shape(self._raw_keys_at_level)[0]}. Got {level}"], + ) return self._raw_keys_at_level[level] -def mod_reduce_key64(unreduced_context, raw_key): - if not isinstance(unreduced_context, ShellContext64): - raise ValueError("context must be a ShellContext64.") - - if not isinstance(raw_key, tf.Tensor): - raise ValueError("raw_key must be a Tensor") - - -def create_key64(context, skip_at_mul_depth=[]): +def create_key64(context): if not isinstance(context, ShellContext64): raise ValueError("context must be a ShellContext64") + num_keys = context.level + keys = tf.TensorArray(tf.variant, size=context.level, clear_after_read=False) - raw_keys_at_level = {} - - # Generate and store the first key. - key = shell_ops.key_gen64(context._raw_context) - raw_keys_at_level[context.level] = key + # Generate and store the first key in the last index. + keys = keys.write(context.level - 1, shell_ops.key_gen64(context._raw_context)) # Mod reduce to compute the remaining keys. - while context.level > 1: - key = shell_ops.modulus_reduce_key64(context._raw_context, key) - context = mod_reduce_context64(context) - - if context.level not in skip_at_mul_depth: - raw_keys_at_level[context.level] = key - - return ShellKey64(_raw_keys_at_level=raw_keys_at_level) + keys, context = tf.while_loop( + lambda ks, c: c.level > 2, + lambda ks, c: ( + ks.write( + c.level - 2, + shell_ops.modulus_reduce_key64(c._raw_context, ks.read(c.level - 1)), + ), + mod_reduce_context64(c), + ), + loop_vars=[keys, context], + shape_invariants=[ + tf.TensorSpec(None, dtype=tf.variant), + context._get_generic_context_spec(), + ], + ) + + # Store the first key for level 1. + keys = tf.cond( + context.level == 2, + lambda: keys.write( + context.level - 2, + shell_ops.modulus_reduce_key64( + context._raw_context, keys.read(context.level - 1) + ), + ), + lambda: keys, + ) + + return ShellKey64(_raw_keys_at_level=keys.gather(tf.range(0, num_keys))) class ShellRotationKey64(tf.experimental.ExtensionType): - _raw_rot_keys_at_level: typing.Mapping[int, tf.Tensor] + _raw_keys_at_level: tf.Tensor def _get_key_at_level(self, level): - if level not in self._raw_rot_keys_at_level: - raise ValueError(f"No rotation key at level {level}.") - return self._raw_rot_keys_at_level[level] + level -= 1 # Keys tensor start at level 1. + tf.Assert(level >= 0, [f"level must be >= 0. Got {level}"]) + tf.Assert( + level < tf.shape(self._raw_keys_at_level)[0], + [f"level must be < {tf.shape(self._raw_keys_at_level)[0]}. Got {level}"], + ) + return self._raw_keys_at_level[level] -def create_rotation_key64(context, key, skip_at_mul_depth=[]): +def create_rotation_key64(context, key): """Create rotation keys for any multiplicative depth of the given context. Rotation key contains keys to perform an arbitrary number of slot rotations. Since rotation key generation is expensive, the caller can choose to skip @@ -79,32 +99,63 @@ def create_rotation_key64(context, key, skip_at_mul_depth=[]): if not isinstance(key, ShellKey64): raise ValueError("key must be a ShellKey64.") - raw_rot_keys_at_level = {} - while context.level >= 0: - if context.level not in skip_at_mul_depth: - raw_rot_keys_at_level[context.level] = shell_ops.rotation_key_gen64( - context._raw_context, - key._get_key_at_level(context.level), - ) - - if context.level <= 1: - break - - context = mod_reduce_context64(context) + num_keys = context.level + rot_keys = tf.TensorArray( + tf.variant, + size=context.level, + clear_after_read=False, + infer_shape=False, + element_shape=(), + ) + + # Generate rotation keys starting from the highest level. + rot_keys, context = tf.while_loop( + lambda ks, c: c.level > 1, + lambda ks, c: ( + ks.write( + c.level - 1, + shell_ops.rotation_key_gen64( + c._raw_context, key._get_key_at_level(c.level) + ), + ), + mod_reduce_context64(c), + ), + loop_vars=[rot_keys, context], + shape_invariants=[ + tf.TensorSpec(None, dtype=tf.variant), + context._get_generic_context_spec(), + ], + ) + + # Store the first key for level 1. + rot_keys = tf.cond( + context.level == 1, + lambda: rot_keys.write( + context.level - 1, + shell_ops.rotation_key_gen64( + context._raw_context, key._get_key_at_level(context.level) + ), + ), + lambda: rot_keys, + ) - return ShellRotationKey64(_raw_rot_keys_at_level=raw_rot_keys_at_level) + return ShellRotationKey64(_raw_keys_at_level=rot_keys.gather(tf.range(0, num_keys))) class ShellFastRotationKey64(tf.experimental.ExtensionType): - _raw_rot_keys_at_level: typing.Mapping[int, tf.Tensor] + _raw_keys_at_level: tf.Tensor def _get_key_at_level(self, level): - if level not in self._raw_rot_keys_at_level: - raise ValueError(f"No rotation key at level {level}.") - return self._raw_rot_keys_at_level[level] + level -= 1 # Keys tensor start at level 1. + tf.Assert(level >= 0, [f"level must be >= 0. Got {level}"]) + tf.Assert( + level < tf.shape(self._raw_keys_at_level)[0], + [f"level must be < {tf.shape(self._raw_keys_at_level)[0]}. Got {level}"], + ) + return self._raw_keys_at_level[level] -def create_fast_rotation_key64(context, key, skip_at_mul_depth=[]): +def create_fast_rotation_key64(context, key): """Create fast rotation keys for any multiplicative depth of the given context. Rotation key contains keys *decrypt* a previously "fast" rotated ciphertext. These keys are much faster to generated than regular rotation keys, and @@ -117,16 +168,46 @@ def create_fast_rotation_key64(context, key, skip_at_mul_depth=[]): if not isinstance(key, ShellKey64): raise ValueError("key must be a ShellKey64.") - raw_rot_keys_at_level = {} - while context.level >= 0: - if context.level not in skip_at_mul_depth: - raw_rot_keys_at_level[context.level] = shell_ops.fast_rotation_key_gen64( + num_keys = context.level + rot_keys = tf.TensorArray( + tf.variant, + size=context.level, + clear_after_read=False, + infer_shape=False, + element_shape=(), + ) + + # Generate rotation keys starting from the highest level. + rot_keys, context = tf.while_loop( + lambda ks, c: c.level > 1, + lambda ks, c: ( + ks.write( + c.level - 1, + shell_ops.fast_rotation_key_gen64( + c._raw_context, key._get_key_at_level(c.level) + ), + ), + mod_reduce_context64(c), + ), + loop_vars=[rot_keys, context], + shape_invariants=[ + tf.TensorSpec(None, dtype=tf.variant), + context._get_generic_context_spec(), + ], + ) + + # Store the first key for level 1. + rot_keys = tf.cond( + context.level == 1, + lambda: rot_keys.write( + context.level - 1, + shell_ops.fast_rotation_key_gen64( context._raw_context, key._get_key_at_level(context.level) - ) - - if context.level <= 1: - break - - context = mod_reduce_context64(context) - - return ShellFastRotationKey64(_raw_rot_keys_at_level=raw_rot_keys_at_level) + ), + ), + lambda: rot_keys, + ) + + return ShellFastRotationKey64( + _raw_keys_at_level=rot_keys.gather(tf.range(0, num_keys)) + ) diff --git a/tf_shell/python/shell_tensor.py b/tf_shell/python/shell_tensor.py index 4c86165..d248a23 100644 --- a/tf_shell/python/shell_tensor.py +++ b/tf_shell/python/shell_tensor.py @@ -18,7 +18,6 @@ from tf_shell.python.shell_context import ShellContext64 from tf_shell.python.shell_context import mod_reduce_context64 from tf_shell.python.shell_key import ShellKey64 -from tf_shell.python.shell_key import mod_reduce_key64 from tf_shell.python.shell_key import ShellRotationKey64 from tf_shell.python.shell_key import ShellFastRotationKey64 @@ -76,6 +75,7 @@ def __getitem__(self, slice): _underlying_dtype=self._underlying_dtype, _scaling_factor=self._scaling_factor, _is_enc=self.is_encrypted, + _is_fast_rotated=self._is_fast_rotated, ) def __add__(self, other): @@ -115,6 +115,7 @@ def __add__(self, other): _underlying_dtype=self._underlying_dtype, _scaling_factor=self._scaling_factor, _is_enc=self._is_enc or other._is_enc, + _is_fast_rotated=self._is_fast_rotated or other._is_fast_rotated, ) elif isinstance(other, tf.Tensor): @@ -193,6 +194,7 @@ def __sub__(self, other): _underlying_dtype=self._underlying_dtype, _scaling_factor=self._scaling_factor, _is_enc=self._is_enc or other._is_enc, + _is_fast_rotated=self._is_fast_rotated or other._is_fast_rotated, ) elif isinstance(other, tf.Tensor): if other.shape == () or other.shape == (1,): @@ -274,6 +276,7 @@ def __rsub__(self, other): _underlying_dtype=self._underlying_dtype, _scaling_factor=self._scaling_factor, _is_enc=self._is_enc, + _is_fast_rotated=self._is_fast_rotated, ) else: # Try to import the unknown operand to a TensorFlow tensor and @@ -302,6 +305,7 @@ def __neg__(self): _underlying_dtype=self._underlying_dtype, _scaling_factor=self._scaling_factor, _is_enc=self._is_enc, + _is_fast_rotated=self._is_fast_rotated, ) def __mul__(self, other): @@ -345,6 +349,7 @@ def __mul__(self, other): _underlying_dtype=self._underlying_dtype, _scaling_factor=matched_self._scaling_factor**2, _is_enc=self._is_enc or other._is_enc, + _is_fast_rotated=self._is_fast_rotated or other._is_fast_rotated, ) elif isinstance(other, tf.Tensor): # Multiplying by a scalar uses a special op which is more efficient @@ -376,6 +381,7 @@ def __mul__(self, other): _underlying_dtype=self._underlying_dtype, _scaling_factor=self._scaling_factor**2, _is_enc=self._is_enc, + _is_fast_rotated=self._is_fast_rotated, ) else: @@ -399,6 +405,16 @@ def __mul__(self, other): def __rmul__(self, other): return self * other + def _get_generic_shell_tensor_spec(self): + return ShellTensor64.Spec( + _raw_tensor=tf.TensorSpec(self._raw_tensor.shape, dtype=tf.variant), + _context=self._context._get_generic_context_spec(), + _underlying_dtype=self._underlying_dtype, + _scaling_factor=self._scaling_factor, + _is_enc=self._is_enc, + _is_fast_rotated=self._is_fast_rotated, + ) + def mod_reduce_tensor64(shell_tensor): """Switches the ShellTensor to a new context with different moduli. If @@ -427,17 +443,32 @@ def mod_reduce_tensor64(shell_tensor): _underlying_dtype=shell_tensor._underlying_dtype, _scaling_factor=shell_tensor._scaling_factor, _is_enc=shell_tensor._is_enc, + _is_fast_rotated=shell_tensor._is_fast_rotated, ) return reduced_self +def _match_moduli_x_to_y(x, target_level): + x = tf.while_loop( + lambda x_red: x_red._context.level > target_level, + lambda x_red: mod_reduce_tensor64(x_red), + loop_vars=[x], + shape_invariants=[ + x._get_generic_shell_tensor_spec(), + ], + )[0] + return x + + def _match_moduli_and_scaling(x, y): # Mod switch to the smaller modulus of the two. - while x._context.level > y._context.level: - x = mod_reduce_tensor64(x) - while x._context.level < y._context.level: - y = mod_reduce_tensor64(y) + x = _match_moduli_x_to_y(x, y._context.level) + y = _match_moduli_x_to_y(y, x._context.level) + # while x._context.level > y._context.level: + # x = mod_reduce_tensor64(x) + # while x._context.level < y._context.level: + # y = mod_reduce_tensor64(y) # Match the scaling factors. # First make sure the scaling factors are compatible. @@ -514,30 +545,18 @@ def to_shell_plaintext(tensor, context): scaled_tensor = _encode_scaling(tensor, context.scaling_factor) # Pad the tensor to the correct number of slots. - if tf.executing_eagerly(): - # In eager mode, we know the number of slots at graph construction - # time and can check the tensor is the correct size. - if scaled_tensor.shape[0] > context.num_slots: - raise ValueError( - f"Tensor first dimension is too large. Maximum is {context.num_slots}, got {scaled_tensor.shape[0]}." - ) - elif scaled_tensor.shape[0] < context.num_slots: - padding = [[0, context.num_slots - scaled_tensor.shape[0]]] + [ - [0, 0] for _ in range(len(scaled_tensor.shape) - 1) - ] - scaled_tensor = tf.pad(scaled_tensor, padding) - else: - # In graph mode, we may not know the number of slots until runtime. - # Try the padding, but if it fails (e.g. the batching dimension is - # too large), the user will see the error when the tensor is used in - # a SHELL operation at runtime. - try: - padding = [[0, context.num_slots - scaled_tensor.shape[0]]] + [ - [0, 0] for _ in range(len(scaled_tensor.shape) - 1) - ] - scaled_tensor = tf.pad(scaled_tensor, padding) - except: - pass + first_dim = tf.cast(tf.shape(scaled_tensor)[0], dtype=tf.int64) + tf.Assert( + context.num_slots >= first_dim, + [f"First dimension must be <= {context.num_slots}. Got {first_dim}"], + ) + padding = [[0, 0] for _ in range(len(scaled_tensor.shape))] + padding[0][1] = tf.cond( + context.num_slots > first_dim, + lambda: context.num_slots - first_dim, + lambda: tf.constant(0, dtype=tf.int64), + ) + scaled_tensor = tf.pad(scaled_tensor, padding) return ShellTensor64( _raw_tensor=shell_ops.polynomial_import64( @@ -685,6 +704,7 @@ def roll(x, shift, rotation_key): _underlying_dtype=x._underlying_dtype, _scaling_factor=x._scaling_factor, _is_enc=True, + _is_fast_rotated=x._is_fast_rotated, ) elif isinstance(x, tf.Tensor): return tf.roll(x, shift) @@ -718,6 +738,7 @@ def reduce_sum(x, axis, rotation_key=None): _underlying_dtype=x._underlying_dtype, _scaling_factor=x._scaling_factor, _is_enc=True, + _is_fast_rotated=x._is_fast_rotated, ) else: @@ -729,6 +750,7 @@ def reduce_sum(x, axis, rotation_key=None): _underlying_dtype=x._underlying_dtype, _scaling_factor=x._scaling_factor, _is_enc=True, + _is_fast_rotated=x._is_fast_rotated, ) elif isinstance(x, tf.Tensor): return tf.reduce_sum(x, axis) @@ -748,6 +770,8 @@ def fast_reduce_sum(x): raise ValueError("Input must be ShellTensor.") if not x._is_enc: raise ValueError("Unencrypted fast_reduce_sum not supported yet.") + if x._is_fast_rotated: + raise ValueError("Cannot fast_reduce_sum a fast_rotated ShellTensor.") return ShellTensor64( _raw_tensor=shell_ops.fast_reduce_sum_by_rotation64( @@ -796,6 +820,7 @@ def matmul(x, y, rotation_key=None, fast=False): _underlying_dtype=x._underlying_dtype, _scaling_factor=x._scaling_factor**2, _is_enc=True, + _is_fast_rotated=x._is_fast_rotated, ) elif isinstance(x, tf.Tensor) and isinstance(y, ShellTensor64): @@ -808,6 +833,10 @@ def matmul(x, y, rotation_key=None, fast=False): scaled_x = _encode_scaling(x, y._scaling_factor) if fast: + if y._is_fast_rotated: + raise ValueError( + "A ShellTensor which has been fast-reduced-summed cannot be fast-reduced-summed again." + ) return ShellTensor64( _raw_tensor=shell_ops.fast_mat_mul_pt_ct64( y._context._raw_context, @@ -841,6 +870,7 @@ def matmul(x, y, rotation_key=None, fast=False): _underlying_dtype=y._underlying_dtype, _scaling_factor=y._scaling_factor**2, _is_enc=True, + _is_fast_rotated=y._is_fast_rotated, ) elif isinstance(x, ShellTensor64) and isinstance(y, ShellTensor64): @@ -868,6 +898,7 @@ def expand_dims(x, axis=-1): _underlying_dtype=x._underlying_dtype, _scaling_factor=x._scaling_factor, _is_enc=x._is_enc, + _is_fast_rotated=x._is_fast_rotated, ) elif isinstance(x, tf.Tensor): return tf.expand_dims(x, axis) @@ -888,6 +919,7 @@ def reshape(x, shape): _underlying_dtype=x._underlying_dtype, _scaling_factor=x._scaling_factor, _is_enc=x._is_enc, + _is_fast_rotated=x._is_fast_rotated, ) elif isinstance(x, tf.Tensor): return tf.reshape(x, shape) @@ -923,6 +955,7 @@ def broadcast_to(x, shape): _underlying_dtype=x._underlying_dtype, _scaling_factor=x._scaling_factor, _is_enc=x._is_enc, + _is_fast_rotated=x._is_fast_rotated, ) elif isinstance(x, tf.Tensor): return tf.broadcast_to(x, shape) @@ -956,6 +989,7 @@ def segment_sum(x, segments, num_segments, rotation_key=None): _underlying_dtype=x._underlying_dtype, _scaling_factor=x._scaling_factor, _is_enc=x._is_enc, + _is_fast_rotated=x._is_fast_rotated, ), reduction_count, ) diff --git a/tf_shell/test/auto_param_optimizer_test.py b/tf_shell/test/auto_param_optimizer_test.py index 17f93ff..734f0ac 100644 --- a/tf_shell/test/auto_param_optimizer_test.py +++ b/tf_shell/test/auto_param_optimizer_test.py @@ -126,21 +126,69 @@ def long_arith_with_scaling(cleartext_a, cleartext_b, use_auto_context=False): result = tf_shell.to_tensorflow(intermediate, key) return result +@tf.function +def reduce_sum_axis_1(cleartext_a, cleartext_b, use_auto_context=False): + shell_context = ( + gen_autocontext(test_values_num_bits + cleartext_a.shape[1].bit_length(), 52) + if use_auto_context + else gen_context() + ) + key = tf_shell.create_key64(shell_context) + a = tf_shell.to_encrypted(cleartext_a, key, shell_context) -# @tf.function -# def ct_roll(cleartext_a, cleartext_b, use_auto_context=False): -# shell_context = ( -# gen_autocontext(test_values_num_bits) -# if use_auto_context -# else gen_context() -# ) -# key = tf_shell.create_key64(shell_context) -# a = tf_shell.to_encrypted(cleartext_a, key, shell_context) -# b = tf_shell.to_shell_plaintext(cleartext_b, shell_context) + intermediate = tf_shell.reduce_sum(a, axis=1) + + result = tf_shell.to_tensorflow(intermediate, key) + return result + +@tf.function +def reduce_sum_axis_0(cleartext_a, cleartext_b, use_auto_context=False): + shell_context = ( + gen_autocontext(test_values_num_bits + cleartext_a.shape[0].bit_length(), 52) + if use_auto_context + else gen_context() + ) + key = tf_shell.create_key64(shell_context) + public_rotation_key = tf_shell.create_rotation_key64(shell_context, key) + a = tf_shell.to_encrypted(cleartext_a, key, shell_context) -# intermediate = (a * b) + b + a -# result = tf_shell.to_tensorflow(intermediate, key) -# return result + intermediate = tf_shell.reduce_sum(a, axis=0, rotation_key=public_rotation_key) + + result = tf_shell.to_tensorflow(intermediate, key) + return result + +@tf.function +def fast_reduce_sum_axis_0(cleartext_a, cleartext_b, use_auto_context=False): + shell_context = ( + gen_autocontext(test_values_num_bits + cleartext_a.shape[0].bit_length() + 14, 52) + if use_auto_context + else gen_context() + ) + key = tf_shell.create_key64(shell_context) + secret_fast_rotation_key = tf_shell.create_fast_rotation_key64(shell_context, key) + + a = tf_shell.to_encrypted(cleartext_a, key, shell_context) + + intermediate = tf_shell.fast_reduce_sum(a) + + result = tf_shell.to_tensorflow(intermediate, secret_fast_rotation_key) + return result + +@tf.function +def ct_roll(cleartext_a, cleartext_b, use_auto_context=False): + shell_context = ( + gen_autocontext(test_values_num_bits, 52) + if use_auto_context + else gen_context() + ) + key = tf_shell.create_key64(shell_context) + public_rotation_key = tf_shell.create_rotation_key64(shell_context, key) + a = tf_shell.to_encrypted(cleartext_a, key, shell_context) + + intermediate = tf_shell.roll(a, 5, rotation_key=public_rotation_key) + + result = tf_shell.to_tensorflow(intermediate, key) + return result def count_ops(graph, op_name): @@ -201,12 +249,32 @@ def _test_func(self, tf_func): # Check the optimized graph still computes the correct value. eager_c = tf_func(a, b, False) - padded_c = tf.pad( - eager_c, - [[0, c.shape[0] - eager_c.shape[0]]] - + [[0, 0] for _ in range(len(c.shape) - 1)], - ) - self.assertAllEqual(c, padded_c) + + # c and eager c may be different dimensions due to the number of slots + # chosen by the optimizer. Match the dimensions before comparing the + # values. + if tf_func == reduce_sum_axis_0 or tf_func == fast_reduce_sum_axis_0: + # Concatenate the first and middle slots of each value before comparing. + eager_c = tf.concat([eager_c[0], eager_c[eager_c.shape[0] // 2]], axis=0) + c = tf.concat([c[0], c[c.shape[0] // 2]], axis=0) + + else: + # Pad the first (slotting) dimension of the outputs to the same value. + def pad_first_dim(tensor, first_dim): + if first_dim > tensor.shape[0]: + return tf.pad( + tensor, + [[0, first_dim - tensor.shape[0]]] + + [[0, 0] for _ in range(len(tensor.shape) - 1)], + ) + else: + return tensor + + max_fist_dim = tf.maximum(c.shape[0], eager_c.shape[0]) + eager_c = pad_first_dim(eager_c, max_fist_dim) + c = pad_first_dim(c, max_fist_dim) + + self.assertAllEqual(c, eager_c) def test_func(self): with self.subTest(f"Optimizer for func ct_ct_add."): @@ -227,6 +295,18 @@ def test_func(self): with self.subTest(f"Optimizer for func long_arith_with_scaling."): self._test_func(long_arith_with_scaling) + with self.subTest(f"Optimizer for reduce sum axis 1."): + self._test_func(reduce_sum_axis_1) + + with self.subTest(f"Optimizer for reduce sum axis 0."): + self._test_func(reduce_sum_axis_0) + + with self.subTest(f"Optimizer for fast reduce sum axis 0."): + self._test_func(fast_reduce_sum_axis_0) + + with self.subTest(f"Optimizer for roll."): + self._test_func(ct_roll) + class TestAutoParamEnableOptimizer(tf.test.TestCase): def test_func(self): diff --git a/tf_shell/test/distribution_test.py b/tf_shell/test/distribution_test.py index 691353f..ec627df 100644 --- a/tf_shell/test/distribution_test.py +++ b/tf_shell/test/distribution_test.py @@ -16,6 +16,7 @@ import tensorflow as tf import tf_shell import os +import test_utils job_prefix = "tfshell" features_party_job = f"{job_prefix}features" @@ -64,6 +65,7 @@ def test_distribution(self): with tf.device(labels_party_dev): key = tf_shell.create_key64(shell_context) + fast_rotation_key = tf_shell.create_fast_rotation_key64(shell_context, key) a = tf.random.uniform( [shell_context.num_slots, 3], dtype=tf.float32, maxval=10 ) @@ -85,14 +87,19 @@ def test_distribution(self): # pt_d is sent between the parties. with tf.device(features_party_dev): enc_e = enc_a + pt_d + enc_f = tf_shell.fast_reduce_sum(enc_e) # enc_c is sent between the parties. with tf.device(labels_party_dev): c = tf_shell.to_tensorflow(enc_c, key) e = tf_shell.to_tensorflow(enc_e, key) + f = tf_shell.to_tensorflow(enc_f, fast_rotation_key) self.assertAllClose(c, a + b, atol=1) self.assertAllClose(e, a + d, atol=1) + self.assertAllClose( + f, test_utils.plaintext_reduce_sum_axis_0(a + d), atol=1, rtol=1e-2 + ) if __name__ == "__main__": diff --git a/tf_shell/test/pt_pt_optimizer_test.py b/tf_shell/test/pt_pt_optimizer_test.py index 07d330e..cffbd0d 100644 --- a/tf_shell/test/pt_pt_optimizer_test.py +++ b/tf_shell/test/pt_pt_optimizer_test.py @@ -170,9 +170,9 @@ def _test_func( orig_num_ops = count_ops(func.graph, pt_op_name) self.assertEqual(orig_num_ops, num_pt_ops) - # print("\noriginal graph:") - # for node in func.graph.as_graph_def().node: - # print(f"{node.name} {node.op}({node.input})") + print("\noriginal graph:") + for node in func.graph.as_graph_def().node: + print(f"{node.name} {node.op}({node.input})") # Optimize the graph using tf_shells HE-specific optimizers. optimized_func = shell_optimizers.optimize_shell_graph(func) @@ -185,9 +185,9 @@ def _test_func( c = optimized_func.function_type.pack_output(c) opt_num_ops = count_ops(optimized_func.graph, pt_op_name) - # print("\noptimized graph:") - # for node in optimized_func.graph.as_graph_def().node: - # print(f"{node.name} {node.op}({node.input})") + print("\noptimized graph:") + for node in optimized_func.graph.as_graph_def().node: + print(f"{node.name} {node.op}({node.input})") self.assertEqual(opt_num_ops, expected_num_pt_ops) diff --git a/tf_shell/test/rotation_test.py b/tf_shell/test/rotation_test.py index f4697bc..7d70136 100644 --- a/tf_shell/test/rotation_test.py +++ b/tf_shell/test/rotation_test.py @@ -164,7 +164,18 @@ def _test_roll_mod_reduced(self, test_context, roll_num): self.assertAllClose(rolled_tftensor, rolled_result_reduced, atol=1e-3) def test_roll_mod_reduced(self): + # Testing all contexts for all possible rotations is slow. Instead, + # test a subset of rotations for each context, and one context tests + # all rotations. for test_context in self.test_contexts: + rotation_range = test_context.shell_context.num_slots // 2 - 1 + for roll_num in [-rotation_range, rotation_range, -1, 0, 1]: + with self.subTest( + f"roll_mod_reduced with context {test_context}, rotating by {roll_num}" + ): + self._test_roll_mod_reduced(test_context, roll_num) + + for test_context in [self.test_contexts[0]]: rotation_range = test_context.shell_context.num_slots // 2 - 1 for roll_num in range(-rotation_range, rotation_range, 1): with self.subTest( diff --git a/tf_shell/test/rotation_test_fast.py b/tf_shell/test/rotation_test_fast.py index b8917f8..69242a3 100644 --- a/tf_shell/test/rotation_test_fast.py +++ b/tf_shell/test/rotation_test_fast.py @@ -67,16 +67,9 @@ def tearDownClass(cls): def _test_fast_reduce_sum_axis_0(self, test_context): # reduce_sum across axis 0 requires adding over all the slots. - try: - tftensor = test_utils.uniform_for_n_adds( - test_context, num_adds=test_context.shell_context.num_slots / 2 - ) - except Exception as e: - print( - f"Note: Skipping test fast_reduce_sum_axis_0 with context {test_context}. Not enough precision to support this test." - ) - print(e) - return + tftensor = test_utils.uniform_for_n_adds( + test_context, num_adds=test_context.shell_context.num_slots / 2 + ) s = tf_shell.to_shell_plaintext(tftensor, test_context.shell_context) enc = tf_shell.to_encrypted(s, test_context.key) @@ -117,7 +110,6 @@ def test_decrypt_with_wrong_key(self): def test_no_fast_reduce_sum_degree_two_ct(self): test_context = self.test_contexts[0] - print(f"test context num slots {test_context.shell_context.num_slots}") tftensor = tf.ones( [test_context.shell_context.num_slots, 1], dtype=test_context.plaintext_dtype, diff --git a/tf_shell/test/test_utils.py b/tf_shell/test/test_utils.py index 5e87346..9b4de67 100644 --- a/tf_shell/test/test_utils.py +++ b/tf_shell/test/test_utils.py @@ -64,7 +64,8 @@ def get_bounds_for_n_adds(test_context, num_adds): """Returns a safe range for plaintext values when doing a given number of additions.""" dtype = test_context.plaintext_dtype - plaintext_modulus = test_context.shell_context.plaintext_modulus + plaintext_modulus = tf.cast(test_context.shell_context.plaintext_modulus, float) + num_adds = math.ceil(num_adds) scaling_factor = test_context.shell_context.scaling_factor # Make sure not to exceed the range of the dtype. @@ -104,6 +105,8 @@ def get_bounds_for_n_muls(test_context, num_muls): plaintext modulus or the datatype.""" dtype = test_context.plaintext_dtype plaintext_modulus = test_context.shell_context.plaintext_modulus + plaintext_modulus = tf.cast(test_context.shell_context.plaintext_modulus, float) + num_muls = math.ceil(num_muls) # Each multiplication doubles the number of scaling factors in the result. max_scaling_factor = test_context.shell_context.scaling_factor ** (2**num_muls) diff --git a/tf_shell_ml/model.py b/tf_shell_ml/model.py index 22764c8..c95d98d 100644 --- a/tf_shell_ml/model.py +++ b/tf_shell_ml/model.py @@ -47,7 +47,7 @@ def __init__( self.use_encryption = use_encryption self.labels_party_dev = labels_party_dev self.features_party_dev = features_party_dev - self.clipping_threshold = 100000 + self.clipping_threshold = 10000000 self.mpc_bit_width = 16 def compile(self, optimizer, shell_loss, loss, metrics=[], **kwargs): @@ -91,7 +91,6 @@ def train_step(self, data): else: enc_y = y public_rotation_key = None - tf.print("ran labels party encrypt") self.mpc_scaling_factor = shell_context.scaling_factor @@ -127,8 +126,6 @@ def train_step(self, data): (g + m) for g, m in zip(reversed(dJ_dw), reversed(mask)) ] - tf.print("ran features party backprop") - with tf.device(self.labels_party_dev): if self.use_encryption: # Decrypt the weight gradients. @@ -158,12 +155,10 @@ def train_step(self, data): masked_grads = [tf.reshape(mg, [-1]) for mg in masked_grads] masked_grads = tf.concat(masked_grads, axis=0) - tf.print(" half way throuh") # Sample the noise for differential privacy. # TODO: set stddev based on clipping threshold. noise = tf.random.normal(tf.shape(masked_grads), stddev=1) - tf.print(" sampled noise") # After decryption, the mask has dtype float. Encode it back to int # with shells scaling factor for use in the clip and noise protocol. @@ -172,12 +167,10 @@ def train_step(self, data): tf.round(masked_grads * self.mpc_scaling_factor), tf.int64 ) noise = tf.cast(tf.round(noise * self.mpc_scaling_factor), tf.int64) - tf.print(" casted to int") # If running features party and labels party on the same node, # skip the MPC protocol. # if self.labels_party_dev != self.features_party_dev: - # tf.print("running labels party decrypt") # # Start labels party MPC protocol. # tf_shell.clip_and_noise_labels_party( # masked_grads, @@ -187,14 +180,13 @@ def train_step(self, data): # StartPort=5555, # FeaturePartyHost="127.0.0.1", # ) - tf.print("ran labels party decrypt") with tf.device(self.features_party_dev): # Encode the mask with the scaling factor for use in the clip and # noise protocol. mask = [tf.reshape(m, [-1]) for m in mask] mask = tf.concat(mask, axis=0) - mask = tf.cast(tf.round(mask * self.mpc_scaling_factor), tf.int64) + # mask = tf.cast(tf.round(mask * self.mpc_scaling_factor), tf.int64) # If running features party and labels party on the same node, # skip the MPC protocol and clip and noise the gradients directly. @@ -207,24 +199,26 @@ def train_step(self, data): # ) # else: unmasked_grads = masked_grads - mask - clipped_noised_grads = tf.cond( - tf.reduce_sum(unmasked_grads * unmasked_grads) - > self.clipping_threshold, - lambda: self.clipping_threshold + noise, - lambda: unmasked_grads + noise, - ) + # clipped_noised_grads = tf.cond( + # tf.reduce_sum(unmasked_grads * unmasked_grads) + # > self.clipping_threshold, + # lambda: self.clipping_threshold + noise, + # lambda: unmasked_grads + noise, + # ) + # clipped_noised_grads = unmasked_grads + noise + clipped_noised_grads = unmasked_grads # Emulate overflow of 2's complement addition between `Bitwidth` # integers from when grad + noise is computed under the MPC # protocol. Note any overflow in the masking / unmasking cancels # out. - min_val = -(2 ** (self.mpc_bit_width - 1)) - max_val = 2 ** (self.mpc_bit_width - 1) - 1 - clipped_noised_grads = tf.where( - clipped_noised_grads > max_val, - min_val + (clipped_noised_grads - max_val), - clipped_noised_grads, - ) + # min_val = -(2 ** (self.mpc_bit_width - 1)) + # max_val = 2 ** (self.mpc_bit_width - 1) - 1 + # clipped_noised_grads = tf.where( + # clipped_noised_grads > max_val, + # min_val + (clipped_noised_grads - max_val), + # clipped_noised_grads, + # ) # end else # Decode the clipped and noised gradients. @@ -243,9 +237,7 @@ def train_step(self, data): weights += l.weights # Apply the gradients to the model. - tf.print(" clipped and noised grads", clipped_noised_grads) self.optimizer.apply_gradients(zip(clipped_noised_grads, weights)) - tf.print("ran features party apply") # Do not update metrics during secure training. if not self.use_encryption: