Skip to content

Commit

Permalink
Simplify elaborate Python tests
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoMi committed Feb 2, 2025
1 parent edc4011 commit 473273b
Showing 1 changed file with 113 additions and 101 deletions.
214 changes: 113 additions & 101 deletions python/pyarrow/tests/test_dataset_encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

import base64
from contextlib import contextmanager
from datetime import timedelta
import random
import pyarrow.fs as fs
Expand Down Expand Up @@ -53,6 +54,7 @@
COLUMNS = ["year", "n_legs", "animal"]
COLUMN_KEYS = {COL_KEY_NAME: ["n_legs", "animal"]}


def create_sample_table():
return pa.table(
{
Expand Down Expand Up @@ -99,8 +101,17 @@ def kms_factory(kms_connection_configuration):
return InMemoryKmsClient(kms_connection_configuration)


@contextmanager
def cond_raises(success, error_type, match):
if success:
yield
else:
with pytest.raises(error_type, match=match):
yield


def do_test_dataset_encryption_decryption(table, extra_column_path=None):
# use extra column key for column extra_column_name if given
# use extra column key for column extra_column_path, if given
if extra_column_path:
keys = dict(**KEYS, **{EXTRA_COL_KEY_NAME: EXTRA_COL_KEY})
column_keys = dict(**COLUMN_KEYS, **{EXTRA_COL_KEY_NAME: [extra_column_path]})
Expand All @@ -118,120 +129,119 @@ def do_test_dataset_encryption_decryption(table, extra_column_path=None):
plaintext_column_names = [column_name
for column_name in all_column_names
if column_name not in encrypted_column_names and
(extra_column_path is None or not extra_column_path.startswith(f"{column_name}."))]
(extra_column_path is None or
not extra_column_path.startswith(f"{column_name}."))]
assert len(encrypted_column_names) > 0
assert len(plaintext_column_names) > 0
footer_key_only = {FOOTER_KEY_NAME: FOOTER_KEY}
column_keys_only = {key_name: key for key_name, key in keys.items() if key_name != FOOTER_KEY_NAME}
column_keys_only = {key_name: key
for key_name, key in keys.items()
if key_name != FOOTER_KEY_NAME}

# define the inner test
def assert_decrypts(
read_keys,
read_columns,
to_table_success,
dataset_success=True,
):
# use all keys for writing
write_keys = keys
encryption_config = create_encryption_config(FOOTER_KEY_NAME, column_keys)
decryption_config = create_decryption_config()
encrypt_kms_connection_config = create_kms_connection_config(write_keys)
decrypt_kms_connection_config = create_kms_connection_config(read_keys)

crypto_factory = pe.CryptoFactory(kms_factory)
parquet_encryption_cfg = ds.ParquetEncryptionConfig(
crypto_factory, encrypt_kms_connection_config, encryption_config
)
parquet_decryption_cfg = ds.ParquetDecryptionConfig(
crypto_factory, decrypt_kms_connection_config, decryption_config
)

# create write_options with dataset encryption config
pformat = pa.dataset.ParquetFileFormat()
write_options = pformat.make_write_options(
encryption_config=parquet_encryption_cfg
)

mockfs = fs._MockFileSystem()
mockfs.create_dir("/")

ds.write_dataset(
data=table,
base_dir="sample_dataset",
format=pformat,
file_options=write_options,
filesystem=mockfs,
)

# read without decryption config -> errors if dataset was properly encrypted
pformat = pa.dataset.ParquetFileFormat()
with pytest.raises(IOError, match=r"no decryption"):
ds.dataset("sample_dataset", format=pformat, filesystem=mockfs)

# set decryption config for parquet fragment scan options
pq_scan_opts = ds.ParquetFragmentScanOptions(
decryption_config=parquet_decryption_cfg
)
pformat = pa.dataset.ParquetFileFormat(
default_fragment_scan_options=pq_scan_opts
)
with cond_raises(dataset_success, ValueError, match="Unknown master key"):
dataset = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs)
with cond_raises(to_table_success, ValueError, match="Unknown master key"):
assert table.select(read_columns).equals(dataset.to_table(read_columns))

# set decryption properties for parquet fragment scan options
decryption_properties = crypto_factory.file_decryption_properties(
decrypt_kms_connection_config, decryption_config)
pq_scan_opts = ds.ParquetFragmentScanOptions(
decryption_properties=decryption_properties
)

pformat = pa.dataset.ParquetFileFormat(
default_fragment_scan_options=pq_scan_opts
)
with cond_raises(dataset_success, ValueError, match="Unknown master key"):
dataset = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs)
with cond_raises(to_table_success, ValueError, match="Unknown master key"):
assert table.select(read_columns).equals(dataset.to_table(read_columns))

# read with footer key only
assert_test_dataset_encryption_decryption(table, column_keys, keys, footer_key_only, plaintext_column_names, True)
assert_test_dataset_encryption_decryption(table, column_keys, keys, footer_key_only, encrypted_column_names, False)
assert_test_dataset_encryption_decryption(table, column_keys, keys, footer_key_only, all_column_names, False)
assert_decrypts(footer_key_only, plaintext_column_names, True)
assert_decrypts(footer_key_only, encrypted_column_names, False)
assert_decrypts(footer_key_only, all_column_names, False)

# read with all but footer key
assert_test_dataset_encryption_decryption(table, column_keys, keys, column_keys_only, plaintext_column_names, False, False)
assert_test_dataset_encryption_decryption(table, column_keys, keys, column_keys_only, encrypted_column_names, False, False)
assert_test_dataset_encryption_decryption(table, column_keys, keys, column_keys_only, all_column_names, False, False)
assert_decrypts(column_keys_only, plaintext_column_names, False, False)
assert_decrypts(column_keys_only, encrypted_column_names, False, False)
assert_decrypts(column_keys_only, all_column_names, False, False)

# with footer key and one column key, all plaintext and
# those encrypted columns that use that key, can be read
if len(column_keys) > 1:
for column_key_name, column_key_column_names in column_keys.items():
for encrypted_column_name in column_key_column_names:
encrypted_column_name = encrypted_column_name.split(".")[0]
footer_key_and_one_column_key = {key_name: key for key_name, key in keys.items()
if key_name in [FOOTER_KEY_NAME, column_key_name]}
assert_test_dataset_encryption_decryption(table, column_keys, keys, footer_key_and_one_column_key, plaintext_column_names,
True)
assert_test_dataset_encryption_decryption(table, column_keys, keys, footer_key_and_one_column_key, plaintext_column_names + [encrypted_column_name],
encrypted_column_name != extra_column_name)
assert_test_dataset_encryption_decryption(table, column_keys, keys, footer_key_and_one_column_key, encrypted_column_names,
False)
assert_test_dataset_encryption_decryption(table, column_keys, keys, footer_key_and_one_column_key, all_column_names, False)
# decrypt with footer key and that one column key
read_keys = {key_name: key
for key_name, key in keys.items()
if key_name in [FOOTER_KEY_NAME, column_key_name]}
# that one encrypted column can only be read
# if it is not a column path / nested field
plaintext_and_one_success = encrypted_column_name != extra_column_name
plaintext_and_one = plaintext_column_names + [encrypted_column_name]
assert_decrypts(read_keys, plaintext_column_names, True)
assert_decrypts(read_keys, plaintext_and_one, plaintext_and_one_success)
assert_decrypts(read_keys, encrypted_column_names, False)
assert_decrypts(read_keys, all_column_names, False)

# with all column keys, all columns can be read
assert_test_dataset_encryption_decryption(table, column_keys, keys, keys, plaintext_column_names, True)
assert_test_dataset_encryption_decryption(table, column_keys, keys, keys, encrypted_column_names, True)
assert_test_dataset_encryption_decryption(table, column_keys, keys, keys, all_column_names, True)


def assert_test_dataset_encryption_decryption(
table,
column_keys,
write_keys,
read_keys,
read_columns,
to_table_success,
dataset_success = True,
):
encryption_config = create_encryption_config(FOOTER_KEY_NAME, column_keys)
decryption_config = create_decryption_config()
encrypt_kms_connection_config = create_kms_connection_config(write_keys)
decrypt_kms_connection_config = create_kms_connection_config(read_keys)

crypto_factory = pe.CryptoFactory(kms_factory)
parquet_encryption_cfg = ds.ParquetEncryptionConfig(
crypto_factory, encrypt_kms_connection_config, encryption_config
)
parquet_decryption_cfg = ds.ParquetDecryptionConfig(
crypto_factory, decrypt_kms_connection_config, decryption_config
)

# create write_options with dataset encryption config
pformat = pa.dataset.ParquetFileFormat()
write_options = pformat.make_write_options(encryption_config=parquet_encryption_cfg)

mockfs = fs._MockFileSystem()
mockfs.create_dir("/")

ds.write_dataset(
data=table,
base_dir="sample_dataset",
format=pformat,
file_options=write_options,
filesystem=mockfs,
)

# read without decryption config -> should error if dataset was properly encrypted
pformat = pa.dataset.ParquetFileFormat()
with pytest.raises(IOError, match=r"no decryption"):
ds.dataset("sample_dataset", format=pformat, filesystem=mockfs)

# set decryption config for parquet fragment scan options
pq_scan_opts = ds.ParquetFragmentScanOptions(
decryption_config=parquet_decryption_cfg
)
pformat = pa.dataset.ParquetFileFormat(default_fragment_scan_options=pq_scan_opts)
if dataset_success:
dataset = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs)
if to_table_success:
assert table.select(read_columns).equals(dataset.to_table(read_columns))
else:
with pytest.raises(ValueError, match="Unknown master key"):
assert table.select(read_columns).equals(dataset.to_table(read_columns))
else:
with pytest.raises(ValueError, match="Unknown master key"):
_ = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs)

# set decryption properties for parquet fragment scan options
decryption_properties = crypto_factory.file_decryption_properties(
decrypt_kms_connection_config, decryption_config)
pq_scan_opts = ds.ParquetFragmentScanOptions(
decryption_properties=decryption_properties
)

pformat = pa.dataset.ParquetFileFormat(default_fragment_scan_options=pq_scan_opts)
if dataset_success:
dataset = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs)
if to_table_success:
assert table.select(read_columns).equals(dataset.to_table(read_columns))
else:
with pytest.raises(ValueError, match="Unknown master key"):
assert table.select(read_columns).equals(dataset.to_table(read_columns))
else:
with pytest.raises(ValueError, match="Unknown master key"):
_ = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs)
assert_decrypts(keys, plaintext_column_names, True)
assert_decrypts(keys, encrypted_column_names, True)
assert_decrypts(keys, all_column_names, True)


@pytest.mark.skipif(
Expand Down Expand Up @@ -260,7 +270,9 @@ def test_list_encryption_decryption(column_name):
reason="Parquet Encryption is not currently enabled"
)
@pytest.mark.parametrize(
"column_name", ["map", "map.key", "map.value", "map.key_value.key", "map.key_value.value"]
"column_name", [
"map", "map.key", "map.value", "map.key_value.key", "map.key_value.value"
]
)
def test_map_encryption_decryption(column_name):
map_type = pa.map_(pa.string(), pa.int32())
Expand All @@ -280,7 +292,7 @@ def test_map_encryption_decryption(column_name):
encryption_unavailable, reason="Parquet Encryption is not currently enabled"
)
@pytest.mark.parametrize(
"column_name", [ "struct", "struct.f1", "struct.f2"]
"column_name", ["struct", "struct.f1", "struct.f2"]
)
def test_struct_encryption_decryption(column_name):
struct_fields = [("f1", pa.int32()), ("f2", pa.string())]
Expand Down

0 comments on commit 473273b

Please sign in to comment.