From 160f014919374ca17edb2438e24ee502ccfcd700 Mon Sep 17 00:00:00 2001 From: wongwill86 Date: Thu, 24 Aug 2017 13:57:50 -0400 Subject: [PATCH] Create MultiTriggerDag Operator (#6) * Added patch module helper to allow patching of loaded modules (i.e. via import.load_source) * Adding multi trigger dag example. * Reorganized example dags. * Use pylama for testing. --- .env | 1 + README.md | 4 +- dags/chunkflow/{noop_dag.py => noop.py} | 0 dags/{many.py => examples/interleaved.py} | 12 +- dags/examples/multi_trigger.py | 54 +++++ dags/simple.py | 49 ---- docker/Dockerfile.test | 2 +- docker/config/airflow.cfg | 2 +- docker/docker-compose.test.yml | 2 +- plugins/custom/multi_trigger_dag.py | 77 +++++++ requirements.txt | 3 +- tests/__init__.py | 0 tests/plugins/__init__.py | 0 tests/plugins/chunkflow/__init__.py | 0 tests/plugins/custom/__init__.py | 0 .../plugins/custom/test_multi_trigger_dag.py | 215 ++++++++++++++++++ tests/utils/__init__.py | 0 tests/utils/mock_helpers.py | 39 ++++ 18 files changed, 398 insertions(+), 62 deletions(-) create mode 100644 .env rename dags/chunkflow/{noop_dag.py => noop.py} (100%) rename dags/{many.py => examples/interleaved.py} (87%) create mode 100644 dags/examples/multi_trigger.py delete mode 100644 dags/simple.py create mode 100644 plugins/custom/multi_trigger_dag.py create mode 100644 tests/__init__.py create mode 100644 tests/plugins/__init__.py create mode 100644 tests/plugins/chunkflow/__init__.py create mode 100644 tests/plugins/custom/__init__.py create mode 100644 tests/plugins/custom/test_multi_trigger_dag.py create mode 100644 tests/utils/__init__.py create mode 100644 tests/utils/mock_helpers.py diff --git a/.env b/.env new file mode 100644 index 00000000..fd13f06d --- /dev/null +++ b/.env @@ -0,0 +1 @@ +IMAGE_NAME=wongwill86/air-tasks:latest diff --git a/README.md b/README.md index 4a1f3551..dfb74dbc 100644 --- a/README.md +++ b/README.md @@ -5,9 +5,11 @@ DooD support and AWS ECR Credential Helper NOTES: Chunkflow: make sure AWS_ACCESS_KEY_ID, etc... are set in environment variables! +export PYTHONDONTWRITEBYTECODE=1 docker-compose -f docker/docker-compose.test.yml -p ci build -docker-compose -f docker/docker-compose.test.yml -p ci run --rm sut ptw +docker-compose -f docker/docker-compose.test.yml -p ci run --rm sut ptw -- --pylama +export When deploying docker/docker-compose-CeleryExecutor.yml remember to deploy secrets! ( or put in blank for no web auth ) diff --git a/dags/chunkflow/noop_dag.py b/dags/chunkflow/noop.py similarity index 100% rename from dags/chunkflow/noop_dag.py rename to dags/chunkflow/noop.py diff --git a/dags/many.py b/dags/examples/interleaved.py similarity index 87% rename from dags/many.py rename to dags/examples/interleaved.py index 2d1bd34d..b68da043 100644 --- a/dags/many.py +++ b/dags/examples/interleaved.py @@ -1,7 +1,6 @@ from datetime import datetime, timedelta from airflow import DAG from airflow.operators.bash_operator import BashOperator -from airflow.operators.docker_operator import DockerOperator default_args = { @@ -13,7 +12,8 @@ 'retry_delay': timedelta(seconds=2), 'retry_exponential_backoff': True, } -dag = DAG("many_ws", default_args=default_args, schedule_interval=None) +dag = DAG( + "example_interleaved", default_args=default_args, schedule_interval=None) def create_print_date(dag, count_print_date): @@ -31,11 +31,9 @@ def create_print_hello(dag, count_print_hello): def create_docker_print(dag, count_docker_print): - return DockerOperator( - task_id='watershed_print_' + str(count_docker_print), - image='watershed', - command='echo "watershed printing!"', - network_mode='bridge', + return BashOperator( + task_id='bash_print_' + str(count_docker_print), + bash_command='echo "watershed printing!"', dag=dag) diff --git a/dags/examples/multi_trigger.py b/dags/examples/multi_trigger.py new file mode 100644 index 00000000..d0997938 --- /dev/null +++ b/dags/examples/multi_trigger.py @@ -0,0 +1,54 @@ +from airflow import DAG +from datetime import datetime, timedelta +from airflow.operators.custom_plugin import MultiTriggerDagRunOperator +from airflow.operators.bash_operator import BashOperator + +SCHEDULE_DAG_ID = 'example_multi_trigger_scheduler' +TARGET_DAG_ID = 'example_multi_trigger_target' + +default_args = { + 'owner': 'airflow', + 'depends_on_past': False, + 'start_date': datetime(2017, 5, 1), + 'cactchup_by_default': False, + 'retries': 1, + 'retry_delay': timedelta(seconds=2), + 'retry_exponential_backoff': True, +} + +# ####################### SCHEDULER ################################# +scheduler_dag = DAG( + dag_id=SCHEDULE_DAG_ID, + default_args=default_args, + schedule_interval=None +) + + +def param_generator(): + iterable = xrange(0, 100) + for i in iterable: + yield i + + +operator = MultiTriggerDagRunOperator( + task_id='trigger_%s' % TARGET_DAG_ID, + trigger_dag_id=TARGET_DAG_ID, + params_list=param_generator(), + default_args=default_args, + dag=scheduler_dag) + +# ####################### TARGET DAG ################################# + +target_dag = DAG( + dag_id=TARGET_DAG_ID, + default_args=default_args, + schedule_interval=None +) + +start = BashOperator( + task_id='bash_task', + bash_command='sleep 1; echo "Hello from message #' + + '{{ dag_run.conf if dag_run else "NO MESSAGE" }}"', + default_args=default_args, + dag=target_dag +) diff --git a/dags/simple.py b/dags/simple.py deleted file mode 100644 index bcfe84f1..00000000 --- a/dags/simple.py +++ /dev/null @@ -1,49 +0,0 @@ -from datetime import datetime, timedelta -from airflow import DAG -from airflow.operators.bash_operator import BashOperator -from airflow.operators.docker_operator import DockerOperator - -default_args = { - 'owner': 'airflow', - 'depends_on_past': False, - 'start_date': datetime(2017, 5, 1), - 'cactchup_by_default': False, - 'retries': 1, - 'retry_delay': timedelta(seconds=2), - 'retry_exponential_backoff': True, - } -dag = DAG("simple_ws", default_args=default_args, schedule_interval=None) - - -t1 = BashOperator( - task_id='print_date', - bash_command='date', - dag=dag) - -t2 = DockerOperator( - task_id='watershed_sleep', - image='watershed', - command='/bin/sleep 10', - network_mode='bridge', - dag=dag) - -t3 = BashOperator( - task_id='print_hello', - bash_command='echo "hello world!"', - dag=dag) - -t4 = BashOperator( - task_id='print_hello_2', - bash_command='echo "hello world 2!"', - dag=dag) - -t5 = BashOperator( - task_id='print_goodbye', - bash_command='echo "goodbye world!"', - dag=dag) - -t1.set_downstream(t2) -t2.set_downstream(t3) -t2.set_downstream(t4) -t3.set_downstream(t5) -t4.set_downstream(t5) diff --git a/docker/Dockerfile.test b/docker/Dockerfile.test index 8c15b8d4..d5c4aa49 100644 --- a/docker/Dockerfile.test +++ b/docker/Dockerfile.test @@ -3,6 +3,6 @@ FROM $image_name ARG IMAGE_NAME USER root COPY docker/scripts/entrypoint-test.sh /entrypoint-test.sh -RUN pip install pytest pytest-watch pytest-env flake8 +RUN pip install pytest pytest-watch pytest-env pylama mock USER airflow ENTRYPOINT ["/entrypoint-test.sh"] diff --git a/docker/config/airflow.cfg b/docker/config/airflow.cfg index 6d313f03..8c5e78ab 100644 --- a/docker/config/airflow.cfg +++ b/docker/config/airflow.cfg @@ -61,7 +61,7 @@ max_active_runs_per_dag = 16 # Whether to load the examples that ship with Airflow. It's good to # get started, but you probably want to set this to False in a production # environment -load_examples = True +load_examples = False # Where your Airflow plugins are stored plugins_folder = /usr/local/airflow/plugins diff --git a/docker/docker-compose.test.yml b/docker/docker-compose.test.yml index 7433bf8e..fa3e0d6e 100644 --- a/docker/docker-compose.test.yml +++ b/docker/docker-compose.test.yml @@ -18,4 +18,4 @@ services: - AWS_SECRET_ACCESS_KEY - AWS_DEFAULT_REGION command: - - pytest && flake8 . + - pytest --pylama diff --git a/plugins/custom/multi_trigger_dag.py b/plugins/custom/multi_trigger_dag.py new file mode 100644 index 00000000..aa86c7fa --- /dev/null +++ b/plugins/custom/multi_trigger_dag.py @@ -0,0 +1,77 @@ +from airflow.plugins_manager import AirflowPlugin +from datetime import datetime +import logging +import types +import collections + +from airflow.models import BaseOperator +from airflow.models import DagBag +from airflow.utils.decorators import apply_defaults +from airflow.utils.state import State +from airflow import settings + + +class MultiTriggerDagRunOperator(BaseOperator): + """ + Triggers multiple DAG runs for a specified ``dag_id``. + + Draws inspiration from: + airflow.operators.dagrun_operator.TriggerDagRunOperator + + :param trigger_dag_id: the dag_id to trigger + :type trigger_dag_id: str + :param params_list: list of dicts for DAG level parameters that are made + acesssible in templates + namespaced under params for each dag run. + :type params: Iterable or types.GeneratorType + """ + + @apply_defaults + def __init__( + self, + trigger_dag_id, + params_list, + *args, **kwargs): + super(MultiTriggerDagRunOperator, self).__init__(*args, **kwargs) + self.trigger_dag_id = trigger_dag_id + self.params_list = params_list + if hasattr(self.params_list, '__len__'): + assert len(self.params_list) > 0 + else: + assert (isinstance(params_list, collections.Iterable) or + isinstance(params_list, types.GeneratorType)) + + def execute(self, context): + session = settings.Session() + dbag = DagBag(settings.DAGS_FOLDER) + trigger_dag = dbag.get_dag(self.trigger_dag_id) + + assert trigger_dag is not None + + trigger_id = 0 + # for trigger_id in range(0, len(self.params_list)): + for params in self.params_list: + dr = trigger_dag.create_dagrun(run_id='trig_%s_%d_%s' % + (self.trigger_dag_id, trigger_id, + datetime.now().isoformat()), + state=State.RUNNING, + conf=params, + external_trigger=True) + logging.info("Creating DagRun {}".format(dr)) + session.add(dr) + trigger_id = trigger_id + 1 + if trigger_id % 10: + session.commit() + session.commit() + session.close() + + +class CustomPlugin(AirflowPlugin): + name = "custom_plugin" + operators = [MultiTriggerDagRunOperator] + hooks = [] + executors = [] + macros = [] + admin_views = [] + flask_blueprints = [] + menu_links = [] diff --git a/requirements.txt b/requirements.txt index 17614331..bdb96709 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1 @@ -apache-airflow==1.8.1 -docker-py +docker diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/plugins/__init__.py b/tests/plugins/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/plugins/chunkflow/__init__.py b/tests/plugins/chunkflow/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/plugins/custom/__init__.py b/tests/plugins/custom/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/plugins/custom/test_multi_trigger_dag.py b/tests/plugins/custom/test_multi_trigger_dag.py new file mode 100644 index 00000000..ef01c88f --- /dev/null +++ b/tests/plugins/custom/test_multi_trigger_dag.py @@ -0,0 +1,215 @@ +from __future__ import unicode_literals +import unittest +from airflow.operators.custom_plugin import MultiTriggerDagRunOperator +from airflow.utils.state import State +from airflow import settings + +from tests.utils.mock_helpers import patch_plugin_file + +try: + import unittest.mock as mock +except ImportError: + import mock + +from datetime import datetime, timedelta + +TRIGGER_DAG_ID = 'test_trigger_dag_id' +DAG_ARGS = { + 'owner': 'airflow', + 'depends_on_past': False, + 'start_date': datetime(2017, 5, 1), + 'cactchup_by_default': False, + 'retries': 1, + 'retry_delay': timedelta(seconds=2), + 'retry_exponential_backoff': True, +} +TASK_ID = 'MultiTriggerDag' + + +@patch_plugin_file('plugins/custom/multi_trigger_dag', 'DagBag', autospec=True) +@mock.patch('airflow.settings.Session', autospec=True) +class TestMultiTriggerDag(unittest.TestCase): + class DagRunWithParams(object): + def __init__(self, parameters): + self.parameters = parameters + + def __eq__(self, other): + return (other['state'] == State.RUNNING and + other['external_trigger'] and + ((type(self.parameters) is dict and + self.parameters.viewitems() <= other['conf']) or + (self.parameters == other['conf']))) + + def __str__(self): + return "Dag Run with parameters \"%s\"" % self.parameters + + def __repr__(self): + return self.__str__() + + @staticmethod + def create_mock_dag_bag(): + mock_dag = mock.MagicMock(name='Dag') + mock_dag.create_dagrun.side_effect = lambda *args, **kwargs: kwargs + + test_dags = {} + test_dags[TRIGGER_DAG_ID] = mock_dag + + mock_dag_bag = mock.MagicMock(name='DagBag') + mock_dag_bag.get_dag.side_effect = lambda dag_id: test_dags.get(dag_id) + + return mock_dag_bag + + @staticmethod + def verify_session(params_list): + """ + Verify the session has added tasks with the params_list. + Assumes params_list is truthy + """ + if not hasattr(params_list, '__len__'): + params_list = [params for params in params_list] + + session = settings.Session() + + for params in params_list: + session.add.assert_any_call( + TestMultiTriggerDag.DagRunWithParams(params)) + + assert session.add.call_count == len(params_list) + + session.commit.assert_called() + + def test_should_fail_when_execute_none(self, mock_session, mock_dag_bag): + params_list = None + + with self.assertRaises(Exception): + operator = MultiTriggerDagRunOperator( + task_id=TASK_ID, + trigger_dag_id=TRIGGER_DAG_ID, + params_list=params_list, + default_args=DAG_ARGS) + + operator.execute(None) + + mock_session.add.assert_not_called() + mock_session.commit.assert_not_called() + + def test_execute_none_should_fail(self, mock_session, mock_dag_bag): + params_list = None + + with self.assertRaises(Exception): + operator = MultiTriggerDagRunOperator( + task_id=TASK_ID, + trigger_dag_id=TRIGGER_DAG_ID, + params_list=params_list, + default_args=DAG_ARGS) + + operator.execute(None) + + mock_session.add.assert_not_called() + mock_session.commit.assert_not_called() + + def test_should_fail_execute_empty_params_list(self, mock_session, + mock_dag_bag): + params_list = [] + + with self.assertRaises(Exception): + operator = MultiTriggerDagRunOperator( + task_id=TASK_ID, + trigger_dag_id=TRIGGER_DAG_ID, + params_list=params_list, + default_args=DAG_ARGS) + + operator.execute(None) + + mock_session.add.assert_not_called() + mock_session.commit.assert_not_called() + + def test_should_add_single_params_list_single(self, mock_session, + dag_bag_class): + a = "a" + params_list = [a] + + dag_bag_class.return_value = TestMultiTriggerDag.create_mock_dag_bag() + + operator = MultiTriggerDagRunOperator( + task_id=TASK_ID, + trigger_dag_id=TRIGGER_DAG_ID, + params_list=params_list, + default_args=DAG_ARGS) + + operator.execute(None) + + TestMultiTriggerDag.verify_session(params_list) + + def test_should_add_params_list(self, mock_session, dag_bag_class): + a = "a" + b = "b" + c = "c" + d = "d" + params_list = [a, b, c, d] + + dag_bag_class.return_value = TestMultiTriggerDag.create_mock_dag_bag() + + operator = MultiTriggerDagRunOperator( + task_id=TASK_ID, + trigger_dag_id=TRIGGER_DAG_ID, + params_list=params_list, + default_args=DAG_ARGS) + + operator.execute(None) + + TestMultiTriggerDag.verify_session(params_list) + + def test_should_execute_params_list_of_nones(self, mock_session, + dag_bag_class): + a = None + b = None + c = None + d = None + params_list = [a, b, c, d] + + dag_bag_class.return_value = TestMultiTriggerDag.create_mock_dag_bag() + + operator = MultiTriggerDagRunOperator( + task_id=TASK_ID, + trigger_dag_id=TRIGGER_DAG_ID, + params_list=params_list, + default_args=DAG_ARGS) + + operator.execute(None) + + TestMultiTriggerDag.verify_session(params_list) + + def test_should_execute_generator_function(self, mock_session, + dag_bag_class): + def param_generator(): + iterable = xrange(1, 10) + for i in iterable: + yield i + + dag_bag_class.return_value = TestMultiTriggerDag.create_mock_dag_bag() + + operator = MultiTriggerDagRunOperator( + task_id=TASK_ID, + trigger_dag_id=TRIGGER_DAG_ID, + params_list=param_generator(), + default_args=DAG_ARGS) + + operator.execute(None) + + TestMultiTriggerDag.verify_session(param_generator()) + + def test_should_execute_iterable(self, mock_session, dag_bag_class): + params_list = xrange(1, 10) + + dag_bag_class.return_value = TestMultiTriggerDag.create_mock_dag_bag() + + operator = MultiTriggerDagRunOperator( + task_id=TASK_ID, + trigger_dag_id=TRIGGER_DAG_ID, + params_list=params_list, + default_args=DAG_ARGS) + + operator.execute(None) + + TestMultiTriggerDag.verify_session(params_list) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/utils/mock_helpers.py b/tests/utils/mock_helpers.py new file mode 100644 index 00000000..59eb07d9 --- /dev/null +++ b/tests/utils/mock_helpers.py @@ -0,0 +1,39 @@ +from functools import wraps +import mock +import os +import re +import inspect + +norm_pattern = re.compile(r'[/|.]') + + +def patch_plugin_file(*patch_args, **patch_kwargs): + """ + Decorator used to search for in items: + """ + root, filename = os.path.split(patch_args[0]) + module_name, file_ext = os.path.splitext(filename) + namespace = '_'.join([re.sub(norm_pattern, '__', root), module_name]) + + import sys + found_modules = [key for key in sys.modules.keys() if namespace in key] + + if len(found_modules) != 1: + raise(NameError('Tried to find 1 module from file %s but found: %s' % + (found_modules, namespace))) + + module = sys.modules[found_modules.pop()] + + def patch_decorator(func, *patch_decorator_args): + if not inspect.isclass(func): + @wraps(func) + @mock.patch.object(module, *patch_args[1:], **patch_kwargs) + def wrapper(*args, **kwargs): + return func(*(args + patch_decorator_args), **kwargs) + return wrapper + else: + @mock.patch.object(module, *patch_args[1:], **patch_kwargs) + class WrappedClass(func): + pass + return WrappedClass + return patch_decorator