diff --git a/cpp/src/parquet/encryption/encryption_internal.cc b/cpp/src/parquet/encryption/encryption_internal.cc index 55ec40682bf60..646ed7f1ada84 100644 --- a/cpp/src/parquet/encryption/encryption_internal.cc +++ b/cpp/src/parquet/encryption/encryption_internal.cc @@ -79,21 +79,10 @@ class AesEncryptionContext { virtual ~AesEncryptionContext() = default; protected: - virtual void InitCipherContext(EVP_CIPHER_CTX* ctx) = 0; - static inline std::function ctx_deleter_ = [](EVP_CIPHER_CTX* ctx) { EVP_CIPHER_CTX_free(ctx); }; - /// Create a new cipher context that auto-frees - std::unique_ptr NewCipherContext() { - auto ctx = std::unique_ptr(EVP_CIPHER_CTX_new(), - ctx_deleter_); - if (!ctx) throw ParquetException("Couldn't init cipher context"); - InitCipherContext(ctx.get()); - return ctx; - } - int32_t aes_mode_; int32_t key_length_; int32_t ciphertext_size_delta_; @@ -130,10 +119,9 @@ class AesEncryptor::AesEncryptorImpl : public AesEncryptionContext { return static_cast(plaintext_len + ciphertext_size_delta_); } - protected: - void InitCipherContext(EVP_CIPHER_CTX* ctx) override; - private: + [[nodiscard]] std::unique_ptr NewCipherContext() const; + int32_t GcmEncrypt(span plaintext, span key, span nonce, span aad, span ciphertext); @@ -147,26 +135,30 @@ AesEncryptor::AesEncryptorImpl::AesEncryptorImpl(ParquetCipher::type alg_id, bool write_length) : AesEncryptionContext(alg_id, key_len, metadata, write_length) { } -void AesEncryptor::AesEncryptorImpl::InitCipherContext(EVP_CIPHER_CTX* ctx) { +std::unique_ptr AesEncryptor::AesEncryptorImpl::NewCipherContext() const { + auto ctx = std::unique_ptr(EVP_CIPHER_CTX_new(), + ctx_deleter_); + if (!ctx) throw ParquetException("Couldn't init cipher context"); if (kGcmMode == aes_mode_) { // Init AES-GCM with specified key length if (16 == key_length_) { - ENCRYPT_INIT(ctx, EVP_aes_128_gcm()); + ENCRYPT_INIT(ctx.get(), EVP_aes_128_gcm()); } else if (24 == key_length_) { - ENCRYPT_INIT(ctx, EVP_aes_192_gcm()); + ENCRYPT_INIT(ctx.get(), EVP_aes_192_gcm()); } else if (32 == key_length_) { - ENCRYPT_INIT(ctx, EVP_aes_256_gcm()); + ENCRYPT_INIT(ctx.get(), EVP_aes_256_gcm()); } } else { // Init AES-CTR with specified key length if (16 == key_length_) { - ENCRYPT_INIT(ctx, EVP_aes_128_ctr()); + ENCRYPT_INIT(ctx.get(), EVP_aes_128_ctr()); } else if (24 == key_length_) { - ENCRYPT_INIT(ctx, EVP_aes_192_ctr()); + ENCRYPT_INIT(ctx.get(), EVP_aes_192_ctr()); } else if (32 == key_length_) { - ENCRYPT_INIT(ctx, EVP_aes_256_ctr()); + ENCRYPT_INIT(ctx.get(), EVP_aes_256_ctr()); } } + return ctx; } int32_t AesEncryptor::AesEncryptorImpl::SignedFooterEncrypt( @@ -424,10 +416,9 @@ class AesDecryptor::AesDecryptorImpl : AesEncryptionContext { return plaintext_len + ciphertext_size_delta_; } - protected: - void InitCipherContext(EVP_CIPHER_CTX* ctx) override; - private: + [[nodiscard]] std::unique_ptr NewCipherContext() const; + /// Get the actual ciphertext length, inclusive of the length buffer length, /// and validate that the provided buffer size is large enough. [[nodiscard]] int32_t GetCiphertextLength(span ciphertext) const; @@ -451,26 +442,30 @@ AesDecryptor::AesDecryptorImpl::AesDecryptorImpl(ParquetCipher::type alg_id, bool contains_length) : AesEncryptionContext(alg_id, key_len, metadata, contains_length) { } -void AesDecryptor::AesDecryptorImpl::InitCipherContext(EVP_CIPHER_CTX* ctx) { +std::unique_ptr AesDecryptor::AesDecryptorImpl::NewCipherContext() const { + auto ctx = std::unique_ptr(EVP_CIPHER_CTX_new(), + ctx_deleter_); + if (!ctx) throw ParquetException("Couldn't init cipher context"); if (kGcmMode == aes_mode_) { // Init AES-GCM with specified key length if (16 == key_length_) { - DECRYPT_INIT(ctx, EVP_aes_128_gcm()); + DECRYPT_INIT(ctx.get(), EVP_aes_128_gcm()); } else if (24 == key_length_) { - DECRYPT_INIT(ctx, EVP_aes_192_gcm()); + DECRYPT_INIT(ctx.get(), EVP_aes_192_gcm()); } else if (32 == key_length_) { - DECRYPT_INIT(ctx, EVP_aes_256_gcm()); + DECRYPT_INIT(ctx.get(), EVP_aes_256_gcm()); } } else { // Init AES-CTR with specified key length if (16 == key_length_) { - DECRYPT_INIT(ctx, EVP_aes_128_ctr()); + DECRYPT_INIT(ctx.get(), EVP_aes_128_ctr()); } else if (24 == key_length_) { - DECRYPT_INIT(ctx, EVP_aes_192_ctr()); + DECRYPT_INIT(ctx.get(), EVP_aes_192_ctr()); } else if (32 == key_length_) { - DECRYPT_INIT(ctx, EVP_aes_256_ctr()); + DECRYPT_INIT(ctx.get(), EVP_aes_256_ctr()); } } + return ctx; } std::unique_ptr AesEncryptor::Make(ParquetCipher::type alg_id,