diff --git a/posthog/temporal/data_imports/pipelines/pipeline/delta_table_helper.py b/posthog/temporal/data_imports/pipelines/pipeline/delta_table_helper.py index 63191aaf918cd..542ab40bc3744 100644 --- a/posthog/temporal/data_imports/pipelines/pipeline/delta_table_helper.py +++ b/posthog/temporal/data_imports/pipelines/pipeline/delta_table_helper.py @@ -1,6 +1,7 @@ from collections.abc import Sequence from conditional_cache import lru_cache from typing import Any +import deltalake.exceptions import pyarrow as pa from dlt.common.libs.deltalake import ensure_delta_compatible_arrow_schema from dlt.common.normalizers.naming.snake_case import NamingConvention @@ -8,6 +9,7 @@ from django.conf import settings from sentry_sdk import capture_exception from posthog.settings.base_variables import TEST +from posthog.temporal.common.logger import FilteringBoundLogger from posthog.warehouse.models import ExternalDataJob from posthog.warehouse.s3 import get_s3_client @@ -15,10 +17,12 @@ class DeltaTableHelper: _resource_name: str _job: ExternalDataJob + _logger: FilteringBoundLogger - def __init__(self, resource_name: str, job: ExternalDataJob) -> None: + def __init__(self, resource_name: str, job: ExternalDataJob, logger: FilteringBoundLogger) -> None: self._resource_name = resource_name self._job = job + self._logger = logger def _get_credentials(self): if TEST: @@ -110,15 +114,27 @@ def write_to_deltalake( delta_table = deltalake.DeltaTable.create( table_uri=self._get_delta_table_uri(), schema=data.schema, storage_options=storage_options ) + try: + deltalake.write_deltalake( + table_or_uri=delta_table, + data=data, + partition_by=None, + mode=mode, + schema_mode=schema_mode, + engine="rust", + ) # type: ignore + except deltalake.exceptions.SchemaMismatchError as e: + self._logger.debug("SchemaMismatchError: attempting to overwrite schema instead", exc_info=e) + capture_exception(e) - deltalake.write_deltalake( - table_or_uri=delta_table, - data=data, - partition_by=None, - mode=mode, - schema_mode=schema_mode, - engine="rust", - ) # type: ignore + deltalake.write_deltalake( + table_or_uri=delta_table, + data=data, + partition_by=None, + mode=mode, + schema_mode="overwrite", + engine="rust", + ) # type: ignore delta_table = self.get_delta_table() assert delta_table is not None diff --git a/posthog/temporal/data_imports/pipelines/pipeline/pipeline.py b/posthog/temporal/data_imports/pipelines/pipeline/pipeline.py index 68a9f715631a2..817a32a1f2830 100644 --- a/posthog/temporal/data_imports/pipelines/pipeline/pipeline.py +++ b/posthog/temporal/data_imports/pipelines/pipeline/pipeline.py @@ -47,7 +47,7 @@ def __init__(self, source: DltSource, logger: FilteringBoundLogger, job_id: str, assert schema is not None self._schema = schema - self._delta_table_helper = DeltaTableHelper(resource_name, self._job) + self._delta_table_helper = DeltaTableHelper(resource_name, self._job, self._logger) self._internal_schema = HogQLSchema() def run(self): diff --git a/posthog/temporal/data_imports/pipelines/pipeline/utils.py b/posthog/temporal/data_imports/pipelines/pipeline/utils.py index 42c6aada0e3aa..622a66090f7cc 100644 --- a/posthog/temporal/data_imports/pipelines/pipeline/utils.py +++ b/posthog/temporal/data_imports/pipelines/pipeline/utils.py @@ -50,6 +50,23 @@ def _evolve_pyarrow_schema(table: pa.Table, delta_schema: deltalake.Schema | Non ) table = table.append_column(field, new_column_data) + # If the delta table schema has a larger scale/precision, then update the + # pyarrow schema to use the larger values so that we're not trying to downscale + if isinstance(field.type, pa.Decimal128Type): + py_arrow_table_column = table.column(field.name) + if ( + field.type.precision > py_arrow_table_column.type.precision + or field.type.scale > py_arrow_table_column.type.scale + ): + field_index = table.schema.get_field_index(field.name) + new_schema = table.schema.set( + field_index, + table.schema.field(field_index).with_type( + pa.decimal128(field.type.precision, field.type.scale) + ), + ) + table = table.cast(new_schema) + # Change types based on what deltalake tables support return table.cast(ensure_delta_compatible_arrow_schema(table.schema)) diff --git a/posthog/temporal/data_imports/pipelines/sql_database_v2/__init__.py b/posthog/temporal/data_imports/pipelines/sql_database_v2/__init__.py index 0c49b04ba1d1c..c5bc6db6674ab 100644 --- a/posthog/temporal/data_imports/pipelines/sql_database_v2/__init__.py +++ b/posthog/temporal/data_imports/pipelines/sql_database_v2/__init__.py @@ -124,6 +124,7 @@ def sql_source_for_type( db_incremental_field_last_value=db_incremental_field_last_value, team_id=team_id, connect_args=connect_args, + chunk_size=DEFAULT_CHUNK_SIZE, ) return db_source @@ -198,6 +199,7 @@ def snowflake_source( table_names=table_names, incremental=incremental, db_incremental_field_last_value=db_incremental_field_last_value, + chunk_size=DEFAULT_CHUNK_SIZE, ) return db_source @@ -243,6 +245,7 @@ def bigquery_source( table_names=[table_name], incremental=incremental, db_incremental_field_last_value=db_incremental_field_last_value, + chunk_size=DEFAULT_CHUNK_SIZE, ) diff --git a/posthog/temporal/tests/data_imports/test_end_to_end.py b/posthog/temporal/tests/data_imports/test_end_to_end.py index 120686381c590..7a953f8d679a2 100644 --- a/posthog/temporal/tests/data_imports/test_end_to_end.py +++ b/posthog/temporal/tests/data_imports/test_end_to_end.py @@ -1094,3 +1094,54 @@ async def test_postgres_uuid_type(team, postgres_config, postgres_connection): job_inputs={"stripe_secret_key": "test-key", "stripe_account_id": "acct_id"}, mock_data_response=[{"id": uuid.uuid4()}], ) + + +@pytest.mark.django_db(transaction=True) +@pytest.mark.asyncio +async def test_decimal_down_scales(team, postgres_config, postgres_connection): + if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE_V2: + await postgres_connection.execute( + "CREATE TABLE IF NOT EXISTS {schema}.downsizing_column (id integer, dec_col numeric(10, 2))".format( + schema=postgres_config["schema"] + ) + ) + await postgres_connection.execute( + "INSERT INTO {schema}.downsizing_column (id, dec_col) VALUES (1, 12345.60)".format( + schema=postgres_config["schema"] + ) + ) + + await postgres_connection.commit() + + workflow_id, inputs = await _run( + team=team, + schema_name="downsizing_column", + table_name="postgres_downsizing_column", + source_type="Postgres", + job_inputs={ + "host": postgres_config["host"], + "port": postgres_config["port"], + "database": postgres_config["database"], + "user": postgres_config["user"], + "password": postgres_config["password"], + "schema": postgres_config["schema"], + "ssh_tunnel_enabled": "False", + }, + mock_data_response=[], + ) + + await postgres_connection.execute( + "ALTER TABLE {schema}.downsizing_column ALTER COLUMN dec_col type numeric(9, 2) using dec_col::numeric(9, 2);".format( + schema=postgres_config["schema"] + ) + ) + + await postgres_connection.execute( + "INSERT INTO {schema}.downsizing_column (id, dec_col) VALUES (1, 1234567.89)".format( + schema=postgres_config["schema"] + ) + ) + + await postgres_connection.commit() + + await _execute_run(str(uuid.uuid4()), inputs, [])