Skip to content

Commit

Permalink
Memory safety around variant encoding.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Oct 2, 2024
1 parent 05fd287 commit b394161
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 30 deletions.
34 changes: 26 additions & 8 deletions tf_shell/cc/kernels/polynomial_variant.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,21 @@ class PolynomialVariant {
std::string TypeName() const { return kTypeName; }

void Encode(VariantTensorData* data) const {
auto serialized_poly_or = poly.Serialize(ct_context->MainPrimeModuli());
auto async_poly_str = poly_str; // Make sure key string is not deallocated.
auto async_ct_context = ct_context;

if (async_ct_context == nullptr) {
// If the context is null, this may have been decoded but not lazy decoded
// yet. In this case, directly encode the polynomial string.
if (async_poly_str == nullptr) {
std::cout << "ERROR: Polynomial not set, cannot encode." << std::endl;
return;
}
data->tensors_.push_back(Tensor(*async_poly_str));
return;
}
auto serialized_poly_or =
poly.Serialize(async_ct_context->MainPrimeModuli());
if (!serialized_poly_or.ok()) {
std::cout << "ERROR: Failed to serialize polynomial: "
<< serialized_poly_or.status();
Expand All @@ -61,24 +75,27 @@ class PolynomialVariant {
return false;
}

if (!poly_str.empty()) {
if (poly_str != nullptr) {
std::cout << "ERROR: Polynomial already decoded";
return false;
}

poly_str = std::string(data.tensors_[0].scalar<tstring>()().begin(),
data.tensors_[0].scalar<tstring>()().end());
poly_str = std::make_shared<std::string>(
data.tensors_[0].scalar<tstring>()().begin(),
data.tensors_[0].scalar<tstring>()().end());

return true;
};

Status MaybeLazyDecode(std::shared_ptr<Context const> ct_context_) {
if (poly_str.empty()) {
std::lock_guard<std::mutex> lock(mutex.mutex);

if (ct_context != nullptr) {
return OkStatus();
}

rlwe::SerializedRnsPolynomial serialized_poly;
bool ok = serialized_poly.ParseFromString(poly_str);
bool ok = serialized_poly.ParseFromString(*poly_str);
if (!ok) {
return InvalidArgument("Failed to parse polynomial.");
}
Expand All @@ -91,14 +108,15 @@ class PolynomialVariant {
ct_context = ct_context_;

// Clear the serialized polynomial string.
poly_str.clear();
poly_str = nullptr;

return OkStatus();
};

std::string DebugString() const { return "ShellPolynomialVariant"; }

variant_mutex mutex;
Polynomial poly;
std::string poly_str;
std::shared_ptr<std::string> poly_str;
std::shared_ptr<Context const> ct_context;
};
41 changes: 33 additions & 8 deletions tf_shell/cc/kernels/rotation_variants.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,28 @@ class FastRotationKeyVariant {
std::string TypeName() const { return kTypeName; }

void Encode(VariantTensorData* data) const {
auto async_key_strs = key_strs; // Make sure key string is not deallocated.
auto async_ct_context = ct_context;

if (async_ct_context == nullptr) {
// If the context is null, this may have been decoded but not lazy decoded
// yet. In this case, directly encode the key strings.
if (async_key_strs == nullptr) {
std::cout << "ERROR: Fast rotation key not set, cannot encode."
<< std::endl;
return;
}
data->tensors_.reserve(async_key_strs->size());
for (auto const& key_str : *async_key_strs) {
data->tensors_.push_back(Tensor(key_str));
}
}

data->tensors_.reserve(keys.size());

for (auto const& key : keys) {
auto serialized_key_or = key.Serialize(ct_context->MainPrimeModuli());
auto serialized_key_or =
key.Serialize(async_ct_context->MainPrimeModuli());
if (!serialized_key_or.ok()) {
std::cout << "ERROR: Failed to serialize key: "
<< serialized_key_or.status();
Expand All @@ -101,32 +119,38 @@ class FastRotationKeyVariant {
return false;
}

if (!key_strs.empty()) {
if (key_strs != nullptr) {
std::cout << "ERROR: Fast rotation key already decoded." << std::endl;
return false;
}

size_t num_keys = data.tensors_.size();
key_strs.reserve(num_keys);
std::vector<std::string> building_key_strs;
building_key_strs.reserve(num_keys);

for (size_t i = 0; i < data.tensors_.size(); ++i) {
std::string const serialized_key(
data.tensors_[i].scalar<tstring>()().begin(),
data.tensors_[i].scalar<tstring>()().end());

key_strs.push_back(std::move(serialized_key));
building_key_strs.push_back(std::move(serialized_key));
}

key_strs = std::make_shared<std::vector<std::string>>(
std::move(building_key_strs));

return true;
};

Status MaybeLazyDecode(std::shared_ptr<Context const> ct_context_) {
std::lock_guard<std::mutex> lock(mutex.mutex);

// If the keys have already been fully decoded, nothing to do.
if (key_strs.empty()) {
if (ct_context != nullptr) {
return OkStatus();
}

for (auto const& key_str : key_strs) {
for (auto const& key_str : *key_strs) {
rlwe::SerializedRnsPolynomial serialized_key;
bool ok = serialized_key.ParseFromString(key_str);
if (!ok) {
Expand All @@ -145,14 +169,15 @@ class FastRotationKeyVariant {
ct_context = ct_context_;

// Clear the key strings.
key_strs.clear();
key_strs = nullptr;

return OkStatus();
};

std::string DebugString() const { return "ShellFastRotationKeyVariant"; }

variant_mutex mutex;
std::vector<RnsPolynomial> keys;
std::vector<std::string> key_strs;
std::shared_ptr<std::vector<std::string>> key_strs;
std::shared_ptr<Context const> ct_context;
};
66 changes: 52 additions & 14 deletions tf_shell/cc/kernels/symmetric_variants.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,20 @@ class SymmetricKeyVariant {
std::string TypeName() const { return kTypeName; }

void Encode(VariantTensorData* data) const {
auto async_key_str = key_str; // Make sure key string is not deallocated.
auto async_ct_context = ct_context;

if (async_ct_context == nullptr) {
// If the context is null, this may have been decoded but not lazy decoded
// yet. In this case, directly encode the key string.
if (async_key_str == nullptr) {
std::cout << "ERROR: Key not set, cannot encode." << std::endl;
return;
}
data->tensors_.push_back(Tensor(*async_key_str));
return;
}

auto serialized_key_or = key->Key().Serialize(key->Moduli());
if (!serialized_key_or.ok()) {
std::cout << "ERROR: Failed to serialize key: "
Expand All @@ -70,27 +84,32 @@ class SymmetricKeyVariant {
return false;
}

if (!key_str.empty()) {
if (key_str != nullptr) {
std::cout << "ERROR: Key already decoded." << std::endl;
return false;
}

// Recover the key polynomial.
key_str = std::string(data.tensors_[0].scalar<tstring>()().begin(),
data.tensors_[0].scalar<tstring>()().end());
// key_str =
// std::make_shared<std::string>(data.tensors_[0].scalar<tstring>()());
key_str = std::make_shared<std::string>(
data.tensors_[0].scalar<tstring>()().begin(),
data.tensors_[0].scalar<tstring>()().end());

return true;
};

Status MaybeLazyDecode(std::shared_ptr<Context const> ct_context_,
int noise_variance) {
std::lock_guard<std::mutex> lock(mutex.mutex);

// If this key has already been fully decoded, nothing to do.
if (key_str.empty()) {
if (ct_context != nullptr) {
return OkStatus();
}

rlwe::SerializedRnsPolynomial serialized_key;
bool ok = serialized_key.ParseFromString(key_str);
bool ok = serialized_key.ParseFromString(*key_str);
if (!ok) {
return InvalidArgument("Failed to parse key polynomial.");
}
Expand All @@ -116,15 +135,16 @@ class SymmetricKeyVariant {
ct_context = ct_context_;

// Clear the serialized key string.
key_str.clear();
key_str = nullptr;

return OkStatus();
}

std::string DebugString() const { return "ShellSymmetricKeyVariant"; }

variant_mutex mutex;
std::shared_ptr<Key> key;
std::string key_str;
std::shared_ptr<std::string> key_str;
std::shared_ptr<Context const> ct_context;
};

Expand All @@ -149,6 +169,20 @@ class SymmetricCtVariant {
std::string TypeName() const { return kTypeName; }

void Encode(VariantTensorData* data) const {
auto async_ct_str = ct_str; // Make sure key string is not deallocated.
auto async_ct_context = ct_context;

if (async_ct_context == nullptr) {
// If the context is null, this may have been decoded but not lazy decoded
// yet. In this case, directly encode the ciphertext string.
if (async_ct_str == nullptr) {
std::cout << "ERROR: Ciphertext not set, cannot encode." << std::endl;
return;
}
data->tensors_.push_back(Tensor(*async_ct_str));
return;
}

// Store the ciphertext.
auto serialized_ct_or = ct.Serialize();
if (!serialized_ct_or.ok()) {
Expand All @@ -174,27 +208,30 @@ class SymmetricCtVariant {
return false;
}

if (!ct_str.empty()) {
if (ct_str != nullptr) {
std::cout << "ERROR: Ciphertext already decoded." << std::endl;
return false;
}

// Recover the serialized ciphertext string.
ct_str = std::string(data.tensors_[0].scalar<tstring>()().begin(),
data.tensors_[0].scalar<tstring>()().end());
ct_str = std::make_shared<std::string>(
data.tensors_[0].scalar<tstring>()().begin(),
data.tensors_[0].scalar<tstring>()().end());

return true;
};

Status MaybeLazyDecode(std::shared_ptr<Context const> ct_context_,
std::shared_ptr<ErrorParams const> error_params_) {
std::lock_guard<std::mutex> lock(mutex.mutex);

// If this ciphertext has already been fully decoded, nothing to do.
if (ct_str.empty()) {
if (ct_context != nullptr) {
return OkStatus();
}

rlwe::SerializedRnsRlweCiphertext serialized_ct;
bool ok = serialized_ct.ParseFromString(ct_str);
bool ok = serialized_ct.ParseFromString(*ct_str);
if (!ok) {
return InvalidArgument("Failed to parse ciphertext.");
}
Expand All @@ -211,15 +248,16 @@ class SymmetricCtVariant {
error_params = error_params_;

// Clear the serialized ciphertext string.
ct_str.clear();
ct_str = nullptr;

return OkStatus();
};

std::string DebugString() const { return "ShellSymmetricCtVariant"; }

variant_mutex mutex;
SymmetricCt ct;
std::string ct_str;
std::shared_ptr<std::string> ct_str;
std::shared_ptr<Context const> ct_context;
std::shared_ptr<ErrorParams const> error_params;
};
12 changes: 12 additions & 0 deletions tf_shell/cc/kernels/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,18 @@ using tensorflow::errors::Unimplemented;
// The substitution power for Galois rotation by one slot.
constexpr int base_power = 5;

// A mutex for use with variants with appropriate copy/assign.
struct variant_mutex {
std::mutex mutex;
variant_mutex() : mutex() {}

variant_mutex(variant_mutex const& other) : mutex() {}
variant_mutex& operator=(variant_mutex const& other) { return *this; }

variant_mutex(variant_mutex&& other) : mutex() {}
variant_mutex& operator=(variant_mutex&& other) { return *this; }
};

constexpr uint64_t BitWidth(uint64_t n) {
uint64_t bits = 0;
while (n) {
Expand Down

0 comments on commit b394161

Please sign in to comment.