Skip to content

Commit

Permalink
Fixes issue in which dataloader did not accept subscripted generics a…
Browse files Browse the repository at this point in the history
…s the output type

Note we only allow dict[str, ...], nothing else. Also updates tests to
work with 3.8 -- we actually remove a few as 3.8 is past EOL now.
  • Loading branch information
elijahbenizzy committed Dec 17, 2024
1 parent 2c8c3d7 commit e79a66b
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 10 deletions.
10 changes: 9 additions & 1 deletion hamilton/function_modifiers/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ParametrizedDependency,
UpstreamDependency,
)
from hamilton.htypes import custom_subclass_check
from hamilton.io.data_adapters import AdapterCommon, DataLoader, DataSaver
from hamilton.node import DependencyType
from hamilton.registry import LOADER_REGISTRY, SAVER_REGISTRY
Expand Down Expand Up @@ -748,10 +749,17 @@ def validate(self, fn: Callable):
)
# check that the second is a dict
second_arg = typing_inspect.get_args(return_annotation)[1]
if not (second_arg == dict or second_arg == Dict):
if not (custom_subclass_check(second_arg, dict)):
raise InvalidDecoratorException(
f"Function: {fn.__qualname__} must return a tuple of type (SOME_TYPE, dict)."
)
second_arg_params = typing_inspect.get_args(second_arg)
if (
len(second_arg_params) > 0 and not second_arg_params[0] == str
): # metadata must have string keys
raise InvalidDecoratorException(
f"Function: {fn.__qualname__} must return a tuple of type (SOME_TYPE, dict[str, ...]). Instead got (SOME_TYPE, dict[{second_arg_params[0]}, ...]"
)

def generate_nodes(self, fn: Callable, config) -> List[node.Node]:
"""Generates two nodes. We have to add tags appropriately.
Expand Down
38 changes: 30 additions & 8 deletions tests/function_modifiers/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,19 +697,23 @@ def fn(data1: dict, data2: dict) -> dict:
import sys

if sys.version_info >= (3, 9):
dl_type = tuple[int, dict]
ds_type = dict
dict_ = dict
tuple_ = tuple
else:
dl_type = Tuple[int, Dict]
ds_type = Dict
dict_ = Dict
tuple_ = Tuple


# Mock functions for dataloader & datasaver testing
def correct_dl_function(foo: int) -> dl_type:
def correct_dl_function(foo: int) -> tuple_[int, dict_]:
return 1, {}


def correct_ds_function(data: float) -> ds_type:
def correct_dl_function_with_subscripts(foo: int) -> tuple_[Dict[str, int], Dict[str, str]]:
return {"a": 1}, {"b": "c"}


def correct_ds_function(data: float) -> dict_:
return {}


Expand All @@ -721,19 +725,24 @@ def non_tuple_return_function() -> int:
return 1


def incorrect_tuple_length_function() -> Tuple[int]:
def incorrect_tuple_length_function() -> tuple_[int]:
return (1,)


def incorrect_second_element_function() -> Tuple[int, list]:
def incorrect_second_element_function() -> tuple_[int, list]:
return 1, []


def incorrect_dict_subscript() -> tuple_[int, Dict[int, str]]:
return 1, {1: "a"}


incorrect_funcs = [
no_return_annotation_function,
non_tuple_return_function,
incorrect_tuple_length_function,
incorrect_second_element_function,
incorrect_dict_subscript,
]


Expand All @@ -744,6 +753,10 @@ def test_dl_validate_incorrect_functions(func):
dl.validate(func)


@pytest.mark.skipif(
sys.version_info < (3, 9, 0),
reason="dataloader not guarenteed to work with subscripted tuples on 3.8",
)
def test_dl_validate_with_correct_function():
dl = dataloader()
try:
Expand All @@ -753,6 +766,15 @@ def test_dl_validate_with_correct_function():
pytest.fail("validate() raised InvalidDecoratorException unexpectedly!")


def test_dl_validate_with_subscripts():
dl = dataloader()
try:
dl.validate(correct_dl_function_with_subscripts)
except InvalidDecoratorException:
# i.e. fail the test if there's an error
pytest.fail("validate() raised InvalidDecoratorException unexpectedly!")


def test_ds_validate_with_correct_function():
dl = datasaver()
try:
Expand Down
8 changes: 7 additions & 1 deletion tests/resources/nodes_with_future_annotation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from __future__ import annotations

import sys
from typing import List, Tuple

from hamilton.function_modifiers import dataloader
from hamilton.htypes import Collect, Parallelizable

"""Tests future annotations with common node types"""

tuple_ = Tuple if sys.version_info < (3, 9, 0) else tuple
list_ = List if sys.version_info < (3, 9, 0) else list


def parallelized() -> Parallelizable[int]:
yield 1
Expand All @@ -21,6 +27,6 @@ def collected(standard: Collect[int]) -> int:


@dataloader()
def sample_dataloader() -> tuple[list[str], dict]:
def sample_dataloader() -> tuple_[list_[str], dict]:
"""Grouping here as the rest test annotations"""
return ["a", "b", "c"], {}

0 comments on commit e79a66b

Please sign in to comment.