Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix inconsistent function inspection for @decorated functions #2246

Merged
merged 5 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions package/kedro_viz/models/flowchart/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,20 @@ def _parse_filepath(dataset_description: Dict[str, Any]) -> Optional[str]:


def _extract_wrapped_func(func: FunctionType) -> FunctionType:
"""Extract a wrapped decorated function to inspect the source code if available.
Adapted from https://stackoverflow.com/a/43506509/1684058
"""
if func.__closure__ is None:
return func
closure = (c.cell_contents for c in func.__closure__)
wrapped_func = next((c for c in closure if isinstance(c, FunctionType)), None)
# return the original function if it's not a decorated function
return func if wrapped_func is None else wrapped_func
"""Extract a wrapped decorated function to inspect the source code if available."""
# Check if the function has a `__wrapped__` attribute (set by functools.wraps)
if hasattr(func, "__wrapped__"):
return func.__wrapped__

# Inspect the closure for the original function if still wrapped
if func.__closure__:
closure = (c.cell_contents for c in func.__closure__)
wrapped_func = next((c for c in closure if isinstance(c, FunctionType)), None)
if wrapped_func:
return wrapped_func

# Return the original function if no wrapping detected
return func


# =============================================================================
Expand Down
44 changes: 43 additions & 1 deletion package/tests/test_models/test_flowchart/test_node_metadata.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from functools import partial
from functools import partial, wraps
from pathlib import Path
from textwrap import dedent

Expand Down Expand Up @@ -35,11 +35,28 @@ def _new_fun(*args, **kwargs):
return _new_fun


def wrapped_decorator(fun):
"""
Decorator that wraps a function.
"""

@wraps(fun)
def _new_fun(*args, **kwargs):
return fun(*args, **kwargs)

return _new_fun


@decorator
def decorated(x):
return x


@wrapped_decorator
def wrapped_decorated(x):
return x


# A normal function
def full_func(a, b, c, x):
return 1000 * a + 100 * b + 10 * c + x
Expand Down Expand Up @@ -158,6 +175,31 @@ def decorated(x):
)
assert not task_node_metadata.parameters

def test_task_node_metadata_with_wrapped_decorated_func(self):
kedro_node = node(
wrapped_decorated,
inputs="x",
outputs="y",
name="identity_node",
tags={"tag"},
namespace="namespace",
)
task_node = GraphNode.create_task_node(
kedro_node, "identity_node", set(["namespace"])
)
task_node_metadata = TaskNodeMetadata(task_node=task_node)
assert task_node_metadata.code == dedent(
"""\
@wrapped_decorator
def wrapped_decorated(x):
return x
"""
)
assert task_node_metadata.filepath == str(
Path(__file__).relative_to(Path.cwd().parent).expanduser()
)
assert not task_node_metadata.parameters

def test_task_node_metadata_with_partial_func(self):
kedro_node = node(
partial_func,
Expand Down
Loading