Skip to content

Commit

Permalink
Use unique_ptr for non-shareable objects + address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
pitrou authored and EnricoMi committed Mar 10, 2025
1 parent a53a046 commit f164eee
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 121 deletions.
46 changes: 25 additions & 21 deletions cpp/src/parquet/column_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,7 @@ class SerializedPageReader : public PageReader {
void set_max_page_header_size(uint32_t size) override { max_page_header_size_ = size; }

private:
void UpdateDecryption(const std::shared_ptr<Decryptor>& decryptor, int8_t module_type,
std::string* page_aad);
void UpdateDecryption(Decryptor* decryptor, int8_t module_type, std::string* page_aad);

void InitDecryption();

Expand Down Expand Up @@ -309,8 +308,8 @@ class SerializedPageReader : public PageReader {
// The CryptoContext used by this PageReader.
CryptoContext crypto_ctx_;
// This PageReader has its own Decryptor instances in order to be thread-safe.
std::shared_ptr<Decryptor> meta_decryptor_;
std::shared_ptr<Decryptor> data_decryptor_;
std::unique_ptr<Decryptor> meta_decryptor_;
std::unique_ptr<Decryptor> data_decryptor_;

// The ordinal fields in the context below are used for AAD suffix calculation.
int32_t page_ordinal_; // page ordinal does not count the dictionary page
Expand All @@ -336,24 +335,28 @@ class SerializedPageReader : public PageReader {

void SerializedPageReader::InitDecryption() {
// Prepare the AAD for quick update later.
if (crypto_ctx_.data_decryptor) {
data_decryptor_ = crypto_ctx_.data_decryptor();
ARROW_DCHECK(!data_decryptor_->file_aad().empty());
data_page_aad_ = encryption::CreateModuleAad(
data_decryptor_->file_aad(), encryption::kDataPage, crypto_ctx_.row_group_ordinal,
crypto_ctx_.column_ordinal, kNonPageOrdinal);
}
if (crypto_ctx_.meta_decryptor) {
meta_decryptor_ = crypto_ctx_.meta_decryptor();
ARROW_DCHECK(!meta_decryptor_->file_aad().empty());
data_page_header_aad_ = encryption::CreateModuleAad(
meta_decryptor_->file_aad(), encryption::kDataPageHeader,
crypto_ctx_.row_group_ordinal, crypto_ctx_.column_ordinal, kNonPageOrdinal);
if (crypto_ctx_.data_decryptor_factory) {
data_decryptor_ = crypto_ctx_.data_decryptor_factory();
if (data_decryptor_) {
ARROW_DCHECK(!data_decryptor_->file_aad().empty());
data_page_aad_ = encryption::CreateModuleAad(
data_decryptor_->file_aad(), encryption::kDataPage,
crypto_ctx_.row_group_ordinal, crypto_ctx_.column_ordinal, kNonPageOrdinal);
}
}
if (crypto_ctx_.meta_decryptor_factory) {
meta_decryptor_ = crypto_ctx_.meta_decryptor_factory();
if (meta_decryptor_) {
ARROW_DCHECK(!meta_decryptor_->file_aad().empty());
data_page_header_aad_ = encryption::CreateModuleAad(
meta_decryptor_->file_aad(), encryption::kDataPageHeader,
crypto_ctx_.row_group_ordinal, crypto_ctx_.column_ordinal, kNonPageOrdinal);
}
}
}

void SerializedPageReader::UpdateDecryption(const std::shared_ptr<Decryptor>& decryptor,
int8_t module_type, std::string* page_aad) {
void SerializedPageReader::UpdateDecryption(Decryptor* decryptor, int8_t module_type,
std::string* page_aad) {
ARROW_DCHECK(decryptor != nullptr);
if (crypto_ctx_.start_decrypt_with_dictionary_page) {
UpdateDecryptor(decryptor, crypto_ctx_.row_group_ordinal, crypto_ctx_.column_ordinal,
Expand Down Expand Up @@ -433,7 +436,7 @@ std::shared_ptr<Page> SerializedPageReader::NextPage() {
header_size = static_cast<uint32_t>(view.size());
try {
if (meta_decryptor_ != nullptr) {
UpdateDecryption(meta_decryptor_, encryption::kDictionaryPageHeader,
UpdateDecryption(meta_decryptor_.get(), encryption::kDictionaryPageHeader,
&data_page_header_aad_);
}
// Reset current page header to avoid unclearing the __isset flag.
Expand Down Expand Up @@ -469,7 +472,8 @@ std::shared_ptr<Page> SerializedPageReader::NextPage() {
}

if (data_decryptor_ != nullptr) {
UpdateDecryption(data_decryptor_, encryption::kDictionaryPage, &data_page_aad_);
UpdateDecryption(data_decryptor_.get(), encryption::kDictionaryPage,
&data_page_aad_);
}

// Read the compressed data page.
Expand Down
14 changes: 2 additions & 12 deletions cpp/src/parquet/column_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,21 +102,11 @@ class PARQUET_EXPORT LevelDecoder {
};

struct CryptoContext {
CryptoContext(bool start_with_dictionary_page, int16_t rg_ordinal, int16_t col_ordinal,
std::function<std::shared_ptr<Decryptor>()> meta,
std::function<std::shared_ptr<Decryptor>()> data)
: start_decrypt_with_dictionary_page(start_with_dictionary_page),
row_group_ordinal(rg_ordinal),
column_ordinal(col_ordinal),
meta_decryptor(std::move(meta)),
data_decryptor(std::move(data)) {}
CryptoContext() {}

bool start_decrypt_with_dictionary_page = false;
int16_t row_group_ordinal = -1;
int16_t column_ordinal = -1;
std::function<std::shared_ptr<Decryptor>()> meta_decryptor;
std::function<std::shared_ptr<Decryptor>()> data_decryptor;
std::function<std::unique_ptr<Decryptor>()> meta_decryptor_factory;
std::function<std::unique_ptr<Decryptor>()> data_decryptor_factory;
};

// Abstract page iterator interface. This way, we can feed column pages to the
Expand Down
51 changes: 18 additions & 33 deletions cpp/src/parquet/encryption/encryption_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,24 +55,30 @@ constexpr int32_t kBufferSizeLength = 4;
class AesCryptoContext {
public:
AesCryptoContext(ParquetCipher::type alg_id, int32_t key_len, bool metadata,
bool write_length) {
bool include_length) {
openssl::EnsureInitialized();

length_buffer_length_ = write_length ? kBufferSizeLength : 0;
length_buffer_length_ = include_length ? kBufferSizeLength : 0;
ciphertext_size_delta_ = length_buffer_length_ + kNonceLength;
if (metadata || (ParquetCipher::AES_GCM_V1 == alg_id)) {
aes_mode_ = kGcmMode;
ciphertext_size_delta_ += kGcmTagLength;
} else {
aes_mode_ = kCtrMode;
}

if (ParquetCipher::AES_GCM_V1 != alg_id && ParquetCipher::AES_GCM_CTR_V1 != alg_id) {
std::stringstream ss;
ss << "Crypto algorithm " << alg_id << " is not supported";
throw ParquetException(ss.str());
}
if (16 != key_len && 24 != key_len && 32 != key_len) {
std::stringstream ss;
ss << "Wrong key length: " << key_len;
throw ParquetException(ss.str());
}

if (metadata || (ParquetCipher::AES_GCM_V1 == alg_id)) {
aes_mode_ = kGcmMode;
ciphertext_size_delta_ += kGcmTagLength;
} else {
aes_mode_ = kCtrMode;
}

key_length_ = key_len;
}

Expand Down Expand Up @@ -102,8 +108,6 @@ class AesEncryptor::AesEncryptorImpl : public AesCryptoContext {
explicit AesEncryptorImpl(ParquetCipher::type alg_id, int32_t key_len, bool metadata,
bool write_length);

~AesEncryptorImpl() override = default;

int32_t Encrypt(span<const uint8_t> plaintext, span<const uint8_t> key,
span<const uint8_t> aad, span<uint8_t> ciphertext);

Expand Down Expand Up @@ -393,8 +397,6 @@ class AesDecryptor::AesDecryptorImpl : AesCryptoContext {
explicit AesDecryptorImpl(ParquetCipher::type alg_id, int32_t key_len, bool metadata,
bool contains_length);

~AesDecryptorImpl() override = default;

int32_t Decrypt(span<const uint8_t> ciphertext, span<const uint8_t> key,
span<const uint8_t> aad, span<uint8_t> plaintext);

Expand Down Expand Up @@ -474,37 +476,20 @@ AesCryptoContext::CipherContext AesDecryptor::AesDecryptorImpl::MakeCipherContex
return ctx;
}

std::unique_ptr<AesEncryptor> AesEncryptor::Make(ParquetCipher::type alg_id,
int32_t key_len, bool metadata) {
return Make(alg_id, key_len, metadata, true /*write_length*/);
}

std::unique_ptr<AesEncryptor> AesEncryptor::Make(ParquetCipher::type alg_id,
int32_t key_len, bool metadata,
bool write_length) {
if (ParquetCipher::AES_GCM_V1 != alg_id && ParquetCipher::AES_GCM_CTR_V1 != alg_id) {
std::stringstream ss;
ss << "Crypto algorithm " << alg_id << " is not supported";
throw ParquetException(ss.str());
}

return std::make_unique<AesEncryptor>(alg_id, key_len, metadata, write_length);
}

AesDecryptor::AesDecryptor(ParquetCipher::type alg_id, int32_t key_len, bool metadata,
bool contains_length)
: impl_{std::unique_ptr<AesDecryptorImpl>(
new AesDecryptorImpl(alg_id, key_len, metadata, contains_length))} {}
: impl_{std::make_unique<AesDecryptorImpl>(alg_id, key_len, metadata,
contains_length)} {}

std::shared_ptr<AesDecryptor> AesDecryptor::Make(ParquetCipher::type alg_id,
std::unique_ptr<AesDecryptor> AesDecryptor::Make(ParquetCipher::type alg_id,
int32_t key_len, bool metadata) {
if (ParquetCipher::AES_GCM_V1 != alg_id && ParquetCipher::AES_GCM_CTR_V1 != alg_id) {
std::stringstream ss;
ss << "Crypto algorithm " << alg_id << " is not supported";
throw ParquetException(ss.str());
}

return std::make_shared<AesDecryptor>(alg_id, key_len, metadata);
return std::make_unique<AesDecryptor>(alg_id, key_len, metadata);
}

int32_t AesDecryptor::PlaintextLength(int32_t ciphertext_len) const {
Expand Down
20 changes: 7 additions & 13 deletions cpp/src/parquet/encryption/encryption_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,7 @@ class PARQUET_EXPORT AesEncryptor {
bool write_length = true);

static std::unique_ptr<AesEncryptor> Make(ParquetCipher::type alg_id, int32_t key_len,
bool metadata);

static std::unique_ptr<AesEncryptor> Make(ParquetCipher::type alg_id, int32_t key_len,
bool metadata, bool write_length);
bool metadata, bool write_length = true);

~AesEncryptor();

Expand Down Expand Up @@ -86,19 +83,16 @@ class PARQUET_EXPORT AesEncryptor {
/// Performs AES decryption operations with GCM or CTR ciphers.
class PARQUET_EXPORT AesDecryptor {
public:
/// Can serve one key length only. Possible values: 16, 24, 32 bytes.
/// If contains_length is true, expect ciphertext length prepended to the ciphertext
explicit AesDecryptor(ParquetCipher::type alg_id, int32_t key_len, bool metadata,
bool contains_length = true);

/// \brief Factory function to create an AesDecryptor
/// \brief Construct an AesDecryptor
///
/// \param alg_id the encryption algorithm to use
/// \param key_len key length. Possible values: 16, 24, 32 bytes.
/// \param metadata if true then this is a metadata decryptor
/// out when decryption is finished
/// \return shared pointer to a new AesDecryptor
static std::shared_ptr<AesDecryptor> Make(ParquetCipher::type alg_id, int32_t key_len,
/// \param contains_length if true, expect ciphertext length prepended to the ciphertext
explicit AesDecryptor(ParquetCipher::type alg_id, int32_t key_len, bool metadata,
bool contains_length = true);

static std::unique_ptr<AesDecryptor> Make(ParquetCipher::type alg_id, int32_t key_len,
bool metadata);

~AesDecryptor();
Expand Down
8 changes: 1 addition & 7 deletions cpp/src/parquet/encryption/encryption_internal_nossl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,6 @@ int32_t AesDecryptor::Decrypt(::arrow::util::span<const uint8_t> ciphertext,

AesDecryptor::~AesDecryptor() {}

std::unique_ptr<AesEncryptor> AesEncryptor::Make(ParquetCipher::type alg_id,
int32_t key_len, bool metadata) {
ThrowOpenSSLRequiredException();
return NULLPTR;
}

std::unique_ptr<AesEncryptor> AesEncryptor::Make(ParquetCipher::type alg_id,
int32_t key_len, bool metadata,
bool write_length) {
Expand All @@ -86,7 +80,7 @@ AesDecryptor::AesDecryptor(ParquetCipher::type alg_id, int32_t key_len, bool met
ThrowOpenSSLRequiredException();
}

std::shared_ptr<AesDecryptor> AesDecryptor::Make(ParquetCipher::type alg_id,
std::unique_ptr<AesDecryptor> AesDecryptor::Make(ParquetCipher::type alg_id,
int32_t key_len, bool metadata) {
ThrowOpenSSLRequiredException();
return NULLPTR;
Expand Down
30 changes: 15 additions & 15 deletions cpp/src/parquet/encryption/internal_file_decryptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
namespace parquet {

// Decryptor
Decryptor::Decryptor(std::shared_ptr<encryption::AesDecryptor> aes_decryptor,
Decryptor::Decryptor(std::unique_ptr<encryption::AesDecryptor> aes_decryptor,
const std::string& key, const std::string& file_aad,
const std::string& aad, ::arrow::MemoryPool* pool)
: aes_decryptor_(std::move(aes_decryptor)),
Expand All @@ -33,6 +33,8 @@ Decryptor::Decryptor(std::shared_ptr<encryption::AesDecryptor> aes_decryptor,
aad_(aad),
pool_(pool) {}

Decryptor::~Decryptor() = default;

int32_t Decryptor::PlaintextLength(int32_t ciphertext_len) const {
return aes_decryptor_->PlaintextLength(ciphertext_len);
}
Expand Down Expand Up @@ -104,19 +106,18 @@ std::string InternalFileDecryptor::GetFooterKey() {
return footer_key;
}

std::shared_ptr<Decryptor> InternalFileDecryptor::GetFooterDecryptor() {
std::unique_ptr<Decryptor> InternalFileDecryptor::GetFooterDecryptor() {
std::string aad = encryption::CreateFooterAad(file_aad_);
return GetFooterDecryptor(aad, true);
}

std::shared_ptr<Decryptor> InternalFileDecryptor::GetFooterDecryptor(
std::unique_ptr<Decryptor> InternalFileDecryptor::GetFooterDecryptor(
const std::string& aad, bool metadata) {
std::string footer_key = GetFooterKey();

auto key_len = static_cast<int32_t>(footer_key.size());
std::shared_ptr<encryption::AesDecryptor> aes_decryptor =
encryption::AesDecryptor::Make(algorithm_, key_len, /*metadata=*/metadata);
return std::make_shared<Decryptor>(std::move(aes_decryptor), footer_key, file_aad_, aad,
auto aes_decryptor = encryption::AesDecryptor::Make(algorithm_, key_len, metadata);
return std::make_unique<Decryptor>(std::move(aes_decryptor), footer_key, file_aad_, aad,
pool_);
}

Expand All @@ -141,17 +142,17 @@ std::string InternalFileDecryptor::GetColumnKey(const std::string& column_path,
return column_key;
}

std::shared_ptr<Decryptor> InternalFileDecryptor::GetColumnDecryptor(
std::unique_ptr<Decryptor> InternalFileDecryptor::GetColumnDecryptor(
const std::string& column_path, const std::string& column_key_metadata,
const std::string& aad, bool metadata) {
std::string column_key = GetColumnKey(column_path, column_key_metadata);
auto key_len = static_cast<int32_t>(column_key.size());
auto aes_decryptor = encryption::AesDecryptor::Make(algorithm_, key_len, metadata);
return std::make_shared<Decryptor>(std::move(aes_decryptor), column_key, file_aad_, aad,
return std::make_unique<Decryptor>(std::move(aes_decryptor), column_key, file_aad_, aad,
pool_);
}

std::function<std::shared_ptr<Decryptor>()>
std::function<std::unique_ptr<Decryptor>()>
InternalFileDecryptor::GetColumnDecryptorFactory(
const ColumnCryptoMetaData* crypto_metadata, const std::string& aad, bool metadata) {
if (crypto_metadata->encrypted_with_footer_key()) {
Expand All @@ -166,12 +167,12 @@ InternalFileDecryptor::GetColumnDecryptorFactory(
return [this, aad, metadata, column_key = std::move(column_key)]() {
auto key_len = static_cast<int32_t>(column_key.size());
auto aes_decryptor = encryption::AesDecryptor::Make(algorithm_, key_len, metadata);
return std::make_shared<Decryptor>(std::move(aes_decryptor), column_key, file_aad_,
return std::make_unique<Decryptor>(std::move(aes_decryptor), column_key, file_aad_,
aad, pool_);
};
}

std::function<std::shared_ptr<Decryptor>()>
std::function<std::unique_ptr<Decryptor>()>
InternalFileDecryptor::GetColumnMetaDecryptorFactory(
InternalFileDecryptor* file_descryptor, const ColumnCryptoMetaData* crypto_metadata,
const std::string& aad) {
Expand All @@ -186,7 +187,7 @@ InternalFileDecryptor::GetColumnMetaDecryptorFactory(
/*metadata=*/true);
}

std::function<std::shared_ptr<Decryptor>()>
std::function<std::unique_ptr<Decryptor>()>
InternalFileDecryptor::GetColumnDataDecryptorFactory(
InternalFileDecryptor* file_descryptor, const ColumnCryptoMetaData* crypto_metadata,
const std::string& aad) {
Expand All @@ -201,9 +202,8 @@ InternalFileDecryptor::GetColumnDataDecryptorFactory(
/*metadata=*/false);
}

void UpdateDecryptor(const std::shared_ptr<Decryptor>& decryptor,
int16_t row_group_ordinal, int16_t column_ordinal,
int8_t module_type) {
void UpdateDecryptor(Decryptor* decryptor, int16_t row_group_ordinal,
int16_t column_ordinal, int8_t module_type) {
ARROW_DCHECK(!decryptor->file_aad().empty());
const std::string aad =
encryption::CreateModuleAad(decryptor->file_aad(), module_type, row_group_ordinal,
Expand Down
Loading

0 comments on commit f164eee

Please sign in to comment.