Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: refactor delete dag #27661

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
163 changes: 65 additions & 98 deletions dags/deletes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
from datetime import datetime
import pandas as pd
from clickhouse_driver.client import Client

from dagster import (
asset,
Expand All @@ -22,7 +21,6 @@

class DeleteConfig(Config):
team_id: int | None = None
file_path: str = "/tmp/pending_person_deletions.parquet"
run_id: str = datetime.now().strftime("%Y%m%d_%H%M%S")


Expand All @@ -31,63 +29,6 @@ def get_versioned_names(run_id: str) -> dict[str, str]:
return {"table": f"pending_person_deletes_{run_id}", "dictionary": f"pending_person_deletes_dictionary_{run_id}"}


@asset
def pending_person_deletions(context: AssetExecutionContext, config: DeleteConfig) -> dict[str, str]:
"""Query postgres using django ORM to get pending person deletions and write to parquet."""

if not config.team_id:
# Use Django's queryset iterator for memory efficiency
pending_deletions = (
AsyncDeletion.objects.filter(deletion_type=DeletionType.Person, delete_verified_at__isnull=True)
.values("team_id", "key", "created_at")
.iterator()
)
else:
pending_deletions = AsyncDeletion.objects.filter(
deletion_type=DeletionType.Person,
team_id=config.team_id,
delete_verified_at__isnull=True,
).values("team_id", "key", "created_at")

# Create a temporary directory for our parquet file
output_path = config.file_path

# Write to parquet in chunks
chunk_size = 10000
current_chunk = []
total_rows = 0

for deletion in pending_deletions:
current_chunk.append(deletion)
if len(current_chunk) >= chunk_size:
if total_rows == 0:
# First chunk, create new file
pd.DataFrame(current_chunk).to_parquet(output_path, index=False)
else:
# Append to existing file
pd.DataFrame(current_chunk).to_parquet(output_path, index=False, append=True)
total_rows += len(current_chunk)
current_chunk = []

# Write any remaining records
if current_chunk:
if total_rows == 0:
pd.DataFrame(current_chunk).to_parquet(output_path, index=False)
else:
pd.DataFrame(current_chunk).to_parquet(output_path, index=False, append=True)
total_rows += len(current_chunk)

context.add_output_metadata(
{
"total_rows": MetadataValue.int(total_rows),
"file_path": MetadataValue.text(output_path),
"file_size": MetadataValue.int(os.path.getsize(output_path)),
}
)

return {"file_path": output_path, "total_rows": str(total_rows)}


@asset
def create_pending_deletes_table(context: AssetExecutionContext, config: DeleteConfig):
"""Create a merge tree table in ClickHouse to store pending deletes."""
Expand All @@ -108,50 +49,81 @@ def create_pending_deletes_table(context: AssetExecutionContext, config: DeleteC
return {"table_name": names["table"]}


@asset(deps=[pending_person_deletions, create_pending_deletes_table])
def insert_pending_deletes(context: AssetExecutionContext, pending_person_deletions, create_pending_deletes_table):
"""Insert pending deletes from parquet file into ClickHouse merge tree using Arrow."""
if not pending_person_deletions.get("total_rows", 0):
return 0

import pyarrow.parquet as pq
from clickhouse_driver.client import Client
@asset(deps=[create_pending_deletes_table])
def pending_person_deletions(context: AssetExecutionContext, config: DeleteConfig, create_pending_deletes_table) -> int:
"""Query postgres using django ORM to get pending person deletions and insert directly into ClickHouse."""

# Read the parquet file into an Arrow table
table = pq.read_table(pending_person_deletions["file_path"])

# Rename the 'key' column to 'person_id' to match our schema
table = table.rename_columns(["team_id", "person_id", "created_at"])
if not config.team_id:
# Use Django's queryset iterator for memory efficiency
pending_deletions = (
AsyncDeletion.objects.filter(deletion_type=DeletionType.Person, delete_verified_at__isnull=True)
.values("team_id", "key", "created_at")
.iterator()
)
else:
pending_deletions = (
AsyncDeletion.objects.filter(
deletion_type=DeletionType.Person,
team_id=config.team_id,
delete_verified_at__isnull=True,
)
.values("team_id", "key", "created_at")
.iterator()
)

# Create a ClickHouse client that supports Arrow
# Create a ClickHouse client
client = Client(
host=CLICKHOUSE_HOST,
user=CLICKHOUSE_USER,
password=CLICKHOUSE_PASSWORD,
secure=CLICKHOUSE_SECURE,
settings={"use_numpy": True}, # Required for Arrow support
)

# Insert the Arrow table directly
client.execute(
f"""
INSERT INTO {create_pending_deletes_table["table_name"]} (team_id, person_id, created_at)
VALUES
""",
table.to_pydict(),
types_check=True,
settings={
"input_format_arrow_skip_columns": ["created_at"], # Skip created_at as it's a default value
},
)
# Process and insert in chunks
chunk_size = 10000
current_chunk = []
total_rows = 0

for deletion in pending_deletions:
# Rename 'key' to 'person_id' to match our schema
current_chunk.append(
{"team_id": deletion["team_id"], "person_id": deletion["key"], "created_at": deletion["created_at"]}
)

if len(current_chunk) >= chunk_size:
client.execute(
f"""
INSERT INTO {create_pending_deletes_table["table_name"]} (team_id, person_id, created_at)
VALUES
""",
current_chunk,
)
total_rows += len(current_chunk)
current_chunk = []
fuziontech marked this conversation as resolved.
Show resolved Hide resolved

# Insert any remaining records
if current_chunk:
client.execute(
f"""
INSERT INTO {create_pending_deletes_table["table_name"]} (team_id, person_id, created_at)
VALUES
""",
current_chunk,
)
total_rows += len(current_chunk)

context.add_output_metadata({"num_deletions": MetadataValue.int(pending_person_deletions["total_rows"])})
context.add_output_metadata(
{
"total_rows": MetadataValue.int(total_rows),
"table_name": MetadataValue.text(create_pending_deletes_table["table_name"]),
}
)

return pending_person_deletions["total_rows"]
return total_rows


@asset(deps=[insert_pending_deletes])
def create_pending_deletes_dictionary(context: AssetExecutionContext, config: DeleteConfig):
@asset(deps=[pending_person_deletions])
def create_pending_deletes_dictionary(context: AssetExecutionContext, config: DeleteConfig, pending_person_deletions):
"""Create a dictionary table that wraps pending_person_deletes for efficient lookups."""
names = get_versioned_names(config.run_id)
sync_execute(
Expand All @@ -176,7 +148,7 @@ def create_pending_deletes_dictionary(context: AssetExecutionContext, config: De


@asset(deps=[create_pending_deletes_dictionary])
def delete_person_events(context: AssetExecutionContext, config: DeleteConfig):
def delete_person_events(context: AssetExecutionContext, config: DeleteConfig, create_pending_deletes_dictionary):
"""Delete events from sharded_events table for persons pending deletion."""

# First check if there are any pending deletes
Expand Down Expand Up @@ -219,7 +191,7 @@ def delete_person_events(context: AssetExecutionContext, config: DeleteConfig):


@asset(deps=[delete_person_events])
def cleanup_delete_assets(context: AssetExecutionContext, config: DeleteConfig):
def cleanup_delete_assets(context: AssetExecutionContext, config: DeleteConfig, delete_person_events):
"""Clean up temporary tables, dictionary, and mark deletions as verified."""
names = get_versioned_names(config.run_id)

Expand All @@ -237,11 +209,6 @@ def cleanup_delete_assets(context: AssetExecutionContext, config: DeleteConfig):
"""
)

# Remove the temporary parquet file
parquet_path = "/tmp/pending_person_deletions.parquet"
if os.path.exists(parquet_path):
os.remove(parquet_path)

# Mark deletions as verified in Django
if not config.team_id:
AsyncDeletion.objects.filter(deletion_type=DeletionType.Person, delete_verified_at__isnull=True).update(
Expand Down
75 changes: 48 additions & 27 deletions dags/tests/test_deletes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import os
import uuid
import pytest
import pandas as pd
from unittest.mock import patch, MagicMock
from dagster import build_asset_context

Expand All @@ -22,20 +20,28 @@ def mock_async_deletion():

@pytest.fixture
def test_config():
return DeleteConfig(team_id=1, file_path="/tmp/test_pending_deletions.parquet", run_id="test_run")
return DeleteConfig(team_id=1, run_id="test_run")


@pytest.fixture
def test_config_no_team():
return DeleteConfig(file_path="/tmp/test_pending_deletions.parquet", run_id="test_run")
return DeleteConfig(run_id="test_run")


@pytest.fixture
def expected_names():
return get_versioned_names("test_run")


def test_pending_person_deletions_with_team_id():
@pytest.fixture
def mock_clickhouse_client():
with patch("dags.deletes.Client") as mock_client:
mock_instance = MagicMock()
mock_client.return_value = mock_instance
yield mock_instance


def test_pending_person_deletions_with_team_id(mock_clickhouse_client, expected_names):
# Setup test data
mock_deletions = [
{"team_id": 1, "key": str(uuid.uuid4()), "created_at": "2025-01-15T00:00:00Z"},
Expand All @@ -44,24 +50,33 @@ def test_pending_person_deletions_with_team_id():

with patch("dags.deletes.AsyncDeletion.objects") as mock_objects:
mock_filter = MagicMock()
mock_filter.values.return_value = mock_deletions
mock_filter.values.return_value.iterator.return_value = mock_deletions
mock_objects.filter.return_value = mock_filter

context = build_asset_context()
config = DeleteConfig(team_id=1, file_path="/tmp/test_pending_deletions.parquet", run_id="test_run")
config = DeleteConfig(team_id=1, run_id="test_run")
table_info = {"table_name": expected_names["table"]}

result = pending_person_deletions(context, config, table_info)

result = pending_person_deletions(context, config)
assert result == 2

assert result["total_rows"] == "2"
assert result["file_path"] == "/tmp/test_pending_deletions.parquet"
# Verify ClickHouse client was called with correct data
expected_data = [
{"team_id": deletion["team_id"], "person_id": deletion["key"], "created_at": deletion["created_at"]}
for deletion in mock_deletions
]

# Verify the parquet file was created with correct data
df = pd.read_parquet("/tmp/test_pending_deletions.parquet")
assert len(df) == 2
assert list(df.columns) == ["team_id", "key", "created_at"]
mock_clickhouse_client.execute.assert_called_once_with(
f"""
INSERT INTO {expected_names["table"]} (team_id, person_id, created_at)
VALUES
""",
expected_data,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: SQL query is missing VALUES clause formatting - could cause syntax error in ClickHouse. Should include parentheses around values.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)


def test_pending_person_deletions_without_team_id(test_config_no_team):
def test_pending_person_deletions_without_team_id(mock_clickhouse_client, test_config_no_team, expected_names):
# Setup test data
mock_deletions = [
{"team_id": 1, "key": str(uuid.uuid4()), "created_at": "2025-01-15T00:00:00Z"},
Expand All @@ -74,11 +89,25 @@ def test_pending_person_deletions_without_team_id(test_config_no_team):
mock_objects.filter.return_value = mock_filter

context = build_asset_context()
table_info = {"table_name": expected_names["table"]}

result = pending_person_deletions(context, test_config_no_team)
result = pending_person_deletions(context, test_config_no_team, table_info)

assert result["total_rows"] == "2"
assert result["file_path"] == "/tmp/test_pending_deletions.parquet"
assert result == 2

# Verify ClickHouse client was called with correct data
expected_data = [
{"team_id": deletion["team_id"], "person_id": deletion["key"], "created_at": deletion["created_at"]}
for deletion in mock_deletions
]

mock_clickhouse_client.execute.assert_called_once_with(
f"""
INSERT INTO {expected_names["table"]} (team_id, person_id, created_at)
VALUES
""",
expected_data,
)


@patch("dags.deletes.sync_execute")
Expand All @@ -96,19 +125,11 @@ def test_create_pending_deletes_table(mock_sync_execute, test_config, expected_n

@patch("dags.deletes.sync_execute")
def test_create_pending_deletes_dictionary(mock_sync_execute, test_config, expected_names):
result = create_pending_deletes_dictionary(build_asset_context(), test_config)
result = create_pending_deletes_dictionary(build_asset_context(), test_config, 2) # Pass mock total rows

assert result["dictionary_name"] == expected_names["dictionary"]
mock_sync_execute.assert_called_once()
# Verify the SQL contains the expected dictionary creation
call_args = mock_sync_execute.call_args[0][0]
assert f"CREATE DICTIONARY IF NOT EXISTS {expected_names['dictionary']}" in call_args
assert f"TABLE {expected_names['table']}" in call_args


def teardown_module(module):
# Clean up test files
try:
os.remove("/tmp/test_pending_deletions.parquet")
except FileNotFoundError:
pass