Skip to content

Commit

Permalink
debugged is_subclass so it works properly for union types
Browse files Browse the repository at this point in the history
  • Loading branch information
tclose committed Sep 7, 2023
1 parent ff5cb7b commit fc3fcd1
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 29 deletions.
24 changes: 22 additions & 2 deletions pydra/utils/tests/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,14 +611,34 @@ def test_type_is_subclass3():
assert TypeParser.is_subclass(ty.Type[Json], ty.Type[File])


def test_type_is_subclass4():
def test_union_is_subclass1():
assert TypeParser.is_subclass(ty.Union[Json, Yaml], ty.Union[Json, Yaml, Xml])


def test_type_is_subclass5():
def test_union_is_subclass2():
assert not TypeParser.is_subclass(ty.Union[Json, Yaml, Xml], ty.Union[Json, Yaml])


def test_union_is_subclass3():
assert TypeParser.is_subclass(Json, ty.Union[Json, Yaml])


def test_union_is_subclass4():
assert not TypeParser.is_subclass(ty.Union[Json, Yaml], Json)


def test_generic_is_subclass1():
assert TypeParser.is_subclass(ty.List[int], list)


def test_generic_is_subclass2():
assert not TypeParser.is_subclass(list, ty.List[int])


def test_generic_is_subclass3():
assert not TypeParser.is_subclass(ty.List[float], ty.List[int])


def test_type_is_instance1():
assert TypeParser.is_instance(File, ty.Type[File])

Expand Down
63 changes: 36 additions & 27 deletions pydra/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def matches_type(
def is_instance(
cls,
obj: object,
candidates: ty.Union[ty.Type[ty.Any], ty.Iterable[ty.Type[ty.Any]]],
candidates: ty.Union[ty.Type[ty.Any], ty.Sequence[ty.Type[ty.Any]]],
) -> bool:
"""Checks whether the object is an instance of cls or that cls is typing.Any,
extending the built-in isinstance to check nested type args
Expand All @@ -574,7 +574,7 @@ def is_instance(
candidates : type or ty.Iterable[type]
the candidate types to check the object against
"""
if not isinstance(candidates, (tuple, list)):
if not isinstance(candidates, ty.Sequence):
candidates = [candidates]
for candidate in candidates:
if candidate is ty.Any:
Expand All @@ -600,7 +600,7 @@ def is_instance(
def is_subclass(
cls,
klass: ty.Type[ty.Any],
candidates: ty.Union[ty.Type[ty.Any], ty.Iterable[ty.Type[ty.Any]]],
candidates: ty.Union[ty.Type[ty.Any], ty.Sequence[ty.Type[ty.Any]]],
any_ok: bool = False,
) -> bool:
"""Checks whether the class a is either the same as b, a subclass of b or b is
Expand All @@ -617,16 +617,23 @@ def is_subclass(
"""
if not isinstance(candidates, ty.Sequence):
candidates = [candidates]
if ty.Any in candidates:
return True
if klass is ty.Any:
return any_ok

origin = get_origin(klass)
args = get_args(klass)

for candidate in candidates:
candidate_origin = get_origin(candidate)
candidate_args = get_args(candidate)
# Handle ty.Type[*] types in klass and candidates
if ty.get_origin(klass) is type and (
candidate is type or ty.get_origin(candidate) is type
):
if origin is type and (candidate is type or candidate_origin is type):
if candidate is type:
return True
return cls.is_subclass(ty.get_args(klass)[0], ty.get_args(candidate)[0])
elif ty.get_origin(klass) is type or ty.get_origin(candidate) is type:
return cls.is_subclass(args[0], candidate_args[0])
elif origin is type or candidate_origin is type:
return False
if NO_GENERIC_ISSUBCLASS:
if klass is type and candidate is not type:
Expand All @@ -636,27 +643,29 @@ def is_subclass(
):
return True
else:
if klass is ty.Any:
if ty.Any in candidates: # type: ignore
return True
else:
return any_ok
origin = get_origin(klass)
if origin is ty.Union:
args = get_args(klass)
if get_origin(candidate) is ty.Union:
candidate_args = get_args(candidate)
else:
candidate_args = [candidate]
return all(
any(cls.is_subclass(a, c) for c in candidate_args) for a in args
union_args = (
candidate_args if candidate_origin is ty.Union else (candidate,)
)
if origin is not None:
klass = origin
if klass is candidate or candidate is ty.Any:
return True
if issubclass(klass, candidate):
return True
matches = all(
any(cls.is_subclass(a, c) for c in union_args) for a in args
)
if matches:
return True
else:
if candidate_args and candidate_origin is not ty.Union:
if (
origin
and issubclass(origin, candidate_origin) # type: ignore[arg-type]
and len(args) == len(candidate_args)
and all(
issubclass(a, c) for a, c in zip(args, candidate_args)
)
):
return True
else:
if issubclass(origin if origin else klass, candidate):
return True
return False

@classmethod
Expand Down

0 comments on commit fc3fcd1

Please sign in to comment.