diff --git a/dags/synaptor_dags.py b/dags/synaptor_dags.py index 5a7d7389..4fd8475e 100644 --- a/dags/synaptor_dags.py +++ b/dags/synaptor_dags.py @@ -4,12 +4,15 @@ from dataclasses import dataclass from airflow import DAG +from airflow.operators.python import PythonOperator from airflow.models import Variable, BaseOperator +from dags.slack_message import slack_message from helper_ops import placeholder_op, scale_up_cluster_op, scale_down_cluster_op, collect_metrics_op, toggle_nfs_server_op from param_default import synaptor_param_default, default_synaptor_image from synaptor_ops import manager_op, drain_op, self_destruct_op from synaptor_ops import synaptor_op, wait_op, generate_op, nglink_op +from dag_utils import get_connection # Processing parameters @@ -38,6 +41,30 @@ # Sanity check DAG # "update synaptor params" +def supply_database_parameters(): + param = Variable.get( + "synaptor_param.json", synaptor_param_default, deserialize_json=True + ) + workflow_params = param.get("Workflow", {}) + if not workflow_params: + return + if workflow_params.get("workspacetype", "File") != "Database": + return + if "connectionstr" in workflow_params: + return + nfs_conn = get_connection("NFSServer") + if not nfs_conn: + slack_message(":u7981:*ERROR: You need to specify `connectionstr` to use Database workflow*") + else: + extra_args = nfs_conn.extra_dejson + connectionstr = f"postgresql+psycopg2://postgres:airflow@{extra_args['hostname']}/postgres" + workflow_params["connectionstr"] = connectionstr + workflow_params["use_nfs_server"] = True + slack_message(":exclamation: *Use postgresql database on the NFS server*") + + Variable.set("synaptor_param.json", param, serialize_json=True) + + dag_sanity = DAG( "synaptor_sanity_check", default_args=default_args, @@ -45,7 +72,15 @@ tags=["synaptor"], ) -manager_op(dag_sanity, "sanity_check", image=SYNAPTOR_IMAGE) +setup_database = PythonOperator( + task_id="setup_database", + python_callable=supply_database_parameters, + default_args=default_args, + dag=dag_sanity, + queue="manager" + ) + +setup_database >> manager_op(dag_sanity, "sanity_check", image=SYNAPTOR_IMAGE) # ========================================= @@ -92,8 +127,15 @@ def __init__(self, name): def fill_dag(dag: DAG, tasklist: list[Task], collect_metrics: bool = True) -> DAG: """Fills a synaptor DAG from a list of Tasks.""" - start_nfs_server = toggle_nfs_server_op(dag, on=True) - stop_nfs_server = toggle_nfs_server_op(dag, on=False) + if WORKFLOW_PARAMS.get("use_nfs_server", False): + start_nfs_server = toggle_nfs_server_op(dag, on=True) + stop_nfs_server = toggle_nfs_server_op(dag, on=False) + db_queue = "nfs" + else: + start_nfs_server = placeholder_op(dag, "dummy_start_nfs") + stop_nfs_server = placeholder_op(dag, "dummy_stop_nfs") + db_queue = "manager" + drain_tasks = [drain_op(dag, task_queue_name=f"synaptor-{t}-tasks") for t in ["cpu", "gpu", "seggraph"]] init_cloudvols = manager_op(dag, "init_cloudvols", image=SYNAPTOR_IMAGE) @@ -101,7 +143,7 @@ def fill_dag(dag: DAG, tasklist: list[Task], collect_metrics: bool = True) -> DA curr_operator = init_cloudvols if WORKFLOW_PARAMS.get("workspacetype", "File") == "Database": - init_db = manager_op(dag, "init_db", queue="nfs", image=SYNAPTOR_IMAGE) + init_db = manager_op(dag, "init_db", queue=db_queue, image=SYNAPTOR_IMAGE) init_cloudvols >> init_db curr_operator = init_db