From b394161603b0b3b973eddb3a8dcb31fc4b962335 Mon Sep 17 00:00:00 2001 From: james-choncholas Date: Wed, 2 Oct 2024 00:33:07 +0000 Subject: [PATCH] Memory safety around variant encoding. --- tf_shell/cc/kernels/polynomial_variant.h | 34 +++++++++--- tf_shell/cc/kernels/rotation_variants.h | 41 ++++++++++++--- tf_shell/cc/kernels/symmetric_variants.h | 66 +++++++++++++++++++----- tf_shell/cc/kernels/utils.h | 12 +++++ 4 files changed, 123 insertions(+), 30 deletions(-) diff --git a/tf_shell/cc/kernels/polynomial_variant.h b/tf_shell/cc/kernels/polynomial_variant.h index e354825..ff2d497 100644 --- a/tf_shell/cc/kernels/polynomial_variant.h +++ b/tf_shell/cc/kernels/polynomial_variant.h @@ -43,7 +43,21 @@ class PolynomialVariant { std::string TypeName() const { return kTypeName; } void Encode(VariantTensorData* data) const { - auto serialized_poly_or = poly.Serialize(ct_context->MainPrimeModuli()); + auto async_poly_str = poly_str; // Make sure key string is not deallocated. + auto async_ct_context = ct_context; + + if (async_ct_context == nullptr) { + // If the context is null, this may have been decoded but not lazy decoded + // yet. In this case, directly encode the polynomial string. + if (async_poly_str == nullptr) { + std::cout << "ERROR: Polynomial not set, cannot encode." << std::endl; + return; + } + data->tensors_.push_back(Tensor(*async_poly_str)); + return; + } + auto serialized_poly_or = + poly.Serialize(async_ct_context->MainPrimeModuli()); if (!serialized_poly_or.ok()) { std::cout << "ERROR: Failed to serialize polynomial: " << serialized_poly_or.status(); @@ -61,24 +75,27 @@ class PolynomialVariant { return false; } - if (!poly_str.empty()) { + if (poly_str != nullptr) { std::cout << "ERROR: Polynomial already decoded"; return false; } - poly_str = std::string(data.tensors_[0].scalar()().begin(), - data.tensors_[0].scalar()().end()); + poly_str = std::make_shared( + data.tensors_[0].scalar()().begin(), + data.tensors_[0].scalar()().end()); return true; }; Status MaybeLazyDecode(std::shared_ptr ct_context_) { - if (poly_str.empty()) { + std::lock_guard lock(mutex.mutex); + + if (ct_context != nullptr) { return OkStatus(); } rlwe::SerializedRnsPolynomial serialized_poly; - bool ok = serialized_poly.ParseFromString(poly_str); + bool ok = serialized_poly.ParseFromString(*poly_str); if (!ok) { return InvalidArgument("Failed to parse polynomial."); } @@ -91,14 +108,15 @@ class PolynomialVariant { ct_context = ct_context_; // Clear the serialized polynomial string. - poly_str.clear(); + poly_str = nullptr; return OkStatus(); }; std::string DebugString() const { return "ShellPolynomialVariant"; } + variant_mutex mutex; Polynomial poly; - std::string poly_str; + std::shared_ptr poly_str; std::shared_ptr ct_context; }; diff --git a/tf_shell/cc/kernels/rotation_variants.h b/tf_shell/cc/kernels/rotation_variants.h index 74ccb12..f1230f8 100644 --- a/tf_shell/cc/kernels/rotation_variants.h +++ b/tf_shell/cc/kernels/rotation_variants.h @@ -71,10 +71,28 @@ class FastRotationKeyVariant { std::string TypeName() const { return kTypeName; } void Encode(VariantTensorData* data) const { + auto async_key_strs = key_strs; // Make sure key string is not deallocated. + auto async_ct_context = ct_context; + + if (async_ct_context == nullptr) { + // If the context is null, this may have been decoded but not lazy decoded + // yet. In this case, directly encode the key strings. + if (async_key_strs == nullptr) { + std::cout << "ERROR: Fast rotation key not set, cannot encode." + << std::endl; + return; + } + data->tensors_.reserve(async_key_strs->size()); + for (auto const& key_str : *async_key_strs) { + data->tensors_.push_back(Tensor(key_str)); + } + } + data->tensors_.reserve(keys.size()); for (auto const& key : keys) { - auto serialized_key_or = key.Serialize(ct_context->MainPrimeModuli()); + auto serialized_key_or = + key.Serialize(async_ct_context->MainPrimeModuli()); if (!serialized_key_or.ok()) { std::cout << "ERROR: Failed to serialize key: " << serialized_key_or.status(); @@ -101,32 +119,38 @@ class FastRotationKeyVariant { return false; } - if (!key_strs.empty()) { + if (key_strs != nullptr) { std::cout << "ERROR: Fast rotation key already decoded." << std::endl; return false; } size_t num_keys = data.tensors_.size(); - key_strs.reserve(num_keys); + std::vector building_key_strs; + building_key_strs.reserve(num_keys); for (size_t i = 0; i < data.tensors_.size(); ++i) { std::string const serialized_key( data.tensors_[i].scalar()().begin(), data.tensors_[i].scalar()().end()); - key_strs.push_back(std::move(serialized_key)); + building_key_strs.push_back(std::move(serialized_key)); } + key_strs = std::make_shared>( + std::move(building_key_strs)); + return true; }; Status MaybeLazyDecode(std::shared_ptr ct_context_) { + std::lock_guard lock(mutex.mutex); + // If the keys have already been fully decoded, nothing to do. - if (key_strs.empty()) { + if (ct_context != nullptr) { return OkStatus(); } - for (auto const& key_str : key_strs) { + for (auto const& key_str : *key_strs) { rlwe::SerializedRnsPolynomial serialized_key; bool ok = serialized_key.ParseFromString(key_str); if (!ok) { @@ -145,14 +169,15 @@ class FastRotationKeyVariant { ct_context = ct_context_; // Clear the key strings. - key_strs.clear(); + key_strs = nullptr; return OkStatus(); }; std::string DebugString() const { return "ShellFastRotationKeyVariant"; } + variant_mutex mutex; std::vector keys; - std::vector key_strs; + std::shared_ptr> key_strs; std::shared_ptr ct_context; }; \ No newline at end of file diff --git a/tf_shell/cc/kernels/symmetric_variants.h b/tf_shell/cc/kernels/symmetric_variants.h index 74e4857..4938229 100644 --- a/tf_shell/cc/kernels/symmetric_variants.h +++ b/tf_shell/cc/kernels/symmetric_variants.h @@ -45,6 +45,20 @@ class SymmetricKeyVariant { std::string TypeName() const { return kTypeName; } void Encode(VariantTensorData* data) const { + auto async_key_str = key_str; // Make sure key string is not deallocated. + auto async_ct_context = ct_context; + + if (async_ct_context == nullptr) { + // If the context is null, this may have been decoded but not lazy decoded + // yet. In this case, directly encode the key string. + if (async_key_str == nullptr) { + std::cout << "ERROR: Key not set, cannot encode." << std::endl; + return; + } + data->tensors_.push_back(Tensor(*async_key_str)); + return; + } + auto serialized_key_or = key->Key().Serialize(key->Moduli()); if (!serialized_key_or.ok()) { std::cout << "ERROR: Failed to serialize key: " @@ -70,27 +84,32 @@ class SymmetricKeyVariant { return false; } - if (!key_str.empty()) { + if (key_str != nullptr) { std::cout << "ERROR: Key already decoded." << std::endl; return false; } // Recover the key polynomial. - key_str = std::string(data.tensors_[0].scalar()().begin(), - data.tensors_[0].scalar()().end()); + // key_str = + // std::make_shared(data.tensors_[0].scalar()()); + key_str = std::make_shared( + data.tensors_[0].scalar()().begin(), + data.tensors_[0].scalar()().end()); return true; }; Status MaybeLazyDecode(std::shared_ptr ct_context_, int noise_variance) { + std::lock_guard lock(mutex.mutex); + // If this key has already been fully decoded, nothing to do. - if (key_str.empty()) { + if (ct_context != nullptr) { return OkStatus(); } rlwe::SerializedRnsPolynomial serialized_key; - bool ok = serialized_key.ParseFromString(key_str); + bool ok = serialized_key.ParseFromString(*key_str); if (!ok) { return InvalidArgument("Failed to parse key polynomial."); } @@ -116,15 +135,16 @@ class SymmetricKeyVariant { ct_context = ct_context_; // Clear the serialized key string. - key_str.clear(); + key_str = nullptr; return OkStatus(); } std::string DebugString() const { return "ShellSymmetricKeyVariant"; } + variant_mutex mutex; std::shared_ptr key; - std::string key_str; + std::shared_ptr key_str; std::shared_ptr ct_context; }; @@ -149,6 +169,20 @@ class SymmetricCtVariant { std::string TypeName() const { return kTypeName; } void Encode(VariantTensorData* data) const { + auto async_ct_str = ct_str; // Make sure key string is not deallocated. + auto async_ct_context = ct_context; + + if (async_ct_context == nullptr) { + // If the context is null, this may have been decoded but not lazy decoded + // yet. In this case, directly encode the ciphertext string. + if (async_ct_str == nullptr) { + std::cout << "ERROR: Ciphertext not set, cannot encode." << std::endl; + return; + } + data->tensors_.push_back(Tensor(*async_ct_str)); + return; + } + // Store the ciphertext. auto serialized_ct_or = ct.Serialize(); if (!serialized_ct_or.ok()) { @@ -174,27 +208,30 @@ class SymmetricCtVariant { return false; } - if (!ct_str.empty()) { + if (ct_str != nullptr) { std::cout << "ERROR: Ciphertext already decoded." << std::endl; return false; } // Recover the serialized ciphertext string. - ct_str = std::string(data.tensors_[0].scalar()().begin(), - data.tensors_[0].scalar()().end()); + ct_str = std::make_shared( + data.tensors_[0].scalar()().begin(), + data.tensors_[0].scalar()().end()); return true; }; Status MaybeLazyDecode(std::shared_ptr ct_context_, std::shared_ptr error_params_) { + std::lock_guard lock(mutex.mutex); + // If this ciphertext has already been fully decoded, nothing to do. - if (ct_str.empty()) { + if (ct_context != nullptr) { return OkStatus(); } rlwe::SerializedRnsRlweCiphertext serialized_ct; - bool ok = serialized_ct.ParseFromString(ct_str); + bool ok = serialized_ct.ParseFromString(*ct_str); if (!ok) { return InvalidArgument("Failed to parse ciphertext."); } @@ -211,15 +248,16 @@ class SymmetricCtVariant { error_params = error_params_; // Clear the serialized ciphertext string. - ct_str.clear(); + ct_str = nullptr; return OkStatus(); }; std::string DebugString() const { return "ShellSymmetricCtVariant"; } + variant_mutex mutex; SymmetricCt ct; - std::string ct_str; + std::shared_ptr ct_str; std::shared_ptr ct_context; std::shared_ptr error_params; }; diff --git a/tf_shell/cc/kernels/utils.h b/tf_shell/cc/kernels/utils.h index d1290a0..af72559 100644 --- a/tf_shell/cc/kernels/utils.h +++ b/tf_shell/cc/kernels/utils.h @@ -38,6 +38,18 @@ using tensorflow::errors::Unimplemented; // The substitution power for Galois rotation by one slot. constexpr int base_power = 5; +// A mutex for use with variants with appropriate copy/assign. +struct variant_mutex { + std::mutex mutex; + variant_mutex() : mutex() {} + + variant_mutex(variant_mutex const& other) : mutex() {} + variant_mutex& operator=(variant_mutex const& other) { return *this; } + + variant_mutex(variant_mutex&& other) : mutex() {} + variant_mutex& operator=(variant_mutex&& other) { return *this; } +}; + constexpr uint64_t BitWidth(uint64_t n) { uint64_t bits = 0; while (n) {