From ce5923718b15aeb0b88b9ffa7950cac01c5e96b2 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Fri, 6 Dec 2024 15:52:23 +0100 Subject: [PATCH] Move common code into AesEncryptionContext, remove wipeout from AesEncryptor --- .../parquet/encryption/encryption_internal.cc | 273 +++++++++--------- .../parquet/encryption/encryption_internal.h | 4 +- .../encryption/internal_file_encryptor.cc | 15 - .../encryption/internal_file_encryptor.h | 1 - cpp/src/parquet/file_writer.cc | 3 - cpp/src/parquet/metadata.cc | 2 - 6 files changed, 134 insertions(+), 164 deletions(-) diff --git a/cpp/src/parquet/encryption/encryption_internal.cc b/cpp/src/parquet/encryption/encryption_internal.cc index 507d6290910be..6cff4b2f0bffc 100644 --- a/cpp/src/parquet/encryption/encryption_internal.cc +++ b/cpp/src/parquet/encryption/encryption_internal.cc @@ -52,12 +52,78 @@ constexpr int32_t kBufferSizeLength = 4; throw ParquetException("Couldn't init ALG decryption"); \ } -class AesEncryptor::AesEncryptorImpl { +class AesEncryptionContext { + public: + AesEncryptionContext(ParquetCipher::type alg_id, int32_t key_len, bool metadata, + bool write_length) { + openssl::EnsureInitialized(); + + length_buffer_length_ = write_length ? kBufferSizeLength : 0; + ciphertext_size_delta_ = length_buffer_length_ + kNonceLength; + if (metadata || (ParquetCipher::AES_GCM_V1 == alg_id)) { + aes_mode_ = kGcmMode; + ciphertext_size_delta_ += kGcmTagLength; + } else { + aes_mode_ = kCtrMode; + } + + if (16 != key_len && 24 != key_len && 32 != key_len) { + std::stringstream ss; + ss << "Wrong key length: " << key_len; + throw ParquetException(ss.str()); + } + + key_length_ = key_len; + }; + + virtual ~AesEncryptionContext() = default; + + protected: + void InitCipherContext() { + if (ctx_) return; + + ctx_ = std::unique_ptr(EVP_CIPHER_CTX_new(), + ctxDeleter); + if (!ctx_) throw ParquetException("Couldn't init cipher context"); + InitCipherContext(ctx_.get()); + } + + virtual void InitCipherContext(EVP_CIPHER_CTX* ctx) = 0; + + std::function ctxDeleter = [](EVP_CIPHER_CTX* ctx) { + EVP_CIPHER_CTX_free(ctx); + }; + + /// Create a new cipher context that auto-frees + /// This duplicates un unused but initialized private context to avoid going through + /// initialization + std::unique_ptr GetCipherContext() { + // could use EVP_CIPHER_CTX_dup instead (requires OpenSSL 3.2.0 and above) + auto ctx = std::unique_ptr(EVP_CIPHER_CTX_new(), + ctxDeleter); + if (ctx && !EVP_CIPHER_CTX_copy(ctx.get(), ctx_.get())) { + // ctx gets freed when leaving this method + throw ParquetException("Couldn't init cipher context"); + } + return ctx; + } + + int32_t aes_mode_; + int32_t key_length_; + int32_t ciphertext_size_delta_; + int32_t length_buffer_length_; + + private: + // a EVP_CIPHER_CTX that gets auto-freed + std::unique_ptr ctx_; +}; + +class AesEncryptor::AesEncryptorImpl : public AesEncryptionContext { public: explicit AesEncryptorImpl(ParquetCipher::type alg_id, int32_t key_len, bool metadata, bool write_length); - ~AesEncryptorImpl() { WipeOut(); } + ~AesEncryptorImpl() override = default; int32_t Encrypt(span plaintext, span key, span aad, span ciphertext); @@ -65,12 +131,6 @@ class AesEncryptor::AesEncryptorImpl { int32_t SignedFooterEncrypt(span footer, span key, span aad, span nonce, span encrypted_footer); - void WipeOut() { - if (nullptr != ctx_) { - EVP_CIPHER_CTX_free(ctx_); - ctx_ = nullptr; - } - } [[nodiscard]] int32_t CiphertextLength(int64_t plaintext_len) const { if (plaintext_len < 0) { @@ -88,19 +148,10 @@ class AesEncryptor::AesEncryptorImpl { return static_cast(plaintext_len + ciphertext_size_delta_); } - private: - void CheckValid() const { - if (ctx_ == nullptr) { - throw ParquetException("AesEncryptor was wiped out"); - } - } - - EVP_CIPHER_CTX* ctx_; - int32_t aes_mode_; - int32_t key_length_; - int32_t ciphertext_size_delta_; - int32_t length_buffer_length_; + protected: + void InitCipherContext(EVP_CIPHER_CTX* ctx) override; + private: int32_t GcmEncrypt(span plaintext, span key, span nonce, span aad, span ciphertext); @@ -111,50 +162,29 @@ class AesEncryptor::AesEncryptorImpl { AesEncryptor::AesEncryptorImpl::AesEncryptorImpl(ParquetCipher::type alg_id, int32_t key_len, bool metadata, - bool write_length) { - openssl::EnsureInitialized(); - - ctx_ = nullptr; - - length_buffer_length_ = write_length ? kBufferSizeLength : 0; - ciphertext_size_delta_ = length_buffer_length_ + kNonceLength; - if (metadata || (ParquetCipher::AES_GCM_V1 == alg_id)) { - aes_mode_ = kGcmMode; - ciphertext_size_delta_ += kGcmTagLength; - } else { - aes_mode_ = kCtrMode; - } - - if (16 != key_len && 24 != key_len && 32 != key_len) { - std::stringstream ss; - ss << "Wrong key length: " << key_len; - throw ParquetException(ss.str()); - } - - key_length_ = key_len; - - ctx_ = EVP_CIPHER_CTX_new(); - if (nullptr == ctx_) { - throw ParquetException("Couldn't init cipher context"); - } + bool write_length) + : AesEncryptionContext(alg_id, key_len, metadata, write_length) { + AesEncryptionContext::InitCipherContext(); +} +void AesEncryptor::AesEncryptorImpl::InitCipherContext(EVP_CIPHER_CTX* ctx) { if (kGcmMode == aes_mode_) { // Init AES-GCM with specified key length - if (16 == key_len) { - ENCRYPT_INIT(ctx_, EVP_aes_128_gcm()); - } else if (24 == key_len) { - ENCRYPT_INIT(ctx_, EVP_aes_192_gcm()); - } else if (32 == key_len) { - ENCRYPT_INIT(ctx_, EVP_aes_256_gcm()); + if (16 == key_length_) { + ENCRYPT_INIT(ctx, EVP_aes_128_gcm()); + } else if (24 == key_length_) { + ENCRYPT_INIT(ctx, EVP_aes_192_gcm()); + } else if (32 == key_length_) { + ENCRYPT_INIT(ctx, EVP_aes_256_gcm()); } } else { // Init AES-CTR with specified key length - if (16 == key_len) { - ENCRYPT_INIT(ctx_, EVP_aes_128_ctr()); - } else if (24 == key_len) { - ENCRYPT_INIT(ctx_, EVP_aes_192_ctr()); - } else if (32 == key_len) { - ENCRYPT_INIT(ctx_, EVP_aes_256_ctr()); + if (16 == key_length_) { + ENCRYPT_INIT(ctx, EVP_aes_128_ctr()); + } else if (24 == key_length_) { + ENCRYPT_INIT(ctx, EVP_aes_192_ctr()); + } else if (32 == key_length_) { + ENCRYPT_INIT(ctx, EVP_aes_256_ctr()); } } } @@ -162,8 +192,6 @@ AesEncryptor::AesEncryptorImpl::AesEncryptorImpl(ParquetCipher::type alg_id, int32_t AesEncryptor::AesEncryptorImpl::SignedFooterEncrypt( span footer, span key, span aad, span nonce, span encrypted_footer) { - CheckValid(); - if (static_cast(key_length_) != key.size()) { std::stringstream ss; ss << "Wrong key length " << key.size() << ". Should be " << key_length_; @@ -188,8 +216,6 @@ int32_t AesEncryptor::AesEncryptorImpl::Encrypt(span plaintext, span key, span aad, span ciphertext) { - CheckValid(); - if (static_cast(key_length_) != key.size()) { std::stringstream ss; ss << "Wrong key length " << key.size() << ". Should be " << key_length_; @@ -231,8 +257,10 @@ int32_t AesEncryptor::AesEncryptorImpl::GcmEncrypt(span plaintext throw ParquetException(ss.str()); } + auto ctx = GetCipherContext(); + // Setting key and IV (nonce) - if (1 != EVP_EncryptInit_ex(ctx_, nullptr, nullptr, key.data(), nonce.data())) { + if (1 != EVP_EncryptInit_ex(ctx.get(), nullptr, nullptr, key.data(), nonce.data())) { throw ParquetException("Couldn't set key and nonce"); } @@ -242,7 +270,7 @@ int32_t AesEncryptor::AesEncryptorImpl::GcmEncrypt(span plaintext ss << "AAD size " << aad.size() << " overflows int"; throw ParquetException(ss.str()); } - if ((!aad.empty()) && (1 != EVP_EncryptUpdate(ctx_, nullptr, &len, aad.data(), + if ((!aad.empty()) && (1 != EVP_EncryptUpdate(ctx.get(), nullptr, &len, aad.data(), static_cast(aad.size())))) { throw ParquetException("Couldn't set AAD"); } @@ -253,25 +281,26 @@ int32_t AesEncryptor::AesEncryptorImpl::GcmEncrypt(span plaintext ss << "Plaintext size " << plaintext.size() << " overflows int"; throw ParquetException(ss.str()); } - if (1 != - EVP_EncryptUpdate(ctx_, ciphertext.data() + length_buffer_length_ + kNonceLength, - &len, plaintext.data(), static_cast(plaintext.size()))) { + if (1 != EVP_EncryptUpdate( + ctx.get(), ciphertext.data() + length_buffer_length_ + kNonceLength, &len, + plaintext.data(), static_cast(plaintext.size()))) { throw ParquetException("Failed encryption update"); } ciphertext_len = len; // Finalization - if (1 != - EVP_EncryptFinal_ex( - ctx_, ciphertext.data() + length_buffer_length_ + kNonceLength + len, &len)) { + if (1 != EVP_EncryptFinal_ex( + ctx.get(), ciphertext.data() + length_buffer_length_ + kNonceLength + len, + &len)) { throw ParquetException("Failed encryption finalization"); } ciphertext_len += len; // Getting the tag - if (1 != EVP_CIPHER_CTX_ctrl(ctx_, EVP_CTRL_GCM_GET_TAG, kGcmTagLength, tag.data())) { + if (1 != + EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_GET_TAG, kGcmTagLength, tag.data())) { throw ParquetException("Couldn't get AES-GCM tag"); } @@ -312,8 +341,10 @@ int32_t AesEncryptor::AesEncryptorImpl::CtrEncrypt(span plaintext std::copy(nonce.begin(), nonce.begin() + kNonceLength, iv.begin()); iv[kCtrIvLength - 1] = 1; + auto ctx = GetCipherContext(); + // Setting key and IV - if (1 != EVP_EncryptInit_ex(ctx_, nullptr, nullptr, key.data(), iv.data())) { + if (1 != EVP_EncryptInit_ex(ctx.get(), nullptr, nullptr, key.data(), iv.data())) { throw ParquetException("Couldn't set key and IV"); } @@ -323,18 +354,18 @@ int32_t AesEncryptor::AesEncryptorImpl::CtrEncrypt(span plaintext ss << "Plaintext size " << plaintext.size() << " overflows int"; throw ParquetException(ss.str()); } - if (1 != - EVP_EncryptUpdate(ctx_, ciphertext.data() + length_buffer_length_ + kNonceLength, - &len, plaintext.data(), static_cast(plaintext.size()))) { + if (1 != EVP_EncryptUpdate( + ctx.get(), ciphertext.data() + length_buffer_length_ + kNonceLength, &len, + plaintext.data(), static_cast(plaintext.size()))) { throw ParquetException("Failed encryption update"); } ciphertext_len = len; // Finalization - if (1 != - EVP_EncryptFinal_ex( - ctx_, ciphertext.data() + length_buffer_length_ + kNonceLength + len, &len)) { + if (1 != EVP_EncryptFinal_ex( + ctx.get(), ciphertext.data() + length_buffer_length_ + kNonceLength + len, + &len)) { throw ParquetException("Failed encryption finalization"); } @@ -354,7 +385,7 @@ int32_t AesEncryptor::AesEncryptorImpl::CtrEncrypt(span plaintext return length_buffer_length_ + buffer_size; } -AesEncryptor::~AesEncryptor() {} +AesEncryptor::~AesEncryptor() = default; int32_t AesEncryptor::SignedFooterEncrypt(span footer, span key, @@ -364,8 +395,6 @@ int32_t AesEncryptor::SignedFooterEncrypt(span footer, return impl_->SignedFooterEncrypt(footer, key, aad, nonce, encrypted_footer); } -void AesEncryptor::WipeOut() { impl_->WipeOut(); } - int32_t AesEncryptor::CiphertextLength(int64_t plaintext_len) const { return impl_->CiphertextLength(plaintext_len); } @@ -380,16 +409,12 @@ AesEncryptor::AesEncryptor(ParquetCipher::type alg_id, int32_t key_len, bool met : impl_{std::unique_ptr( new AesEncryptorImpl(alg_id, key_len, metadata, write_length))} {} -class AesDecryptor::AesDecryptorImpl { +class AesDecryptor::AesDecryptorImpl : AesEncryptionContext { public: explicit AesDecryptorImpl(ParquetCipher::type alg_id, int32_t key_len, bool metadata, bool contains_length); - ~AesDecryptorImpl() = default; - - std::function ctxDeleter = [](EVP_CIPHER_CTX* ctx) { - EVP_CIPHER_CTX_free(ctx); - }; + ~AesDecryptorImpl() override = default; int32_t Decrypt(span ciphertext, span key, span aad, span plaintext); @@ -419,13 +444,10 @@ class AesDecryptor::AesDecryptorImpl { return plaintext_len + ciphertext_size_delta_; } - private: - std::unique_ptr ctx_; - int32_t aes_mode_; - int32_t key_length_; - int32_t ciphertext_size_delta_; - int32_t length_buffer_length_; + protected: + void InitCipherContext(EVP_CIPHER_CTX* ctx) override; + private: /// Get the actual ciphertext length, inclusive of the length buffer length, /// and validate that the provided buffer size is large enough. [[nodiscard]] int32_t GetCiphertextLength(span ciphertext) const; @@ -435,17 +457,6 @@ class AesDecryptor::AesDecryptorImpl { int32_t CtrDecrypt(span ciphertext, span key, span plaintext); - - /// Create a new cipher context, duplicates unused ctx_ to avoid going through initialization - std::unique_ptr GetCipherContext() { - // could use EVP_CIPHER_CTX_dup instead (requires OpenSSL 3.2.0 and above) - auto ctx = std::unique_ptr(EVP_CIPHER_CTX_new(), ctxDeleter); - if (ctx && !EVP_CIPHER_CTX_copy(ctx.get(), ctx_.get())) { - // ctx gets freed when leaving this method - throw ParquetException("Couldn't init cipher context"); - } - return ctx; - } }; int32_t AesDecryptor::Decrypt(span ciphertext, span key, @@ -457,49 +468,29 @@ AesDecryptor::~AesDecryptor() {} AesDecryptor::AesDecryptorImpl::AesDecryptorImpl(ParquetCipher::type alg_id, int32_t key_len, bool metadata, - bool contains_length) { - openssl::EnsureInitialized(); - - length_buffer_length_ = contains_length ? kBufferSizeLength : 0; - ciphertext_size_delta_ = length_buffer_length_ + kNonceLength; - if (metadata || (ParquetCipher::AES_GCM_V1 == alg_id)) { - aes_mode_ = kGcmMode; - ciphertext_size_delta_ += kGcmTagLength; - } else { - aes_mode_ = kCtrMode; - } - - if (16 != key_len && 24 != key_len && 32 != key_len) { - std::stringstream ss; - ss << "Wrong key length: " << key_len; - throw ParquetException(ss.str()); - } - - key_length_ = key_len; - - // create a EVP_CIPHER_CTX that gets auto-freed - ctx_ = std::unique_ptr(EVP_CIPHER_CTX_new(), ctxDeleter); - if (!ctx_) { - throw ParquetException("Couldn't init cipher context"); - } + bool contains_length) + : AesEncryptionContext(alg_id, key_len, metadata, contains_length) { + AesEncryptionContext::InitCipherContext(); +} +void AesDecryptor::AesDecryptorImpl::InitCipherContext(EVP_CIPHER_CTX* ctx) { if (kGcmMode == aes_mode_) { // Init AES-GCM with specified key length - if (16 == key_len) { - DECRYPT_INIT(ctx_.get(), EVP_aes_128_gcm()); - } else if (24 == key_len) { - DECRYPT_INIT(ctx_.get(), EVP_aes_192_gcm()); - } else if (32 == key_len) { - DECRYPT_INIT(ctx_.get(), EVP_aes_256_gcm()); + if (16 == key_length_) { + DECRYPT_INIT(ctx, EVP_aes_128_gcm()); + } else if (24 == key_length_) { + DECRYPT_INIT(ctx, EVP_aes_192_gcm()); + } else if (32 == key_length_) { + DECRYPT_INIT(ctx, EVP_aes_256_gcm()); } } else { // Init AES-CTR with specified key length - if (16 == key_len) { - DECRYPT_INIT(ctx_.get(), EVP_aes_128_ctr()); - } else if (24 == key_len) { - DECRYPT_INIT(ctx_.get(), EVP_aes_192_ctr()); - } else if (32 == key_len) { - DECRYPT_INIT(ctx_.get(), EVP_aes_256_ctr()); + if (16 == key_length_) { + DECRYPT_INIT(ctx, EVP_aes_128_ctr()); + } else if (24 == key_length_) { + DECRYPT_INIT(ctx, EVP_aes_192_ctr()); + } else if (32 == key_length_) { + DECRYPT_INIT(ctx, EVP_aes_256_ctr()); } } } diff --git a/cpp/src/parquet/encryption/encryption_internal.h b/cpp/src/parquet/encryption/encryption_internal.h index 62dadb058ccea..e3c09493f444f 100644 --- a/cpp/src/parquet/encryption/encryption_internal.h +++ b/cpp/src/parquet/encryption/encryption_internal.h @@ -44,6 +44,8 @@ constexpr int8_t kOffsetIndex = 7; constexpr int8_t kBloomFilterHeader = 8; constexpr int8_t kBloomFilterBitset = 9; +class AesEncryptionContext; + /// Performs AES encryption operations with GCM or CTR ciphers. class PARQUET_EXPORT AesEncryptor { public: @@ -77,8 +79,6 @@ class PARQUET_EXPORT AesEncryptor { ::arrow::util::span nonce, ::arrow::util::span encrypted_footer); - void WipeOut(); - private: // PIMPL Idiom class AesEncryptorImpl; diff --git a/cpp/src/parquet/encryption/internal_file_encryptor.cc b/cpp/src/parquet/encryption/internal_file_encryptor.cc index 94094e6aca228..80ec6412fd6e5 100644 --- a/cpp/src/parquet/encryption/internal_file_encryptor.cc +++ b/cpp/src/parquet/encryption/internal_file_encryptor.cc @@ -50,21 +50,6 @@ InternalFileEncryptor::InternalFileEncryptor(FileEncryptionProperties* propertie properties_->set_utilized(); } -void InternalFileEncryptor::WipeOutEncryptionKeys() { - properties_->WipeOutEncryptionKeys(); - - for (auto const& i : meta_encryptor_) { - if (i != nullptr) { - i->WipeOut(); - } - } - for (auto const& i : data_encryptor_) { - if (i != nullptr) { - i->WipeOut(); - } - } -} - std::shared_ptr InternalFileEncryptor::GetFooterEncryptor() { if (footer_encryptor_ != nullptr) { return footer_encryptor_; diff --git a/cpp/src/parquet/encryption/internal_file_encryptor.h b/cpp/src/parquet/encryption/internal_file_encryptor.h index 5a3d743ce5365..a7108ab66f610 100644 --- a/cpp/src/parquet/encryption/internal_file_encryptor.h +++ b/cpp/src/parquet/encryption/internal_file_encryptor.h @@ -77,7 +77,6 @@ class InternalFileEncryptor { std::shared_ptr GetFooterSigningEncryptor(); std::shared_ptr GetColumnMetaEncryptor(const std::string& column_path); std::shared_ptr GetColumnDataEncryptor(const std::string& column_path); - void WipeOutEncryptionKeys(); private: FileEncryptionProperties* properties_; diff --git a/cpp/src/parquet/file_writer.cc b/cpp/src/parquet/file_writer.cc index baa9e00da2351..1d65868025956 100644 --- a/cpp/src/parquet/file_writer.cc +++ b/cpp/src/parquet/file_writer.cc @@ -435,9 +435,6 @@ class FileSerializer : public ParquetFileWriter::Contents { WriteEncryptedFileMetadata(*file_metadata_, sink_.get(), footer_signing_encryptor, false); } - if (file_encryptor_) { - file_encryptor_->WipeOutEncryptionKeys(); - } } void WritePageIndex() { diff --git a/cpp/src/parquet/metadata.cc b/cpp/src/parquet/metadata.cc index 8f577be45b96d..23fb15fcc830d 100644 --- a/cpp/src/parquet/metadata.cc +++ b/cpp/src/parquet/metadata.cc @@ -736,8 +736,6 @@ class FileMetaData::FileMetaDataImpl { int32_t encrypted_len = aes_encryptor->SignedFooterEncrypt( serialized_data_span, str2span(key), str2span(aad), nonce, encrypted_buffer->mutable_span_as()); - // Delete AES encryptor object. It was created only to verify the footer signature. - aes_encryptor->WipeOut(); return 0 == memcmp(encrypted_buffer->data() + encrypted_len - encryption::kGcmTagLength, tag, encryption::kGcmTagLength);