Skip to content

Commit

Permalink
touching up typing of tasks to include TaskDef template
Browse files Browse the repository at this point in the history
  • Loading branch information
tclose committed Jan 23, 2025
1 parent e5c2556 commit 03c7438
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 177 deletions.
2 changes: 1 addition & 1 deletion pydra/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ async def run_async(self, rerun: bool = False):
self.audit.start_audit(odir=self.output_dir)
try:
self.audit.monitor()
await self.definition._run(self)
await self.definition._run_async(self)
result.outputs = self.definition.Outputs._from_task(self)
except Exception:
etype, eval, etr = sys.exc_info()
Expand Down
9 changes: 5 additions & 4 deletions pydra/engine/environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

if ty.TYPE_CHECKING:
from pydra.engine.core import Task
from pydra.engine.specs import ShellDef


class Environment:
Expand All @@ -17,7 +18,7 @@ class Environment:
def setup(self):
pass

def execute(self, task: "Task") -> dict[str, ty.Any]:
def execute(self, task: "Task[ShellDef]") -> dict[str, ty.Any]:
"""
Execute the task in the environment.
Expand All @@ -42,7 +43,7 @@ class Native(Environment):
Native environment, i.e. the tasks are executed in the current python environment.
"""

def execute(self, task: "Task") -> dict[str, ty.Any]:
def execute(self, task: "Task[ShellDef]") -> dict[str, ty.Any]:
keys = ["return_code", "stdout", "stderr"]
values = execute(task.definition._command_args())
output = dict(zip(keys, values))
Expand Down Expand Up @@ -90,7 +91,7 @@ def bind(self, loc, mode="ro"):
class Docker(Container):
"""Docker environment."""

def execute(self, task: "Task") -> dict[str, ty.Any]:
def execute(self, task: "Task[ShellDef]") -> dict[str, ty.Any]:
docker_img = f"{self.image}:{self.tag}"
# mounting all input locations
mounts = task.definition._get_bindings(root=self.root)
Expand Down Expand Up @@ -125,7 +126,7 @@ def execute(self, task: "Task") -> dict[str, ty.Any]:
class Singularity(Container):
"""Singularity environment."""

def execute(self, task: "Task") -> dict[str, ty.Any]:
def execute(self, task: "Task[ShellDef]") -> dict[str, ty.Any]:
singularity_img = f"{self.image}:{self.tag}"
# mounting all input locations
mounts = task.definition._get_bindings(root=self.root)
Expand Down
6 changes: 4 additions & 2 deletions pydra/engine/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

PYDRA_ATTR_METADATA = "__PYDRA_METADATA__"

DefType = ty.TypeVar("DefType", bound="TaskDef")


def attrs_fields(definition, exclude_names=()) -> list[attrs.Attribute]:
"""Get the fields of a definition, excluding some names."""
Expand Down Expand Up @@ -132,7 +134,7 @@ def load_result(checksum, cache_locations):
def save(
task_path: Path,
result: "Result | None" = None,
task: "Task | None" = None,
task: "Task[DefType] | None" = None,
name_prefix: str = None,
) -> None:
"""
Expand Down Expand Up @@ -449,7 +451,7 @@ def load_and_run(task_pkl: Path, rerun: bool = False) -> Path:
from .specs import Result

try:
task: Task = load_task(task_pkl=task_pkl)
task: Task[DefType] = load_task(task_pkl=task_pkl)
except Exception:
if task_pkl.parent.exists():
etype, eval, etr = sys.exc_info()
Expand Down
4 changes: 3 additions & 1 deletion pydra/engine/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from .graph import DiGraph
from .submitter import NodeExecution
from .core import Task, Workflow
from .specs import TaskDef


T = ty.TypeVar("T")
DefType = ty.TypeVar("DefType", bound="TaskDef")

TypeOrAny = ty.Union[type, ty.Any]

Expand Down Expand Up @@ -150,7 +152,7 @@ def get_value(
task = graph.node(self.node.name).task(state_index)
_, split_depth = TypeParser.strip_splits(self.type)

def get_nested(task: "Task", depth: int):
def get_nested(task: "Task[DefType]", depth: int):
if isinstance(task, StateArray):
val = [get_nested(task=t, depth=depth - 1) for t in task]
if depth:
Expand Down
35 changes: 19 additions & 16 deletions pydra/engine/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@
from pydra.engine.graph import DiGraph
from pydra.engine.submitter import NodeExecution
from pydra.engine.lazy import LazyOutField
from pydra.engine.task import ShellTask
from pydra.engine.core import Workflow
from pydra.engine.environments import Environment
from pydra.engine.workers import Worker


DefType = ty.TypeVar("DefType", bound="TaskDef")


def is_set(value: ty.Any) -> bool:
"""Check if a value has been set."""
return value not in (attrs.NOTHING, EMPTY)
Expand Down Expand Up @@ -372,7 +374,7 @@ def _compute_hashes(self) -> ty.Tuple[bytes, ty.Dict[str, bytes]]:
}
return hash_function(sorted(field_hashes.items())), field_hashes

def _retrieve_values(self, wf, state_index=None):
def _resolve_lazy_fields(self, wf, state_index=None):
"""Parse output results."""
temp_values = {}
for field in attrs_fields(self):
Expand Down Expand Up @@ -482,7 +484,7 @@ class Runtime:
class Result(ty.Generic[OutputsType]):
"""Metadata regarding the outputs of processing."""

task: "Task"
task: "Task[DefType]"
outputs: OutputsType | None = None
runtime: Runtime | None = None
errored: bool = False
Expand Down Expand Up @@ -548,13 +550,13 @@ class RuntimeSpec:
class PythonOutputs(TaskOutputs):

@classmethod
def _from_task(cls, task: "Task") -> Self:
def _from_task(cls, task: "Task[PythonDef]") -> Self:
"""Collect the outputs of a task from a combination of the provided inputs,
the objects in the output directory, and the stdout and stderr of the process.
Parameters
----------
task : Task
task : Task[PythonDef]
The task whose outputs are being collected.
outputs_dict : dict[str, ty.Any]
The outputs of the task, as a dictionary
Expand All @@ -575,7 +577,7 @@ def _from_task(cls, task: "Task") -> Self:

class PythonDef(TaskDef[PythonOutputsType]):

def _run(self, task: "Task") -> None:
def _run(self, task: "Task[PythonDef]") -> None:
# Prepare the inputs to the function
inputs = attrs_values(self)
del inputs["function"]
Expand All @@ -602,12 +604,12 @@ def _run(self, task: "Task") -> None:
class WorkflowOutputs(TaskOutputs):

@classmethod
def _from_task(cls, task: "Task") -> Self:
def _from_task(cls, task: "Task[WorkflowDef]") -> Self:
"""Collect the outputs of a workflow task from the outputs of the nodes in the
Parameters
----------
task : Task
task : Task[WorfklowDef]
The task whose outputs are being collected.
Returns
Expand Down Expand Up @@ -659,12 +661,13 @@ class WorkflowDef(TaskDef[WorkflowOutputsType]):

_constructed = attrs.field(default=None, init=False)

def _run(self, task: "Task") -> None:
def _run(self, task: "Task[WorkflowDef]") -> None:
"""Run the workflow."""
if task.submitter.worker.is_async:
task.submitter.expand_workflow_async(task)
else:
task.submitter.expand_workflow(task)
task.submitter.expand_workflow(task)

async def _run_async(self, task: "Task[WorkflowDef]") -> None:
"""Run the workflow asynchronously."""
await task.submitter.expand_workflow_async(task)

def construct(self) -> "Workflow":
from pydra.engine.core import Workflow
Expand All @@ -688,7 +691,7 @@ class ShellOutputs(TaskOutputs):
stderr: str = shell.out(help=STDERR_HELP)

@classmethod
def _from_task(cls, task: "ShellTask") -> Self:
def _from_task(cls, task: "Task[ShellDef]") -> Self:
"""Collect the outputs of a shell process from a combination of the provided inputs,
the objects in the output directory, and the stdout and stderr of the process.
Expand Down Expand Up @@ -784,7 +787,7 @@ def _required_fields_satisfied(cls, fld: shell.out, inputs: "ShellDef") -> bool:
def _resolve_value(
cls,
fld: "shell.out",
task: "Task",
task: "Task[DefType]",
) -> ty.Any:
"""Collect output file if metadata specified."""
from pydra.design import shell
Expand Down Expand Up @@ -842,7 +845,7 @@ class ShellDef(TaskDef[ShellOutputsType]):

RESERVED_FIELD_NAMES = TaskDef.RESERVED_FIELD_NAMES + ("cmdline",)

def _run(self, task: "Task") -> None:
def _run(self, task: "Task[ShellDef]") -> None:
"""Run the shell command."""
task.return_values = task.environment.execute(task)

Expand Down
Loading

0 comments on commit 03c7438

Please sign in to comment.