diff --git a/jupyter_scheduler/executors.py b/jupyter_scheduler/executors.py index 7e1a9974..d1e41dd2 100644 --- a/jupyter_scheduler/executors.py +++ b/jupyter_scheduler/executors.py @@ -1,20 +1,30 @@ import io +import multiprocessing as mp import os import shutil import tarfile import traceback from abc import ABC, abstractmethod -from typing import Dict +from functools import lru_cache +from pathlib import Path +from typing import Dict, List +import dask import fsspec import nbconvert import nbformat from nbconvert.preprocessors import CellExecutionError, ExecutePreprocessor -from jupyter_scheduler.models import DescribeJob, JobFeature, Status -from jupyter_scheduler.orm import Job, create_session +from jupyter_scheduler.models import CreateJob, DescribeJob, JobFeature, Status +from jupyter_scheduler.orm import Job, Workflow, WorkflowDefinition, create_session from jupyter_scheduler.parameterize import add_parameters +from jupyter_scheduler.scheduler import Scheduler from jupyter_scheduler.utils import get_utc_timestamp +from jupyter_scheduler.workflows import ( + CreateWorkflow, + DescribeWorkflow, + DescribeWorkflowDefinition, +) class ExecutionManager(ABC): @@ -29,14 +39,42 @@ class ExecutionManager(ABC): _model = None _db_session = None - def __init__(self, job_id: str, root_dir: str, db_url: str, staging_paths: Dict[str, str]): + def __init__( + self, + db_url: str, + job_id: str = None, + workflow_id: str = None, + workflow_definition_id: str = None, + root_dir: str = None, + staging_paths: Dict[str, str] = None, + ): self.job_id = job_id + self.workflow_id = workflow_id + self.workflow_definition_id = workflow_definition_id self.staging_paths = staging_paths self.root_dir = root_dir self.db_url = db_url @property def model(self): + if self.workflow_id: + with self.db_session() as session: + workflow = ( + session.query(Workflow).filter(Workflow.workflow_id == self.workflow_id).first() + ) + self._model = DescribeWorkflow.from_orm(workflow) + return self._model + if self.workflow_definition_id: + with self.db_session() as session: + workflow_definition = ( + session.query(WorkflowDefinition) + .filter( + WorkflowDefinition.workflow_definition_id == self.workflow_definition_id + ) + .first() + ) + self._model = DescribeWorkflowDefinition.from_orm(workflow_definition) + return self._model if self._model is None: with self.db_session() as session: job = session.query(Job).filter(Job.job_id == self.job_id).first() @@ -65,6 +103,18 @@ def process(self): else: self.on_complete() + def process_workflow(self): + print(f"calling ExecutionManager(ABC).process_workflow for {self.model}") + self.before_start_workflow() + try: + self.execute_workflow() + except CellExecutionError as e: + self.on_failure_workflow(e) + except Exception as e: + self.on_failure_workflow(e) + else: + self.on_complete_workflow() + @abstractmethod def execute(self): """Performs notebook execution, @@ -74,6 +124,11 @@ def execute(self): """ pass + @abstractmethod + def execute_workflow(self): + """Performs workflow execution""" + pass + @classmethod @abstractmethod def supported_features(cls) -> Dict[JobFeature, bool]: @@ -98,6 +153,16 @@ def before_start(self): ) session.commit() + def before_start_workflow(self): + """Called before start of execute""" + print(f"calling ExecutionManager(ABC).before_start_workflow for {self.model}") + workflow = self.model + with self.db_session() as session: + session.query(Workflow).filter(Workflow.workflow_id == workflow.workflow_id).update( + {"status": Status.IN_PROGRESS} + ) + session.commit() + def on_failure(self, e: Exception): """Called after failure of execute""" job = self.model @@ -109,6 +174,17 @@ def on_failure(self, e: Exception): traceback.print_exc() + def on_failure_workflow(self, e: Exception): + """Called after failure of execute""" + workflow = self.model + with self.db_session() as session: + session.query(Workflow).filter(Workflow.workflow_id == workflow.workflow_id).update( + {"status": Status.FAILED, "status_message": str(e)} + ) + session.commit() + + traceback.print_exc() + def on_complete(self): """Called after job is completed""" job = self.model @@ -118,10 +194,60 @@ def on_complete(self): ) session.commit() + def on_complete_workflow(self): + workflow = self.model + with self.db_session() as session: + session.query(Workflow).filter(Workflow.workflow_id == workflow.workflow_id).update( + {"status": Status.COMPLETED} + ) + session.commit() + class DefaultExecutionManager(ExecutionManager): """Default execution manager that executes notebooks""" + def get_tasks_records(self, task_ids: List[str]) -> List[Job]: + print(f"getting task records for task: {task_ids}") + with self.db_session() as session: + tasks = session.query(Job).filter(Job.job_id.in_(task_ids)).all() + print(f"gotten task records for task {task_ids}: {tasks}") + + return tasks + + # @dask.delayed(name="Execute workflow") + def execute_workflow(self): + tasks_info: List[Job] = self.get_tasks_records(self.model.tasks) + print(f"tasks_info in execute_workflow: {tasks_info}") + tasks = {task.job_id: task for task in tasks_info} + print(f"tasks in execute_workflow: {tasks}") + + @lru_cache(maxsize=None) + def make_task(task_id): + """Create a delayed object for the given task recursively creating delayed objects for all tasks it depends on""" + print("making task for") + print(task_id) + deps = tasks[task_id].depends_on or [] + print(deps) + print(f"dependencies in make_task for {task_id}") + print(deps) + + execute_task_delayed = execute_task( + job=tasks[task_id], + root_dir=self.root_dir, + db_url=self.db_url, + dependencies=[make_task(dep_id) for dep_id in deps], + ) + print("execute task result from make_task") + print(execute_task_delayed) + + return execute_task_delayed + + final_tasks = [make_task(task_id) for task_id in tasks] + print("Final tasks:") + print(final_tasks) + print(f"Calling compute after loops") + dask.compute(*final_tasks) + def execute(self): job = self.model @@ -144,6 +270,7 @@ def execute(self): self.add_side_effects_files(staging_dir) self.create_output_files(job, nb) + # @dask.delayed(name="Check for and add side effect files") def add_side_effects_files(self, staging_dir: str): """Scan for side effect files potentially created after input file execution and update the job's packaged_files with these files""" input_notebook = os.path.relpath(self.staging_paths["input"]) @@ -166,6 +293,7 @@ def add_side_effects_files(self, staging_dir: str): ) session.commit() + # @dask.delayed(name="Create output files") def create_output_files(self, job: DescribeJob, notebook_node): for output_format in job.output_formats: cls = nbconvert.get_exporter(output_format) @@ -201,6 +329,19 @@ def validate(cls, input_path: str) -> bool: return True +@dask.delayed(name="Execute workflow task") +def execute_task(job: Job, root_dir: str, db_url: str, dependencies: List[str] = []): + print(f"executing task {job.job_id} with dependencies {dependencies}") + staging_paths = Scheduler.get_staging_paths(DescribeJob.from_orm(job)) + process_job = DefaultExecutionManager( + job_id=job.job_id, + staging_paths=staging_paths, + root_dir=root_dir, + db_url=db_url, + ).process + return process_job() + + class ArchivingExecutionManager(DefaultExecutionManager): """Execution manager that archives all output files in and under the output directory into a single archive file diff --git a/jupyter_scheduler/extension.py b/jupyter_scheduler/extension.py index 1a4ba373..f6cafb63 100644 --- a/jupyter_scheduler/extension.py +++ b/jupyter_scheduler/extension.py @@ -6,6 +6,14 @@ from traitlets import Bool, Type, Unicode, default from jupyter_scheduler.orm import create_tables +from jupyter_scheduler.workflows import ( + WorkflowDefinitionsDeploymentHandler, + WorkflowDefinitionsHandler, + WorkflowDefinitionsTasksHandler, + WorkflowsHandler, + WorkflowsRunHandler, + WorkflowsTasksHandler, +) from .handlers import ( BatchJobHandler, @@ -20,6 +28,8 @@ JOB_DEFINITION_ID_REGEX = r"(?P\w+(?:-\w+)+)" JOB_ID_REGEX = r"(?P\w+(?:-\w+)+)" +WORKFLOW_DEFINITION_ID_REGEX = r"(?P\w+(?:-\w+)+)" +WORKFLOW_ID_REGEX = r"(?P\w+(?:-\w+)+)" class SchedulerApp(ExtensionApp): @@ -35,6 +45,29 @@ class SchedulerApp(ExtensionApp): (r"scheduler/job_definitions/%s/jobs" % JOB_DEFINITION_ID_REGEX, JobFromDefinitionHandler), (r"scheduler/runtime_environments", RuntimeEnvironmentsHandler), (r"scheduler/config", ConfigHandler), + (r"scheduler/worklows", WorkflowsHandler), + (rf"scheduler/worklows/{WORKFLOW_ID_REGEX}", WorkflowsHandler), + ( + rf"scheduler/worklows/{WORKFLOW_ID_REGEX}/run", + WorkflowsRunHandler, + ), + ( + rf"scheduler/worklows/{WORKFLOW_ID_REGEX}/tasks", + WorkflowsTasksHandler, + ), + (r"scheduler/worklow_definitions", WorkflowDefinitionsHandler), + ( + rf"scheduler/worklow_definitions/{WORKFLOW_DEFINITION_ID_REGEX}", + WorkflowDefinitionsHandler, + ), + ( + rf"scheduler/worklow_definitions/{WORKFLOW_DEFINITION_ID_REGEX}/deploy", + WorkflowDefinitionsDeploymentHandler, + ), + ( + rf"scheduler/worklow_definitions/{WORKFLOW_DEFINITION_ID_REGEX}/tasks", + WorkflowDefinitionsTasksHandler, + ), ] drop_tables = Bool(False, config=True, help="Drop the database tables before starting.") @@ -91,3 +124,30 @@ def initialize_settings(self): if scheduler.task_runner: loop = asyncio.get_event_loop() loop.create_task(scheduler.task_runner.start()) + + if scheduler.workflow_runner: + loop = asyncio.get_event_loop() + loop.create_task(scheduler.workflow_runner.start()) + + async def stop_extension(self): + """ + Public method called by Jupyter Server when the server is stopping. + This calls the cleanup code defined in `self._stop_exception()` inside + an exception handler, as the server halts if this method raises an + exception. + """ + try: + await self._stop_extension() + except Exception as e: + self.log.error("Jupyter Scheduler raised an exception while stopping:") + + self.log.exception(e) + + async def _stop_extension(self): + """ + Private method that defines the cleanup code to run when the server is + stopping. + """ + if "scheduler" in self.settings: + scheduler: SchedulerApp = self.settings["scheduler"] + await scheduler.stop_extension() diff --git a/jupyter_scheduler/job_files_manager.py b/jupyter_scheduler/job_files_manager.py index 0e39c2b7..77d0e811 100644 --- a/jupyter_scheduler/job_files_manager.py +++ b/jupyter_scheduler/job_files_manager.py @@ -4,6 +4,7 @@ from multiprocessing import Process from typing import Dict, List, Optional, Type +import dask import fsspec from jupyter_server.utils import ensure_async @@ -23,17 +24,14 @@ async def copy_from_staging(self, job_id: str, redownload: Optional[bool] = Fals output_filenames = self.scheduler.get_job_filenames(job) output_dir = self.scheduler.get_local_output_path(model=job, root_dir_relative=True) - p = Process( - target=Downloader( - output_formats=job.output_formats, - output_filenames=output_filenames, - staging_paths=staging_paths, - output_dir=output_dir, - redownload=redownload, - include_staging_files=job.package_input_folder, - ).download - ) - p.start() + target = Downloader( + output_formats=job.output_formats, + output_filenames=output_filenames, + staging_paths=staging_paths, + output_dir=output_dir, + redownload=redownload, + include_staging_files=job.package_input_folder, + ).download class Downloader: @@ -77,6 +75,7 @@ def download_tar(self, archive_format: str = "tar"): with tarfile.open(fileobj=f, mode=read_mode) as tar: tar.extractall(self.output_dir) + # @dask.delayed(name="Download job files") def download(self): # ensure presence of staging paths if not self.staging_paths: diff --git a/jupyter_scheduler/models.py b/jupyter_scheduler/models.py index 38e240e0..8697e008 100644 --- a/jupyter_scheduler/models.py +++ b/jupyter_scheduler/models.py @@ -42,6 +42,8 @@ def __str__(self) -> str: class Status(str, Enum): + DRAFT = "DRAFT" + DEPLOYED = "DEPLOYED" CREATED = "CREATED" QUEUED = "QUEUED" IN_PROGRESS = "IN_PROGRESS" @@ -70,6 +72,20 @@ def __str__(self): OUTPUT_FILENAME_TEMPLATE = "{{input_filename}}-{{create_time}}" +class TriggerRule(str, Enum): + ALL_SUCCESS = "all_success" + ALL_FAILES = "all_failed" + ALL_DONE = "all_done" + ONE_FAILED = "one_failed" + ONE_SUCCESS = "one_success" + NONE_FAILED = "none_failed" + NONE_SKIPPED = "none_skipped" + DUMMY = "dummy" + + def __str__(self): + return self.value + + class CreateJob(BaseModel): """Defines the model for creating a new job""" @@ -86,6 +102,9 @@ class CreateJob(BaseModel): output_filename_template: Optional[str] = OUTPUT_FILENAME_TEMPLATE compute_type: Optional[str] = None package_input_folder: Optional[bool] = None + depends_on: Optional[List[str]] = None + workflow_id: Optional[str] = None + trigger_rule: Optional[TriggerRule] = None @root_validator def compute_input_filename(cls, values) -> Dict: @@ -148,6 +167,9 @@ class DescribeJob(BaseModel): downloaded: bool = False package_input_folder: Optional[bool] = None packaged_files: Optional[List[str]] = [] + depends_on: Optional[List[str]] = None + workflow_id: Optional[str] = None + trigger_rule: Optional[TriggerRule] = None class Config: orm_mode = True @@ -193,6 +215,8 @@ class UpdateJob(BaseModel): status: Optional[Status] = None name: Optional[str] = None compute_type: Optional[str] = None + depends_on: Optional[List[str]] = None + trigger_rule: Optional[TriggerRule] = None class DeleteJob(BaseModel): @@ -213,6 +237,9 @@ class CreateJobDefinition(BaseModel): schedule: Optional[str] = None timezone: Optional[str] = None package_input_folder: Optional[bool] = None + depends_on: Optional[List[str]] = None + workflow_id: Optional[str] = None + trigger_rule: Optional[TriggerRule] = None @root_validator def compute_input_filename(cls, values) -> Dict: @@ -240,6 +267,9 @@ class DescribeJobDefinition(BaseModel): active: bool package_input_folder: Optional[bool] = None packaged_files: Optional[List[str]] = [] + depends_on: Optional[List[str]] = None + workflow_id: Optional[str] = None + trigger_rule: Optional[TriggerRule] = None class Config: orm_mode = True @@ -259,6 +289,9 @@ class UpdateJobDefinition(BaseModel): active: Optional[bool] = None compute_type: Optional[str] = None input_uri: Optional[str] = None + depends_on: Optional[List[str]] = None + workflow_id: Optional[str] = None + trigger_rule: Optional[TriggerRule] = None class ListJobDefinitionsQuery(BaseModel): diff --git a/jupyter_scheduler/orm.py b/jupyter_scheduler/orm.py index dbbbfad8..f4ca1353 100644 --- a/jupyter_scheduler/orm.py +++ b/jupyter_scheduler/orm.py @@ -7,7 +7,7 @@ from sqlalchemy.orm import declarative_base, declarative_mixin, registry, sessionmaker from sqlalchemy.sql import text -from jupyter_scheduler.models import EmailNotifications, Status +from jupyter_scheduler.models import EmailNotifications, Status, TriggerRule from jupyter_scheduler.utils import get_utc_timestamp Base = declarative_base() @@ -89,6 +89,9 @@ class CommonColumns: # Any default values specified for new columns will be ignored during the migration process. package_input_folder = Column(Boolean) packaged_files = Column(JsonType, default=[]) + depends_on = Column(JsonType) + workflow_id = Column(String(36)) + trigger_rule = Column(String(64)) class Job(CommonColumns, Base): @@ -107,6 +110,35 @@ class Job(CommonColumns, Base): # Any default values specified for new columns will be ignored during the migration process. +class Workflow(Base): + __tablename__ = "workflows" + __table_args__ = {"extend_existing": True} + workflow_id = Column(String(36), primary_key=True, default=generate_uuid) + tasks = Column(JsonType) + status = Column(String(64), default=Status.DRAFT) + active = Column(Boolean, default=False) + name = Column(String(256)) + parameters = Column(JsonType(1024)) + create_time = Column(Integer, default=get_utc_timestamp) + # All new columns added to this table must be nullable to ensure compatibility during database migrations. + # Any default values specified for new columns will be ignored during the migration process. + + +class WorkflowDefinition(Base): + __tablename__ = "workflow_definitions" + __table_args__ = {"extend_existing": True} + workflow_definition_id = Column(String(36), primary_key=True, default=generate_uuid) + tasks = Column(JsonType) + status = Column(String(64), default=Status.DRAFT) + active = Column(Boolean, default=False) + schedule = Column(String(256)) + timezone = Column(String(36)) + name = Column(String(256)) + parameters = Column(JsonType(1024)) + # All new columns added to this table must be nullable to ensure compatibility during database migrations. + # Any default values specified for new columns will be ignored during the migration process. + + class JobDefinition(CommonColumns, Base): __tablename__ = "job_definitions" __table_args__ = {"extend_existing": True} diff --git a/jupyter_scheduler/scheduler.py b/jupyter_scheduler/scheduler.py index 867034c6..4bad9884 100644 --- a/jupyter_scheduler/scheduler.py +++ b/jupyter_scheduler/scheduler.py @@ -1,4 +1,3 @@ -import multiprocessing as mp import os import random import shutil @@ -6,6 +5,8 @@ import fsspec import psutil +from dask.distributed import Client as DaskClient +from distributed import LocalCluster from jupyter_core.paths import jupyter_data_dir from jupyter_server.transutils import _i18n from jupyter_server.utils import to_os_path @@ -38,12 +39,26 @@ UpdateJob, UpdateJobDefinition, ) -from jupyter_scheduler.orm import Job, JobDefinition, create_session +from jupyter_scheduler.orm import ( + Job, + JobDefinition, + Workflow, + WorkflowDefinition, + create_session, +) from jupyter_scheduler.utils import ( copy_directory, create_output_directory, create_output_filename, ) +from jupyter_scheduler.workflows import ( + CreateWorkflow, + CreateWorkflowDefinition, + DescribeWorkflow, + DescribeWorkflowDefinition, + UpdateWorkflow, + UpdateWorkflowDefinition, +) class BaseScheduler(LoggingConfigurable): @@ -109,6 +124,32 @@ def create_job(self, model: CreateJob) -> str: """ raise NotImplementedError("must be implemented by subclass") + def create_workflow(self, model: CreateWorkflow) -> str: + """Creates a new workflow record.""" + raise NotImplementedError("must be implemented by subclass") + + def run_workflow(self, workflow_id: str) -> str: + """Triggers execution of the workflow.""" + raise NotImplementedError("must be implemented by subclass") + + def deploy_workflow_definition(self, workflow_definition_id: str) -> str: + """Activates workflow marking it as ready for execution.""" + raise NotImplementedError("must be implemented by subclass") + + def get_workflow(self, workflow_id: str) -> DescribeWorkflow: + """Returns workflow record for a single workflow.""" + raise NotImplementedError("must be implemented by subclass") + + def create_workflow_task(self, workflow_id: str, model: CreateJob) -> str: + """Adds a task to a workflow.""" + raise NotImplementedError("must be implemented by subclass") + + def create_workflow_definition_task( + self, workflow_definition_id: str, model: CreateJobDefinition + ) -> str: + """Adds a task to a workflow definition.""" + raise NotImplementedError("must be implemented by subclass") + def update_job(self, job_id: str, model: UpdateJob): """Updates job metadata in the persistence store, for example name, status etc. In case of status @@ -160,6 +201,13 @@ def create_job_definition(self, model: CreateJobDefinition) -> str: """ raise NotImplementedError("must be implemented by subclass") + def create_workflow_definition(self, model: CreateWorkflowDefinition) -> str: + """Creates a new workflow definition record, + consider this as the template for creating + recurring/scheduled workflows. + """ + raise NotImplementedError("must be implemented by subclass") + def update_job_definition(self, job_definition_id: str, model: UpdateJobDefinition): """Updates job definition metadata in the persistence store, should only impact all future jobs. @@ -176,6 +224,10 @@ def get_job_definition(self, job_definition_id: str) -> DescribeJobDefinition: """Returns job definition record for a single job definition""" raise NotImplementedError("must be implemented by subclass") + def get_workflow_definition(self, workflow_definition_id: str) -> DescribeWorkflowDefinition: + """Returns workflow definition record for a single workflow definition""" + raise NotImplementedError("must be implemented by subclass") + def list_job_definitions(self, query: ListJobDefinitionsQuery) -> ListJobDefinitionsResponse: """Returns list of all job definitions filtered by query""" raise NotImplementedError("must be implemented by subclass") @@ -381,6 +433,12 @@ def get_local_output_path( else: return os.path.join(self.root_dir, self.output_directory, output_dir_name) + async def stop_extension(self): + """ + Placeholder method for a cleanup code to run when the server is stopping. + """ + pass + class Scheduler(BaseScheduler): _db_session = None @@ -395,10 +453,30 @@ class Scheduler(BaseScheduler): ), ) + workflow_runner_class = TType( + allow_none=True, + config=True, + default_value="jupyter_scheduler.workflow_runner.WorkflowRunner", + klass="jupyter_scheduler.workflow_runner.BaseWorkflowRunner", + help=_i18n( + "The class that handles the workflow creation of scheduled workflows from workflow definitions." + ), + ) + + dask_cluster_url = Unicode( + allow_none=True, + config=True, + help="URL of the Dask cluster to connect to.", + ) + db_url = Unicode(help=_i18n("Scheduler database url")) task_runner = Instance(allow_none=True, klass="jupyter_scheduler.task_runner.BaseTaskRunner") + workflow_runner = Instance( + allow_none=True, klass="jupyter_scheduler.workflow_runner.BaseWorkflowRunner" + ) + def __init__( self, root_dir: str, @@ -413,6 +491,20 @@ def __init__( self.db_url = db_url if self.task_runner_class: self.task_runner = self.task_runner_class(scheduler=self, config=config) + if self.workflow_runner_class: + self.workflow_runner = self.workflow_runner_class(scheduler=self, config=config) + self.dask_client: DaskClient = self._get_dask_client() + + def _get_dask_client(self): + """Creates and configures a Dask client.""" + if self.dask_cluster_url: + return DaskClient(self.dask_cluster_url) + print("Starting local Dask cluster") + cluster = LocalCluster(processes=True) + client = DaskClient(cluster) + print(client) + print(f"Dask dashboard link: {client.dashboard_link}") + return client @property def db_session(self): @@ -437,7 +529,7 @@ def copy_input_folder(self, input_uri: str, nb_copy_to_path: str) -> List[str]: destination_dir=staging_dir, ) - def create_job(self, model: CreateJob) -> str: + def create_job(self, model: CreateJob, run: bool = True) -> str: if not model.job_definition_id and not self.file_exists(model.input_uri): raise InputUriError(model.input_uri) @@ -478,31 +570,184 @@ def create_job(self, model: CreateJob) -> str: else: self.copy_input_file(model.input_uri, staging_paths["input"]) - # The MP context forces new processes to not be forked on Linux. - # This is necessary because `asyncio.get_event_loop()` is bugged in - # forked processes in Python versions below 3.12. This method is - # called by `jupyter_core` by `nbconvert` in the default executor. - # - # See: https://github.com/python/cpython/issues/66285 - # See also: https://github.com/jupyter/jupyter_core/pull/362 - mp_ctx = mp.get_context("spawn") - p = mp_ctx.Process( - target=self.execution_manager_class( - job_id=job.job_id, - staging_paths=staging_paths, - root_dir=self.root_dir, - db_url=self.db_url, - ).process - ) - p.start() + if not run: + return job.job_id + + job_id = self.run_job(job=job, staging_paths=staging_paths) + return job_id - job.pid = p.pid + def run_job(self, job: Job, staging_paths: Dict[str, str]) -> str: + with self.db_session() as session: + process_job = self.execution_manager_class( + job_id=job.job_id, + staging_paths=staging_paths, + root_dir=self.root_dir, + db_url=self.db_url, + ).process + future = self.dask_client.submit(process_job) + job.pid = future.key session.commit() job_id = job.job_id return job_id + def run_workflow_from_definition(self, model: DescribeWorkflowDefinition) -> str: + print( + f"calling scheduler.run_workflow_from_definition with DescribeWorkflowDefinition {model}" + ) + workflow_id = self.create_workflow( + CreateWorkflow( + **model.dict(exclude={"schedule", "timezone", "tasks"}, exclude_none=True), + ) + ) + task_definitions = self.get_workflow_definition_tasks(model.workflow_definition_id) + for task_definition in task_definitions: + self.create_workflow_task( + workflow_id=workflow_id, + model=CreateJob(**task_definition.dict(exclude={"schedule", "timezone"})), + ) + return workflow_id + + def create_workflow(self, model: CreateWorkflow) -> str: + print(f"calling scheduler.create_workflow with {model}") + print(model.dict) + with self.db_session() as session: + workflow = Workflow(**model.dict(exclude_none=True)) + session.add(workflow) + session.commit() + return workflow.workflow_id + + def run_workflow(self, workflow_id: str) -> str: + print(f"calling scheduler.run_workflow for {workflow_id}") + process_workflow = self.execution_manager_class( + workflow_id=workflow_id, + root_dir=self.root_dir, + db_url=self.db_url, + ).process_workflow + self.dask_client.submit(process_workflow) + return workflow_id + + def create_workflow_definition(self, model: CreateWorkflowDefinition) -> str: + with self.db_session() as session: + workflow_definition = WorkflowDefinition(**model.dict(exclude_none=True)) + session.add(workflow_definition) + session.commit() + return workflow_definition.workflow_definition_id + + def deploy_workflow_definition(self, workflow_definition_id: str) -> str: + with self.db_session() as session: + workflow_definition = ( + session.query(WorkflowDefinition) + .filter(WorkflowDefinition.workflow_definition_id == workflow_definition_id) + .with_for_update() + .one() + ) + workflow_definition_schedule = workflow_definition.schedule + session.query(WorkflowDefinition).filter( + WorkflowDefinition.workflow_definition_id == workflow_definition_id + ).update({"active": True, "status": Status.DEPLOYED}) + session.commit() + + if self.workflow_runner and workflow_definition_schedule: + self.workflow_runner.add_workflow_definition(workflow_definition_id) + + return workflow_definition_id + + def get_workflow(self, workflow_id: str) -> DescribeWorkflow: + with self.db_session() as session: + workflow_record = ( + session.query(Workflow).filter(Workflow.workflow_id == workflow_id).one() + ) + model = DescribeWorkflow.from_orm(workflow_record) + return model + + def get_all_workflows(self) -> List[DescribeWorkflow]: + with self.db_session() as session: + workflow_records = session.query(Workflow).all() + models = [ + DescribeWorkflow.from_orm(workflow_record) for workflow_record in workflow_records + ] + return models + + def get_workflow_definition(self, workflow_definition_id: str) -> List[Workflow]: + with self.db_session() as session: + workflow_definition_record = ( + session.query(WorkflowDefinition) + .filter(WorkflowDefinition.workflow_definition_id == workflow_definition_id) + .one() + ) + model = DescribeWorkflowDefinition.from_orm(workflow_definition_record) + return model + + def get_workflow_definition_tasks( + self, workflow_definition_id: str + ) -> List[DescribeJobDefinition]: + print(f"calling scheduler.get_workflow_definition_tasks for{workflow_definition_id}") + with self.db_session() as session: + task_records = ( + session.query(JobDefinition) + .filter(JobDefinition.workflow_id == workflow_definition_id) + .all() + ) + tasks = [DescribeJobDefinition.from_orm(task_record) for task_record in task_records] + return tasks + + def create_workflow_task(self, workflow_id: str, model: CreateJob) -> str: + print( + f"calling scheduler.create_workflow_task with workflow_id {workflow_id},\n CreateJob {model},\n about to call scheduler.create_job" + ) + job_id = self.create_job(model, run=False) + print(f"create_workflow_task job_id: {job_id}") + workflow: DescribeWorkflow = self.get_workflow(workflow_id) + print(f"workflow in create_workflow_task: {workflow}") + updated_tasks = (workflow.tasks or [])[:] + print(f"updated_tasks before update: {updated_tasks}") + updated_tasks.append(job_id) + print(f"updated_tasks after update: {updated_tasks}") + + self.update_workflow(workflow_id, UpdateWorkflow(tasks=updated_tasks)) + return job_id + + def create_workflow_definition_task( + self, workflow_definition_id: str, model: CreateJobDefinition + ) -> str: + job_definition_id = self.create_job_definition(model, add_to_task_runner=False) + workflow_definition: DescribeWorkflowDefinition = self.get_workflow_definition( + workflow_definition_id + ) + updated_tasks = (workflow_definition.tasks or [])[:] + updated_tasks.append(job_definition_id) + self.update_workflow_definition( + workflow_definition_id, UpdateWorkflowDefinition(tasks=updated_tasks) + ) + return job_definition_id + + def get_all_workflow_definition_tasks(self) -> List[DescribeWorkflowDefinition]: + with self.db_session() as session: + workflow_definition_records = session.query(WorkflowDefinition).all() + models = [ + DescribeWorkflowDefinition.from_orm(workflow_definition_record) + for workflow_definition_record in workflow_definition_records + ] + return models + + def update_workflow(self, workflow_id: str, model: UpdateWorkflow): + with self.db_session() as session: + session.query(Workflow).filter(Workflow.workflow_id == workflow_id).update( + model.dict(exclude_none=True) + ) + session.commit() + + def update_workflow_definition( + self, workflow_definition_id: str, model: UpdateWorkflowDefinition + ): + with self.db_session() as session: + session.query(WorkflowDefinition).filter( + WorkflowDefinition.workflow_definition_id == workflow_definition_id + ).update(model.dict(exclude_none=True)) + session.commit() + def update_job(self, job_id: str, model: UpdateJob): with self.db_session() as session: session.query(Job).filter(Job.job_id == job_id).update(model.dict(exclude_none=True)) @@ -604,7 +849,9 @@ def stop_job(self, job_id): session.commit() break - def create_job_definition(self, model: CreateJobDefinition) -> str: + def create_job_definition( + self, model: CreateJobDefinition, add_to_task_runner: bool = True + ) -> str: with self.db_session() as session: if not self.file_exists(model.input_uri): raise InputUriError(model.input_uri) @@ -628,7 +875,7 @@ def create_job_definition(self, model: CreateJobDefinition) -> str: else: self.copy_input_file(model.input_uri, staging_paths["input"]) - if self.task_runner and job_definition_schedule: + if add_to_task_runner and self.task_runner and job_definition_schedule: self.task_runner.add_job_definition(job_definition_id) return job_definition_id @@ -777,6 +1024,37 @@ def get_staging_paths(self, model: Union[DescribeJob, DescribeJobDefinition]) -> return staging_paths + @staticmethod + def get_staging_paths(model: Union[DescribeJob, DescribeJobDefinition]) -> Dict[str, str]: + staging_paths = {} + if not model: + return staging_paths + + id = model.job_id if isinstance(model, DescribeJob) else model.job_definition_id + + for output_format in model.output_formats: + filename = create_output_filename( + model.input_filename, model.create_time, output_format + ) + staging_paths[output_format] = os.path.join( + os.path.join(jupyter_data_dir(), "scheduler_staging_area"), id, filename + ) + + staging_paths["input"] = os.path.join( + os.path.join(jupyter_data_dir(), "scheduler_staging_area"), id, model.input_filename + ) + + return staging_paths + + async def stop_extension(self): + """ + Cleanup code to run when the server is stopping. + """ + if self.dask_client is None: + return + if self.dask_client and self.dask_client.close: + self.dask_client.close() + class ArchivingScheduler(Scheduler): """Scheduler that captures all files in output directory in an archive.""" diff --git a/jupyter_scheduler/workflow_runner.py b/jupyter_scheduler/workflow_runner.py new file mode 100644 index 00000000..a569349b --- /dev/null +++ b/jupyter_scheduler/workflow_runner.py @@ -0,0 +1,351 @@ +import asyncio +from dataclasses import dataclass +from datetime import datetime +from heapq import heappop, heappush +from typing import List, Optional + +import traitlets +from jupyter_server.transutils import _i18n +from sqlalchemy import Boolean, Column, Integer, String, create_engine +from sqlalchemy.orm import sessionmaker +from traitlets.config import LoggingConfigurable + +from jupyter_scheduler.orm import WorkflowDefinition, declarative_base +from jupyter_scheduler.pydantic_v1 import BaseModel +from jupyter_scheduler.utils import ( + compute_next_run_time, + get_localized_timestamp, + get_utc_timestamp, +) +from jupyter_scheduler.workflows import ( + CreateWorkflow, + DescribeWorkflowDefinition, + UpdateWorkflowDefinition, +) + +Base = declarative_base() + + +class WorkflowDefinitionCache(Base): + __tablename__ = "workflow_definitions_cache" + workflow_definition_id = Column(String(36), primary_key=True) + next_run_time = Column(Integer) + active = Column(Boolean) + timezone = Column(String(36)) + schedule = Column(String(256)) + + +class DescribeWorkflowDefinitionCache(BaseModel): + workflow_definition_id: str + next_run_time: int + active: bool + timezone: Optional[str] = None + schedule: str + + class Config: + orm_mode = True + + +class UpdateWorkflowDefinitionCache(BaseModel): + next_run_time: Optional[int] = None + active: Optional[bool] = None + timezone: Optional[str] = None + schedule: Optional[str] = None + + +@dataclass +class WorkflowDefinitionTask: + workflow_definition_id: str + next_run_time: int + + def __lt__(self, other): + return self.next_run_time < other.next_run_time + + def __str__(self): + next_run_time = datetime.fromtimestamp(self.next_run_time / 1e3) + return f"Id: {self.workflow_definition_id}, Run-time: {next_run_time}" + + +class PriorityQueue: + """A priority queue using heapq""" + + def __init__(self): + self._heap = [] + + def peek(self): + if self.isempty(): + raise "Queue is empty" + + return self._heap[0] + + def push(self, task: WorkflowDefinitionTask): + heappush(self._heap, task) + + def pop(self): + task = heappop(self._heap) + return task + + def __len__(self): + return len(self._heap) + + def isempty(self): + return len(self._heap) < 1 + + def __str__(self): + tasks = [] + for task in self._heap: + tasks.append(str(task)) + + return "\n".join(tasks) + + +class Cache: + def __init__(self) -> None: + self.cache_url = "sqlite://" + engine = create_engine(self.cache_url, echo=False) + Base.metadata.create_all(engine) + self.session = sessionmaker(bind=engine) + + def load(self, models: List[DescribeWorkflowDefinitionCache]): + with self.session() as session: + for model in models: + session.add(WorkflowDefinitionCache(**model.dict())) + session.commit() + + def get(self, workflow_definition_id: str) -> DescribeWorkflowDefinitionCache: + with self.session() as session: + definition = ( + session.query(WorkflowDefinitionCache) + .filter(WorkflowDefinitionCache.workflow_definition_id == workflow_definition_id) + .first() + ) + + if definition: + return DescribeWorkflowDefinitionCache.from_orm(definition) + else: + return None + + def put(self, model: DescribeWorkflowDefinitionCache): + with self.session() as session: + session.add(WorkflowDefinitionCache(**model.dict())) + session.commit() + + def update(self, workflow_definition_id: str, model: UpdateWorkflowDefinitionCache): + with self.session() as session: + session.query(WorkflowDefinitionCache).filter( + WorkflowDefinitionCache.workflow_definition_id == workflow_definition_id + ).update(model.dict(exclude_none=True)) + session.commit() + + def delete(self, workflow_definition_id: str): + with self.session() as session: + session.query(WorkflowDefinitionCache).filter( + WorkflowDefinitionCache.workflow_definition_id == workflow_definition_id + ).delete() + session.commit() + + +class BaseWorkflowRunner(LoggingConfigurable): + """Base task runner, this class's start method is called + at the start of jupyter server, and is responsible for + polling for the workflow definitions and creating new workflows + based on the schedule/timezone in the workflow definition. + """ + + def __init__(self, config=None, **kwargs): + super().__init__(config=config) + + poll_interval = traitlets.Integer( + default_value=10, + config=True, + help=_i18n( + "The interval in seconds that the task runner polls for scheduled workflows to run." + ), + ) + + async def start(self): + """Async method that is called by extension at server start""" + raise NotImplementedError("must be implemented by subclass") + + def add_workflow_definition(self, workflow_definition_id: str): + """This should handle adding data for new + workflow definition to the PriorityQueue and Cache.""" + raise NotImplementedError("must be implemented by subclass") + + def update_workflow_definition( + self, workflow_definition_id: str, model: UpdateWorkflowDefinition + ): + """This should handles updates to workflow definitions""" + NotImplementedError("must be implemented by subclass") + + def delete_workflow_definition(self, workflow_definition_id: str): + """Handles deletion of workflow definitions""" + NotImplementedError("must be implemented by subclass") + + def pause_workflows(self, workflow_definition_id: str): + """Handles pausing a workflow definition""" + NotImplementedError("must be implemented by subclass") + + def resume_workflows(self, workflow_definition_id: str): + """Handles resuming of a workflow definition""" + NotImplementedError("must be implemented by subclass") + + +class WorkflowRunner(BaseWorkflowRunner): + """Default workflow runner that maintains a workflow definition cache and a + priority queue, and polls the queue every `poll_interval` seconds + for new workflows to create. + """ + + def __init__(self, scheduler, config=None) -> None: + super().__init__(config=config) + self.scheduler = scheduler + self.db_session = scheduler.db_session + self.cache = Cache() + self.queue = PriorityQueue() + + def compute_next_run_time(self, schedule: str, timezone: Optional[str] = None): + return compute_next_run_time(schedule, timezone) + + def populate_cache(self): + with self.db_session() as session: + definitions: List[WorkflowDefinition] = ( + session.query(WorkflowDefinition).filter(WorkflowDefinition.schedule != None).all() + ) + + for definition in definitions: + next_run_time = self.compute_next_run_time(definition.schedule, definition.timezone) + self.cache.put( + DescribeWorkflowDefinitionCache( + workflow_definition_id=definition.workflow_definition_id, + next_run_time=next_run_time, + active=definition.active, + timezone=definition.timezone, + schedule=definition.schedule, + ) + ) + if definition.active: + self.queue.push( + WorkflowDefinitionTask( + workflow_definition_id=definition.workflow_definition_id, + next_run_time=next_run_time, + ) + ) + + def add_workflow_definition(self, workflow_definition_id: str): + with self.db_session() as session: + definition = ( + session.query(WorkflowDefinition) + .filter(WorkflowDefinition.workflow_definition_id == workflow_definition_id) + .first() + ) + + next_run_time = self.compute_next_run_time(definition.schedule, definition.timezone) + + self.cache.put( + DescribeWorkflowDefinitionCache( + workflow_definition_id=definition.workflow_definition_id, + active=definition.active, + next_run_time=next_run_time, + timezone=definition.timezone, + schedule=definition.schedule, + ) + ) + if definition.active: + self.queue.push( + WorkflowDefinitionTask( + workflow_definition_id=definition.workflow_definition_id, + next_run_time=next_run_time, + ) + ) + + def update_workflow_definition( + self, workflow_definition_id: str, model: UpdateWorkflowDefinition + ): + cache = self.cache.get(workflow_definition_id) + schedule = model.schedule or cache.schedule + timezone = model.timezone or cache.timezone + active = model.active if model.active is not None else cache.active + cached_next_run_time = cache.next_run_time + next_run_time = self.compute_next_run_time(schedule, timezone) + + self.cache.update( + workflow_definition_id, + UpdateWorkflowDefinitionCache( + timezone=timezone, next_run_time=next_run_time, active=active, schedule=schedule + ), + ) + + next_run_time_changed = cached_next_run_time != next_run_time and active + resumed_workflow = model.active and not cache.active + + if next_run_time_changed or resumed_workflow: + self.log.debug("Updating queue...") + task = WorkflowDefinitionTask( + workflow_definition_id=workflow_definition_id, next_run_time=next_run_time + ) + self.queue.push(task) + self.log.debug(f"Updated queue, {task}") + + def delete_workflow_definition(self, workflow_definition_id: str): + self.cache.delete(workflow_definition_id) + + def create_and_run_workflow(self, workflow_definition_id: str): + definition: DescribeWorkflowDefinition = self.scheduler.get_workflow_definition( + workflow_definition_id + ) + print(f"calling workflow_runner.create_and_run_workflow with {definition.dict}") + if definition and definition.active: + print( + f"calling self.scheduler.run_workflow_from_definition from workflow_runner.create_and_run_workflow with {definition.dict}" + ) + self.scheduler.run_workflow_from_definition(definition) + + def compute_time_diff(self, queue_run_time: int, timezone: str): + local_time = get_localized_timestamp(timezone) if timezone else get_utc_timestamp() + return local_time - queue_run_time + + def process_queue(self): + self.log.debug(self.queue) + while not self.queue.isempty(): + task: WorkflowDefinitionTask = self.queue.peek() + cache = self.cache.get(task.workflow_definition_id) + + if not cache: + self.queue.pop() + continue + + cache_run_time = cache.next_run_time + queue_run_time = task.next_run_time + + if not cache.active or queue_run_time != cache_run_time: + self.queue.pop() + continue + + time_diff = self.compute_time_diff(queue_run_time, cache.timezone) + + # if run time is in future + if time_diff < 0: + break + else: + try: + self.create_and_run_workflow(task.workflow_definition_id) + except Exception as e: + self.log.exception(e) + self.queue.pop() + run_time = self.compute_next_run_time(cache.schedule, cache.timezone) + self.cache.update( + task.workflow_definition_id, + UpdateWorkflowDefinitionCache(next_run_time=run_time), + ) + self.queue.push( + WorkflowDefinitionTask( + workflow_definition_id=task.workflow_definition_id, next_run_time=run_time + ) + ) + + async def start(self): + self.populate_cache() + while True: + self.process_queue() + await asyncio.sleep(self.poll_interval) diff --git a/jupyter_scheduler/workflows.py b/jupyter_scheduler/workflows.py new file mode 100644 index 00000000..60c281bc --- /dev/null +++ b/jupyter_scheduler/workflows.py @@ -0,0 +1,445 @@ +import json +from typing import Dict, List, Optional + +from jupyter_server.utils import ensure_async +from tornado.web import HTTPError, authenticated + +from jupyter_scheduler.exceptions import ( + IdempotencyTokenError, + InputUriError, + SchedulerError, +) +from jupyter_scheduler.handlers import ( + APIHandler, + ExtensionHandlerMixin, + JobHandlersMixin, +) +from jupyter_scheduler.models import ( + CreateJob, + CreateJobDefinition, + Status, + UpdateJob, + UpdateJobDefinition, +) +from jupyter_scheduler.pydantic_v1 import BaseModel, ValidationError + + +class WorkflowsHandler(ExtensionHandlerMixin, JobHandlersMixin, APIHandler): + @authenticated + async def post(self): + payload = self.get_json_body() or {} + try: + workflow_id = await ensure_async( + self.scheduler.create_workflow(CreateWorkflow(**payload)) + ) + except ValidationError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except InputUriError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except IdempotencyTokenError as e: + self.log.exception(e) + raise HTTPError(409, str(e)) from e + except SchedulerError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except Exception as e: + self.log.exception(e) + raise HTTPError(500, "Unexpected error occurred during creation of a workflow.") from e + else: + self.finish(json.dumps(dict(workflow_id=workflow_id))) + + @authenticated + async def get(self, workflow_id: str = None): + if not workflow_id: + raise HTTPError(400, "Missing workflow_id in the request URL.") + try: + workflow = await ensure_async(self.scheduler.get_workflow(workflow_id)) + except SchedulerError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except Exception as e: + self.log.exception(e) + raise HTTPError(500, "Unexpected error occurred while getting workflow details.") from e + else: + self.finish(workflow.json()) + + @authenticated + async def get(self, workflow_id: str = None): + if workflow_id: + try: + workflow = await ensure_async(self.scheduler.get_workflow(workflow_id)) + except SchedulerError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except Exception as e: + self.log.exception(e) + raise HTTPError( + 500, "Unexpected error occurred while getting workflow details." + ) from e + else: + self.finish(workflow.json()) + else: + try: + workflows = await ensure_async(self.scheduler.get_all_workflows()) + workflows_json = [workflow.dict() for workflow in workflows] + except SchedulerError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except Exception as e: + self.log.exception(e) + raise HTTPError( + 500, "Unexpected error occurred while getting all workflows details." + ) from e + else: + self.finish(json.dumps(workflows_json)) + + +class WorkflowsTasksHandler(ExtensionHandlerMixin, JobHandlersMixin, APIHandler): + @authenticated + async def post(self, workflow_id: str): + payload = self.get_json_body() + try: + task_id = await ensure_async( + self.scheduler.create_workflow_task( + workflow_id=workflow_id, model=CreateJob(**payload) + ) + ) + except ValidationError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except InputUriError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except IdempotencyTokenError as e: + self.log.exception(e) + raise HTTPError(409, str(e)) from e + except SchedulerError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except Exception as e: + self.log.exception(e) + raise HTTPError( + 500, "Unexpected error occurred during creation of workflow job." + ) from e + else: + self.finish(json.dumps(dict(task_id=task_id))) + + @authenticated + async def patch(self, _: str, task_id: str): + payload = self.get_json_body() + status = payload.get("status") + status = Status(status) if status else None + + if status and status != Status.STOPPED: + raise HTTPError( + 500, + "Invalid value for field 'status'. Workflow task status can only be updated to status 'STOPPED' after creation.", + ) + try: + if status: + await ensure_async(self.scheduler.stop_job(task_id)) + else: + await ensure_async(self.scheduler.update_job(task_id, UpdateJob(**payload))) + except ValidationError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except SchedulerError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except Exception as e: + self.log.exception(e) + raise HTTPError( + 500, "Unexpected error occurred while updating the workflow job." + ) from e + else: + self.set_status(204) + self.finish() + + +class WorkflowsRunHandler(ExtensionHandlerMixin, JobHandlersMixin, APIHandler): + @authenticated + async def post(self, workflow_id: str): + try: + workflow_id = await ensure_async(self.scheduler.run_workflow(workflow_id)) + except ValidationError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except InputUriError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except IdempotencyTokenError as e: + self.log.exception(e) + raise HTTPError(409, str(e)) from e + except SchedulerError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except Exception as e: + self.log.exception(e) + raise HTTPError( + 500, "Unexpected error occurred during attempt to run a workflow." + ) from e + else: + self.finish(json.dumps(dict(workflow_id=workflow_id))) + + +class WorkflowDefinitionsHandler(ExtensionHandlerMixin, JobHandlersMixin, APIHandler): + @authenticated + async def post(self): + payload = self.get_json_body() or {} + try: + workflow_definition_id = await ensure_async( + self.scheduler.create_workflow_definition(CreateWorkflowDefinition(**payload)) + ) + except ValidationError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except InputUriError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except IdempotencyTokenError as e: + self.log.exception(e) + raise HTTPError(409, str(e)) from e + except SchedulerError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except Exception as e: + self.log.exception(e) + raise HTTPError( + 500, "Unexpected error occurred during creation of a workflow definition." + ) from e + else: + self.finish(json.dumps(dict(workflow_definition_id=workflow_definition_id))) + + @authenticated + async def get(self, workflow_definition_id: str = None): + if workflow_definition_id: + try: + workflow_definition = await ensure_async( + self.scheduler.get_workflow_definition(workflow_definition_id) + ) + except SchedulerError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except Exception as e: + self.log.exception(e) + raise HTTPError( + 500, "Unexpected error occurred while getting workflow definition details." + ) from e + else: + self.finish(workflow_definition.json()) + else: + try: + workflow_definitions = await ensure_async( + self.scheduler.get_all_workflow_definitions() + ) + workflow_definitions_json = [ + workflow_definition.dict() for workflow_definition in workflow_definitions + ] + except SchedulerError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except Exception as e: + self.log.exception(e) + raise HTTPError( + 500, + "Unexpected error occurred while getting all workflows definitions details.", + ) from e + else: + self.finish(json.dumps(workflow_definitions_json)) + + +class WorkflowDefinitionsTasksHandler(ExtensionHandlerMixin, JobHandlersMixin, APIHandler): + @authenticated + async def post(self, workflow_definition_id: str): + payload = self.get_json_body() + try: + task_defintion_id = await ensure_async( + self.scheduler.create_workflow_definition_task( + workflow_definition_id=workflow_definition_id, + model=CreateJobDefinition(**payload), + ) + ) + except ValidationError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except InputUriError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except IdempotencyTokenError as e: + self.log.exception(e) + raise HTTPError(409, str(e)) from e + except SchedulerError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except Exception as e: + self.log.exception(e) + raise HTTPError( + 500, "Unexpected error occurred during creation of workflow definition task." + ) from e + else: + self.finish(json.dumps(dict(task_defintion_id=task_defintion_id))) + + @authenticated + async def patch(self, _: str, task_definition_id: str): + payload = self.get_json_body() + status = payload.get("status") + status = Status(status) if status else None + + try: + await ensure_async( + self.scheduler.update_job_definition( + task_definition_id, UpdateJobDefinition(**payload) + ) + ) + except ValidationError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except SchedulerError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except Exception as e: + self.log.exception(e) + raise HTTPError( + 500, "Unexpected error occurred while updating the workflow definition task." + ) from e + else: + self.set_status(204) + self.finish() + + @authenticated + async def get(self, workflow_definition_id: str = None): + if workflow_definition_id: + try: + task_definitions = await ensure_async( + self.scheduler.get_workflow_definition_tasks(workflow_definition_id) + ) + task_definitions_json = [ + task_definition.dict() for task_definition in task_definitions + ] + except SchedulerError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except Exception as e: + self.log.exception(e) + raise HTTPError( + 500, + "Unexpected error occurred while getting workflow task definitions details.", + ) from e + else: + self.finish(json.dumps(task_definitions_json)) + else: + try: + task_definitions = await ensure_async( + self.scheduler.get_all_workflow_definition_tasks() + ) + task_definitions_json = [ + task_definition.dict() for task_definition in task_definitions + ] + except SchedulerError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except Exception as e: + self.log.exception(e) + raise HTTPError( + 500, + "Unexpected error occurred while getting all task definitions details.", + ) from e + else: + self.finish(json.dumps(task_definitions_json)) + + +class WorkflowDefinitionsDeploymentHandler(ExtensionHandlerMixin, JobHandlersMixin, APIHandler): + @authenticated + async def post(self, workflow_definition_id: str): + try: + workflow_definition_id = await ensure_async( + self.scheduler.deploy_workflow_definition(workflow_definition_id) + ) + except ValidationError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except InputUriError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except IdempotencyTokenError as e: + self.log.exception(e) + raise HTTPError(409, str(e)) from e + except SchedulerError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except Exception as e: + self.log.exception(e) + raise HTTPError( + 500, "Unexpected error occurred during attempt to run a workflow." + ) from e + else: + self.finish(json.dumps(dict(workflow_definition_id=workflow_definition_id))) + + +class CreateWorkflow(BaseModel): + tasks: List[str] = [] + name: str + parameters: Optional[Dict[str, str]] = None + + +class DescribeWorkflow(BaseModel): + name: str + parameters: Optional[Dict[str, str]] = None + workflow_id: str + tasks: List[str] = None + status: Status = Status.CREATED + active: Optional[bool] = None + + class Config: + orm_mode = True + + +class UpdateWorkflow(BaseModel): + name: Optional[str] = None + parameters: Optional[Dict[str, str]] = None + tasks: Optional[List[str]] = None + status: Optional[Status] = None + active: Optional[bool] = None + + class Config: + orm_mode = True + + +class CreateWorkflowDefinition(BaseModel): + tasks: List[str] = [] + # any field added to CreateWorkflow should also be added to this model as well + name: str = "" + parameters: Optional[Dict[str, str]] = None + schedule: Optional[str] = None + timezone: Optional[str] = None + + class Config: + orm_mode = True + + +class DescribeWorkflowDefinition(BaseModel): + name: str + parameters: Optional[Dict[str, str]] = None + workflow_definition_id: str + tasks: List[str] = None + schedule: Optional[str] = None + timezone: Optional[str] = None + status: Status = Status.CREATED + active: Optional[bool] = None + + class Config: + orm_mode = True + + +class UpdateWorkflowDefinition(BaseModel): + name: Optional[str] = None + parameters: Optional[Dict[str, str]] = None + tasks: Optional[List[str]] = None + schedule: Optional[str] = None + timezone: Optional[str] = None + active: Optional[bool] = None + + class Config: + orm_mode = True diff --git a/src/handler.ts b/src/handler.ts index 4381bbd3..285e44ef 100644 --- a/src/handler.ts +++ b/src/handler.ts @@ -372,6 +372,7 @@ export namespace Scheduler { timezone?: string; active?: boolean; input_uri?: string; + depends_on?: string[]; } export interface IDescribeJobDefinition { @@ -418,6 +419,8 @@ export namespace Scheduler { output_formats?: string[]; compute_type?: string; package_input_folder?: boolean; + depends_on?: string[]; + workflow_id?: string; } export interface ICreateJobFromDefinition { @@ -467,6 +470,8 @@ export namespace Scheduler { end_time?: number; downloaded: boolean; package_input_folder?: boolean; + depends_on?: string[]; + workflow_id?: string; } export interface ICreateJobResponse { diff --git a/src/model.ts b/src/model.ts index 01b501cb..f4ed1326 100644 --- a/src/model.ts +++ b/src/model.ts @@ -100,6 +100,8 @@ export interface ICreateJobModel // Is the create button disabled due to a submission in progress? createInProgress?: boolean; packageInputFolder?: boolean; + dependsOn?: string[]; + workflowId?: string; } export const defaultScheduleFields: ModelWithScheduleFields = { @@ -312,6 +314,8 @@ export interface IJobDetailModel { job_files: Scheduler.IJobFile[]; downloaded: boolean; packageInputFolder?: boolean; + dependsOn?: string[]; + workflowId?: string; } export interface IJobDefinitionModel { @@ -388,7 +392,9 @@ export function convertDescribeJobtoJobDetail( startTime: describeJob.start_time, endTime: describeJob.end_time, downloaded: describeJob.downloaded, - packageInputFolder: describeJob.package_input_folder + packageInputFolder: describeJob.package_input_folder, + dependsOn: describeJob.depends_on, + workflowId: describeJob.workflow_id }; }