Skip to content

Commit

Permalink
[airlift] Add migration state tests to Airlift tutorial (#23993)
Browse files Browse the repository at this point in the history
## Summary

Adds an integration test which tries loading the tutorial project at
each stage of migration & validates that Airflow & Dagster show the
correct status.

## Changelog [New]

`NOCHANGELOG`
  • Loading branch information
benpankow authored Aug 28, 2024
1 parent be8463a commit ef1da09
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Generator
from typing import Any, Callable, Generator

import pytest
import requests
Expand Down Expand Up @@ -47,6 +47,16 @@ def setup_fixture(airflow_home: Path, dags_dir: Path) -> Generator[Path, None, N
yield airflow_home


@pytest.fixture(name="reserialize_dags")
def reserialize_fixture(airflow_instance: None) -> Callable[[], None]:
"""Forces airflow to reserialize dags, to ensure that the latest changes are picked up."""

def _reserialize_dags() -> None:
subprocess.check_output(["airflow", "dags", "reserialize"])

return _reserialize_dags


@pytest.fixture(name="airflow_instance")
def airflow_instance_fixture(setup: None) -> Generator[subprocess.Popen, None, None]:
process = subprocess.Popen(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,6 @@ airflow_setup:
airflow_run:
airflow standalone


dagster_run:
dagster dev -m tutorial_example.dagster_defs.definitions -p 3000
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
import contextlib
import os
import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import Generator
from typing import AbstractSet, Callable, Generator, Iterator

import pytest
import yaml
from dagster._core.test_utils import environ


@pytest.fixture(name="makefile_dir")
def makefile_dir_fixture() -> Path:
return Path(__file__).parent.parent.parent


@pytest.fixture(name="local_env")
def local_env_fixture() -> Generator[None, None, None]:
makefile_dir = Path(__file__).parent.parent.parent
def local_env_fixture(makefile_dir: Path) -> Generator[None, None, None]:
subprocess.run(["make", "airflow_setup"], cwd=makefile_dir, check=True)
with environ(
{
Expand All @@ -23,10 +31,56 @@ def local_env_fixture() -> Generator[None, None, None]:


@pytest.fixture(name="dags_dir")
def dags_dir_fixture() -> Path:
return Path(__file__).parent.parent.parent / "tutorial_example" / "airflow_dags"
def dags_dir_fixture(makefile_dir: Path) -> Iterator[Path]:
# Creates a temporary directory and copies the dags into it
# So we can manipulate the migration state without affecting the original files
with tempfile.TemporaryDirectory() as tmpdir:
shutil.copytree(
makefile_dir / "tutorial_example" / "airflow_dags", tmpdir, dirs_exist_ok=True
)
yield Path(tmpdir)


@pytest.fixture(name="airflow_home")
def airflow_home_fixture(local_env) -> Path:
return Path(os.environ["AIRFLOW_HOME"])


@pytest.fixture(name="mark_tasks_migrated")
def mark_tasks_migrated_fixture(
dags_dir: Path,
reserialize_dags: Callable[[], None],
) -> Callable[[AbstractSet[str]], contextlib.AbstractContextManager[None]]:
"""Returns a context manager that marks the specified tasks as migrated in the migration state file
for the duration of the context manager's scope.
"""
migration_state_file = dags_dir / "migration_state" / "rebuild_customers_list.yaml"
all_tasks = {"load_raw_customers", "build_dbt_models", "export_customers"}

@contextlib.contextmanager
def mark_tasks_migrated(migrated_tasks: AbstractSet[str]) -> Iterator[None]:
"""Updates the contents of the migration state file to mark the specified tasks as migrated."""
with open(migration_state_file, "r") as f:
contents = f.read()

try:
with open(migration_state_file, "w") as f:
f.write(
yaml.dump(
{
"tasks": [
{"id": task, "migrated": task in migrated_tasks}
for task in all_tasks
]
}
)
)

reserialize_dags()
yield

finally:
with open(migration_state_file, "w") as f:
f.write(contents)

return mark_tasks_migrated
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import contextlib
import importlib
from typing import AbstractSet, Callable, Optional

from dagster._core.definitions.asset_spec import AssetSpec
from dagster_airlift.core import AirflowInstance, BasicAuthBackend
from dagster_airlift.core.utils import MIGRATED_TAG, TASK_ID_TAG


def _assert_dagster_migration_states_are(
state: bool, where: Optional[Callable[[AssetSpec], bool]] = None
) -> None:
"""Loads the Dagster asset definitions and checks that all asset specs have the correct migration state.
This is a helper function so that we can call this as many times as we need to - if there are dangling references
to any of the imports from the module, the importlib.reload won't work properly.
Args:
state: The expected migration state.
where: A function that takes an AssetSpec and returns True if the spec should be checked, False otherwise.
"""
import tutorial_example
from tutorial_example.dagster_defs import definitions
from tutorial_example.dagster_defs.definitions import defs

importlib.reload(tutorial_example)
importlib.reload(definitions)

assert defs
specs = defs.get_all_asset_specs()
spec_migration_states = {
spec.key: spec.tags.get(MIGRATED_TAG) for spec in specs if not where or where(spec)
}

assert all(
value == str(state)
for key, value in spec_migration_states.items()
if key.path[0] != "airflow_instance" # ignore overall dag, which doesn't have tag
), str(spec_migration_states)


def test_migration_status(
airflow_instance,
mark_tasks_migrated: Callable[[AbstractSet[str]], contextlib.AbstractContextManager],
) -> None:
"""Iterates through various combinations of marking tasks as migrated and checks that the migration state is updated correctly in
both the Airflow DAGs and the Dagster asset definitions.
"""
instance = AirflowInstance(
auth_backend=BasicAuthBackend(
webserver_url="http://localhost:8080",
username="admin",
password="admin",
),
name="airflow_instance_one",
)

with mark_tasks_migrated(set()):
assert len(instance.list_dags()) == 1
dag = instance.list_dags()[0]
assert dag.dag_id == "rebuild_customers_list"
assert not dag.migration_state.is_task_migrated("load_raw_customers")
assert not dag.migration_state.is_task_migrated("build_dbt_models")
assert not dag.migration_state.is_task_migrated("export_customers")

_assert_dagster_migration_states_are(False)

with mark_tasks_migrated({"load_raw_customers"}):
assert len(instance.list_dags()) == 1
dag = instance.list_dags()[0]

assert dag.dag_id == "rebuild_customers_list"
assert dag.migration_state.is_task_migrated("load_raw_customers")
assert not dag.migration_state.is_task_migrated("build_dbt_models")
assert not dag.migration_state.is_task_migrated("export_customers")

_assert_dagster_migration_states_are(
True, where=lambda spec: spec.tags.get(TASK_ID_TAG) == "load_raw_customers"
)
_assert_dagster_migration_states_are(
False, where=lambda spec: spec.tags.get(TASK_ID_TAG) != "load_raw_customers"
)

with mark_tasks_migrated({"build_dbt_models"}):
assert len(instance.list_dags()) == 1
dag = instance.list_dags()[0]
assert dag.dag_id == "rebuild_customers_list"
assert not dag.migration_state.is_task_migrated("load_raw_customers")
assert dag.migration_state.is_task_migrated("build_dbt_models")
assert not dag.migration_state.is_task_migrated("export_customers")

_assert_dagster_migration_states_are(
True, where=lambda spec: spec.tags.get(TASK_ID_TAG) == "build_dbt_models"
)
_assert_dagster_migration_states_are(
False, where=lambda spec: spec.tags.get(TASK_ID_TAG) != "build_dbt_models"
)

with mark_tasks_migrated({"load_raw_customers", "build_dbt_models", "export_customers"}):
assert len(instance.list_dags()) == 1
dag = instance.list_dags()[0]
assert dag.dag_id == "rebuild_customers_list"
assert dag.migration_state.is_task_migrated("load_raw_customers")
assert dag.migration_state.is_task_migrated("build_dbt_models")
assert dag.migration_state.is_task_migrated("export_customers")

_assert_dagster_migration_states_are(True)

0 comments on commit ef1da09

Please sign in to comment.