Skip to content

Commit

Permalink
fixes issues with super->sub-class auto-cast and handles MultiInputOb…
Browse files Browse the repository at this point in the history
…j coercion
  • Loading branch information
tclose committed May 21, 2024
1 parent 1ff10b0 commit 3715841
Showing 1 changed file with 54 additions and 16 deletions.
70 changes: 54 additions & 16 deletions pydra/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,18 @@ def coerce_obj(obj, type_):
f"Cannot coerce {obj!r} into {type_}{msg}{self.label_str}"
) from e

return expand_and_coerce(object_, self.pattern)
# Special handling for MultiInputObjects (which are annoying)
if isinstance(self.pattern, tuple) and self.pattern[0] == MultiInputObj:
try:
self.check_coercible(object_, self.pattern[1][0])
except TypeError:
pass
else:
obj = [object_]
else:
obj = object_

return expand_and_coerce(obj, self.pattern)

def check_type(self, type_: ty.Type[ty.Any]):
"""Checks the given type to see whether it matches or is a subtype of the
Expand Down Expand Up @@ -413,7 +424,7 @@ def expand_and_check(tp, pattern: ty.Union[type, tuple]):
f"{self.pattern}{self.label_str}"
)
tp_args = get_args(tp)
self.check_coercible(tp_origin, pattern_origin)
self.check_type_coercible(tp_origin, pattern_origin)
if issubclass(pattern_origin, ty.Mapping):
return check_mapping(tp_args, pattern_args)
if issubclass(pattern_origin, tuple):
Expand Down Expand Up @@ -446,7 +457,7 @@ def check_basic(tp, target):
+ "\n\n".join(f"{a} -> {e}" for a, e in zip(tp_args, reasons))
)
if not self.is_subclass(tp, target):
self.check_coercible(tp, target)
self.check_type_coercible(tp, target)

def check_union(tp, pattern_args):
if get_origin(tp) in UNION_TYPES:
Expand Down Expand Up @@ -526,19 +537,46 @@ def check_sequence(tp_args, pattern_args):
for arg in tp_args:
expand_and_check(arg, pattern_args[0])

return expand_and_check(type_, self.pattern)
# Special handling for MultiInputObjects (which are annoying)
if isinstance(self.pattern, tuple) and self.pattern[0] == MultiInputObj:
pattern = (ty.Union, [self.pattern[1][0], (ty.List, self.pattern[1])])
else:
pattern = self.pattern
return expand_and_check(type_, pattern)

def check_coercible(self, source: ty.Any, target: ty.Union[type, ty.Any]):
"""Checks whether the source object is coercible to the target type given the coercion
rules defined in the `coercible` and `not_coercible` attrs
Parameters
----------
source : object
the object to be coerced
target : type or typing.Any
the target type for the object to be coerced to
Raises
------
TypeError
If the object cannot be coerced into the target type depending on the explicit
inclusions and exclusions set in the `coercible` and `not_coercible` member attrs
"""
self.check_type_coercible(type(source), target, source_repr=repr(source))

def check_coercible(
self, source: ty.Union[object, type], target: ty.Union[type, ty.Any]
def check_type_coercible(
self,
source: ty.Union[type, ty.Any],
target: ty.Union[type, ty.Any],
source_repr: ty.Optional[str] = None,
):
"""Checks whether the source object or type is coercible to the target type
"""Checks whether the source type is coercible to the target type
given the coercion rules defined in the `coercible` and `not_coercible` attrs
Parameters
----------
source : object or type
source object or type to be coerced
target : type or ty.Any
source : type or typing.Any
source type to be coerced
target : type or typing.Any
target type for the source to be coerced to
Raises
Expand All @@ -548,10 +586,12 @@ def check_coercible(
explicit inclusions and exclusions set in the `coercible` and `not_coercible`
member attrs
"""
if source_repr is None:
source_repr = repr(source)
# Short-circuit the basic cases where the source and target are the same
if source is target:
return
if self.superclass_auto_cast and self.is_subclass(target, type(source)):
if self.superclass_auto_cast and self.is_subclass(target, source):
logger.info(
"Attempting to coerce %s into %s due to super-to-sub class coercion "
"being permitted",
Expand All @@ -563,13 +603,11 @@ def check_coercible(
if source_origin is not None:
source = source_origin

source_check = self.is_subclass if inspect.isclass(source) else self.is_instance

def matches_criteria(criteria):
return [
(src, tgt)
for src, tgt in criteria
if source_check(source, src) and self.is_subclass(target, tgt)
if self.is_subclass(source, src) and self.is_subclass(target, tgt)
]

def type_name(t):
Expand All @@ -580,7 +618,7 @@ def type_name(t):

if not matches_criteria(self.coercible):
raise TypeError(
f"Cannot coerce {repr(source)} into {target}{self.label_str} as the "
f"Cannot coerce {source_repr} into {target}{self.label_str} as the "
"coercion doesn't match any of the explicit inclusion criteria: "
+ ", ".join(
f"{type_name(s)} -> {type_name(t)}" for s, t in self.coercible
Expand All @@ -589,7 +627,7 @@ def type_name(t):
matches_not_coercible = matches_criteria(self.not_coercible)
if matches_not_coercible:
raise TypeError(
f"Cannot coerce {repr(source)} into {target}{self.label_str} as it is explicitly "
f"Cannot coerce {source_repr} into {target}{self.label_str} as it is explicitly "
"excluded by the following coercion criteria: "
+ ", ".join(
f"{type_name(s)} -> {type_name(t)}"
Expand Down

0 comments on commit 3715841

Please sign in to comment.