From d7455f67676ede80f2085b2f39a43821dcd00b62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Far=C3=ADas=20Santana?= Date: Wed, 8 Jan 2025 09:11:10 +0100 Subject: [PATCH] refactor: Refactor snowflake to use spmc abstractions (#26900) --- posthog/settings/temporal.py | 3 + .../batch_exports/snowflake_batch_export.py | 275 +++++++++--------- .../test_snowflake_batch_export_workflow.py | 6 - 3 files changed, 147 insertions(+), 137 deletions(-) diff --git a/posthog/settings/temporal.py b/posthog/settings/temporal.py index e168d12c46d84..28fc449404ad8 100644 --- a/posthog/settings/temporal.py +++ b/posthog/settings/temporal.py @@ -20,6 +20,9 @@ "BATCH_EXPORT_S3_RECORD_BATCH_QUEUE_MAX_SIZE_BYTES", 0, type_cast=int ) BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES: int = 1024 * 1024 * 100 # 100MB +BATCH_EXPORT_SNOWFLAKE_RECORD_BATCH_QUEUE_MAX_SIZE_BYTES: int = get_from_env( + "BATCH_EXPORT_SNOWFLAKE_RECORD_BATCH_QUEUE_MAX_SIZE_BYTES", 1024 * 1024 * 300, type_cast=int +) BATCH_EXPORT_POSTGRES_UPLOAD_CHUNK_SIZE_BYTES: int = 1024 * 1024 * 50 # 50MB BATCH_EXPORT_BIGQUERY_UPLOAD_CHUNK_SIZE_BYTES: int = 1024 * 1024 * 100 # 100MB BATCH_EXPORT_BIGQUERY_RECORD_BATCH_QUEUE_MAX_SIZE_BYTES: int = get_from_env( diff --git a/posthog/temporal/batch_exports/snowflake_batch_export.py b/posthog/temporal/batch_exports/snowflake_batch_export.py index 8887fcd52a317..0e3046a371048 100644 --- a/posthog/temporal/batch_exports/snowflake_batch_export.py +++ b/posthog/temporal/batch_exports/snowflake_batch_export.py @@ -12,7 +12,7 @@ import snowflake.connector from django.conf import settings from snowflake.connector.connection import SnowflakeConnection -from snowflake.connector.errors import OperationalError, InterfaceError +from snowflake.connector.errors import InterfaceError, OperationalError from temporalio import activity, workflow from temporalio.common import RetryPolicy @@ -31,32 +31,43 @@ default_fields, execute_batch_export_insert_activity, get_data_interval, - iter_model_records, start_batch_export_run, ) -from posthog.temporal.batch_exports.metrics import ( - get_bytes_exported_metric, - get_rows_exported_metric, +from posthog.temporal.batch_exports.heartbeat import ( + BatchExportRangeHeartbeatDetails, + DateRange, + should_resume_from_activity_heartbeat, +) +from posthog.temporal.batch_exports.spmc import ( + Consumer, + Producer, + RecordBatchQueue, + run_consumer, + wait_for_schema_or_producer, ) from posthog.temporal.batch_exports.temporary_file import ( BatchExportTemporaryFile, - JSONLBatchExportWriter, + WriterFormat, ) from posthog.temporal.batch_exports.utils import ( JsonType, - apeek_first_and_rewind, - cast_record_batch_json_columns, set_status_to_running_task, ) from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.heartbeat import Heartbeater from posthog.temporal.common.logger import bind_temporal_worker_logger -from posthog.temporal.batch_exports.heartbeat import ( - BatchExportRangeHeartbeatDetails, - DateRange, - HeartbeatParseError, - should_resume_from_activity_heartbeat, -) + +NON_RETRYABLE_ERROR_TYPES = [ + # Raised when we cannot connect to Snowflake. + "DatabaseError", + # Raised by Snowflake when a query cannot be compiled. + # Usually this means we don't have table permissions or something doesn't exist (db, schema). + "ProgrammingError", + # Raised by Snowflake with an incorrect account name. + "ForbiddenError", + # Our own exception when we can't connect to Snowflake, usually due to invalid parameters. + "SnowflakeConnectionError", +] class SnowflakeFileNotUploadedError(Exception): @@ -91,37 +102,9 @@ class SnowflakeRetryableConnectionError(Exception): @dataclasses.dataclass class SnowflakeHeartbeatDetails(BatchExportRangeHeartbeatDetails): - """The Snowflake batch export details included in every heartbeat. - - Attributes: - file_no: The file number of the last file we managed to upload. - """ - - file_no: int = 0 - - @classmethod - def deserialize_details(cls, details: collections.abc.Sequence[typing.Any]) -> dict[str, typing.Any]: - """Attempt to initialize HeartbeatDetails from an activity's details.""" - file_no = 0 - remaining = super().deserialize_details(details) - - if len(remaining["_remaining"]) == 0: - return {"file_no": 0, **remaining} - - first_detail = remaining["_remaining"][0] - remaining["_remaining"] = remaining["_remaining"][1:] - - try: - file_no = int(first_detail) - except (TypeError, ValueError) as e: - raise HeartbeatParseError("file_no") from e - - return {"file_no": file_no, **remaining} + """The Snowflake batch export details included in every heartbeat.""" - def serialize_details(self) -> tuple[typing.Any, ...]: - """Attempt to initialize HeartbeatDetails from an activity's details.""" - serialized_parent_details = super().serialize_details() - return (*serialized_parent_details[:-1], self.file_no, self._remaining) + pass @dataclasses.dataclass @@ -344,7 +327,6 @@ async def put_file_to_snowflake_table( file: BatchExportTemporaryFile, table_stage_prefix: str, table_name: str, - file_no: int, ): """Executes a PUT query using the provided cursor to the provided table_name. @@ -352,14 +334,9 @@ async def put_file_to_snowflake_table( call to run_in_executor: Since execute ends up boiling down to blocking IO (HTTP request), the event loop should not be locked up. - We add a file_no to the file_name when executing PUT as Snowflake will reject any files with the same - name. Since batch exports re-use the same file, our name does not change, but we don't want Snowflake - to reject or overwrite our new data. - Args: file: The name of the local file to PUT. table_name: The name of the Snowflake table where to PUT the file. - file_no: An int to identify which file number this is. Raises: TypeError: If we don't get a tuple back from Snowflake (should never happen). @@ -371,7 +348,7 @@ async def put_file_to_snowflake_table( # So we ask mypy to be nice with us. reader = io.BufferedReader(file) # type: ignore query = f""" - PUT file://{file.name}_{file_no}.jsonl '@%"{table_name}"/{table_stage_prefix}' + PUT file://{file.name} '@%"{table_name}"/{table_stage_prefix}' """ with self.connection.cursor() as cursor: @@ -518,6 +495,60 @@ def snowflake_default_fields() -> list[BatchExportField]: return batch_export_fields +class SnowflakeConsumer(Consumer): + def __init__( + self, + heartbeater: Heartbeater, + heartbeat_details: SnowflakeHeartbeatDetails, + data_interval_start: dt.datetime | str | None, + data_interval_end: dt.datetime | str, + writer_format: WriterFormat, + snowflake_client: SnowflakeClient, + snowflake_table: str, + snowflake_table_stage_prefix: str, + ): + super().__init__( + heartbeater=heartbeater, + heartbeat_details=heartbeat_details, + data_interval_start=data_interval_start, + data_interval_end=data_interval_end, + writer_format=writer_format, + ) + self.heartbeat_details: SnowflakeHeartbeatDetails = heartbeat_details + self.snowflake_table = snowflake_table + self.snowflake_client = snowflake_client + self.snowflake_table_stage_prefix = snowflake_table_stage_prefix + + async def flush( + self, + batch_export_file: BatchExportTemporaryFile, + records_since_last_flush: int, + bytes_since_last_flush: int, + flush_counter: int, + last_date_range: DateRange, + is_last: bool, + error: Exception | None, + ): + await self.logger.ainfo( + "Putting file %s containing %s records with size %s bytes", + flush_counter, + records_since_last_flush, + bytes_since_last_flush, + ) + + await self.snowflake_client.put_file_to_snowflake_table( + batch_export_file, + self.snowflake_table_stage_prefix, + self.snowflake_table, + ) + + await self.logger.adebug("Loaded %s to Snowflake table '%s'", records_since_last_flush, self.snowflake_table) + self.rows_exported_counter.add(records_since_last_flush) + self.bytes_exported_counter.add(bytes_since_last_flush) + + self.heartbeat_details.track_done_range(last_date_range, self.data_interval_start) + + def get_snowflake_fields_from_record_schema( record_schema: pa.Schema, known_variant_columns: list[str] ) -> list[SnowflakeField]: @@ -594,42 +625,63 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs) -> Recor details = SnowflakeHeartbeatDetails() done_ranges: list[DateRange] = details.done_ranges - if done_ranges: - data_interval_start: str | None = done_ranges[-1][1].isoformat() - else: - data_interval_start = inputs.data_interval_start - - current_flush_counter = details.file_no - - rows_exported = get_rows_exported_metric() - bytes_exported = get_bytes_exported_metric() model: BatchExportModel | BatchExportSchema | None = None if inputs.batch_export_schema is None and "batch_export_model" in { field.name for field in dataclasses.fields(inputs) }: model = inputs.batch_export_model + if model is not None: + model_name = model.name + extra_query_parameters = model.schema["values"] if model.schema is not None else None + fields = model.schema["fields"] if model.schema is not None else None + else: + model_name = "events" + extra_query_parameters = None + fields = None else: model = inputs.batch_export_schema + model_name = "custom" + extra_query_parameters = model["values"] if model is not None else {} + fields = model["fields"] if model is not None else None - records_iterator = iter_model_records( - client=client, - model=model, + data_interval_start = ( + dt.datetime.fromisoformat(inputs.data_interval_start) if inputs.data_interval_start else None + ) + data_interval_end = dt.datetime.fromisoformat(inputs.data_interval_end) + full_range = (data_interval_start, data_interval_end) + + queue = RecordBatchQueue(max_size_bytes=settings.BATCH_EXPORT_SNOWFLAKE_RECORD_BATCH_QUEUE_MAX_SIZE_BYTES) + producer = Producer(clickhouse_client=client) + producer_task = producer.start( + queue=queue, + model_name=model_name, + is_backfill=inputs.is_backfill, team_id=inputs.team_id, - interval_start=data_interval_start, - interval_end=inputs.data_interval_end, + full_range=full_range, + done_ranges=done_ranges, + fields=fields, + destination_default_fields=snowflake_default_fields(), exclude_events=inputs.exclude_events, include_events=inputs.include_events, - destination_default_fields=snowflake_default_fields(), - is_backfill=inputs.is_backfill, + extra_query_parameters=extra_query_parameters, + ) + records_completed = 0 + + record_batch_schema = await wait_for_schema_or_producer(queue, producer_task) + if record_batch_schema is None: + return records_completed + + record_batch_schema = pa.schema( + # NOTE: For some reason, some batches set non-nullable fields as non-nullable, whereas other + # record batches have them as nullable. + # Until we figure it out, we set all fields to nullable. There are some fields we know + # are not nullable, but I'm opting for the more flexible option until we out why schemas differ + # between batches. + [field.with_nullable(True) for field in record_batch_schema if field.name != "_inserted_at"] ) - first_record_batch, records_iterator = await apeek_first_and_rewind(records_iterator) - - if first_record_batch is None: - return 0 known_variant_columns = ["properties", "people_set", "people_set_once", "person_properties"] - first_record_batch = cast_record_batch_json_columns(first_record_batch, json_columns=known_variant_columns) if model is None or (isinstance(model, BatchExportModel) and model.name == "events"): table_fields = [ @@ -647,10 +699,8 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs) -> Recor ] else: - column_names = [column for column in first_record_batch.schema.names if column != "_inserted_at"] - record_schema = first_record_batch.select(column_names).schema table_fields = get_snowflake_fields_from_record_schema( - record_schema, + record_batch_schema, known_variant_columns=known_variant_columns, ) @@ -671,57 +721,30 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs) -> Recor stagle_table_name, data_interval_end_str, table_fields, create=requires_merge, delete=requires_merge ) as snow_stage_table, ): - record_columns = [field[0] for field in table_fields] - record_schema = pa.schema( - [field.with_nullable(True) for field in first_record_batch.select(record_columns).schema] + consumer = SnowflakeConsumer( + heartbeater=heartbeater, + heartbeat_details=details, + data_interval_end=data_interval_end, + data_interval_start=data_interval_start, + writer_format=WriterFormat.JSONL, + snowflake_client=snow_client, + snowflake_table=snow_stage_table if requires_merge else snow_table, + snowflake_table_stage_prefix=data_interval_end_str, ) - - async def flush_to_snowflake( - local_results_file, - records_since_last_flush, - bytes_since_last_flush, - flush_counter: int, - last_date_range: DateRange, - last: bool, - error: Exception | None, - ): - logger.info( - "Putting %sfile %s containing %s records with size %s bytes", - "last " if last else "", - flush_counter, - records_since_last_flush, - bytes_since_last_flush, - ) - - table = snow_stage_table if requires_merge else snow_table - - await snow_client.put_file_to_snowflake_table( - local_results_file, data_interval_end_str, table, flush_counter - ) - rows_exported.add(records_since_last_flush) - bytes_exported.add(bytes_since_last_flush) - - details.track_done_range(last_date_range, data_interval_start) - details.file_no = flush_counter - heartbeater.set_from_heartbeat_details(details) - - writer = JSONLBatchExportWriter( + records_completed = await run_consumer( + consumer=consumer, + queue=queue, + producer_task=producer_task, + schema=record_batch_schema, max_bytes=settings.BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES, - flush_callable=flush_to_snowflake, + json_columns=known_variant_columns, + multiple_files=True, ) - async with writer.open_temporary_file(current_flush_counter): - async for record_batch in records_iterator: - record_batch = cast_record_batch_json_columns(record_batch, json_columns=known_variant_columns) - - await writer.write_record_batch(record_batch) - - details.complete_done_ranges(inputs.data_interval_end) - heartbeater.set_from_heartbeat_details(details) - await snow_client.copy_loaded_files_to_snowflake_table( snow_stage_table if requires_merge else snow_table, data_interval_end_str ) + if requires_merge: merge_key = ( ("team_id", "INT64"), @@ -734,7 +757,7 @@ async def flush_to_snowflake( merge_key=merge_key, ) - return writer.records_total + return records_completed @workflow.defn(name="snowflake-export", failure_exception_types=[workflow.NondeterminismError]) @@ -811,16 +834,6 @@ async def run(self, inputs: SnowflakeBatchExportInputs): insert_into_snowflake_activity, insert_inputs, interval=inputs.interval, - non_retryable_error_types=[ - # Raised when we cannot connect to Snowflake. - "DatabaseError", - # Raised by Snowflake when a query cannot be compiled. - # Usually this means we don't have table permissions or something doesn't exist (db, schema). - "ProgrammingError", - # Raised by Snowflake with an incorrect account name. - "ForbiddenError", - # Our own exception when we can't connect to Snowflake, usually due to invalid parameters. - "SnowflakeConnectionError", - ], + non_retryable_error_types=NON_RETRYABLE_ERROR_TYPES, finish_inputs=finish_inputs, ) diff --git a/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py index e99ef3f1ca350..a08e49357e9ee 100644 --- a/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py @@ -467,7 +467,6 @@ async def test_snowflake_export_workflow_exports_events( ] assert all(query.startswith("PUT") for query in execute_calls[0:9]) - assert all(f"_{n}.jsonl" in query for n, query in enumerate(execute_calls[0:9])) assert execute_async_calls[3].startswith(f'CREATE TABLE IF NOT EXISTS "{table_name}"') assert execute_async_calls[4].startswith(f"""REMOVE '@%"{table_name}"/{data_interval_end_str}'""") @@ -1656,8 +1655,3 @@ def __init__(self): dt.datetime.fromisoformat(expected_done_ranges[0][1]), ) ] - - if len(details) >= 2: - assert snowflake_details.file_no == details[1] - else: - assert snowflake_details.file_no == 0