From 03c7438bafafb16a186b97b3956b9b756cb1140c Mon Sep 17 00:00:00 2001 From: Tom Close Date: Thu, 23 Jan 2025 17:25:33 +1100 Subject: [PATCH] touching up typing of tasks to include TaskDef template --- pydra/engine/core.py | 2 +- pydra/engine/environments.py | 9 ++- pydra/engine/helpers.py | 6 +- pydra/engine/lazy.py | 4 +- pydra/engine/specs.py | 35 +++++---- pydra/engine/submitter.py | 148 ++++++----------------------------- pydra/engine/workers.py | 50 +++++------- 7 files changed, 77 insertions(+), 177 deletions(-) diff --git a/pydra/engine/core.py b/pydra/engine/core.py index d6f4c2ac8..949c4373a 100644 --- a/pydra/engine/core.py +++ b/pydra/engine/core.py @@ -409,7 +409,7 @@ async def run_async(self, rerun: bool = False): self.audit.start_audit(odir=self.output_dir) try: self.audit.monitor() - await self.definition._run(self) + await self.definition._run_async(self) result.outputs = self.definition.Outputs._from_task(self) except Exception: etype, eval, etr = sys.exc_info() diff --git a/pydra/engine/environments.py b/pydra/engine/environments.py index f0d1d9ee9..06fd3fdef 100644 --- a/pydra/engine/environments.py +++ b/pydra/engine/environments.py @@ -4,6 +4,7 @@ if ty.TYPE_CHECKING: from pydra.engine.core import Task + from pydra.engine.specs import ShellDef class Environment: @@ -17,7 +18,7 @@ class Environment: def setup(self): pass - def execute(self, task: "Task") -> dict[str, ty.Any]: + def execute(self, task: "Task[ShellDef]") -> dict[str, ty.Any]: """ Execute the task in the environment. @@ -42,7 +43,7 @@ class Native(Environment): Native environment, i.e. the tasks are executed in the current python environment. """ - def execute(self, task: "Task") -> dict[str, ty.Any]: + def execute(self, task: "Task[ShellDef]") -> dict[str, ty.Any]: keys = ["return_code", "stdout", "stderr"] values = execute(task.definition._command_args()) output = dict(zip(keys, values)) @@ -90,7 +91,7 @@ def bind(self, loc, mode="ro"): class Docker(Container): """Docker environment.""" - def execute(self, task: "Task") -> dict[str, ty.Any]: + def execute(self, task: "Task[ShellDef]") -> dict[str, ty.Any]: docker_img = f"{self.image}:{self.tag}" # mounting all input locations mounts = task.definition._get_bindings(root=self.root) @@ -125,7 +126,7 @@ def execute(self, task: "Task") -> dict[str, ty.Any]: class Singularity(Container): """Singularity environment.""" - def execute(self, task: "Task") -> dict[str, ty.Any]: + def execute(self, task: "Task[ShellDef]") -> dict[str, ty.Any]: singularity_img = f"{self.image}:{self.tag}" # mounting all input locations mounts = task.definition._get_bindings(root=self.root) diff --git a/pydra/engine/helpers.py b/pydra/engine/helpers.py index 8ce9c7209..a158c5c31 100644 --- a/pydra/engine/helpers.py +++ b/pydra/engine/helpers.py @@ -25,6 +25,8 @@ PYDRA_ATTR_METADATA = "__PYDRA_METADATA__" +DefType = ty.TypeVar("DefType", bound="TaskDef") + def attrs_fields(definition, exclude_names=()) -> list[attrs.Attribute]: """Get the fields of a definition, excluding some names.""" @@ -132,7 +134,7 @@ def load_result(checksum, cache_locations): def save( task_path: Path, result: "Result | None" = None, - task: "Task | None" = None, + task: "Task[DefType] | None" = None, name_prefix: str = None, ) -> None: """ @@ -449,7 +451,7 @@ def load_and_run(task_pkl: Path, rerun: bool = False) -> Path: from .specs import Result try: - task: Task = load_task(task_pkl=task_pkl) + task: Task[DefType] = load_task(task_pkl=task_pkl) except Exception: if task_pkl.parent.exists(): etype, eval, etr = sys.exc_info() diff --git a/pydra/engine/lazy.py b/pydra/engine/lazy.py index c1b0a8820..c938833c8 100644 --- a/pydra/engine/lazy.py +++ b/pydra/engine/lazy.py @@ -10,9 +10,11 @@ from .graph import DiGraph from .submitter import NodeExecution from .core import Task, Workflow + from .specs import TaskDef T = ty.TypeVar("T") +DefType = ty.TypeVar("DefType", bound="TaskDef") TypeOrAny = ty.Union[type, ty.Any] @@ -150,7 +152,7 @@ def get_value( task = graph.node(self.node.name).task(state_index) _, split_depth = TypeParser.strip_splits(self.type) - def get_nested(task: "Task", depth: int): + def get_nested(task: "Task[DefType]", depth: int): if isinstance(task, StateArray): val = [get_nested(task=t, depth=depth - 1) for t in task] if depth: diff --git a/pydra/engine/specs.py b/pydra/engine/specs.py index 6898b9274..62e7b6c6d 100644 --- a/pydra/engine/specs.py +++ b/pydra/engine/specs.py @@ -40,12 +40,14 @@ from pydra.engine.graph import DiGraph from pydra.engine.submitter import NodeExecution from pydra.engine.lazy import LazyOutField - from pydra.engine.task import ShellTask from pydra.engine.core import Workflow from pydra.engine.environments import Environment from pydra.engine.workers import Worker +DefType = ty.TypeVar("DefType", bound="TaskDef") + + def is_set(value: ty.Any) -> bool: """Check if a value has been set.""" return value not in (attrs.NOTHING, EMPTY) @@ -372,7 +374,7 @@ def _compute_hashes(self) -> ty.Tuple[bytes, ty.Dict[str, bytes]]: } return hash_function(sorted(field_hashes.items())), field_hashes - def _retrieve_values(self, wf, state_index=None): + def _resolve_lazy_fields(self, wf, state_index=None): """Parse output results.""" temp_values = {} for field in attrs_fields(self): @@ -482,7 +484,7 @@ class Runtime: class Result(ty.Generic[OutputsType]): """Metadata regarding the outputs of processing.""" - task: "Task" + task: "Task[DefType]" outputs: OutputsType | None = None runtime: Runtime | None = None errored: bool = False @@ -548,13 +550,13 @@ class RuntimeSpec: class PythonOutputs(TaskOutputs): @classmethod - def _from_task(cls, task: "Task") -> Self: + def _from_task(cls, task: "Task[PythonDef]") -> Self: """Collect the outputs of a task from a combination of the provided inputs, the objects in the output directory, and the stdout and stderr of the process. Parameters ---------- - task : Task + task : Task[PythonDef] The task whose outputs are being collected. outputs_dict : dict[str, ty.Any] The outputs of the task, as a dictionary @@ -575,7 +577,7 @@ def _from_task(cls, task: "Task") -> Self: class PythonDef(TaskDef[PythonOutputsType]): - def _run(self, task: "Task") -> None: + def _run(self, task: "Task[PythonDef]") -> None: # Prepare the inputs to the function inputs = attrs_values(self) del inputs["function"] @@ -602,12 +604,12 @@ def _run(self, task: "Task") -> None: class WorkflowOutputs(TaskOutputs): @classmethod - def _from_task(cls, task: "Task") -> Self: + def _from_task(cls, task: "Task[WorkflowDef]") -> Self: """Collect the outputs of a workflow task from the outputs of the nodes in the Parameters ---------- - task : Task + task : Task[WorfklowDef] The task whose outputs are being collected. Returns @@ -659,12 +661,13 @@ class WorkflowDef(TaskDef[WorkflowOutputsType]): _constructed = attrs.field(default=None, init=False) - def _run(self, task: "Task") -> None: + def _run(self, task: "Task[WorkflowDef]") -> None: """Run the workflow.""" - if task.submitter.worker.is_async: - task.submitter.expand_workflow_async(task) - else: - task.submitter.expand_workflow(task) + task.submitter.expand_workflow(task) + + async def _run_async(self, task: "Task[WorkflowDef]") -> None: + """Run the workflow asynchronously.""" + await task.submitter.expand_workflow_async(task) def construct(self) -> "Workflow": from pydra.engine.core import Workflow @@ -688,7 +691,7 @@ class ShellOutputs(TaskOutputs): stderr: str = shell.out(help=STDERR_HELP) @classmethod - def _from_task(cls, task: "ShellTask") -> Self: + def _from_task(cls, task: "Task[ShellDef]") -> Self: """Collect the outputs of a shell process from a combination of the provided inputs, the objects in the output directory, and the stdout and stderr of the process. @@ -784,7 +787,7 @@ def _required_fields_satisfied(cls, fld: shell.out, inputs: "ShellDef") -> bool: def _resolve_value( cls, fld: "shell.out", - task: "Task", + task: "Task[DefType]", ) -> ty.Any: """Collect output file if metadata specified.""" from pydra.design import shell @@ -842,7 +845,7 @@ class ShellDef(TaskDef[ShellOutputsType]): RESERVED_FIELD_NAMES = TaskDef.RESERVED_FIELD_NAMES + ("cmdline",) - def _run(self, task: "Task") -> None: + def _run(self, task: "Task[ShellDef]") -> None: """Run the shell command.""" task.return_values = task.environment.execute(task) diff --git a/pydra/engine/submitter.py b/pydra/engine/submitter.py index df0512e07..c059a20ac 100644 --- a/pydra/engine/submitter.py +++ b/pydra/engine/submitter.py @@ -32,6 +32,8 @@ from .specs import TaskDef, WorkflowDef from .environments import Environment +DefType = ty.TypeVar("DefType", bound="TaskDef") + # Used to flag development mode of Audit develop = False @@ -155,9 +157,9 @@ def Split(): ) task = Task(task_def, submitter=self, name="task", environment=self.environment) if task.is_async: - self.loop.run_until_complete(self.expand_runnable_async(task)) + self.loop.run_until_complete(task.run_async(rerun=self.rerun)) else: - self.expand_runnable(task) + task.run(rerun=self.rerun) PersistentCache().clean_up() result = task.result() if result is None: @@ -187,72 +189,6 @@ def __setstate__(self, state): self._worker = WORKERS[self.worker_name](**self.worker_kwargs) self.worker.loop = self.loop - def expand_runnable(self, task: "Task"): - """ - This coroutine handles state expansion. - - Removes any states from `runnable`. If `wait` is - set to False (default), aggregates all worker - execution coroutines and returns them. If `wait` is - True, waits for all coroutines to complete / error - and returns None. - - Parameters - ---------- - runnable : pydra Task - Task instance (`Task`, `Workflow`) - wait : bool (False) - Await all futures before completing - - Returns - ------- - futures : set or None - Coroutines for :class:`~pydra.engine.core.TaskBase` execution. - - """ - task.run(rerun=self.rerun) - - async def expand_runnable_async( - self, task: "Task", wait=False - ) -> set[ty.Coroutine] | None: - """ - This coroutine handles state expansion. - - Removes any states from `runnable`. If `wait` is - set to False (default), aggregates all worker - execution coroutines and returns them. If `wait` is - True, waits for all coroutines to complete / error - and returns None. - - Parameters - ---------- - runnable : pydra Task - Task instance (`Task`, `Workflow`) - wait : bool (False) - Await all futures before completing - - Returns - ------- - futures : set or None - Coroutines for :class:`~pydra.engine.core.TaskBase` execution. - - """ - futures = set() - - if is_workflow(task.definition): - futures.add(asyncio.create_task(task.run(self.rerun))) - else: - task_pkl = await prepare_runnable(task) - futures.add(self.worker.run((task_pkl, task), rerun=self.rerun)) - - if wait and futures: - # if wait is True, we are at the end of the graph / state expansion. - # Once the remaining jobs end, we will exit `submit_from_call` - await asyncio.gather(*futures) - return - # pass along futures to be awaited independently - return futures - def expand_workflow(self, workflow_task: "Task[WorkflowDef]") -> None: """Expands and executes a workflow task synchronously. Typically only used during debugging and testing, as the asynchronous version is more efficient. @@ -273,12 +209,12 @@ def expand_workflow(self, workflow_task: "Task[WorkflowDef]") -> None: # grab inputs if needed logger.debug(f"Retrieving inputs for {task}") # TODO: add state idx to retrieve values to reduce waiting - task.definition._retrieve_values(wf) + task.definition._resolve_lazy_fields(wf) self.worker.run(task, rerun=self.rerun) tasks = self.get_runnable_tasks(exec_graph) workflow_task.return_values = {"workflow": wf, "exec_graph": exec_graph} - async def expand_workflow_async(self, task: "Task[WorkflowDef]") -> None: + async def expand_workflow_async(self, workflow_task: "Task[WorkflowDef]") -> None: """ Expand and execute a workflow task asynchronously. @@ -287,7 +223,7 @@ async def expand_workflow_async(self, task: "Task[WorkflowDef]") -> None: task : :obj:`~pydra.engine.core.Task[WorkflowDef]` Workflow Task object """ - wf = task.definition.construct() + wf = workflow_task.definition.construct() # Generate the execution graph exec_graph = wf.execution_graph(submitter=self) # keep track of pending futures @@ -301,7 +237,7 @@ async def expand_workflow_async(self, task: "Task[WorkflowDef]") -> None: # so try to get_runnable_tasks for another minute ii = 0 while not tasks and exec_graph.nodes: - tasks, follow_err = self.get_runnable_tasks(exec_graph) + tasks = self.get_runnable_tasks(exec_graph) ii += 1 # don't block the event loop! await asyncio.sleep(1) @@ -312,7 +248,7 @@ async def expand_workflow_async(self, task: "Task[WorkflowDef]") -> None: "results predecessors:\n\n" ) # Get blocked tasks and the predecessors they are waiting on - outstanding: dict[Task, list[Task]] = { + outstanding: dict[Task[DefType], list[Task[DefType]]] = { t: [ p for p in exec_graph.predecessors[t.name] if not p.done ] @@ -359,7 +295,7 @@ async def expand_workflow_async(self, task: "Task[WorkflowDef]") -> None: # grab inputs if needed logger.debug(f"Retrieving inputs for {task}") # TODO: add state idx to retrieve values to reduce waiting - task.definition._retrieve_values(wf) + task.definition._resolve_lazy_fields(wf) if is_workflow(task): await task.run(self) # single task @@ -367,7 +303,7 @@ async def expand_workflow_async(self, task: "Task[WorkflowDef]") -> None: task_futures.add(self.worker.run(task, rerun=self.rerun)) task_futures = await self.worker.fetch_finished(task_futures) tasks = self.get_runnable_tasks(exec_graph) - task.return_values = {"workflow": wf, "exec_graph": exec_graph} + workflow_task.return_values = {"workflow": wf, "exec_graph": exec_graph} def __enter__(self): return self @@ -386,10 +322,7 @@ def close(self): if self._own_loop: self.loop.close() - def get_runnable_tasks( - self, - graph: DiGraph, - ) -> tuple[list["Task"], dict["NodeExecution", list[str]]]: + def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]: """Parse a graph and return all runnable tasks. Parameters @@ -435,7 +368,7 @@ def cache_dir(self, location): self._cache_dir = Path(self._cache_dir).resolve() -class NodeExecution: +class NodeExecution(ty.Generic[DefType]): """A wrapper around a workflow node containing the execution state of the tasks that are generated from it""" @@ -444,17 +377,17 @@ class NodeExecution: submitter: Submitter # List of tasks that were completed successfully - successful: dict[StateIndex | None, list["Task"]] + successful: dict[StateIndex | None, list["Task[DefType]"]] # List of tasks that failed - errored: dict[StateIndex | None, "Task"] + errored: dict[StateIndex | None, "Task[DefType]"] # List of tasks that couldn't be run due to upstream errors - unrunnable: dict[StateIndex | None, list["Task"]] + unrunnable: dict[StateIndex | None, list["Task[DefType]"]] # List of tasks that are running - running: dict[StateIndex | None, "Task"] + running: dict[StateIndex | None, "Task[DefType]"] # List of tasks that are waiting on other tasks to complete before they can be run - waiting: dict[StateIndex | None, "Task"] + waiting: dict[StateIndex | None, "Task[DefType]"] - _tasks: dict[StateIndex | None, "Task"] | None + _tasks: dict[StateIndex | None, "Task[DefType]"] | None def __init__(self, node: "Node", submitter: Submitter): self.name = node.name @@ -478,12 +411,12 @@ def _definition(self) -> "Node": return self.node._definition @property - def tasks(self) -> ty.Iterable["Task"]: + def tasks(self) -> ty.Iterable["Task[DefType]"]: if self._tasks is None: self._tasks = {t.state_index: t for t in self._generate_tasks()} return self._tasks.values() - def task(self, index: StateIndex | None = None) -> "Task | list[Task]": + def task(self, index: StateIndex | None = None) -> "Task | list[Task[DefType]]": """Get a task object for a given state index.""" self.tasks # Ensure tasks are loaded try: @@ -513,7 +446,7 @@ def all_failed(self) -> bool: self.successful or self.waiting or self.running ) - def _generate_tasks(self) -> ty.Iterable["Task"]: + def _generate_tasks(self) -> ty.Iterable["Task[DefType]"]: if self.node.state is None: yield Task( definition=self.node._definition, @@ -529,40 +462,7 @@ def _generate_tasks(self) -> ty.Iterable["Task"]: state_index=index, ) - # if state_index is None: - # # if state_index=None, collecting all results - # if self.node.state.combiner: - # return self._combined_output(return_inputs=return_inputs) - # else: - # results = [] - # for ind in range(len(self.node.state.inputs_ind)): - # checksum = self.checksum_states(state_index=ind) - # result = load_result(checksum, cache_locations) - # if result is None: - # return None - # results.append(result) - # if return_inputs is True or return_inputs == "val": - # return list(zip(self.node.state.states_val, results)) - # elif return_inputs == "ind": - # return list(zip(self.node.state.states_ind, results)) - # else: - # return results - # else: # state_index is not None - # if self.node.state.combiner: - # return self._combined_output(return_inputs=return_inputs)[ - # state_index - # ] - # result = load_result(self.checksum_states(state_index), cache_locations) - # if return_inputs is True or return_inputs == "val": - # return (self.node.state.states_val[state_index], result) - # elif return_inputs == "ind": - # return (self.node.state.states_ind[state_index], result) - # else: - # return result - # else: - # return load_result(self._definition._checksum, cache_locations) - - def get_runnable_tasks(self, graph: DiGraph) -> list["Task"]: + def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]: """For a given node, check to see which tasks have been successfully run, are ready to run, can't be run due to upstream errors, or are waiting on other tasks to complete. @@ -579,7 +479,7 @@ def get_runnable_tasks(self, graph: DiGraph) -> list["Task"]: runnable : list[NodeExecution] List of tasks that are ready to run """ - runnable: list["Task"] = [] + runnable: list["Task[DefType]"] = [] self.tasks # Ensure tasks are loaded if not self.started: self.waiting = copy(self._tasks) diff --git a/pydra/engine/workers.py b/pydra/engine/workers.py index 766b585f1..2c4dc533f 100644 --- a/pydra/engine/workers.py +++ b/pydra/engine/workers.py @@ -4,13 +4,13 @@ import sys import json import re +import typing as ty from tempfile import gettempdir from pathlib import Path from shutil import copyfile, which - import concurrent.futures as cf - from .core import Task +from .specs import TaskDef from .helpers import ( get_available_cpus, read_and_display_async, @@ -20,11 +20,12 @@ ) import logging -from pydra.engine.environments import Environment import random logger = logging.getLogger("pydra.worker") +DefType = ty.TypeVar("DefType", bound="TaskDef") + class Worker: """A base class for execution of tasks.""" @@ -37,7 +38,7 @@ def __init__(self, loop=None): logger.debug(f"Initializing {self.__class__.__name__}") self.loop = loop - def run(self, task: "Task", **kwargs): + def run(self, task: "Task[DefType]", **kwargs): """Return coroutine for task execution.""" raise NotImplementedError @@ -140,7 +141,7 @@ def __init__(self, **kwargs): def run( self, - task: "Task | tuple[Path, Task]", + task: "Task[DefType] | tuple[Path, Task[DefType]]", rerun: bool = False, ): """Run a task.""" @@ -178,21 +179,18 @@ def __init__(self, n_procs=None): def run( self, - task: "Task", + task: "Task[DefType]", rerun: bool = False, - environment: Environment | None = None, **kwargs, ): """Run a task.""" assert self.loop, "No event loop available to submit tasks" - return self.exec_as_coro(task, rerun=rerun, environment=environment) + return self.exec_as_coro(task, rerun=rerun) - async def exec_as_coro(self, runnable, rerun=False, environment=None): + async def exec_as_coro(self, runnable: "Task[DefType]", rerun: bool = False): """Run a task (coroutine wrapper).""" if isinstance(runnable, Task): - res = await self.loop.run_in_executor( - self.pool, runnable._run, rerun, environment - ) + res = await self.loop.run_in_executor(self.pool, runnable.run, rerun) else: # it could be tuple that includes pickle files with tasks and inputs task_main_pkl, task_orig = runnable res = await self.loop.run_in_executor( @@ -235,9 +233,7 @@ def __init__(self, loop=None, max_jobs=None, poll_delay=1, sbatch_args=None): self.sbatch_args = sbatch_args or "" self.error = {} - def run( - self, task: "Task", rerun: bool = False, environment: Environment | None = None - ): + def run(self, task: "Task[DefType]", rerun: bool = False): """Worker submission API.""" script_dir, batch_script = self._prepare_runscripts(task, rerun=rerun) if (script_dir / script_dir.parts[1]) == gettempdir(): @@ -467,9 +463,7 @@ def __init__( self.default_qsub_args = default_qsub_args self.max_mem_free = max_mem_free - def run( - self, task: "Task", rerun: bool = False, environment: Environment | None = None - ): # TODO: add env + def run(self, task: "Task[DefType]", rerun: bool = False): # TODO: add env """Worker submission API.""" ( script_dir, @@ -899,17 +893,14 @@ def __init__(self, **kwargs): def run( self, - task: "Task", + task: "Task[DefType]", rerun: bool = False, - environment: Environment | None = None, **kwargs, ): """Run a task.""" - return self.exec_dask(task, rerun=rerun, environment=environment) + return self.exec_dask(task, rerun=rerun) - async def exec_dask( - self, task: "Task", rerun: bool = False, environment: Environment | None = None - ): + async def exec_dask(self, task: "Task[DefType]", rerun: bool = False): """Run a task (coroutine wrapper).""" from dask.distributed import Client @@ -950,13 +941,12 @@ def __init__(self, **kwargs): def run( self, - task: "Task", + task: "Task[DefType]", rerun: bool = False, - environment: Environment | None = None, **kwargs, ): """Run a task.""" - return self.exec_psij(task, rerun=rerun, environment=environment) + return self.exec_psij(task, rerun=rerun) def make_spec(self, cmd=None, arg=None): """ @@ -1001,7 +991,9 @@ def make_job(self, definition, attributes): return job async def exec_psij( - self, task: "Task", rerun: bool = False, environment: Environment | None = None + self, + task: "Task[DefType]", + rerun: bool = False, ): """ Run a task (coroutine wrapper). @@ -1025,7 +1017,7 @@ async def exec_psij( cache_dir = task.cache_dir file_path = cache_dir / "runnable_function.pkl" with open(file_path, "wb") as file: - pickle.dump(task._run, file) + pickle.dump(task.run, file) func_path = absolute_path / "run_pickled.py" definition = self.make_spec("python", [func_path, file_path]) else: # it could be tuple that includes pickle files with tasks and inputs