From dcd0a800608895d2ded219921b54dc4e4b6625a1 Mon Sep 17 00:00:00 2001 From: Tom Close Date: Thu, 7 Sep 2023 18:43:14 +1000 Subject: [PATCH] added tests for explict and auto-superclass casting --- pydra/engine/helpers.py | 2 +- pydra/utils/tests/test_typing.py | 94 +++++++++++++++++++++++++++----- pydra/utils/tests/utils.py | 89 +++++++++++++++++++++++++++++- pydra/utils/typing.py | 19 +++++-- 4 files changed, 182 insertions(+), 22 deletions(-) diff --git a/pydra/engine/helpers.py b/pydra/engine/helpers.py index d91d9e544a..3dc3457da9 100644 --- a/pydra/engine/helpers.py +++ b/pydra/engine/helpers.py @@ -263,7 +263,7 @@ def make_klass(spec): ) checker_label = f"'{name}' field of {spec.name}" type_checker = TypeParser[newfield.type]( - newfield.type, label=checker_label, allow_lazy_super=True + newfield.type, label=checker_label, superclass_auto_cast=True ) if newfield.type in (MultiInputObj, MultiInputFile): converter = attr.converters.pipe(ensure_list, type_checker) diff --git a/pydra/utils/tests/test_typing.py b/pydra/utils/tests/test_typing.py index 7580070511..be362cec86 100644 --- a/pydra/utils/tests/test_typing.py +++ b/pydra/utils/tests/test_typing.py @@ -14,7 +14,10 @@ GenericShellTask, specific_func_task, SpecificShellTask, + other_specific_func_task, + OtherSpecificShellTask, MyFormatX, + MyOtherFormatX, MyHeader, ) @@ -168,12 +171,12 @@ def test_type_check_permit_superclass(): # Typical case as Json is subclass of File TypeParser(ty.List[File])(lz(ty.List[Json])) # Permissive super class, as File is superclass of Json - TypeParser(ty.List[Json], allow_lazy_super=True)(lz(ty.List[File])) + TypeParser(ty.List[Json], superclass_auto_cast=True)(lz(ty.List[File])) with pytest.raises(TypeError, match="Cannot coerce"): - TypeParser(ty.List[Json], allow_lazy_super=False)(lz(ty.List[File])) + TypeParser(ty.List[Json], superclass_auto_cast=False)(lz(ty.List[File])) # Fails because Yaml is neither sub or super class of Json with pytest.raises(TypeError, match="Cannot coerce"): - TypeParser(ty.List[Json], allow_lazy_super=True)(lz(ty.List[Yaml])) + TypeParser(ty.List[Json], superclass_auto_cast=True)(lz(ty.List[Yaml])) def test_type_check_fail1(): @@ -550,7 +553,17 @@ def specific_task(request): assert False -def test_typing_cast(tmp_path, generic_task, specific_task): +@pytest.fixture(params=["func", "shell"]) +def other_specific_task(request): + if request.param == "func": + return other_specific_func_task + elif request.param == "shell": + return OtherSpecificShellTask + else: + assert False + + +def test_typing_implicit_cast_from_super(tmp_path, generic_task, specific_task): """Check the casting of lazy fields and whether specific file-sets can be recovered from generic `File` classes""" @@ -574,33 +587,86 @@ def test_typing_cast(tmp_path, generic_task, specific_task): ) ) + wf.add( + specific_task( + in_file=wf.generic.lzout.out, + name="specific2", + ) + ) + + wf.set_output( + [ + ("out_file", wf.specific2.lzout.out), + ] + ) + + in_file = MyFormatX.sample() + + result = wf(in_file=in_file, plugin="serial") + + out_file: MyFormatX = result.output.out_file + assert type(out_file) is MyFormatX + assert out_file.parent != in_file.parent + assert type(out_file.header) is MyHeader + assert out_file.header.parent != in_file.header.parent + + +def test_typing_cast(tmp_path, specific_task, other_specific_task): + """Check the casting of lazy fields and whether specific file-sets can be recovered + from generic `File` classes""" + + wf = Workflow( + name="test", + input_spec={"in_file": MyFormatX}, + output_spec={"out_file": MyFormatX}, + ) + + wf.add( + specific_task( + in_file=wf.lzin.in_file, + name="entry", + ) + ) + + with pytest.raises(TypeError, match="Cannot coerce"): + # No cast of generic task output to MyFormatX + wf.add( # Generic task + other_specific_task( + in_file=wf.entry.lzout.out, + name="inner", + ) + ) + + wf.add( # Generic task + other_specific_task( + in_file=wf.entry.lzout.out.cast(MyOtherFormatX), + name="inner", + ) + ) + with pytest.raises(TypeError, match="Cannot coerce"): # No cast of generic task output to MyFormatX wf.add( specific_task( - in_file=wf.generic.lzout.out, - name="specific2", + in_file=wf.inner.lzout.out, + name="exit", ) ) wf.add( specific_task( - in_file=wf.generic.lzout.out.cast(MyFormatX), - name="specific2", + in_file=wf.inner.lzout.out.cast(MyFormatX), + name="exit", ) ) wf.set_output( [ - ("out_file", wf.specific2.lzout.out), + ("out_file", wf.exit.lzout.out), ] ) - my_fspath = tmp_path / "in_file.my" - hdr_fspath = tmp_path / "in_file.hdr" - my_fspath.write_text("my-format") - hdr_fspath.write_text("my-header") - in_file = MyFormatX([my_fspath, hdr_fspath]) + in_file = MyFormatX.sample() result = wf(in_file=in_file, plugin="serial") diff --git a/pydra/utils/tests/utils.py b/pydra/utils/tests/utils.py index eb452edf91..9611ef0a4e 100644 --- a/pydra/utils/tests/utils.py +++ b/pydra/utils/tests/utils.py @@ -1,12 +1,15 @@ +from pathlib import Path +import typing as ty from fileformats.generic import File -from fileformats.core.mixin import WithSeparateHeader +from fileformats.core.mixin import WithSeparateHeader, WithMagicNumber from pydra import mark from pydra.engine.task import ShellCommandTask from pydra.engine import specs -class MyFormat(File): +class MyFormat(WithMagicNumber, File): ext = ".my" + magic_number = b"MYFORMAT" class MyHeader(File): @@ -17,6 +20,34 @@ class MyFormatX(WithSeparateHeader, MyFormat): header_type = MyHeader +class MyOtherFormatX(WithMagicNumber, WithSeparateHeader, File): + magic_number = b"MYFORMAT" + ext = ".my" + header_type = MyHeader + + +@File.generate_sample_data.register +def my_format_x_generate_sample_data( + my_format_x: MyFormatX, dest_dir: Path +) -> ty.List[Path]: + fspath = dest_dir / "file.my" + with open(fspath, "wb") as f: + f.write(b"MYFORMAT\nsome data goes here") + header_fspath = dest_dir / "file.hdr" + header_fspath.write_text("a: 1\nb: 2\nc: 3\n") + return [fspath, header_fspath] + + +@File.generate_sample_data.register +def my_other_format_generate_sample_data( + my_other_format: MyOtherFormatX, dest_dir: Path +) -> ty.List[Path]: + fspath = dest_dir / "file.my" + with open(fspath, "wb") as f: + f.write(b"MYFORMAT\nsome data goes here") + return [fspath] + + @mark.task def generic_func_task(in_file: File) -> File: return in_file @@ -118,3 +149,57 @@ class SpecificShellTask(ShellCommandTask): input_spec = specific_shell_input_spec output_spec = specific_shelloutput_spec executable = "echo" + + +@mark.task +def other_specific_func_task(in_file: MyOtherFormatX) -> MyOtherFormatX: + return in_file + + +other_specific_shell_input_fields = [ + ( + "in_file", + MyOtherFormatX, + { + "help_string": "the input file", + "argstr": "", + "copyfile": "copy", + "sep": " ", + }, + ), + ( + "out", + str, + { + "help_string": "output file name", + "argstr": "", + "position": -1, + "output_file_template": "{in_file}", # Pass through un-altered + }, + ), +] + +other_specific_shell_input_spec = specs.SpecInfo( + name="Input", fields=other_specific_shell_input_fields, bases=(specs.ShellSpec,) +) + +other_specific_shell_output_fields = [ + ( + "out", + MyOtherFormatX, + { + "help_string": "output file", + }, + ), +] +other_specific_shelloutput_spec = specs.SpecInfo( + name="Output", + fields=other_specific_shell_output_fields, + bases=(specs.ShellOutSpec,), +) + + +class OtherSpecificShellTask(ShellCommandTask): + input_spec = other_specific_shell_input_spec + output_spec = other_specific_shelloutput_spec + executable = "echo" diff --git a/pydra/utils/typing.py b/pydra/utils/typing.py index fbbb07cc76..d0ae8fbaba 100644 --- a/pydra/utils/typing.py +++ b/pydra/utils/typing.py @@ -58,7 +58,7 @@ class TypeParser(ty.Generic[T]): the tree of more complex nested container types. Overrides 'coercible' to enable you to carve out exceptions, such as TypeParser(list, coercible=[(ty.Iterable, list)], not_coercible=[(str, list)]) - allow_lazy_super : bool + superclass_auto_cast : bool Allow lazy fields to pass the type check if their types are superclasses of the specified pattern (instead of matching or being subclasses of the pattern) label : str @@ -69,7 +69,7 @@ class TypeParser(ty.Generic[T]): tp: ty.Type[T] coercible: ty.List[ty.Tuple[TypeOrAny, TypeOrAny]] not_coercible: ty.List[ty.Tuple[TypeOrAny, TypeOrAny]] - allow_lazy_super: bool + superclass_auto_cast: bool label: str COERCIBLE_DEFAULT: ty.Tuple[ty.Tuple[type, type], ...] = ( @@ -113,7 +113,7 @@ def __init__( not_coercible: ty.Optional[ ty.Iterable[ty.Tuple[TypeOrAny, TypeOrAny]] ] = NOT_COERCIBLE_DEFAULT, - allow_lazy_super: bool = False, + superclass_auto_cast: bool = False, label: str = "", ): def expand_pattern(t): @@ -142,7 +142,7 @@ def expand_pattern(t): ) self.not_coercible = list(not_coercible) if not_coercible is not None else [] self.pattern = expand_pattern(tp) - self.allow_lazy_super = allow_lazy_super + self.superclass_auto_cast = superclass_auto_cast def __call__(self, obj: ty.Any) -> ty.Union[T, LazyField[T]]: """Attempts to coerce the object to the specified type, unless the value is @@ -172,7 +172,7 @@ def __call__(self, obj: ty.Any) -> ty.Union[T, LazyField[T]]: try: self.check_type(obj.type) except TypeError as e: - if self.allow_lazy_super: + if self.superclass_auto_cast: try: # Check whether the type of the lazy field isn't a superclass of # the type to check against, and if so, allow it due to permissive @@ -492,8 +492,17 @@ def check_coercible( explicit inclusions and exclusions set in the `coercible` and `not_coercible` member attrs """ + # Short-circuit the basic cases where the source and target are the same if source is target: return + if self.superclass_auto_cast and self.is_subclass(target, type(source)): + logger.info( + "Attempting to coerce %s into %s due to super-to-sub class coercion " + "being permitted", + source, + target, + ) + return source_origin = get_origin(source) if source_origin is not None: source = source_origin