diff --git a/dlt/destinations/adapters.py b/dlt/destinations/adapters.py index 554bd88924..1c3e094e19 100644 --- a/dlt/destinations/adapters.py +++ b/dlt/destinations/adapters.py @@ -5,6 +5,7 @@ from dlt.destinations.impl.bigquery import bigquery_adapter from dlt.destinations.impl.synapse import synapse_adapter from dlt.destinations.impl.clickhouse import clickhouse_adapter +from dlt.destinations.impl.athena import athena_adapter __all__ = [ "weaviate_adapter", @@ -12,4 +13,5 @@ "bigquery_adapter", "synapse_adapter", "clickhouse_adapter", + "athena_adapter", ] diff --git a/dlt/destinations/impl/athena/athena.py b/dlt/destinations/impl/athena/athena.py index 7e1ab8fc27..8f043ba4d5 100644 --- a/dlt/destinations/impl/athena/athena.py +++ b/dlt/destinations/impl/athena/athena.py @@ -11,6 +11,7 @@ Callable, Iterable, Type, + cast, ) from copy import deepcopy import re @@ -69,6 +70,7 @@ from dlt.destinations.impl.athena.configuration import AthenaClientConfiguration from dlt.destinations.type_mapping import TypeMapper from dlt.destinations import path_utils +from dlt.destinations.impl.athena.athena_adapter import PARTITION_HINT class AthenaTypeMapper(TypeMapper): @@ -405,6 +407,16 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c, table_format)}" ) + def _iceberg_partition_clause(self, partition_hints: Optional[Dict[str, str]]) -> str: + if not partition_hints: + return "" + formatted_strings = [] + for column_name, template in partition_hints.items(): + formatted_strings.append( + template.format(column_name=self.sql_client.escape_ddl_identifier(column_name)) + ) + return f"PARTITIONED BY ({', '.join(formatted_strings)})" + def _get_table_update_sql( self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool ) -> List[str]: @@ -431,8 +443,12 @@ def _get_table_update_sql( sql.append(f"""ALTER TABLE {qualified_table_name} ADD COLUMNS ({columns});""") else: if is_iceberg: + partition_clause = self._iceberg_partition_clause( + cast(Optional[Dict[str, str]], table.get(PARTITION_HINT)) + ) sql.append(f"""CREATE TABLE {qualified_table_name} ({columns}) + {partition_clause} LOCATION '{location.rstrip('/')}' TBLPROPERTIES ('table_type'='ICEBERG', 'format'='parquet');""") elif table_format == "jsonl": diff --git a/dlt/destinations/impl/athena/athena_adapter.py b/dlt/destinations/impl/athena/athena_adapter.py new file mode 100644 index 0000000000..cb600335c0 --- /dev/null +++ b/dlt/destinations/impl/athena/athena_adapter.py @@ -0,0 +1,117 @@ +from typing import Any, Optional, Dict, Protocol, Sequence, Union, Final + +from dateutil import parser + +from dlt.common.pendulum import timezone +from dlt.common.schema.typing import TColumnNames, TTableSchemaColumns, TColumnSchema +from dlt.destinations.utils import ensure_resource +from dlt.extract import DltResource +from dlt.extract.items import TTableHintTemplate + + +PARTITION_HINT: Final[str] = "x-athena-partition" + + +class PartitionTransformation: + template: str + """Template string of the transformation including column name placeholder. E.g. `bucket(16, {column_name})`""" + column_name: str + """Column name to apply the transformation to""" + + def __init__(self, template: str, column_name: str) -> None: + self.template = template + self.column_name = column_name + + +class athena_partition: + """Helper class to generate iceberg partition transformations + + E.g. `athena_partition.bucket(16, "id")` will return a transformation with template `bucket(16, {column_name})` + This can be correctly rendered by the athena loader with escaped column name. + """ + + @staticmethod + def year(column_name: str) -> PartitionTransformation: + """Partition by year part of a date or timestamp column.""" + return PartitionTransformation("year({column_name})", column_name) + + @staticmethod + def month(column_name: str) -> PartitionTransformation: + """Partition by month part of a date or timestamp column.""" + return PartitionTransformation("month({column_name})", column_name) + + @staticmethod + def day(column_name: str) -> PartitionTransformation: + """Partition by day part of a date or timestamp column.""" + return PartitionTransformation("day({column_name})", column_name) + + @staticmethod + def hour(column_name: str) -> PartitionTransformation: + """Partition by hour part of a date or timestamp column.""" + return PartitionTransformation("hour({column_name})", column_name) + + @staticmethod + def bucket(n: int, column_name: str) -> PartitionTransformation: + """Partition by hashed value to n buckets.""" + return PartitionTransformation(f"bucket({n}, {{column_name}})", column_name) + + @staticmethod + def truncate(length: int, column_name: str) -> PartitionTransformation: + """Partition by value truncated to length.""" + return PartitionTransformation(f"truncate({length}, {{column_name}})", column_name) + + +def athena_adapter( + data: Any, + partition: Union[ + str, PartitionTransformation, Sequence[Union[str, PartitionTransformation]] + ] = None, +) -> DltResource: + """ + Prepares data for loading into Athena + + Args: + data: The data to be transformed. + This can be raw data or an instance of DltResource. + If raw data is provided, the function will wrap it into a `DltResource` object. + partition: Column name(s) or instances of `PartitionTransformation` to partition the table by. + To use a transformation it's best to use the methods of the helper class `athena_partition` + to generate correctly escaped SQL in the loader. + + Returns: + A `DltResource` object that is ready to be loaded into BigQuery. + + Raises: + ValueError: If any hint is invalid or none are specified. + + Examples: + >>> data = [{"name": "Marcel", "department": "Engineering", "date_hired": "2024-01-30"}] + >>> athena_adapter(data, partition=["department", athena_partition.year("date_hired"), athena_partition.bucket(8, "name")]) + [DltResource with hints applied] + """ + resource = ensure_resource(data) + additional_table_hints: Dict[str, TTableHintTemplate[Any]] = {} + + if partition: + if isinstance(partition, str) or not isinstance(partition, Sequence): + partition = [partition] + + # Partition hint is `{column_name: template}`, e.g. `{"department": "{column_name}", "date_hired": "year({column_name})"}` + # Use one dict for all hints instead of storing on column so order is preserved + partition_hint: Dict[str, str] = {} + + for item in partition: + if isinstance(item, PartitionTransformation): + # Client will generate the final SQL string with escaped column name injected + partition_hint[item.column_name] = item.template + else: + # Item is the column name + partition_hint[item] = "{column_name}" + + additional_table_hints[PARTITION_HINT] = partition_hint + + if additional_table_hints: + resource.apply_hints(additional_table_hints=additional_table_hints) + else: + raise ValueError("A value for `partition` must be specified.") + return resource diff --git a/docs/website/docs/dlt-ecosystem/destinations/athena.md b/docs/website/docs/dlt-ecosystem/destinations/athena.md index 7c907664d3..93291bfe9a 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/athena.md +++ b/docs/website/docs/dlt-ecosystem/destinations/athena.md @@ -161,5 +161,62 @@ aws_data_catalog="awsdatacatalog" You can choose the following file formats: * [parquet](../file-formats/parquet.md) is used by default + +## Athena adapter + +You can use the `athena_adapter` to add partitioning to Athena tables. This is currently only supported for Iceberg tables. + +Iceberg tables support a few transformation functions for partitioning. Info on all supported functions in the [AWS documentation](https://docs.aws.amazon.com/athena/latest/ug/querying-iceberg-creating-tables.html#querying-iceberg-creating-tables-query-editor). + +Use the `athena_partition` helper to generate the partitioning hints for these functions: + +* `athena_partition.year(column_name: str)`: Partition by year of date/datetime column. +* `athena_partition.month(column_name: str)`: Partition by month of date/datetime column. +* `athena_partition.day(column_name: str)`: Partition by day of date/datetime column. +* `athena_partition.hour(column_name: str)`: Partition by hour of date/datetime column. +* `athena_partition.bucket(n: int, column_name: str)`: Partition by hashed value to `n` buckets +* `athena_partition.truncate(length: int, column_name: str)`: Partition by truncated value to `length` (or width for numbers) + +Here is an example of how to use the adapter to partition a table: + +```py +from datetime import date + +import dlt +from dlt.destinations.impl.athena.athena_adapter import athena_partition, athena_adapter + +data_items = [ + (1, "A", date(2021, 1, 1)), + (2, "A", date(2021, 1, 2)), + (3, "A", date(2021, 1, 3)), + (4, "A", date(2021, 2, 1)), + (5, "A", date(2021, 2, 2)), + (6, "B", date(2021, 1, 1)), + (7, "B", date(2021, 1, 2)), + (8, "B", date(2021, 1, 3)), + (9, "B", date(2021, 2, 1)), + (10, "B", date(2021, 3, 2)), +] + +@dlt.resource(table_format="iceberg") +def partitioned_data(): + yield [{"id": i, "category": c, "created_at": d} for i, c, d in data_items] + + +# Add partitioning hints to the table +athena_adapter( + partitioned_table, + partition=[ + # Partition per category and month + "category", + athena_partition.month("created_at"), + ], +) + + +pipeline = dlt.pipeline("athena_example") +pipeline.run(partitioned_data) +``` + diff --git a/tests/load/athena_iceberg/__init__.py b/tests/load/athena_iceberg/__init__.py index e69de29bb2..56e5d539c2 100644 --- a/tests/load/athena_iceberg/__init__.py +++ b/tests/load/athena_iceberg/__init__.py @@ -0,0 +1,4 @@ +from tests.utils import skip_if_not_active + + +skip_if_not_active("athena") diff --git a/tests/load/athena_iceberg/test_athena_adapter.py b/tests/load/athena_iceberg/test_athena_adapter.py new file mode 100644 index 0000000000..3144eb9cc9 --- /dev/null +++ b/tests/load/athena_iceberg/test_athena_adapter.py @@ -0,0 +1,69 @@ +import pytest + +import dlt +from dlt.destinations import filesystem +from dlt.destinations.impl.athena.athena_adapter import athena_adapter, athena_partition + +# mark all tests as essential, do not remove +pytestmark = pytest.mark.essential + + +def test_iceberg_partition_hints(): + """Create a table with athena partition hints and check that the SQL is generated correctly.""" + + @dlt.resource(table_format="iceberg") + def partitioned_table(): + yield { + "product_id": 1, + "name": "product 1", + "created_at": "2021-01-01T00:00:00Z", + "category": "category 1", + "price": 100.0, + "quantity": 10, + } + + @dlt.resource(table_format="iceberg") + def not_partitioned_table(): + yield {"a": 1, "b": 2} + + athena_adapter( + partitioned_table, + partition=[ + "category", + athena_partition.month("created_at"), + athena_partition.bucket(10, "product_id"), + athena_partition.truncate(2, "name"), + ], + ) + + pipeline = dlt.pipeline( + "athena_test", + destination="athena", + staging=filesystem("s3://not-a-real-bucket"), + full_refresh=True, + ) + + pipeline.extract([partitioned_table, not_partitioned_table]) + pipeline.normalize() + + with pipeline._sql_job_client(pipeline.default_schema) as client: + sql_partitioned = client._get_table_update_sql( + "partitioned_table", + list(pipeline.default_schema.tables["partitioned_table"]["columns"].values()), + False, + )[0] + sql_not_partitioned = client._get_table_update_sql( + "not_partitioned_table", + list(pipeline.default_schema.tables["not_partitioned_table"]["columns"].values()), + False, + )[0] + + # Partition clause is generated with original order + expected_clause = ( + "PARTITIONED BY (`category`, month(`created_at`), bucket(10, `product_id`), truncate(2," + " `name`))" + ) + assert expected_clause in sql_partitioned + + # No partition clause otherwise + assert "PARTITIONED BY" not in sql_not_partitioned diff --git a/tests/load/athena_iceberg/test_athena_iceberg.py b/tests/load/athena_iceberg/test_athena_iceberg.py index dbcdc5c23e..d3bb9eb5f5 100644 --- a/tests/load/athena_iceberg/test_athena_iceberg.py +++ b/tests/load/athena_iceberg/test_athena_iceberg.py @@ -11,14 +11,11 @@ from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration -from tests.utils import skip_if_not_active from dlt.destinations.exceptions import DatabaseTerminalException # mark all tests as essential, do not remove pytestmark = pytest.mark.essential -skip_if_not_active("athena") - def test_iceberg() -> None: """ diff --git a/tests/load/pipeline/test_athena.py b/tests/load/pipeline/test_athena.py index 8c034a066b..a5bb6efc0d 100644 --- a/tests/load/pipeline/test_athena.py +++ b/tests/load/pipeline/test_athena.py @@ -9,6 +9,8 @@ from tests.pipeline.utils import assert_load_info, load_table_counts from tests.pipeline.utils import load_table_counts from dlt.destinations.exceptions import CantExtractTablePrefix +from dlt.destinations.impl.athena.athena_adapter import athena_partition, athena_adapter +from dlt.destinations.fs_client import FSClientBase from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration from tests.load.utils import ( @@ -231,3 +233,69 @@ def test_athena_file_layouts(destination_config: DestinationTestConfiguration, l pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) assert table_counts == {"items1": 3, "items2": 7} + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["athena"], force_iceberg=True), + ids=lambda x: x.name, +) +def test_athena_partitioned_iceberg_table(destination_config: DestinationTestConfiguration): + """Load an iceberg table with partition hints and verifiy partitions are created correctly.""" + pipeline = destination_config.setup_pipeline("athena_" + uniq_id(), full_refresh=True) + + data_items = [ + (1, "A", datetime.date.fromisoformat("2021-01-01")), + (2, "A", datetime.date.fromisoformat("2021-01-02")), + (3, "A", datetime.date.fromisoformat("2021-01-03")), + (4, "A", datetime.date.fromisoformat("2021-02-01")), + (5, "A", datetime.date.fromisoformat("2021-02-02")), + (6, "B", datetime.date.fromisoformat("2021-01-01")), + (7, "B", datetime.date.fromisoformat("2021-01-02")), + (8, "B", datetime.date.fromisoformat("2021-01-03")), + (9, "B", datetime.date.fromisoformat("2021-02-01")), + (10, "B", datetime.date.fromisoformat("2021-03-02")), + ] + + @dlt.resource(table_format="iceberg") + def partitioned_table(): + yield [{"id": i, "category": c, "created_at": d} for i, c, d in data_items] + + athena_adapter( + partitioned_table, + partition=[ + "category", + athena_partition.month("created_at"), + ], + ) + + info = pipeline.run(partitioned_table) + assert_load_info(info) + + # Get partitions from metadata + with pipeline.sql_client() as sql_client: + tbl_name = sql_client.make_qualified_table_name("partitioned_table$partitions") + rows = sql_client.execute_sql(f"SELECT partition FROM {tbl_name}") + partition_keys = {r[0] for r in rows} + + data_rows = sql_client.execute_sql( + "SELECT id, category, created_at FROM" + f" {sql_client.make_qualified_table_name('partitioned_table')}" + ) + # data_rows = [(i, c, d.toisoformat()) for i, c, d in data_rows] + + # All data is in table + assert len(data_rows) == len(data_items) + assert set(data_rows) == set(data_items) + + # Compare with expected partitions + # Months are number of months since epoch + expected_partitions = { + "{category=A, created_at_month=612}", + "{category=A, created_at_month=613}", + "{category=B, created_at_month=612}", + "{category=B, created_at_month=613}", + "{category=B, created_at_month=614}", + } + + assert partition_keys == expected_partitions diff --git a/tests/load/utils.py b/tests/load/utils.py index c03470676f..e6b860c723 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -180,6 +180,7 @@ def destinations_configs( file_format: Union[TLoaderFileFormat, Sequence[TLoaderFileFormat]] = None, supports_merge: Optional[bool] = None, supports_dbt: Optional[bool] = None, + force_iceberg: Optional[bool] = None, ) -> List[DestinationTestConfiguration]: # sanity check for item in subset: @@ -495,6 +496,11 @@ def destinations_configs( conf for conf in destination_configs if conf.name not in EXCLUDED_DESTINATION_CONFIGURATIONS ] + if force_iceberg is not None: + destination_configs = [ + conf for conf in destination_configs if conf.force_iceberg is force_iceberg + ] + return destination_configs