From 032fd4e04886ccdcfffb565317f2c4c06e54ee75 Mon Sep 17 00:00:00 2001 From: Tom Close Date: Mon, 9 Dec 2024 09:51:05 +1100 Subject: [PATCH] restored functionality from specs --- pydra/design/base.py | 39 +- pydra/design/python.py | 8 +- pydra/design/shell.py | 33 +- pydra/design/tests/test_workflow.py | 4 +- pydra/design/workflow.py | 14 +- pydra/engine/core.py | 108 +-- pydra/engine/helpers.py | 19 +- pydra/engine/helpers_file.py | 5 +- pydra/engine/specs.py | 1307 +++++++++++---------------- pydra/engine/task.py | 162 +--- pydra/engine/workflow/base.py | 23 +- pydra/engine/workflow/node.py | 21 +- pydra/utils/typing.py | 4 +- 13 files changed, 629 insertions(+), 1118 deletions(-) diff --git a/pydra/design/base.py b/pydra/design/base.py index 0fbf79ac82..3f959358af 100644 --- a/pydra/design/base.py +++ b/pydra/design/base.py @@ -14,6 +14,7 @@ ensure_list, PYDRA_ATTR_METADATA, list_fields, + is_lazy, ) from pydra.utils.typing import ( MultiInputObj, @@ -21,11 +22,10 @@ MultiOutputObj, MultiOutputFile, ) -from pydra.engine.workflow.lazy import LazyField if ty.TYPE_CHECKING: - from pydra.engine.specs import OutputsSpec + from pydra.engine.specs import TaskSpec, OutSpec from pydra.engine.core import Task __all__ = [ @@ -84,7 +84,9 @@ class Field: validator=is_type, default=ty.Any, converter=default_if_none(ty.Any) ) help_string: str = "" - requires: list | None = None + requires: list[str] | list[list[str]] = attrs.field( + factory=list, converter=ensure_list + ) converter: ty.Callable | None = None validator: ty.Callable | None = None @@ -240,6 +242,8 @@ def get_fields(klass, field_type, auto_attribs, helps) -> dict[str, Field]: def make_task_spec( + spec_type: type["TaskSpec"], + out_type: type["OutSpec"], task_type: type["Task"], inputs: dict[str, Arg], outputs: dict[str, Out], @@ -281,14 +285,16 @@ def make_task_spec( if name is None and klass is not None: name = klass.__name__ - outputs_klass = make_outputs_spec(outputs, outputs_bases, name) - if klass is None or not issubclass(klass, TaskSpec): + outputs_klass = make_outputs_spec(out_type, outputs, outputs_bases, name) + if klass is None or not issubclass(klass, spec_type): if name is None: raise ValueError("name must be provided if klass is not") + if klass is not None and issubclass(klass, TaskSpec): + raise ValueError(f"Cannot change type of spec {klass} to {spec_type}") bases = tuple(bases) # Ensure that TaskSpec is a base class - if not any(issubclass(b, TaskSpec) for b in bases): - bases = bases + (TaskSpec,) + if not any(issubclass(b, spec_type) for b in bases): + bases = bases + (spec_type,) # If building from a decorated class (as opposed to dynamically from a function # or shell-template), add any base classes not already in the bases tuple if klass is not None: @@ -346,8 +352,11 @@ def make_task_spec( def make_outputs_spec( - outputs: dict[str, Out], bases: ty.Sequence[type], spec_name: str -) -> type["OutputsSpec"]: + spec_type: type["OutSpec"], + outputs: dict[str, Out], + bases: ty.Sequence[type], + spec_name: str, +) -> type["OutSpec"]: """Create an outputs specification class and its outputs specification class from the output fields provided to the decorator/function. @@ -368,10 +377,14 @@ def make_outputs_spec( klass : type The class created using the attrs package """ - from pydra.engine.specs import OutputsSpec + from pydra.engine.specs import OutSpec - if not any(issubclass(b, OutputsSpec) for b in bases): - outputs_bases = bases + (OutputsSpec,) + if not any(issubclass(b, spec_type) for b in bases): + if out_spec_bases := [b for b in bases if issubclass(b, OutSpec)]: + raise ValueError( + f"Cannot make {spec_type} output spec from {out_spec_bases} bases" + ) + outputs_bases = bases + (spec_type,) if reserved_names := [n for n in outputs if n in RESERVED_OUTPUT_NAMES]: raise ValueError( f"{reserved_names} are reserved and cannot be used for output field names" @@ -549,7 +562,7 @@ def make_validator(field: Field, interface_name: str) -> ty.Callable[..., None] def allowed_values_validator(_, attribute, value): """checking if the values is in allowed_values""" allowed = attribute.metadata[PYDRA_ATTR_METADATA].allowed_values - if value is attrs.NOTHING or isinstance(value, LazyField): + if value is attrs.NOTHING or is_lazy(value): pass elif value not in allowed: raise ValueError( diff --git a/pydra/design/python.py b/pydra/design/python.py index b25d36e010..9de6860e1d 100644 --- a/pydra/design/python.py +++ b/pydra/design/python.py @@ -2,7 +2,7 @@ import inspect import attrs from pydra.engine.task import FunctionTask -from pydra.engine.specs import TaskSpec +from pydra.engine.specs import PythonSpec, PythonOutSpec from .base import ( Arg, Out, @@ -87,7 +87,7 @@ def define( bases: ty.Sequence[type] = (), outputs_bases: ty.Sequence[type] = (), auto_attribs: bool = True, -) -> TaskSpec: +) -> PythonSpec: """ Create an interface for a function or a class. @@ -103,7 +103,7 @@ def define( Whether to use auto_attribs mode when creating the class. """ - def make(wrapped: ty.Callable | type) -> TaskSpec: + def make(wrapped: ty.Callable | type) -> PythonSpec: if inspect.isclass(wrapped): klass = wrapped function = klass.function @@ -139,6 +139,8 @@ def make(wrapped: ty.Callable | type) -> TaskSpec: ) interface = make_task_spec( + PythonSpec, + PythonOutSpec, FunctionTask, parsed_inputs, parsed_outputs, diff --git a/pydra/design/shell.py b/pydra/design/shell.py index 21d5d435c9..6587608960 100644 --- a/pydra/design/shell.py +++ b/pydra/design/shell.py @@ -11,7 +11,7 @@ from fileformats.core import from_mime from fileformats import generic from fileformats.core.exceptions import FormatRecognitionError -from pydra.engine.specs import TaskSpec +from pydra.engine.specs import ShellSpec, ShellOutSpec from .base import ( Arg, Out, @@ -177,9 +177,8 @@ class outarg(Out, arg): inputs (entire inputs will be passed) or any input field name (a specific input field will be sent). path_template: str, optional - If provided, the field is treated also as an output field and it is added to - the output spec. The template can use other fields, e.g. {file1}. Used in order - to create an output specification. + The template used to specify where the output file will be written to can use + other fields, e.g. {file1}. Used in order to create an output specification. """ path_template: str | None = attrs.field(default=None) @@ -202,7 +201,7 @@ def define( outputs_bases: ty.Sequence[type] = (), auto_attribs: bool = True, name: str | None = None, -) -> TaskSpec: +) -> ShellSpec: """Create a task specification for a shell command. Can be used either as a decorator on the "canonical" dataclass-form of a task specification or as a function that takes a "shell-command template string" of the form @@ -251,13 +250,13 @@ def define( Returns ------- - TaskSpec + ShellSpec The interface for the shell command """ def make( wrapped: ty.Callable | type | None = None, - ) -> TaskSpec: + ) -> ShellSpec: if inspect.isclass(wrapped): klass = wrapped @@ -272,6 +271,14 @@ def make( f"Shell task class {wrapped} must have an `executable` " "attribute that specifies the command to run" ) from None + if not isinstance(executable, str) and not ( + isinstance(executable, ty.Sequence) + and all(isinstance(e, str) for e in executable) + ): + raise ValueError( + "executable must be a string or a sequence of strings" + f", not {executable!r}" + ) class_name = klass.__name__ check_explicit_fields_are_none(klass, inputs, outputs) parsed_inputs, parsed_outputs = extract_fields_from_class( @@ -309,7 +316,15 @@ def make( {o.name: o for o in parsed_outputs.values() if isinstance(o, arg)} ) parsed_inputs["executable"] = arg( - name="executable", type=str, argstr="", position=0, default=executable + name="executable", + type=str | ty.Sequence[str], + argstr="", + position=0, + default=executable, + help_string=( + "the first part of the command, can be a string, " + "e.g. 'ls', or a list, e.g. ['ls', '-l', 'dirname']" + ), ) # Set positions for the remaining inputs that don't have an explicit position @@ -319,6 +334,8 @@ def make( inpt.position = position_stack.pop(0) interface = make_task_spec( + ShellSpec, + ShellOutSpec, ShellCommandTask, parsed_inputs, parsed_outputs, diff --git a/pydra/design/tests/test_workflow.py b/pydra/design/tests/test_workflow.py index 9f25bf81bd..f4c1c0c19d 100644 --- a/pydra/design/tests/test_workflow.py +++ b/pydra/design/tests/test_workflow.py @@ -344,7 +344,9 @@ def MyTestWorkflow(a: list[int], b: list[float]) -> list[float]: wf = Workflow.construct(MyTestWorkflow(a=[1, 2, 3], b=[1.0, 10.0, 100.0])) assert wf["Mul"].splitter == ["Mul.x", "Mul.y"] assert wf["Mul"].combiner == ["Mul.x"] - assert wf.outputs.out == LazyOutField(node=wf["Sum"], field="out", type=list[float]) + assert wf.outputs.out == LazyOutField( + node=wf["Sum"], field="out", type=list[float], type_checked=True + ) def test_workflow_split_combine2(): diff --git a/pydra/design/workflow.py b/pydra/design/workflow.py index 75ac13197f..86a9f3ca9a 100644 --- a/pydra/design/workflow.py +++ b/pydra/design/workflow.py @@ -13,7 +13,7 @@ check_explicit_fields_are_none, extract_fields_from_class, ) -from pydra.engine.specs import TaskSpec +from pydra.engine.specs import TaskSpec, OutSpec, WorkflowSpec, WorkflowOutSpec __all__ = ["define", "add", "this", "arg", "out"] @@ -154,6 +154,8 @@ def make(wrapped: ty.Callable | type) -> TaskSpec: parsed_inputs[inpt_name].lazy = True interface = make_task_spec( + WorkflowSpec, + WorkflowOutSpec, WorkflowTask, parsed_inputs, parsed_outputs, @@ -172,9 +174,6 @@ def make(wrapped: ty.Callable | type) -> TaskSpec: return make -OutputType = ty.TypeVar("OutputType") - - def this() -> Workflow: """Get the workflow currently being constructed. @@ -186,7 +185,10 @@ def this() -> Workflow: return Workflow.under_construction -def add(task_spec: TaskSpec[OutputType], name: str = None) -> OutputType: +OutSpecType = ty.TypeVar("OutSpecType", bound=OutSpec) + + +def add(task_spec: TaskSpec[OutSpecType], name: str = None) -> OutSpecType: """Add a node to the workflow currently being constructed Parameters @@ -199,7 +201,7 @@ def add(task_spec: TaskSpec[OutputType], name: str = None) -> OutputType: Returns ------- - OutputType + OutSpec The outputs specification of the node """ return this().add(task_spec, name=name) diff --git a/pydra/engine/core.py b/pydra/engine/core.py index 4607e23f71..7cf35f455f 100644 --- a/pydra/engine/core.py +++ b/pydra/engine/core.py @@ -3,8 +3,6 @@ import abc import json import logging -import itertools -from functools import cached_property import os import sys from pathlib import Path @@ -21,16 +19,12 @@ from . import helpers_state as hlpst from .specs import ( File, - # BaseSpec, RuntimeSpec, Result, - # SpecInfo, - # LazyIn, - # LazyOut, TaskHook, ) +from .workflow.lazy import is_lazy from .helpers import ( - # make_klass, create_checksum, attr_fields, print_help, @@ -138,8 +132,6 @@ def __init__( self.interface = spec # raise error if name is same as of attributes - if name in dir(self): - raise ValueError("Cannot use names of attributes or methods as task name") self.name = name if not self.input_spec: raise Exception("No input_spec in class: %s" % self.__class__.__name__) @@ -227,10 +219,6 @@ def __setstate__(self, state): state["inputs"] = self.interface(**state["inputs"]) self.__dict__.update(state) - @cached_property - def lzout(self): - return LazyOut(self) - def help(self, returnhelp=False): """Print class help.""" help_obj = print_help(self) @@ -818,80 +806,6 @@ def _check_for_hash_changes(self): DEFAULT_COPY_COLLATION = FileSet.CopyCollation.any -def _sanitize_spec( - spec: ty.Union[ty.List[str], ty.Dict[str, ty.Type[ty.Any]], None], - wf_name: str, - spec_name: str, - allow_empty: bool = False, -): - """Makes sure the provided input specifications are valid. - - If the input specification is a list of strings, this will - build a proper SpecInfo object out of it. - - Parameters - ---------- - spec : SpecInfo or List[str] or Dict[str, type] - Specification to be sanitized. - wf_name : str - The name of the workflow for which the input specifications - spec_name : str - name given to generated SpecInfo object - - Returns - ------- - spec : SpecInfo - Sanitized specification. - - Raises - ------ - ValueError - If provided `spec` is None. - """ - graph_checksum_input = ("_graph_checksums", ty.Any) - if spec: - if isinstance(spec, SpecInfo): - if BaseSpec not in spec.bases: - raise ValueError("Provided SpecInfo must have BaseSpec as its base.") - if "_graph_checksums" not in {f[0] for f in spec.fields}: - spec.fields.insert(0, graph_checksum_input) - return spec - else: - base = BaseSpec - if isinstance(spec, list): - typed_spec = zip(spec, itertools.repeat(ty.Any)) - elif isinstance(spec, dict): - typed_spec = spec.items() # type: ignore - elif isinstance(spec, BaseSpec): - base = spec - typed_spec = [] - else: - raise TypeError( - f"Unrecognised spec type, {spec}, should be SpecInfo, list or dict" - ) - return SpecInfo( - name=spec_name, - fields=[graph_checksum_input] - + [ - ( - nm, - attr.ib( - type=tp, - metadata={ - "help_string": f"{nm} input from {wf_name} workflow" - }, - ), - ) - for nm, tp in typed_spec - ], - bases=(base,), - ) - elif allow_empty: - return None - else: - raise ValueError(f'Empty "{spec_name}" spec provided to Workflow {wf_name}.') - - class WorkflowTask(Task): """A composite task with structure of computational graph.""" @@ -939,10 +853,6 @@ def __init__( TODO """ - self.input_spec = _sanitize_spec(input_spec, name, "Inputs") - self.output_spec = _sanitize_spec( - output_spec, name, "Outputs", allow_empty=True - ) if name in dir(self): raise ValueError( @@ -974,10 +884,6 @@ def __init__( # propagating rerun if task_rerun=True self.propagate_rerun = propagate_rerun - @cached_property - def lzin(self): - return LazyIn(self) - def __getattr__(self, name): if name in self.name2obj: return self.name2obj[name] @@ -1075,7 +981,7 @@ def create_connections(self, task, detailed=False): other_states = {} for field in attr_fields(task.inputs): val = getattr(task.inputs, field.name) - if isinstance(val, LazyField): + if is_lazy(val): # saving all connections with LazyFields task.inp_lf[field.name] = val # adding an edge to the graph if task id expecting output from a different task @@ -1292,7 +1198,7 @@ def _collect_outputs(self): # collecting outputs from tasks output_wf = {} for name, val in self._connections: - if not isinstance(val, LazyField): + if not is_lazy(val): raise ValueError("all connections must be lazy") try: val_out = val.get_value(self) @@ -1395,11 +1301,3 @@ def is_task(obj): def is_workflow(obj): """Check whether an object is a :class:`Workflow` instance.""" return isinstance(obj, WorkflowTask) - - -def is_lazy(obj): - """Check whether an object has any field that is a Lazy Field""" - for f in attr_fields(obj): - if isinstance(getattr(obj, f.name), LazyField): - return True - return False diff --git a/pydra/engine/helpers.py b/pydra/engine/helpers.py index 92efc9de53..dc85205521 100644 --- a/pydra/engine/helpers.py +++ b/pydra/engine/helpers.py @@ -45,11 +45,10 @@ def list_fields(interface: "TaskSpec") -> list["Field"]: def from_list_if_single(obj): """Converts a list to a single item if it is of length == 1""" - from pydra.engine.workflow.lazy import LazyField if obj is attrs.NOTHING: return obj - if isinstance(obj, LazyField): + if is_lazy(obj): return obj obj = list(obj) if len(obj) == 1: @@ -637,7 +636,6 @@ def ensure_list(obj, tuple2list=False): [5.0] """ - from pydra.engine.workflow.lazy import LazyField if obj is attrs.NOTHING: return attrs.NOTHING @@ -648,6 +646,19 @@ def ensure_list(obj, tuple2list=False): return obj elif tuple2list and isinstance(obj, tuple): return list(obj) - elif isinstance(obj, LazyField): + elif is_lazy(obj): return obj return [obj] + + +def is_lazy(obj): + """Check whether an object has any field that is a Lazy Field""" + from pydra.engine.workflow.lazy import LazyField + + if is_lazy(obj): + return True + + for f in attr_fields(obj): + if isinstance(getattr(obj, f.name), LazyField): + return True + return False diff --git a/pydra/engine/helpers_file.py b/pydra/engine/helpers_file.py index 8be955b20e..f846e40db2 100644 --- a/pydra/engine/helpers_file.py +++ b/pydra/engine/helpers_file.py @@ -10,6 +10,7 @@ from contextlib import contextmanager import attr from fileformats.core import FileSet +from pydra.engine.helpers import is_lazy logger = logging.getLogger("pydra") @@ -151,7 +152,7 @@ def template_update_single( # if input_dict_st with state specific value is not available, # the dictionary will be created from inputs object from pydra.utils.typing import TypeParser # noqa - from pydra.engine.specs import LazyField, OUTPUT_TEMPLATE_TYPES + from pydra.engine.specs import OUTPUT_TEMPLATE_TYPES if inputs_dict_st is None: inputs_dict_st = attr.asdict(inputs, recurse=False) @@ -162,7 +163,7 @@ def template_update_single( raise TypeError( f"type of '{field.name}' is Path, consider using Union[Path, bool]" ) - if inp_val_set is not attr.NOTHING and not isinstance(inp_val_set, LazyField): + if inp_val_set is not attr.NOTHING and not is_lazy(inp_val_set): inp_val_set = TypeParser(ty.Union[OUTPUT_TEMPLATE_TYPES])(inp_val_set) elif spec_type == "output": if not TypeParser.contains_type(FileSet, field.type): diff --git a/pydra/engine/specs.py b/pydra/engine/specs.py index 0cdc4f07f2..4b45e9cf7b 100644 --- a/pydra/engine/specs.py +++ b/pydra/engine/specs.py @@ -1,805 +1,22 @@ """Task I/O specifications.""" +import os from pathlib import Path +import re +import inspect import typing as ty - -# import inspect -# import re -# import os -from pydra.engine.audit import AuditFlag - -# from glob import glob -import attrs +from glob import glob from typing_extensions import Self - -# from fileformats.core import FileSet -from fileformats.generic import ( - File, - # Directory, -) -from .helpers import attr_fields - -# from .helpers_file import template_update_single -# from pydra.utils.hash import hash_function, Cache - -# from pydra.utils.misc import add_exc_note - - -# @attrs.define(auto_attribs=True, kw_only=True) -# class SpecInfo: -# """Base data structure for metadata of specifications.""" - -# name: str -# """A name for the specification.""" -# fields: ty.List[ty.Tuple] = attrs.field(factory=list) -# """List of names of fields (can be inputs or outputs).""" -# bases: ty.Sequence[ty.Type["BaseSpec"]] = attrs.field(factory=tuple) -# """Keeps track of specification inheritance. -# Should be a tuple containing at least one BaseSpec """ - - -# @attrs.define(auto_attribs=True, kw_only=True) -# class BaseSpec: -# """The base dataclass specs for all inputs and outputs.""" - -# def collect_additional_outputs(self, inputs, output_dir, outputs): -# """Get additional outputs.""" -# return {} - -# @property -# def hash(self): -# hsh, self._hashes = self._compute_hashes() -# return hsh - -# def hash_changes(self): -# """Detects any changes in the hashed values between the current inputs and the -# previously calculated values""" -# _, new_hashes = self._compute_hashes() -# return [k for k, v in new_hashes.items() if v != self._hashes[k]] - -# def _compute_hashes(self) -> ty.Tuple[bytes, ty.Dict[str, bytes]]: -# """Compute a basic hash for any given set of fields.""" -# inp_dict = {} -# for field in attr_fields( -# self, exclude_names=("_graph_checksums", "bindings", "files_hash") -# ): -# if field.metadata.get("output_file_template"): -# continue -# # removing values that are not set from hash calculation -# if getattr(self, field.name) is attrs.NOTHING: -# continue -# if "container_path" in field.metadata: -# continue -# inp_dict[field.name] = getattr(self, field.name) -# hash_cache = Cache() -# field_hashes = { -# k: hash_function(v, cache=hash_cache) for k, v in inp_dict.items() -# } -# if hasattr(self, "_graph_checksums"): -# field_hashes["_graph_checksums"] = self._graph_checksums -# return hash_function(sorted(field_hashes.items())), field_hashes - -# def retrieve_values(self, wf, state_index: ty.Optional[int] = None): -# """Get values contained by this spec.""" -# retrieved_values = {} -# for field in attr_fields(self): -# value = getattr(self, field.name) -# if isinstance(value, LazyField): -# retrieved_values[field.name] = value.get_value( -# wf, state_index=state_index -# ) -# for field, val in retrieved_values.items(): -# setattr(self, field, val) - -# def check_fields_input_spec(self): -# """ -# Check fields from input spec based on the medatada. - -# e.g., if xor, requires are fulfilled, if value provided when mandatory. - -# """ -# fields = attr_fields(self) - -# for field in fields: -# field_is_mandatory = bool(field.metadata.get("mandatory")) -# field_is_unset = getattr(self, field.name) is attrs.NOTHING - -# if field_is_unset and not field_is_mandatory: -# continue - -# # Collect alternative fields associated with this field. -# alternative_fields = { -# name: getattr(self, name) is not attrs.NOTHING -# for name in field.metadata.get("xor", []) -# if name != field.name -# } -# alternatives_are_set = any(alternative_fields.values()) - -# # Raise error if no field in mandatory alternative group is set. -# if field_is_unset: -# if alternatives_are_set: -# continue -# message = f"{field.name} is mandatory and unset." -# if alternative_fields: -# raise AttributeError( -# message[:-1] -# + f", but no alternative provided by {list(alternative_fields)}." -# ) -# else: -# raise AttributeError(message) - -# # Raise error if multiple alternatives are set. -# elif alternatives_are_set: -# set_alternative_fields = [ -# name for name, is_set in alternative_fields.items() if is_set -# ] -# raise AttributeError( -# f"{field.name} is mutually exclusive with {set_alternative_fields}" -# ) - -# # Collect required fields associated with this field. -# required_fields = { -# name: getattr(self, name) is not attrs.NOTHING -# for name in field.metadata.get("requires", []) -# if name != field.name -# } - -# # Raise error if any required field is unset. -# if not all(required_fields.values()): -# unset_required_fields = [ -# name for name, is_set in required_fields.items() if not is_set -# ] -# raise AttributeError(f"{field.name} requires {unset_required_fields}") - -# def check_metadata(self): -# """Check contained metadata.""" - -# def template_update(self): -# """Update template.""" - -# def copyfile_input(self, output_dir): -# """Copy the file pointed by a :class:`File` input.""" - - -@attrs.define(auto_attribs=True, kw_only=True) -class Runtime: - """Represent run time metadata.""" - - rss_peak_gb: ty.Optional[float] = None - """Peak in consumption of physical RAM.""" - vms_peak_gb: ty.Optional[float] = None - """Peak in consumption of virtual memory.""" - cpu_peak_percent: ty.Optional[float] = None - """Peak in cpu consumption.""" - - -@attrs.define(auto_attribs=True, kw_only=True) -class Result: - """Metadata regarding the outputs of processing.""" - - output: ty.Optional[ty.Any] = None - runtime: ty.Optional[Runtime] = None - errored: bool = False - - def __getstate__(self): - state = self.__dict__.copy() - if state["output"] is not None: - fields = tuple((el.name, el.type) for el in attr_fields(state["output"])) - state["output_spec"] = (state["output"].__class__.__name__, fields) - state["output"] = attrs.asdict(state["output"], recurse=False) - return state - - def __setstate__(self, state): - if "output_spec" in state: - spec = list(state["output_spec"]) - del state["output_spec"] - klass = attrs.make_class( - spec[0], {k: attrs.field(type=v) for k, v in list(spec[1])} - ) - state["output"] = klass(**state["output"]) - self.__dict__.update(state) - - def get_output_field(self, field_name): - """Used in get_values in Workflow - - Parameters - ---------- - field_name : `str` - Name of field in LazyField object - """ - if field_name == "all_": - return attrs.asdict(self.output, recurse=False) - else: - return getattr(self.output, field_name) - - -@attrs.define(auto_attribs=True, kw_only=True) -class RuntimeSpec: - """ - Specification for a task. - - From CWL:: - - InlineJavascriptRequirement - SchemaDefRequirement - DockerRequirement - SoftwareRequirement - InitialWorkDirRequirement - EnvVarRequirement - ShellCommandRequirement - ResourceRequirement - - InlineScriptRequirement - - """ - - outdir: ty.Optional[str] = None - container: ty.Optional[str] = "shell" - network: bool = False - - -# @attrs.define(auto_attribs=True, kw_only=True) -# class FunctionSpec(BaseSpec): -# """Specification for a process invoked from a shell.""" - -# def check_metadata(self): -# """ -# Check the metadata for fields in input_spec and fields. - -# Also sets the default values when available and needed. - -# """ -# supported_keys = { -# "allowed_values", -# "copyfile", -# "help_string", -# "mandatory", -# # "readonly", #likely not needed -# # "output_field_name", #likely not needed -# # "output_file_template", #likely not needed -# "requires", -# "keep_extension", -# "xor", -# "sep", -# } -# for fld in attr_fields(self, exclude_names=("_func", "_graph_checksums")): -# mdata = fld.metadata -# # checking keys from metadata -# if set(mdata.keys()) - supported_keys: -# raise AttributeError( -# f"only these keys are supported {supported_keys}, but " -# f"{set(mdata.keys()) - supported_keys} provided" -# ) -# # checking if the help string is provided (required field) -# if "help_string" not in mdata: -# raise AttributeError(f"{fld.name} doesn't have help_string field") -# # not allowing for default if the field is mandatory -# if not fld.default == attrs.NOTHING and mdata.get("mandatory"): -# raise AttributeError( -# f"default value ({fld.default!r}) should not be set when the field " -# f"('{fld.name}') in {self}) is mandatory" -# ) -# # setting default if value not provided and default is available -# if getattr(self, fld.name) is None: -# if not fld.default == attrs.NOTHING: -# setattr(self, fld.name, fld.default) - - -# @attrs.define(auto_attribs=True, kw_only=True) -# class ShellSpec(BaseSpec): -# """Specification for a process invoked from a shell.""" - -# executable: ty.Union[str, ty.List[str]] = attrs.field( -# metadata={ -# "help_string": "the first part of the command, can be a string, " -# "e.g. 'ls', or a list, e.g. ['ls', '-l', 'dirname']" -# } -# ) -# args: ty.Union[str, ty.List[str], None] = attrs.field( -# default=None, -# metadata={ -# "help_string": "the last part of the command, can be a string, " -# "e.g. , or a list" -# }, -# ) - -# def retrieve_values(self, wf, state_index=None): -# """Parse output results.""" -# temp_values = {} -# for field in attr_fields(self): -# # retrieving values that do not have templates -# if not field.metadata.get("output_file_template"): -# value = getattr(self, field.name) -# if isinstance(value, LazyField): -# temp_values[field.name] = value.get_value( -# wf, state_index=state_index -# ) -# for field, val in temp_values.items(): -# value = path_to_string(value) -# setattr(self, field, val) - -# def check_metadata(self): -# """ -# Check the metadata for fields in input_spec and fields. - -# Also sets the default values when available and needed. - -# """ -# from pydra.utils.typing import TypeParser - -# supported_keys = { -# "allowed_values", -# "argstr", -# "container_path", -# "copyfile", -# "help_string", -# "mandatory", -# "readonly", -# "output_field_name", -# "output_file_template", -# "position", -# "requires", -# "keep_extension", -# "xor", -# "sep", -# "formatter", -# "_output_type", -# } - -# for fld in attr_fields(self, exclude_names=("_func", "_graph_checksums")): -# mdata = fld.metadata -# # checking keys from metadata -# if set(mdata.keys()) - supported_keys: -# raise AttributeError( -# f"only these keys are supported {supported_keys}, but " -# f"{set(mdata.keys()) - supported_keys} provided for '{fld.name}' " -# f"field in {self}" -# ) -# # checking if the help string is provided (required field) -# if "help_string" not in mdata: -# raise AttributeError( -# f"{fld.name} doesn't have help_string field in {self}" -# ) -# # assuming that fields with output_file_template shouldn't have default -# if mdata.get("output_file_template"): -# if not any( -# TypeParser.matches_type(fld.type, t) for t in OUTPUT_TEMPLATE_TYPES -# ): -# raise TypeError( -# f"Type of '{fld.name}' should be one of {OUTPUT_TEMPLATE_TYPES} " -# f"(not {fld.type}) because it has a value for output_file_template " -# f"({mdata['output_file_template']!r})" -# ) -# if fld.default not in [attrs.NOTHING, True, False]: -# raise AttributeError( -# f"default value ({fld.default!r}) should not be set together with " -# f"output_file_template ({mdata['output_file_template']!r}) for " -# f"'{fld.name}' field in {self}" -# ) -# # not allowing for default if the field is mandatory -# if not fld.default == attrs.NOTHING and mdata.get("mandatory"): -# raise AttributeError( -# f"default value ({fld.default!r}) should not be set when the field " -# f"('{fld.name}') in {self}) is mandatory" -# ) -# # setting default if value not provided and default is available -# if getattr(self, fld.name) is None: -# if not fld.default == attrs.NOTHING: -# setattr(self, fld.name, fld.default) - - -# @attrs.define(auto_attribs=True, kw_only=True) -# class ShellOutSpec: -# """Output specification of a generic shell process.""" - -# return_code: int -# """The process' exit code.""" -# stdout: str -# """The process' standard output.""" -# stderr: str -# """The process' standard input.""" - -# def collect_additional_outputs(self, inputs, output_dir, outputs): -# from pydra.utils.typing import TypeParser - -# """Collect additional outputs from shelltask output_spec.""" -# additional_out = {} -# for fld in attr_fields(self, exclude_names=("return_code", "stdout", "stderr")): -# if not TypeParser.is_subclass( -# fld.type, -# ( -# os.PathLike, -# MultiOutputObj, -# int, -# float, -# bool, -# str, -# list, -# ), -# ): -# raise TypeError( -# f"Support for {fld.type} type, required for '{fld.name}' in {self}, " -# "has not been implemented in collect_additional_output" -# ) -# # assuming that field should have either default or metadata, but not both -# input_value = getattr(inputs, fld.name, attrs.NOTHING) -# if fld.metadata and "callable" in fld.metadata: -# fld_out = self._field_metadata(fld, inputs, output_dir, outputs) -# elif fld.type in [int, float, bool, str, list]: -# raise AttributeError(f"{fld.type} has to have a callable in metadata") -# elif input_value: # Map input value through to output -# fld_out = input_value -# elif fld.default != attrs.NOTHING: -# fld_out = self._field_defaultvalue(fld, output_dir) -# else: -# raise AttributeError("File has to have default value or metadata") -# if TypeParser.contains_type(FileSet, fld.type): -# label = f"output field '{fld.name}' of {self}" -# fld_out = TypeParser(fld.type, label=label).coerce(fld_out) -# additional_out[fld.name] = fld_out -# return additional_out - -# def generated_output_names(self, inputs, output_dir): -# """Returns a list of all outputs that will be generated by the task. -# Takes into account the task input and the requires list for the output fields. -# TODO: should be in all Output specs? -# """ -# # checking the input (if all mandatory fields are provided, etc.) -# inputs.check_fields_input_spec() -# output_names = ["return_code", "stdout", "stderr"] -# for fld in attr_fields(self, exclude_names=("return_code", "stdout", "stderr")): -# if fld.type not in [File, MultiOutputFile, Directory]: -# raise Exception("not implemented (collect_additional_output)") -# # assuming that field should have either default or metadata, but not both -# if ( -# fld.default in (None, attrs.NOTHING) and not fld.metadata -# ): # TODO: is it right? -# raise AttributeError("File has to have default value or metadata") -# elif fld.default != attrs.NOTHING: -# output_names.append(fld.name) -# elif ( -# fld.metadata -# and self._field_metadata( -# fld, inputs, output_dir, outputs=None, check_existance=False -# ) -# != attrs.NOTHING -# ): -# output_names.append(fld.name) -# return output_names - -# def _field_defaultvalue(self, fld, output_dir): -# """Collect output file if the default value specified.""" -# if not isinstance(fld.default, (str, Path)): -# raise AttributeError( -# f"{fld.name} is a File, so default value " -# f"should be a string or a Path, " -# f"{fld.default} provided" -# ) -# default = fld.default -# if isinstance(default, str): -# default = Path(default) - -# default = output_dir / default -# if "*" not in str(default): -# if default.exists(): -# return default -# else: -# raise AttributeError(f"file {default} does not exist") -# else: -# all_files = [Path(el) for el in glob(str(default.expanduser()))] -# if len(all_files) > 1: -# return all_files -# elif len(all_files) == 1: -# return all_files[0] -# else: -# raise AttributeError(f"no file matches {default.name}") - -# def _field_metadata( -# self, fld, inputs, output_dir, outputs=None, check_existance=True -# ): -# """Collect output file if metadata specified.""" -# if self._check_requires(fld, inputs) is False: -# return attrs.NOTHING - -# if "value" in fld.metadata: -# return output_dir / fld.metadata["value"] -# # this block is only run if "output_file_template" is provided in output_spec -# # if the field is set in input_spec with output_file_template, -# # than the field already should have value -# elif "output_file_template" in fld.metadata: -# value = template_update_single( -# fld, inputs=inputs, output_dir=output_dir, spec_type="output" -# ) - -# if fld.type is MultiOutputFile and type(value) is list: -# # TODO: how to deal with mandatory list outputs -# ret = [] -# for val in value: -# val = Path(val) -# if check_existance and not val.exists(): -# ret.append(attrs.NOTHING) -# else: -# ret.append(val) -# return ret -# else: -# val = Path(value) -# # checking if the file exists -# if check_existance and not val.exists(): -# # if mandatory raise exception -# if "mandatory" in fld.metadata: -# if fld.metadata["mandatory"]: -# raise Exception( -# f"mandatory output for variable {fld.name} does not exist" -# ) -# return attrs.NOTHING -# return val -# elif "callable" in fld.metadata: -# callable_ = fld.metadata["callable"] -# if isinstance(callable_, staticmethod): -# # In case callable is defined as a static method, -# # retrieve the function wrapped in the descriptor. -# callable_ = callable_.__func__ -# call_args = inspect.getfullargspec(callable_) -# call_args_val = {} -# for argnm in call_args.args: -# if argnm == "field": -# call_args_val[argnm] = fld -# elif argnm == "output_dir": -# call_args_val[argnm] = output_dir -# elif argnm == "inputs": -# call_args_val[argnm] = inputs -# elif argnm == "stdout": -# call_args_val[argnm] = outputs["stdout"] -# elif argnm == "stderr": -# call_args_val[argnm] = outputs["stderr"] -# else: -# try: -# call_args_val[argnm] = getattr(inputs, argnm) -# except AttributeError: -# raise AttributeError( -# f"arguments of the callable function from {fld.name} " -# f"has to be in inputs or be field or output_dir, " -# f"but {argnm} is used" -# ) -# return callable_(**call_args_val) -# else: -# raise Exception( -# f"Metadata for '{fld.name}', does not not contain any of the required fields " -# f'("callable", "output_file_template" or "value"): {fld.metadata}.' -# ) - -# def _check_requires(self, fld, inputs): -# """checking if all fields from the requires and template are set in the input -# if requires is a list of list, checking if at least one list has all elements set -# """ -# from .helpers import ensure_list - -# if "requires" in fld.metadata: -# # if requires is a list of list it is treated as el[0] OR el[1] OR... -# required_fields = ensure_list(fld.metadata["requires"]) -# if all([isinstance(el, list) for el in required_fields]): -# field_required_OR = required_fields -# # if requires is a list of tuples/strings - I'm creating a 1-el nested list -# elif all([isinstance(el, (str, tuple)) for el in required_fields]): -# field_required_OR = [required_fields] -# else: -# raise Exception( -# f"requires field can be a list of list, or a list " -# f"of strings/tuples, but {fld.metadata['requires']} " -# f"provided for {fld.name}" -# ) -# else: -# field_required_OR = [[]] - -# for field_required in field_required_OR: -# # if the output has output_file_template field, -# # adding all input fields from the template to requires -# if "output_file_template" in fld.metadata: -# template = fld.metadata["output_file_template"] -# # if a template is a function it has to be run first with the inputs as the only arg -# if callable(template): -# template = template(inputs) -# inp_fields = re.findall(r"{\w+}", template) -# field_required += [ -# el[1:-1] for el in inp_fields if el[1:-1] not in field_required -# ] - -# # it's a flag, of the field from the list is not in input it will be changed to False -# required_found = True -# for field_required in field_required_OR: -# required_found = True -# # checking if the input fields from requires have set values -# for inp in field_required: -# if isinstance(inp, str): # name of the input field -# if not hasattr(inputs, inp): -# raise Exception( -# f"{inp} is not a valid input field, can't be used in requires" -# ) -# elif getattr(inputs, inp) in [attrs.NOTHING, None]: -# required_found = False -# break -# elif isinstance(inp, tuple): # (name, allowed values) -# inp, allowed_val = inp[0], ensure_list(inp[1]) -# if not hasattr(inputs, inp): -# raise Exception( -# f"{inp} is not a valid input field, can't be used in requires" -# ) -# elif getattr(inputs, inp) not in allowed_val: -# required_found = False -# break -# else: -# raise Exception( -# f"each element of the requires element should be a string or a tuple, " -# f"but {inp} is found in {field_required}" -# ) -# # if the specific list from field_required_OR has all elements set, no need to check more -# if required_found: -# break - -# if required_found: -# return True -# else: -# return False - - -# @attrs.define -# class LazyInterface: -# _task: "core.Task" = attrs.field() -# _attr_type: str - -# def __getattr__(self, name): -# if name in ("_task", "_attr_type", "_field_names"): -# raise AttributeError(f"{name} hasn't been set yet") -# if name not in self._field_names: -# raise AttributeError( -# f"Task '{self._task.name}' has no {self._attr_type} attribute '{name}', " -# "available: '" + "', '".join(self._field_names) + "'" -# ) -# type_ = self._get_type(name) -# splits = self._get_task_splits() -# combines = self._get_task_combines() -# if combines and self._attr_type == "output": -# # Add in any scalar splits referencing upstream splits, i.e. "_myupstreamtask", -# # "_myarbitrarytask" -# combined_upstreams = set() -# if self._task.state: -# for scalar in LazyField.normalize_splitter( -# self._task.state.splitter, strip_previous=False -# ): -# for field in scalar: -# if field.startswith("_"): -# node_name = field[1:] -# if any(c.split(".")[0] == node_name for c in combines): -# combines.update( -# f for f in scalar if not f.startswith("_") -# ) -# combined_upstreams.update( -# f[1:] for f in scalar if f.startswith("_") -# ) -# if combines: -# # Wrap type in list which holds the combined items -# type_ = ty.List[type_] -# # Iterate through splits to remove any splits which are removed by the -# # combiner -# for splitter in copy(splits): -# remaining = tuple( -# s -# for s in splitter -# if not any( -# (x in combines or x.split(".")[0] in combined_upstreams) -# for x in s -# ) -# ) -# if remaining != splitter: -# splits.remove(splitter) -# if remaining: -# splits.add(remaining) -# # Wrap the type in a nested StateArray type -# if splits: -# type_ = StateArray[type_] -# lf_klass = LazyInField if self._attr_type == "input" else LazyOutField -# return lf_klass[type_]( -# name=self._task.name, -# field=name, -# type=type_, -# splits=splits, -# ) - -# def _get_task_splits(self) -> ty.Set[ty.Tuple[ty.Tuple[str, ...], ...]]: -# """Returns the states over which the inputs of the task are split""" -# splitter = self._task.state.splitter if self._task.state else None -# splits = set() -# if splitter: -# # Ensure that splits is of tuple[tuple[str, ...], ...] form -# splitter = LazyField.normalize_splitter(splitter) -# if splitter: -# splits.add(splitter) -# for inpt in attrs.asdict(self._task.inputs, recurse=False).values(): -# if isinstance(inpt, LazyField): -# splits.update(inpt.splits) -# return splits - -# def _get_task_combines(self) -> ty.Set[ty.Union[str, ty.Tuple[str, ...]]]: -# """Returns the states over which the outputs of the task are combined""" -# combiner = ( -# self._task.state.combiner -# if self._task.state is not None -# else getattr(self._task, "fut_combiner", None) -# ) -# return set(combiner) if combiner else set() - - -# class LazyIn(LazyInterface): -# _attr_type = "input" - -# def _get_type(self, name): -# attr = next(t for n, t in self._task.input_spec.fields if n == name) -# if attr is None: -# return ty.Any -# elif inspect.isclass(attr): -# return attr -# else: -# return attr.type - -# @property -# def _field_names(self): -# return [field[0] for field in self._task.input_spec.fields] - - -# class LazyOut(LazyInterface): -# _attr_type = "output" - -# def _get_type(self, name): -# try: -# type_ = next(f[1] for f in self._task.output_spec.fields if f[0] == name) -# except StopIteration: -# type_ = ty.Any -# else: -# if not inspect.isclass(type_): -# try: -# type_ = type_.type # attrs _CountingAttribute -# except AttributeError: -# pass # typing._SpecialForm -# return type_ - -# @property -# def _field_names(self): -# return self._task.output_names + ["all_"] - - -def donothing(*args, **kwargs): - return None - - -@attrs.define(auto_attribs=True, kw_only=True) -class TaskHook: - """Callable task hooks.""" - - pre_run_task: ty.Callable = donothing - post_run_task: ty.Callable = donothing - pre_run: ty.Callable = donothing - post_run: ty.Callable = donothing - - def __setattr__(self, attr, val): - if attr not in ["pre_run_task", "post_run_task", "pre_run", "post_run"]: - raise AttributeError("Cannot set unknown hook") - super().__setattr__(attr, val) - - def reset(self): - for val in ["pre_run_task", "post_run_task", "pre_run", "post_run"]: - setattr(self, val, donothing) - - -def path_to_string(value): - """Convert paths to strings.""" - if isinstance(value, Path): - value = str(value) - elif isinstance(value, list) and len(value) and isinstance(value[0], Path): - value = [str(val) for val in value] - return value +import attrs +from fileformats.generic import File, FileSet, Directory +from pydra.engine.audit import AuditFlag +from pydra.utils.typing import MultiOutputObj, MultiOutputFile +from .helpers import attr_fields, is_lazy +from .helpers_file import template_update_single +from pydra.utils.hash import hash_function, Cache -class OutputsSpec: +class OutSpec: """Base class for all output specifications""" def split( @@ -865,14 +82,17 @@ def combine( return self -OutputType = ty.TypeVar("OutputType", bound=OutputsSpec) +OutSpecType = ty.TypeVar("OutputType", bound=OutSpec) -class TaskSpec(ty.Generic[OutputType]): +class TaskSpec(ty.Generic[OutSpecType]): """Base class for all task specifications""" Task: "ty.Type[core.Task]" + def __attrs_post_init__(self): + self._check_rules() + def __call__( self, name: str | None = None, @@ -912,5 +132,498 @@ def _check_for_unset_values(self): "before the workflow can be constructed" ) + @property + def hash(self): + hsh, self._hashes = self._compute_hashes() + return hsh + + def _hash_changes(self): + """Detects any changes in the hashed values between the current inputs and the + previously calculated values""" + _, new_hashes = self._compute_hashes() + return [k for k, v in new_hashes.items() if v != self._hashes[k]] + + def _compute_hashes(self) -> ty.Tuple[bytes, ty.Dict[str, bytes]]: + """Compute a basic hash for any given set of fields.""" + inp_dict = {} + for field in attr_fields( + self, exclude_names=("_graph_checksums", "bindings", "files_hash") + ): + if field.metadata.get("output_file_template"): + continue + # removing values that are not set from hash calculation + if getattr(self, field.name) is attrs.NOTHING: + continue + if "container_path" in field.metadata: + continue + inp_dict[field.name] = getattr(self, field.name) + hash_cache = Cache() + field_hashes = { + k: hash_function(v, cache=hash_cache) for k, v in inp_dict.items() + } + if hasattr(self, "_graph_checksums"): + field_hashes["_graph_checksums"] = self._graph_checksums + return hash_function(sorted(field_hashes.items())), field_hashes + + def _retrieve_values(self, wf, state_index=None): + """Parse output results.""" + temp_values = {} + for field in attr_fields(self): + # retrieving values that do not have templates + if not field.metadata.get("output_file_template"): + value = getattr(self, field.name) + if is_lazy(value): + temp_values[field.name] = value.get_value( + wf, state_index=state_index + ) + for field, val in temp_values.items(): + value = path_to_string(value) + setattr(self, field, val) + + def _check_rules(self): + fields = attr_fields(self) + + for field in fields: + field_is_mandatory = bool(field.metadata.get("mandatory")) + field_is_unset = getattr(self, field.name) is attrs.NOTHING + + if field_is_unset and not field_is_mandatory: + continue + + # Collect alternative fields associated with this field. + alternative_fields = { + name: getattr(self, name) is not attrs.NOTHING + for name in field.metadata.get("xor", []) + if name != field.name + } + alternatives_are_set = any(alternative_fields.values()) + + # Raise error if no field in mandatory alternative group is set. + if field_is_unset: + if alternatives_are_set: + continue + message = f"{field.name} is mandatory and unset." + if alternative_fields: + raise AttributeError( + message[:-1] + + f", but no alternative provided by {list(alternative_fields)}." + ) + else: + raise AttributeError(message) + + # Raise error if multiple alternatives are set. + elif alternatives_are_set: + set_alternative_fields = [ + name for name, is_set in alternative_fields.items() if is_set + ] + raise AttributeError( + f"{field.name} is mutually exclusive with {set_alternative_fields}" + ) + + # Collect required fields associated with this field. + required_fields = { + name: getattr(self, name) is not attrs.NOTHING + for name in field.metadata.get("requires", []) + if name != field.name + } + + # Raise error if any required field is unset. + if not all(required_fields.values()): + unset_required_fields = [ + name for name, is_set in required_fields.items() if not is_set + ] + raise AttributeError(f"{field.name} requires {unset_required_fields}") + + +@attrs.define(kw_only=True) +class Runtime: + """Represent run time metadata.""" + + rss_peak_gb: ty.Optional[float] = None + """Peak in consumption of physical RAM.""" + vms_peak_gb: ty.Optional[float] = None + """Peak in consumption of virtual memory.""" + cpu_peak_percent: ty.Optional[float] = None + """Peak in cpu consumption.""" + + +@attrs.define(kw_only=True) +class Result: + """Metadata regarding the outputs of processing.""" + + output: ty.Optional[ty.Any] = None + runtime: ty.Optional[Runtime] = None + errored: bool = False + + def __getstate__(self): + state = self.__dict__.copy() + if state["output"] is not None: + fields = tuple((el.name, el.type) for el in attr_fields(state["output"])) + state["output_spec"] = (state["output"].__class__.__name__, fields) + state["output"] = attrs.asdict(state["output"], recurse=False) + return state + + def __setstate__(self, state): + if "output_spec" in state: + spec = list(state["output_spec"]) + del state["output_spec"] + klass = attrs.make_class( + spec[0], {k: attrs.field(type=v) for k, v in list(spec[1])} + ) + state["output"] = klass(**state["output"]) + self.__dict__.update(state) + + def get_output_field(self, field_name): + """Used in get_values in Workflow + + Parameters + ---------- + field_name : `str` + Name of field in LazyField object + """ + if field_name == "all_": + return attrs.asdict(self.output, recurse=False) + else: + return getattr(self.output, field_name) + + +@attrs.define(kw_only=True) +class RuntimeSpec: + """ + Specification for a task. + + From CWL:: + + InlineJavascriptRequirement + SchemaDefRequirement + DockerRequirement + SoftwareRequirement + InitialWorkDirRequirement + EnvVarRequirement + ShellCommandRequirement + ResourceRequirement + + InlineScriptRequirement + + """ + + outdir: ty.Optional[str] = None + container: ty.Optional[str] = "shell" + network: bool = False + + +class PythonOutSpec(OutSpec): + pass + + +class PythonSpec(TaskSpec): + pass + + +class WorkflowOutSpec(OutSpec): + pass + + +class WorkflowSpec(TaskSpec): + pass + + +@attrs.define(kw_only=True) +class ShellOutSpec(OutSpec): + """Output specification of a generic shell process.""" + + return_code: int + """The process' exit code.""" + stdout: str + """The process' standard output.""" + stderr: str + """The process' standard input.""" + + def _collect_additional_outputs(self, inputs, output_dir, outputs): + from ..utils.typing import TypeParser + + """Collect additional outputs from shelltask output_spec.""" + additional_out = {} + for fld in attr_fields(self, exclude_names=("return_code", "stdout", "stderr")): + if not TypeParser.is_subclass( + fld.type, + ( + os.PathLike, + MultiOutputObj, + int, + float, + bool, + str, + list, + ), + ): + raise TypeError( + f"Support for {fld.type} type, required for '{fld.name}' in {self}, " + "has not been implemented in collect_additional_output" + ) + # assuming that field should have either default or metadata, but not both + input_value = getattr(inputs, fld.name, attrs.NOTHING) + if input_value is not attrs.NOTHING: + if TypeParser.contains_type(FileSet, fld.type): + if input_value is not False: + label = f"output field '{fld.name}' of {self}" + input_value = TypeParser(fld.type, label=label).coerce( + input_value + ) + additional_out[fld.name] = input_value + elif ( + fld.default is None or fld.default == attrs.NOTHING + ) and not fld.metadata: # TODO: is it right? + raise AttributeError("File has to have default value or metadata") + elif fld.default != attrs.NOTHING: + additional_out[fld.name] = self._field_defaultvalue(fld, output_dir) + elif fld.metadata: + if ( + fld.type in [int, float, bool, str, list] + and "callable" not in fld.metadata + ): + raise AttributeError( + f"{fld.type} has to have a callable in metadata" + ) + additional_out[fld.name] = self._field_metadata( + fld, inputs, output_dir, outputs + ) + return additional_out + + def _generated_output_names(self, inputs, output_dir): + """Returns a list of all outputs that will be generated by the task. + Takes into account the task input and the requires list for the output fields. + TODO: should be in all Output specs? + """ + # checking the input (if all mandatory fields are provided, etc.) + inputs.check_fields_input_spec() + output_names = ["return_code", "stdout", "stderr"] + for fld in attr_fields(self, exclude_names=("return_code", "stdout", "stderr")): + if fld.type not in [File, MultiOutputFile, Directory]: + raise Exception("not implemented (collect_additional_output)") + # assuming that field should have either default or metadata, but not both + if ( + fld.default in (None, attrs.NOTHING) and not fld.metadata + ): # TODO: is it right? + raise AttributeError("File has to have default value or metadata") + elif fld.default != attrs.NOTHING: + output_names.append(fld.name) + elif ( + fld.metadata + and self._field_metadata( + fld, inputs, output_dir, outputs=None, check_existance=False + ) + != attrs.NOTHING + ): + output_names.append(fld.name) + return output_names + + def _field_defaultvalue(self, fld, output_dir): + """Collect output file if the default value specified.""" + if not isinstance(fld.default, (str, Path)): + raise AttributeError( + f"{fld.name} is a File, so default value " + f"should be a string or a Path, " + f"{fld.default} provided" + ) + default = fld.default + if isinstance(default, str): + default = Path(default) + + default = output_dir / default + if "*" not in str(default): + if default.exists(): + return default + else: + raise AttributeError(f"file {default} does not exist") + else: + all_files = [Path(el) for el in glob(str(default.expanduser()))] + if len(all_files) > 1: + return all_files + elif len(all_files) == 1: + return all_files[0] + else: + raise AttributeError(f"no file matches {default.name}") + + def _field_metadata( + self, fld, inputs, output_dir, outputs=None, check_existance=True + ): + """Collect output file if metadata specified.""" + if self._check_requires(fld, inputs) is False: + return attrs.NOTHING + + if "value" in fld.metadata: + return output_dir / fld.metadata["value"] + # this block is only run if "output_file_template" is provided in output_spec + # if the field is set in input_spec with output_file_template, + # than the field already should have value + elif "output_file_template" in fld.metadata: + value = template_update_single( + fld, inputs=inputs, output_dir=output_dir, spec_type="output" + ) + + if fld.type is MultiOutputFile and type(value) is list: + # TODO: how to deal with mandatory list outputs + ret = [] + for val in value: + val = Path(val) + if check_existance and not val.exists(): + ret.append(attrs.NOTHING) + else: + ret.append(val) + return ret + else: + val = Path(value) + # checking if the file exists + if check_existance and not val.exists(): + # if mandatory raise exception + if "mandatory" in fld.metadata: + if fld.metadata["mandatory"]: + raise Exception( + f"mandatory output for variable {fld.name} does not exist" + ) + return attrs.NOTHING + return val + elif "callable" in fld.metadata: + callable_ = fld.metadata["callable"] + if isinstance(callable_, staticmethod): + # In case callable is defined as a static method, + # retrieve the function wrapped in the descriptor. + callable_ = callable_.__func__ + call_args = inspect.getfullargspec(callable_) + call_args_val = {} + for argnm in call_args.args: + if argnm == "field": + call_args_val[argnm] = fld + elif argnm == "output_dir": + call_args_val[argnm] = output_dir + elif argnm == "inputs": + call_args_val[argnm] = inputs + elif argnm == "stdout": + call_args_val[argnm] = outputs["stdout"] + elif argnm == "stderr": + call_args_val[argnm] = outputs["stderr"] + else: + try: + call_args_val[argnm] = getattr(inputs, argnm) + except AttributeError: + raise AttributeError( + f"arguments of the callable function from {fld.name} " + f"has to be in inputs or be field or output_dir, " + f"but {argnm} is used" + ) + return callable_(**call_args_val) + else: + raise Exception( + f"Metadata for '{fld.name}', does not not contain any of the required fields " + f'("callable", "output_file_template" or "value"): {fld.metadata}.' + ) + + def _check_requires(self, fld, inputs): + """checking if all fields from the requires and template are set in the input + if requires is a list of list, checking if at least one list has all elements set + """ + from .helpers import ensure_list + + if "requires" in fld.metadata: + # if requires is a list of list it is treated as el[0] OR el[1] OR... + required_fields = ensure_list(fld.metadata["requires"]) + if all([isinstance(el, list) for el in required_fields]): + field_required_OR = required_fields + # if requires is a list of tuples/strings - I'm creating a 1-el nested list + elif all([isinstance(el, (str, tuple)) for el in required_fields]): + field_required_OR = [required_fields] + else: + raise Exception( + f"requires field can be a list of list, or a list " + f"of strings/tuples, but {fld.metadata['requires']} " + f"provided for {fld.name}" + ) + else: + field_required_OR = [[]] + + for field_required in field_required_OR: + # if the output has output_file_template field, + # adding all input fields from the template to requires + if self.path_template: + # if a template is a function it has to be run first with the inputs as the only arg + if callable(self.path_template): + template = self.path_template(inputs) + inp_fields = re.findall(r"{(\w+)(?:\:[^\}]+)?}", template) + field_required += [ + el[1:-1] for el in inp_fields if el[1:-1] not in field_required + ] + + # it's a flag, of the field from the list is not in input it will be changed to False + required_found = True + for field_required in field_required_OR: + required_found = True + # checking if the input fields from requires have set values + for inp in field_required: + if isinstance(inp, str): # name of the input field + if not hasattr(inputs, inp): + raise Exception( + f"{inp} is not a valid input field, can't be used in requires" + ) + elif getattr(inputs, inp) in [attrs.NOTHING, None]: + required_found = False + break + elif isinstance(inp, tuple): # (name, allowed values) + inp, allowed_val = inp[0], ensure_list(inp[1]) + if not hasattr(inputs, inp): + raise Exception( + f"{inp} is not a valid input field, can't be used in requires" + ) + elif getattr(inputs, inp) not in allowed_val: + required_found = False + break + else: + raise Exception( + f"each element of the requires element should be a string or a tuple, " + f"but {inp} is found in {field_required}" + ) + # if the specific list from field_required_OR has all elements set, no need to check more + if required_found: + break + + if not required_found: + raise ValueError("Did not find all required fields in the input") + + +class ShellSpec(TaskSpec): + pass + + +def donothing(*args, **kwargs): + return None + + +@attrs.define(kw_only=True) +class TaskHook: + """Callable task hooks.""" + + pre_run_task: ty.Callable = donothing + post_run_task: ty.Callable = donothing + pre_run: ty.Callable = donothing + post_run: ty.Callable = donothing + + def __setattr__(self, attr, val): + if attr not in ["pre_run_task", "post_run_task", "pre_run", "post_run"]: + raise AttributeError("Cannot set unknown hook") + super().__setattr__(attr, val) + + def reset(self): + for val in ["pre_run_task", "post_run_task", "pre_run", "post_run"]: + setattr(self, val, donothing) + + +def path_to_string(value): + """Convert paths to strings.""" + if isinstance(value, Path): + value = str(value) + elif isinstance(value, list) and len(value) and isinstance(value[0], Path): + value = [str(val) for val in value] + return value + from pydra.engine import core # noqa: E402 diff --git a/pydra/engine/task.py b/pydra/engine/task.py index 68731f47bc..f2d3a2283e 100644 --- a/pydra/engine/task.py +++ b/pydra/engine/task.py @@ -45,28 +45,22 @@ import re import attr import attrs -import warnings import inspect import typing as ty import shlex from pathlib import Path import cloudpickle as cp -from fileformats.core import FileSet, DataType +from fileformats.core import FileSet from .core import Task, is_lazy from pydra.utils.messenger import AuditFlag from .specs import ( - # BaseSpec, - # SpecInfo, - # ShellSpec, - # ShellOutSpec, - TaskSpec, + ShellSpec, attr_fields, ) from .helpers import ( parse_format_string, position_sort, ensure_list, - # output_from_inputfields, parse_copyfile, ) from .helpers_file import template_update @@ -77,124 +71,6 @@ class FunctionTask(Task): """Wrap a Python callable as a task element.""" - def __init__( - self, - spec: TaskSpec, - audit_flags: AuditFlag = AuditFlag.NONE, - cache_dir=None, - cache_locations=None, - cont_dim=None, - messenger_args=None, - messengers=None, - name=None, - rerun=False, - **kwargs, - ): - """ - Initialize this task. - - Parameters - ---------- - func : :obj:`callable` - A Python executable function. - audit_flags : :obj:`pydra.utils.messenger.AuditFlag` - Auditing configuration - cache_dir : :obj:`os.pathlike` - Cache directory - cache_locations : :obj:`list` of :obj:`os.pathlike` - List of alternative cache locations. - input_spec : :obj:`pydra.engine.specs.SpecInfo` - Specification of inputs. - cont_dim : :obj:`dict`, or `None` - Container dimensions for input fields, - if any of the container should be treated as a container - messenger_args : - TODO - messengers : - TODO - name : :obj:`str` - Name of this task. - output_spec : :obj:`pydra.engine.specs.BaseSpec` - Specification of inputs. - - """ - if input_spec is None: - fields = [] - for val in inspect.signature(func).parameters.values(): - if val.default is not inspect.Signature.empty: - val_dflt = val.default - else: - val_dflt = attr.NOTHING - if isinstance(val.annotation, ty.TypeVar): - raise NotImplementedError( - "Template types are not currently supported in task signatures " - f"(found in '{val.name}' field of '{name}' task), " - "see https://github.com/nipype/pydra/issues/672" - ) - fields.append( - ( - val.name, - attr.ib( - default=val_dflt, - type=val.annotation, - metadata={ - "help_string": f"{val.name} parameter from {func.__name__}" - }, - ), - ) - ) - fields.append(("_func", attr.ib(default=cp.dumps(func), type=bytes))) - input_spec = SpecInfo(name="Inputs", fields=fields, bases=(BaseSpec,)) - else: - input_spec.fields.append( - ("_func", attr.ib(default=cp.dumps(func), type=bytes)) - ) - self.input_spec = input_spec - if name is None: - name = func.__name__ - super().__init__( - name, - inputs=kwargs, - cont_dim=cont_dim, - audit_flags=audit_flags, - messengers=messengers, - messenger_args=messenger_args, - cache_dir=cache_dir, - cache_locations=cache_locations, - rerun=rerun, - ) - if output_spec is None: - name = "Output" - fields = [("out", ty.Any)] - if "return" in func.__annotations__: - return_info = func.__annotations__["return"] - # # e.g. python annotation: fun() -> ty.NamedTuple("Output", [("out", float)]) - # # or pydra decorator: @pydra.mark.annotate({"return": ty.NamedTuple(...)}) - # - - if ( - hasattr(return_info, "__name__") - and getattr(return_info, "__annotations__", None) - and not issubclass(return_info, DataType) - ): - name = return_info.__name__ - fields = list(return_info.__annotations__.items()) - # e.g. python annotation: fun() -> {"out": int} - # or pydra decorator: @pydra.mark.annotate({"return": {"out": int}}) - elif isinstance(return_info, dict): - fields = list(return_info.items()) - # e.g. python annotation: fun() -> (int, int) - # or pydra decorator: @pydra.mark.annotate({"return": (int, int)}) - elif isinstance(return_info, tuple): - fields = [(f"out{i}", t) for i, t in enumerate(return_info, 1)] - # e.g. python annotation: fun() -> int - # or pydra decorator: @pydra.mark.annotate({"return": int}) - else: - fields = [("out", return_info)] - output_spec = SpecInfo(name=name, fields=fields, bases=(BaseSpec,)) - - self.output_spec = output_spec - def _run_task(self, environment=None): inputs = attr.asdict(self.inputs, recurse=False) del inputs["_func"] @@ -220,12 +96,9 @@ def _run_task(self, environment=None): class ShellCommandTask(Task): """Wrap a shell command as a task element.""" - input_spec = None - output_spec = None - def __init__( self, - spec: TaskSpec, + spec: ShellSpec, audit_flags: AuditFlag = AuditFlag.NONE, cache_dir=None, cont_dim=None, @@ -261,36 +134,7 @@ def __init__( Specification of inputs. strip : :obj:`bool` TODO - """ - - # using default name for task if no name provided - if name is None: - name = "ShellTask_noname" - - # # using provided spec, class attribute or setting the default SpecInfo - # self.input_spec = ( - # input_spec - # or self.input_spec - # or SpecInfo(name="Inputs", fields=[], bases=(ShellSpec,)) - # ) - # self.output_spec = ( - # output_spec - # or self.output_spec - # or SpecInfo(name="Output", fields=[], bases=(ShellOutSpec,)) - # ) - # self.output_spec = output_from_inputfields(self.output_spec, self.input_spec) - - for special_inp in ["executable", "args"]: - if hasattr(self, special_inp): - if special_inp not in kwargs: - kwargs[special_inp] = getattr(self, special_inp) - elif kwargs[special_inp] != getattr(self, special_inp): - warnings.warn( - f"you are changing the executable from {getattr(self, special_inp)} " - f"to {kwargs[special_inp]}" - ) - super().__init__( name=name, inputs=kwargs, diff --git a/pydra/engine/workflow/base.py b/pydra/engine/workflow/base.py index cbfbe6d1c2..11a8434c22 100644 --- a/pydra/engine/workflow/base.py +++ b/pydra/engine/workflow/base.py @@ -4,18 +4,19 @@ from typing_extensions import Self import attrs from pydra.engine.helpers import list_fields -from pydra.engine.specs import TaskSpec, OutputsSpec +from pydra.engine.specs import TaskSpec, OutSpec, WorkflowOutSpec from .lazy import LazyInField from pydra.utils.hash import hash_function from pydra.utils.typing import TypeParser, StateArray from .node import Node -OutputType = ty.TypeVar("OutputType", bound=OutputsSpec) +OutSpecType = ty.TypeVar("OutputType", bound=OutSpec) +WorkflowOutSpecType = ty.TypeVar("OutputType", bound=WorkflowOutSpec) @attrs.define(auto_attribs=False) -class Workflow(ty.Generic[OutputType]): +class Workflow(ty.Generic[WorkflowOutSpecType]): """A workflow, constructed from a workflow specification Parameters @@ -29,14 +30,14 @@ class Workflow(ty.Generic[OutputType]): """ name: str = attrs.field() - inputs: TaskSpec[OutputType] = attrs.field() - outputs: OutputType = attrs.field() + inputs: TaskSpec[WorkflowOutSpecType] = attrs.field() + outputs: WorkflowOutSpecType = attrs.field() _nodes: dict[str, Node] = attrs.field(factory=dict) @classmethod def construct( cls, - spec: TaskSpec[OutputType], + spec: TaskSpec[WorkflowOutSpecType], ) -> Self: """Construct a workflow from a specification, caching the constructed worklow""" @@ -104,11 +105,9 @@ def construct( f"{len(output_lazy_fields)} ({output_lazy_fields})" ) for outpt, outpt_lf in zip(output_fields, output_lazy_fields): + # Automatically combine any uncombined state arrays into lists if TypeParser.get_origin(outpt_lf.type) is StateArray: - # Automatically combine any uncombined state arrays into lists - tp, _ = TypeParser.strip_splits(outpt_lf.type) - outpt_lf.type = list[tp] - outpt_lf.splits = frozenset() + outpt_lf.type = list[TypeParser.strip_splits(outpt_lf.type)[0]] setattr(outputs, outpt.name, outpt_lf) else: if unset_outputs := [ @@ -127,7 +126,7 @@ def construct( return wf - def add(self, task_spec: TaskSpec[OutputType], name=None) -> OutputType: + def add(self, task_spec: TaskSpec[OutSpecType], name=None) -> OutSpecType: """Add a node to the workflow Parameters @@ -147,7 +146,7 @@ def add(self, task_spec: TaskSpec[OutputType], name=None) -> OutputType: name = type(task_spec).__name__ if name in self._nodes: raise ValueError(f"Node with name {name!r} already exists in the workflow") - node = Node[OutputType](name=name, spec=task_spec, workflow=self) + node = Node[OutSpecType](name=name, spec=task_spec, workflow=self) self._nodes[name] = node return node.lzout diff --git a/pydra/engine/workflow/node.py b/pydra/engine/workflow/node.py index 7f5b32972a..197d3ca32d 100644 --- a/pydra/engine/workflow/node.py +++ b/pydra/engine/workflow/node.py @@ -4,7 +4,7 @@ import attrs from pydra.utils.typing import TypeParser, StateArray from . import lazy -from ..specs import TaskSpec, OutputsSpec +from ..specs import TaskSpec, OutSpec from ..helpers import ensure_list from .. import helpers_state as hlpst from ..state import State @@ -13,7 +13,7 @@ from .base import Workflow -OutputType = ty.TypeVar("OutputType", bound=OutputsSpec) +OutputType = ty.TypeVar("OutputType", bound=OutSpec) Splitter = ty.Union[str, ty.Tuple[str, ...]] _not_set = Enum("_not_set", "NOT_SET") @@ -43,8 +43,8 @@ class Node(ty.Generic[OutputType]): _cont_dim: dict[str, int] | None = attrs.field( init=False, default=None ) # QUESTION: should this be included in the state? - _inner_cont_dim: dict[str, int] | None = attrs.field( - init=False, default=None + _inner_cont_dim: dict[str, int] = attrs.field( + init=False, factory=dict ) # QUESTION: should this be included in the state? class Inputs: @@ -247,7 +247,9 @@ def combine( if not isinstance(combiner, (str, list)): raise Exception("combiner has to be a string or a list") combiner = hlpst.add_name_combiner(ensure_list(combiner), self.name) - if not_split := [c for c in combiner if not any(c in s for s in self.splitter)]: + if not_split := [ + c for c in combiner if not any(c in s for s in self.state.splitter) + ]: raise ValueError( f"Combiner fields {not_split} for Node {self.name!r} are not in the " f"splitter fields {self.splitter}" @@ -343,7 +345,14 @@ def _wrap_lzout_types_in_state_arrays(self) -> None: if not self.state: return outpt_lf: lazy.LazyOutField - state_depth = len(self.state.splitter_rpn) + remaining_splits = [] + for split in self.state.splitter: + if isinstance(split, str): + if split not in self.state.combiner: + remaining_splits.append(split) + elif all(s not in self.state.combiner for s in split): + remaining_splits.append(split) + state_depth = len(remaining_splits) for outpt_lf in attrs.asdict(self.lzout, recurse=False).values(): assert not outpt_lf.type_checked type_, _ = TypeParser.strip_splits(outpt_lf.type) diff --git a/pydra/utils/typing.py b/pydra/utils/typing.py index 58249ddbfd..ee21d26db3 100644 --- a/pydra/utils/typing.py +++ b/pydra/utils/typing.py @@ -10,6 +10,7 @@ import attr from pydra.utils import add_exc_note from fileformats import field, core, generic +from pydra.engine.helpers import is_lazy try: from typing import get_origin, get_args @@ -213,12 +214,11 @@ def __call__(self, obj: ty.Any) -> T: if the coercion is not possible, or not specified by the `coercible`/`not_coercible` parameters, then a TypeError is raised """ - from pydra.engine.workflow.lazy import LazyField coerced: T if obj is attr.NOTHING: coerced = attr.NOTHING # type: ignore[assignment] - elif isinstance(obj, LazyField): + elif is_lazy(obj): try: self.check_type(obj.type) except TypeError as e: