Skip to content

Commit

Permalink
Use nfs server for psql by default if it is available
Browse files Browse the repository at this point in the history
  • Loading branch information
ranlu committed Jul 3, 2024
1 parent 071e198 commit 6329468
Showing 1 changed file with 46 additions and 4 deletions.
50 changes: 46 additions & 4 deletions dags/synaptor_dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
from dataclasses import dataclass

from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.models import Variable, BaseOperator
from dags.slack_message import slack_message

from helper_ops import placeholder_op, scale_up_cluster_op, scale_down_cluster_op, collect_metrics_op, toggle_nfs_server_op
from param_default import synaptor_param_default, default_synaptor_image
from synaptor_ops import manager_op, drain_op, self_destruct_op
from synaptor_ops import synaptor_op, wait_op, generate_op, nglink_op
from dag_utils import get_connection


# Processing parameters
Expand Down Expand Up @@ -38,14 +41,46 @@
# Sanity check DAG
# "update synaptor params"

def supply_database_parameters():
param = Variable.get(
"synaptor_param.json", synaptor_param_default, deserialize_json=True
)
workflow_params = param.get("Workflow", {})
if not workflow_params:
return
if workflow_params.get("workspacetype", "File") != "Database":
return
if "connectionstr" in workflow_params:
return
nfs_conn = get_connection("NFSServer")
if not nfs_conn:
slack_message(":u7981:*ERROR: You need to specify `connectionstr` to use Database workflow*")
else:
extra_args = nfs_conn.extra_dejson
connectionstr = f"postgresql+psycopg2://postgres:airflow@{extra_args['hostname']}/postgres"
workflow_params["connectionstr"] = connectionstr
workflow_params["use_nfs_server"] = True
slack_message(":exclamation: *Use postgresql database on the NFS server*")

Variable.set("synaptor_param.json", param, serialize_json=True)


dag_sanity = DAG(
"synaptor_sanity_check",
default_args=default_args,
schedule_interval=None,
tags=["synaptor"],
)

manager_op(dag_sanity, "sanity_check", image=SYNAPTOR_IMAGE)
setup_database = PythonOperator(
task_id="setup_database",
python_callable=supply_database_parameters,
default_args=default_args,
dag=dag_sanity,
queue="manager"
)

setup_database >> manager_op(dag_sanity, "sanity_check", image=SYNAPTOR_IMAGE)


# =========================================
Expand Down Expand Up @@ -92,16 +127,23 @@ def __init__(self, name):

def fill_dag(dag: DAG, tasklist: list[Task], collect_metrics: bool = True) -> DAG:
"""Fills a synaptor DAG from a list of Tasks."""
start_nfs_server = toggle_nfs_server_op(dag, on=True)
stop_nfs_server = toggle_nfs_server_op(dag, on=False)
if WORKFLOW_PARAMS.get("use_nfs_server", False):
start_nfs_server = toggle_nfs_server_op(dag, on=True)
stop_nfs_server = toggle_nfs_server_op(dag, on=False)
db_queue = "nfs"
else:
start_nfs_server = placeholder_op(dag, "dummy_start_nfs")
stop_nfs_server = placeholder_op(dag, "dummy_stop_nfs")
db_queue = "manager"

drain_tasks = [drain_op(dag, task_queue_name=f"synaptor-{t}-tasks") for t in ["cpu", "gpu", "seggraph"]]
init_cloudvols = manager_op(dag, "init_cloudvols", image=SYNAPTOR_IMAGE)

start_nfs_server >> drain_tasks >> init_cloudvols

curr_operator = init_cloudvols
if WORKFLOW_PARAMS.get("workspacetype", "File") == "Database":
init_db = manager_op(dag, "init_db", queue="nfs", image=SYNAPTOR_IMAGE)
init_db = manager_op(dag, "init_db", queue=db_queue, image=SYNAPTOR_IMAGE)
init_cloudvols >> init_db
curr_operator = init_db

Expand Down

0 comments on commit 6329468

Please sign in to comment.