diff --git a/hamilton/function_modifiers/adapters.py b/hamilton/function_modifiers/adapters.py index ea6aa0d10..11133822d 100644 --- a/hamilton/function_modifiers/adapters.py +++ b/hamilton/function_modifiers/adapters.py @@ -17,6 +17,7 @@ ParametrizedDependency, UpstreamDependency, ) +from hamilton.htypes import custom_subclass_check from hamilton.io.data_adapters import AdapterCommon, DataLoader, DataSaver from hamilton.node import DependencyType from hamilton.registry import LOADER_REGISTRY, SAVER_REGISTRY @@ -748,10 +749,17 @@ def validate(self, fn: Callable): ) # check that the second is a dict second_arg = typing_inspect.get_args(return_annotation)[1] - if not (second_arg == dict or second_arg == Dict): + if not (custom_subclass_check(second_arg, dict)): raise InvalidDecoratorException( f"Function: {fn.__qualname__} must return a tuple of type (SOME_TYPE, dict)." ) + second_arg_params = typing_inspect.get_args(second_arg) + if ( + len(second_arg_params) > 0 and not second_arg_params[0] == str + ): # metadata must have string keys + raise InvalidDecoratorException( + f"Function: {fn.__qualname__} must return a tuple of type (SOME_TYPE, dict[str, ...]). Instead got (SOME_TYPE, dict[{second_arg_params[0]}, ...]" + ) def generate_nodes(self, fn: Callable, config) -> List[node.Node]: """Generates two nodes. We have to add tags appropriately. diff --git a/tests/function_modifiers/test_adapters.py b/tests/function_modifiers/test_adapters.py index 06307dd96..06ba57b6b 100644 --- a/tests/function_modifiers/test_adapters.py +++ b/tests/function_modifiers/test_adapters.py @@ -697,19 +697,23 @@ def fn(data1: dict, data2: dict) -> dict: import sys if sys.version_info >= (3, 9): - dl_type = tuple[int, dict] - ds_type = dict + dict_ = dict + tuple_ = tuple else: - dl_type = Tuple[int, Dict] - ds_type = Dict + dict_ = Dict + tuple_ = Tuple # Mock functions for dataloader & datasaver testing -def correct_dl_function(foo: int) -> dl_type: +def correct_dl_function(foo: int) -> tuple_[int, dict_]: return 1, {} -def correct_ds_function(data: float) -> ds_type: +def correct_dl_function_with_subscripts(foo: int) -> tuple_[Dict[str, int], Dict[str, str]]: + return {"a": 1}, {"b": "c"} + + +def correct_ds_function(data: float) -> dict_: return {} @@ -721,19 +725,24 @@ def non_tuple_return_function() -> int: return 1 -def incorrect_tuple_length_function() -> Tuple[int]: +def incorrect_tuple_length_function() -> tuple_[int]: return (1,) -def incorrect_second_element_function() -> Tuple[int, list]: +def incorrect_second_element_function() -> tuple_[int, list]: return 1, [] +def incorrect_dict_subscript() -> tuple_[int, Dict[int, str]]: + return 1, {1: "a"} + + incorrect_funcs = [ no_return_annotation_function, non_tuple_return_function, incorrect_tuple_length_function, incorrect_second_element_function, + incorrect_dict_subscript, ] @@ -744,6 +753,10 @@ def test_dl_validate_incorrect_functions(func): dl.validate(func) +@pytest.mark.skipif( + sys.version_info < (3, 9, 0), + reason="dataloader not guarenteed to work with subscripted tuples on 3.8", +) def test_dl_validate_with_correct_function(): dl = dataloader() try: @@ -753,6 +766,15 @@ def test_dl_validate_with_correct_function(): pytest.fail("validate() raised InvalidDecoratorException unexpectedly!") +def test_dl_validate_with_subscripts(): + dl = dataloader() + try: + dl.validate(correct_dl_function_with_subscripts) + except InvalidDecoratorException: + # i.e. fail the test if there's an error + pytest.fail("validate() raised InvalidDecoratorException unexpectedly!") + + def test_ds_validate_with_correct_function(): dl = datasaver() try: diff --git a/tests/resources/nodes_with_future_annotation.py b/tests/resources/nodes_with_future_annotation.py index 2a937ece9..b8566d362 100644 --- a/tests/resources/nodes_with_future_annotation.py +++ b/tests/resources/nodes_with_future_annotation.py @@ -1,10 +1,16 @@ from __future__ import annotations +import sys +from typing import List, Tuple + from hamilton.function_modifiers import dataloader from hamilton.htypes import Collect, Parallelizable """Tests future annotations with common node types""" +tuple_ = Tuple if sys.version_info < (3, 9, 0) else tuple +list_ = List if sys.version_info < (3, 9, 0) else list + def parallelized() -> Parallelizable[int]: yield 1 @@ -21,6 +27,6 @@ def collected(standard: Collect[int]) -> int: @dataloader() -def sample_dataloader() -> tuple[list[str], dict]: +def sample_dataloader() -> tuple_[list_[str], dict]: """Grouping here as the rest test annotations""" return ["a", "b", "c"], {}