diff --git a/pydra/utils/tests/test_typing.py b/pydra/utils/tests/test_typing.py index 5e78fde9b0..df87a87f2c 100644 --- a/pydra/utils/tests/test_typing.py +++ b/pydra/utils/tests/test_typing.py @@ -7,7 +7,7 @@ import tempfile import pytest from pydra import mark -from ...engine.specs import File, LazyOutField +from ...engine.specs import File, LazyOutField, MultiInputObj from ..typing import TypeParser from pydra import Workflow from fileformats.application import Json, Yaml, Xml @@ -249,7 +249,7 @@ def test_type_check_fail3(): def test_type_check_fail4(): with pytest.raises(TypeError) as exc_info: TypeParser(ty.Sequence)(lz(ty.Dict[str, int])) - assert exc_info_matches(exc_info, "Cannot coerce into") + assert exc_info_matches(exc_info, "Cannot coerce .*(d|D)ict.* into") def test_type_check_fail5(): @@ -1043,3 +1043,63 @@ def test_type_is_instance11(): @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") def test_type_is_instance11a(): assert not TypeParser.is_instance(None, int | str) + + +def test_multi_input_obj_coerce1(): + assert TypeParser(MultiInputObj[str])("a") == ["a"] + + +def test_multi_input_obj_coerce2(): + assert TypeParser(MultiInputObj[str])(["a"]) == ["a"] + + +def test_multi_input_obj_coerce3(): + assert TypeParser(MultiInputObj[ty.List[str]])(["a"]) == [["a"]] + + +def test_multi_input_obj_coerce3a(): + assert TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])(["a"]) == [["a"]] + + +def test_multi_input_obj_coerce3b(): + assert TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])([["a"]]) == [["a"]] + + +def test_multi_input_obj_coerce4(): + assert TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])([1]) == [1] + + +def test_multi_input_obj_coerce4a(): + with pytest.raises(TypeError): + TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])([[1]]) + + +def test_multi_input_obj_check_type1(): + TypeParser(MultiInputObj[str])(lz(str)) + + +def test_multi_input_obj_check_type2(): + TypeParser(MultiInputObj[str])(lz(ty.List[str])) + + +def test_multi_input_obj_check_type3(): + TypeParser(MultiInputObj[ty.List[str]])(lz(ty.List[str])) + + +def test_multi_input_obj_check_type3a(): + TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])(lz(ty.List[str])) + + +def test_multi_input_obj_check_type3b(): + TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])(lz(ty.List[ty.List[str]])) + + +def test_multi_input_obj_check_type4(): + TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])(lz(ty.List[int])) + + +def test_multi_input_obj_check_type4a(): + with pytest.raises(TypeError): + TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])( + lz(ty.List[ty.List[int]]) + ) diff --git a/pydra/utils/typing.py b/pydra/utils/typing.py index 32968001c9..136bfd443e 100644 --- a/pydra/utils/typing.py +++ b/pydra/utils/typing.py @@ -2,6 +2,7 @@ import inspect from pathlib import Path import os +from copy import copy import sys import types import typing as ty @@ -13,6 +14,7 @@ MultiInputObj, MultiOutputObj, ) +from ..utils import add_exc_note from fileformats import field try: @@ -366,18 +368,26 @@ def coerce_obj(obj, type_): f"Cannot coerce {obj!r} into {type_}{msg}{self.label_str}" ) from e - # Special handling for MultiInputObjects (which are annoying) - if isinstance(self.pattern, tuple) and self.pattern[0] == MultiInputObj: - try: - self.check_coercible(object_, self.pattern[1][0]) - except TypeError: - pass + try: + return expand_and_coerce(object_, self.pattern) + except TypeError as e: + # Special handling for MultiInputObjects (which are annoying) + if isinstance(self.pattern, tuple) and self.pattern[0] == MultiInputObj: + # Attempt to coerce the object into arg type of the MultiInputObj first, + # and if that fails, try to coerce it into a list of the arg type + inner_type_parser = copy(self) + inner_type_parser.pattern = self.pattern[1][0] + try: + return [inner_type_parser.coerce(object_)] + except TypeError: + add_exc_note( + e, + "Also failed to coerce to the arg-type of the MultiInputObj " + f"({self.pattern[1][0]})", + ) + raise e else: - obj = [object_] - else: - obj = object_ - - return expand_and_coerce(obj, self.pattern) + raise e def check_type(self, type_: ty.Type[ty.Any]): """Checks the given type to see whether it matches or is a subtype of the @@ -537,12 +547,26 @@ def check_sequence(tp_args, pattern_args): for arg in tp_args: expand_and_check(arg, pattern_args[0]) - # Special handling for MultiInputObjects (which are annoying) - if isinstance(self.pattern, tuple) and self.pattern[0] == MultiInputObj: - pattern = (ty.Union, [self.pattern[1][0], (ty.List, self.pattern[1])]) - else: - pattern = self.pattern - return expand_and_check(type_, pattern) + try: + return expand_and_check(type_, self.pattern) + except TypeError as e: + # Special handling for MultiInputObjects (which are annoying) + if isinstance(self.pattern, tuple) and self.pattern[0] == MultiInputObj: + # Attempt to coerce the object into arg type of the MultiInputObj first, + # and if that fails, try to coerce it into a list of the arg type + inner_type_parser = copy(self) + inner_type_parser.pattern = self.pattern[1][0] + try: + inner_type_parser.check_type(type_) + except TypeError: + add_exc_note( + e, + "Also failed to coerce to the arg-type of the MultiInputObj " + f"({self.pattern[1][0]})", + ) + raise e + else: + raise e def check_coercible(self, source: ty.Any, target: ty.Union[type, ty.Any]): """Checks whether the source object is coercible to the target type given the coercion