Skip to content

Commit

Permalink
Fixes dataloader to use the correct type hinting
Browse files Browse the repository at this point in the history
This makes it work with __future__.annotations. See #1259.
  • Loading branch information
elijahbenizzy committed Dec 17, 2024
1 parent 9f6ea84 commit 2c8c3d7
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 1 deletion.
2 changes: 1 addition & 1 deletion hamilton/function_modifiers/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ def load_json_data(json_path: str = "data/my_data.json") -> tuple[pd.DataFrame,

def validate(self, fn: Callable):
"""Validates that the output type is correctly annotated."""
return_annotation = inspect.signature(fn).return_annotation
return_annotation = typing.get_type_hints(fn).get("return")
if return_annotation is inspect.Signature.empty:
raise InvalidDecoratorException(
f"Function: {fn.__qualname__} must have a return annotation."
Expand Down
13 changes: 13 additions & 0 deletions tests/function_modifiers/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
resolve_kwargs,
)
from hamilton.function_modifiers.base import DefaultNodeCreator
from hamilton.htypes import custom_subclass_check
from hamilton.io.data_adapters import DataLoader, DataSaver
from hamilton.io.default_data_loaders import JSONDataSaver
from hamilton.registry import LOADER_REGISTRY
Expand Down Expand Up @@ -792,6 +793,18 @@ def test_dataloader():
}


def test_dataloader_future_annotations():
from tests.resources import nodes_with_future_annotation

fn_to_collect = nodes_with_future_annotation.sample_dataloader
fg = graph.create_function_graph(
ad_hoc_utils.create_temporary_module(fn_to_collect),
config={},
)
# the data loaded is a list
assert custom_subclass_check(fg["sample_dataloader"].type, list)


def test_datasaver():
annotation = datasaver()
(node1,) = annotation.generate_nodes(correct_ds_function, {})
Expand Down
7 changes: 7 additions & 0 deletions tests/resources/nodes_with_future_annotation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from hamilton.function_modifiers import dataloader
from hamilton.htypes import Collect, Parallelizable

"""Tests future annotations with common node types"""
Expand All @@ -17,3 +18,9 @@ def standard(parallelized: int) -> int:

def collected(standard: Collect[int]) -> int:
return sum(standard)


@dataloader()
def sample_dataloader() -> tuple[list[str], dict]:
"""Grouping here as the rest test annotations"""
return ["a", "b", "c"], {}

0 comments on commit 2c8c3d7

Please sign in to comment.