Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support partitioning hints for athena iceberg #1403

Merged
merged 6 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dlt/destinations/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from dlt.destinations.impl.bigquery import bigquery_adapter
from dlt.destinations.impl.synapse import synapse_adapter
from dlt.destinations.impl.clickhouse import clickhouse_adapter
from dlt.destinations.impl.athena import athena_adapter

__all__ = [
"weaviate_adapter",
"qdrant_adapter",
"bigquery_adapter",
"synapse_adapter",
"clickhouse_adapter",
"athena_adapter",
]
38 changes: 29 additions & 9 deletions dlt/destinations/impl/athena/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Callable,
Iterable,
Type,
cast,
)
from copy import deepcopy
import re
Expand Down Expand Up @@ -69,6 +70,7 @@
from dlt.destinations.impl.athena.configuration import AthenaClientConfiguration
from dlt.destinations.type_mapping import TypeMapper
from dlt.destinations import path_utils
from dlt.destinations.impl.athena.athena_adapter import PARTITION_HINT


class AthenaTypeMapper(TypeMapper):
Expand Down Expand Up @@ -401,9 +403,17 @@ def _from_db_type(
return self.type_mapper.from_db_type(hive_t, precision, scale)

def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str:
return (
f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c, table_format)}"
)
return f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c, table_format)}"

def _iceberg_partition_clause(self, partition_hints: Optional[Dict[str, str]]) -> str:
if not partition_hints:
return ""
formatted_strings = []
for column_name, template in partition_hints.items():
formatted_strings.append(
template.format(column_name=self.sql_client.escape_ddl_identifier(column_name))
)
return f"PARTITIONED BY ({', '.join(formatted_strings)})"

def _get_table_update_sql(
self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool
Expand Down Expand Up @@ -431,20 +441,30 @@ def _get_table_update_sql(
sql.append(f"""ALTER TABLE {qualified_table_name} ADD COLUMNS ({columns});""")
else:
if is_iceberg:
sql.append(f"""CREATE TABLE {qualified_table_name}
partition_clause = self._iceberg_partition_clause(
cast(Optional[Dict[str, str]], table.get(PARTITION_HINT))
)
sql.append(
f"""CREATE TABLE {qualified_table_name}
({columns})
{partition_clause}
LOCATION '{location.rstrip('/')}'
TBLPROPERTIES ('table_type'='ICEBERG', 'format'='parquet');""")
TBLPROPERTIES ('table_type'='ICEBERG', 'format'='parquet');"""
)
elif table_format == "jsonl":
sql.append(f"""CREATE EXTERNAL TABLE {qualified_table_name}
sql.append(
f"""CREATE EXTERNAL TABLE {qualified_table_name}
({columns})
ROW FORMAT SERDE 'org.openx.data.jsonserde.JsonSerDe'
LOCATION '{location}';""")
LOCATION '{location}';"""
)
else:
sql.append(f"""CREATE EXTERNAL TABLE {qualified_table_name}
sql.append(
f"""CREATE EXTERNAL TABLE {qualified_table_name}
({columns})
STORED AS PARQUET
LOCATION '{location}';""")
LOCATION '{location}';"""
)
return sql

def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob:
Expand Down
117 changes: 117 additions & 0 deletions dlt/destinations/impl/athena/athena_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from typing import Any, Optional, Dict, Protocol, Sequence, Union, Final

from dateutil import parser

from dlt.common.pendulum import timezone
from dlt.common.schema.typing import TColumnNames, TTableSchemaColumns, TColumnSchema
from dlt.destinations.utils import ensure_resource
from dlt.extract import DltResource
from dlt.extract.items import TTableHintTemplate


PARTITION_HINT: Final[str] = "x-athena-partition"


class PartitionTransformation:
template: str
"""Template string of the transformation including column name placeholder. E.g. `bucket(16, {column_name})`"""
column_name: str
"""Column name to apply the transformation to"""

def __init__(self, template: str, column_name: str) -> None:
self.template = template
self.column_name = column_name


class athena_partition:
"""Helper class to generate iceberg partition transformations

E.g. `athena_partition.bucket(16, "id")` will return a transformation with template `bucket(16, {column_name})`
This can be correctly rendered by the athena loader with escaped column name.
"""

@staticmethod
def year(column_name: str) -> PartitionTransformation:
"""Partition by year part of a date or timestamp column."""
return PartitionTransformation("year({column_name})", column_name)

@staticmethod
def month(column_name: str) -> PartitionTransformation:
"""Partition by month part of a date or timestamp column."""
return PartitionTransformation("month({column_name})", column_name)

@staticmethod
def day(column_name: str) -> PartitionTransformation:
"""Partition by day part of a date or timestamp column."""
return PartitionTransformation("day({column_name})", column_name)

@staticmethod
def hour(column_name: str) -> PartitionTransformation:
"""Partition by hour part of a date or timestamp column."""
return PartitionTransformation("hour({column_name})", column_name)

@staticmethod
def bucket(n: int, column_name: str) -> PartitionTransformation:
"""Partition by hashed value to n buckets."""
return PartitionTransformation(f"bucket({n}, {{column_name}})", column_name)

@staticmethod
def truncate(length: int, column_name: str) -> PartitionTransformation:
"""Partition by value truncated to length."""
return PartitionTransformation(f"truncate({length}, {{column_name}})", column_name)


def athena_adapter(
data: Any,
partition: Union[
str, PartitionTransformation, Sequence[Union[str, PartitionTransformation]]
] = None,
) -> DltResource:
"""
Prepares data for loading into Athena

Args:
data: The data to be transformed.
This can be raw data or an instance of DltResource.
If raw data is provided, the function will wrap it into a `DltResource` object.
partition: Column name(s) or instances of `PartitionTransformation` to partition the table by.
To use a transformation it's best to use the methods of the helper class `athena_partition`
to generate correctly escaped SQL in the loader.

Returns:
A `DltResource` object that is ready to be loaded into BigQuery.

Raises:
ValueError: If any hint is invalid or none are specified.

Examples:
>>> data = [{"name": "Marcel", "department": "Engineering", "date_hired": "2024-01-30"}]
>>> athena_adapter(data, partition=["department", athena_partition.year("date_hired"), athena_partition.bucket(8, "name")])
[DltResource with hints applied]
"""
resource = ensure_resource(data)
additional_table_hints: Dict[str, TTableHintTemplate[Any]] = {}

if partition:
if isinstance(partition, str) or not isinstance(partition, Sequence):
partition = [partition]

# Partition hint is `{column_name: template}`, e.g. `{"department": "{column_name}", "date_hired": "year({column_name})"}`
# Use one dict for all hints instead of storing on column so order is preserved
partition_hint: Dict[str, str] = {}

for item in partition:
if isinstance(item, PartitionTransformation):
# Client will generate the final SQL string with escaped column name injected
partition_hint[item.column_name] = item.template
else:
# Item is the column name
partition_hint[item] = "{column_name}"

additional_table_hints[PARTITION_HINT] = partition_hint

if additional_table_hints:
resource.apply_hints(additional_table_hints=additional_table_hints)
else:
raise ValueError("A value for `partition` must be specified.")
return resource
57 changes: 57 additions & 0 deletions docs/website/docs/dlt-ecosystem/destinations/athena.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,5 +161,62 @@ aws_data_catalog="awsdatacatalog"
You can choose the following file formats:
* [parquet](../file-formats/parquet.md) is used by default


## Athena adapter

You can use the `athena_adapter` to add partitioning to Athena tables. This is currently only supported for Iceberg tables.

Iceberg tables support a few transformation functions for partitioning. Info on all supported functions in the [AWS documentation](https://docs.aws.amazon.com/athena/latest/ug/querying-iceberg-creating-tables.html#querying-iceberg-creating-tables-query-editor).

Use the `athena_partition` helper to generate the partitioning hints for these functions:

* `athena_partition.year(column_name: str)`: Partition by year of date/datetime column.
* `athena_partition.month(column_name: str)`: Partition by month of date/datetime column.
* `athena_partition.day(column_name: str)`: Partition by day of date/datetime column.
* `athena_partition.hour(column_name: str)`: Partition by hour of date/datetime column.
* `athena_partition.bucket(n: int, column_name: str)`: Partition by hashed value to `n` buckets
* `athena_partition.truncate(length: int, column_name: str)`: Partition by truncated value to `length` (or width for numbers)

Here is an example of how to use the adapter to partition a table:

```py
from datetime import date

import dlt
from dlt.destinations.impl.athena.athena_adapter import athena_partition, athena_adapter

data_items = [
(1, "A", date(2021, 1, 1)),
(2, "A", date(2021, 1, 2)),
(3, "A", date(2021, 1, 3)),
(4, "A", date(2021, 2, 1)),
(5, "A", date(2021, 2, 2)),
(6, "B", date(2021, 1, 1)),
(7, "B", date(2021, 1, 2)),
(8, "B", date(2021, 1, 3)),
(9, "B", date(2021, 2, 1)),
(10, "B", date(2021, 3, 2)),
]

@dlt.resource(table_format="iceberg")
def partitioned_data():
yield [{"id": i, "category": c, "created_at": d} for i, c, d in data_items]


# Add partitioning hints to the table
athena_adapter(
partitioned_table,
partition=[
# Partition per category and month
"category",
athena_partition.month("created_at"),
],
)


pipeline = dlt.pipeline("athena_example")
pipeline.run(partitioned_data)
```

<!--@@@DLT_TUBA athena-->

4 changes: 4 additions & 0 deletions tests/load/athena_iceberg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from tests.utils import skip_if_not_active


skip_if_not_active("athena")
64 changes: 64 additions & 0 deletions tests/load/athena_iceberg/test_athena_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pytest

import dlt
from dlt.destinations import filesystem
from dlt.destinations.impl.athena.athena_adapter import athena_adapter, athena_partition
from tests.load.utils import destinations_configs, DestinationTestConfiguration


def test_iceberg_partition_hints():
"""Create a table with athena partition hints and check that the SQL is generated correctly."""

@dlt.resource(table_format="iceberg")
def partitioned_table():
yield {
"product_id": 1,
"name": "product 1",
"created_at": "2021-01-01T00:00:00Z",
"category": "category 1",
"price": 100.0,
"quantity": 10,
}

@dlt.resource(table_format="iceberg")
def not_partitioned_table():
yield {"a": 1, "b": 2}

athena_adapter(
partitioned_table,
partition=[
"category",
athena_partition.month("created_at"),
athena_partition.bucket(10, "product_id"),
athena_partition.truncate(2, "name"),
],
)

pipeline = dlt.pipeline(
"athena_test",
destination="athena",
staging=filesystem("s3://not-a-real-bucket"),
full_refresh=True,
)

pipeline.extract([partitioned_table, not_partitioned_table])
pipeline.normalize()

with pipeline._sql_job_client(pipeline.default_schema) as client:
sql_partitioned = client._get_table_update_sql(
"partitioned_table",
list(pipeline.default_schema.tables["partitioned_table"]["columns"].values()),
False,
)[0]
sql_not_partitioned = client._get_table_update_sql(
"not_partitioned_table",
list(pipeline.default_schema.tables["not_partitioned_table"]["columns"].values()),
False,
)[0]

# Partition clause is generated with original order
expected_clause = "PARTITIONED BY (`category`, month(`created_at`), bucket(10, `product_id`), truncate(2, `name`))"
assert expected_clause in sql_partitioned

# No partition clause otherwise
assert "PARTITIONED BY" not in sql_not_partitioned
3 changes: 0 additions & 3 deletions tests/load/athena_iceberg/test_athena_iceberg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,11 @@

from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration

from tests.utils import skip_if_not_active
from dlt.destinations.exceptions import DatabaseTerminalException

# mark all tests as essential, do not remove
pytestmark = pytest.mark.essential

skip_if_not_active("athena")


def test_iceberg() -> None:
"""
Expand Down
Loading
Loading