Skip to content

Commit

Permalink
added unittests for multi_input_obj coercion
Browse files Browse the repository at this point in the history
  • Loading branch information
tclose committed May 21, 2024
1 parent 3715841 commit 2a06afd
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 19 deletions.
64 changes: 62 additions & 2 deletions pydra/utils/tests/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import tempfile
import pytest
from pydra import mark
from ...engine.specs import File, LazyOutField
from ...engine.specs import File, LazyOutField, MultiInputObj
from ..typing import TypeParser
from pydra import Workflow
from fileformats.application import Json, Yaml, Xml
Expand Down Expand Up @@ -244,7 +244,7 @@ def test_type_check_fail3():
def test_type_check_fail4():
with pytest.raises(TypeError) as exc_info:
TypeParser(ty.Sequence)(lz(ty.Dict[str, int]))
assert exc_info_matches(exc_info, "Cannot coerce <class 'dict'> into")
assert exc_info_matches(exc_info, "Cannot coerce .*(d|D)ict.* into")


def test_type_check_fail5():
Expand Down Expand Up @@ -1020,3 +1020,63 @@ def test_type_is_instance11():
@pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10")
def test_type_is_instance11a():
assert not TypeParser.is_instance(None, int | str)


def test_multi_input_obj_coerce1():
assert TypeParser(MultiInputObj[str])("a") == ["a"]


def test_multi_input_obj_coerce2():
assert TypeParser(MultiInputObj[str])(["a"]) == ["a"]


def test_multi_input_obj_coerce3():
assert TypeParser(MultiInputObj[ty.List[str]])(["a"]) == [["a"]]


def test_multi_input_obj_coerce3a():
assert TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])(["a"]) == [["a"]]


def test_multi_input_obj_coerce3b():
assert TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])([["a"]]) == [["a"]]


def test_multi_input_obj_coerce4():
assert TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])([1]) == [1]


def test_multi_input_obj_coerce4a():
with pytest.raises(TypeError):
TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])([[1]])


def test_multi_input_obj_check_type1():
TypeParser(MultiInputObj[str])(lz(str))


def test_multi_input_obj_check_type2():
TypeParser(MultiInputObj[str])(lz(ty.List[str]))


def test_multi_input_obj_check_type3():
TypeParser(MultiInputObj[ty.List[str]])(lz(ty.List[str]))


def test_multi_input_obj_check_type3a():
TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])(lz(ty.List[str]))


def test_multi_input_obj_check_type3b():
TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])(lz(ty.List[ty.List[str]]))


def test_multi_input_obj_check_type4():
TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])(lz(ty.List[int]))


def test_multi_input_obj_check_type4a():
with pytest.raises(TypeError):
TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])(
lz(ty.List[ty.List[int]])
)
58 changes: 41 additions & 17 deletions pydra/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
from pathlib import Path
import os
from copy import copy
import sys
import types
import typing as ty
Expand All @@ -13,6 +14,7 @@
MultiInputObj,
MultiOutputObj,
)
from ..utils import add_exc_note
from fileformats import field

try:
Expand Down Expand Up @@ -366,18 +368,26 @@ def coerce_obj(obj, type_):
f"Cannot coerce {obj!r} into {type_}{msg}{self.label_str}"
) from e

# Special handling for MultiInputObjects (which are annoying)
if isinstance(self.pattern, tuple) and self.pattern[0] == MultiInputObj:
try:
self.check_coercible(object_, self.pattern[1][0])
except TypeError:
pass
try:
return expand_and_coerce(object_, self.pattern)
except TypeError as e:
# Special handling for MultiInputObjects (which are annoying)
if isinstance(self.pattern, tuple) and self.pattern[0] == MultiInputObj:
# Attempt to coerce the object into arg type of the MultiInputObj first,
# and if that fails, try to coerce it into a list of the arg type
inner_type_parser = copy(self)
inner_type_parser.pattern = self.pattern[1][0]
try:
return [inner_type_parser.coerce(object_)]
except TypeError:
add_exc_note(
e,
"Also failed to coerce to the arg-type of the MultiInputObj "
f"({self.pattern[1][0]})",
)
raise e
else:
obj = [object_]
else:
obj = object_

return expand_and_coerce(obj, self.pattern)
raise e

def check_type(self, type_: ty.Type[ty.Any]):
"""Checks the given type to see whether it matches or is a subtype of the
Expand Down Expand Up @@ -537,12 +547,26 @@ def check_sequence(tp_args, pattern_args):
for arg in tp_args:
expand_and_check(arg, pattern_args[0])

# Special handling for MultiInputObjects (which are annoying)
if isinstance(self.pattern, tuple) and self.pattern[0] == MultiInputObj:
pattern = (ty.Union, [self.pattern[1][0], (ty.List, self.pattern[1])])
else:
pattern = self.pattern
return expand_and_check(type_, pattern)
try:
return expand_and_check(type_, self.pattern)
except TypeError as e:
# Special handling for MultiInputObjects (which are annoying)
if isinstance(self.pattern, tuple) and self.pattern[0] == MultiInputObj:
# Attempt to coerce the object into arg type of the MultiInputObj first,
# and if that fails, try to coerce it into a list of the arg type
inner_type_parser = copy(self)
inner_type_parser.pattern = self.pattern[1][0]
try:
inner_type_parser.check_type(type_)
except TypeError:
add_exc_note(
e,
"Also failed to coerce to the arg-type of the MultiInputObj "
f"({self.pattern[1][0]})",
)
raise e
else:
raise e

def check_coercible(self, source: ty.Any, target: ty.Union[type, ty.Any]):
"""Checks whether the source object is coercible to the target type given the coercion
Expand Down

0 comments on commit 2a06afd

Please sign in to comment.