Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CapturedCallable to work better with kwargs #121

Merged
merged 6 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
<!--
A new scriv changelog fragment.

Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Removed

- A bullet item for the Removed category.

-->
<!--
### Added

- A bullet item for the Added category.

-->
<!--
### Changed

- A bullet item for the Changed category.

-->
<!--
### Deprecated

- A bullet item for the Deprecated category.

-->

### Fixed

- `CapturedCallable` now handles variadic keywords arguments (`**kwargs`) correctly ([#121](https://github.com/mckinsey/vizro/pull/121))

<!--
### Security

- A bullet item for the Security category.

-->
70 changes: 51 additions & 19 deletions vizro-core/src/vizro/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import functools
import inspect
from copy import deepcopy
from typing import Any, Dict, List, Literal, Protocol, Union, runtime_checkable

from pydantic import Field, StrictBool
Expand Down Expand Up @@ -40,37 +39,72 @@ def __init__(self, function, /, *args, **kwargs):
"""Creates a new CapturedCallable object that will be able to re-run `function`.

Partially binds *args and **kwargs to the function call.

Raises:
ValueError if `function` contains positional-only or variadic positional parameters (*args).
"""
# It is difficult to get positional-only and variadic positional arguments working at the same time as
# variadic keyword arguments. Ideally we would do the __call__ as
# self.__function(*bound_arguments.args, **bound_arguments.kwargs) as in the
# Python documentation. This would handle positional-only and variadic positional arguments better but makes
# it more difficult to handle variadic keyword arguments due to https://bugs.python.org/issue41745.
# Hence we abandon bound_arguments.args and bound_arguments.kwargs in favor of just using
# self.__function(**bound_arguments.arguments).
parameters = inspect.signature(function).parameters
invalid_params = {
param.name
for param in parameters.values()
if param.kind in [inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.VAR_POSITIONAL]
}

if invalid_params:
raise ValueError(
f"Invalid parameter {', '.join(invalid_params)}. CapturedCallable does not accept functions with "
maxschulz-COL marked this conversation as resolved.
Show resolved Hide resolved
f"positional-only or variadic positional parameters (*args)."
)

self.__function = function
self.__bound_arguments = inspect.signature(function).bind_partial(*args, **kwargs)
self.__arguments = inspect.signature(function).bind_partial(*args, **kwargs).arguments

# A function can only ever have one variadic keyword parameter. {""} is just here so that var_keyword_param
# is always unpacking a one element set.
(var_keyword_param,) = {
maxschulz-COL marked this conversation as resolved.
Show resolved Hide resolved
param.name for param in parameters.values() if param.kind == inspect.Parameter.VAR_KEYWORD
} or {""}

# Since we do __call__ as self.__function(**bound_arguments.arguments), we need to restructure the arguments
# a bit to put the kwargs in the right place.
# For a function with parameter **kwargs this converts self.__arguments = {"kwargs": {"a": 1}} into
# self.__arguments = {"a": 1}.
if var_keyword_param in self.__arguments:
self.__arguments.update(self.__arguments[var_keyword_param])
del self.__arguments[var_keyword_param]

def __call__(self, **kwargs):
"""Run the `function` with the initial arguments overridden by **kwargs.

Note *args are not possible here, but you can still override positional arguments using argument name.
"""
if not kwargs:
return self.__function(*self.__bound_arguments.args, **self.__bound_arguments.kwargs)

bound_arguments = deepcopy(self.__bound_arguments)
bound_arguments.arguments.update(kwargs)
# This looks like it should be self.__function(*bound_arguments.args, **bound_arguments.kwargs) as in the
# Python documentation, but that leads to problems due to https://bugs.python.org/issue41745.
return self.__function(**bound_arguments.arguments)
return self.__function(**{**self.__arguments, **kwargs})
antonymilne marked this conversation as resolved.
Show resolved Hide resolved

def __getitem__(self, arg_name: str):
"""Gets the value of a bound argument."""
return self.__bound_arguments.arguments[arg_name]
return self.__arguments[arg_name]

def __delitem__(self, arg_name: str):
"""Deletes a bound argument."""
del self.__bound_arguments.arguments[arg_name]
del self.__arguments[arg_name]

@property
def _arguments(self):
# TODO: This is used twice: in _get_parametrized_config and in vm.Action and should be removed when those
# references are removed.
return self.__bound_arguments.arguments
return self.__arguments

# TODO-actions: Find a way how to compare CapturedCallable and function
@property
def _function(self):
return self.__function

@classmethod
def __get_validators__(cls):
Expand Down Expand Up @@ -137,11 +171,6 @@ def _parse_json(
else:
raise ValueError(f"_target_={function_name} must be wrapped in the @capture decorator.")

# TODO-actions: Find a way how to compare CapturedCallable and function
@property
def _function(self):
return self.__function


class capture:
"""Captures a function call to create a [`CapturedCallable`][vizro.models.types.CapturedCallable].
Expand Down Expand Up @@ -175,6 +204,8 @@ def __call__(self, func, /):
# The more difficult case, where we need to still have a valid plotly figure that renders in a notebook.
# Hence we attach the CapturedCallable as a property instead of returning it directly.
# TODO: move point of checking that data_frame argument exists earlier on.
# TODO: also would be nice to raise errors in CapturedCallable.__init__ at point of function definition
# rather than point of calling if possible.
@functools.wraps(func)
def wrapped(*args, **kwargs) -> _DashboardReadyFigure:
if "data_frame" not in inspect.signature(func).parameters:
Expand Down Expand Up @@ -278,7 +309,8 @@ class OptionsDictType(TypedDict):
NavigationPagesType = Annotated[
Union[List[str], Dict[str, List[str]]],
Field(
None, description="List of Page IDs or dict mapping of Page IDs and titles (for hierarchical sub-navigation)"
None,
description="List of Page IDs or dict mapping of Page IDs and titles (for hierarchical sub-navigation)",
),
]
"""Permissible value types for page attribute. Values are displayed as default."""
158 changes: 78 additions & 80 deletions vizro-core/tests/unit/vizro/models/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,113 +8,111 @@
from vizro.models.types import CapturedCallable, capture


@pytest.fixture
antonymilne marked this conversation as resolved.
Show resolved Hide resolved
def varargs_function():
def function(*args, b=2):
return args[0] + b

return CapturedCallable(function, 1)

def positional_only_function(a, /):
pass

@pytest.mark.xfail
# Known bug: *args doesn't work properly. Fix while keeping the more important test_varkwargs
# passing due to https://bugs.python.org/issue41745.
# Error raised is IndexError: tuple index out of range
def test_varargs(varargs_function):
assert varargs_function(b=2) == 1 + 2

def var_positional_function(*args):
pass

@pytest.fixture
def positional_only_function():
def function(a, /, b):
return a + b

return CapturedCallable(function, 1)

@pytest.mark.parametrize("function", [positional_only_function, var_positional_function])
def test_invalid_parameter_kind(function):
with pytest.raises(
ValueError,
match="CapturedCallable does not accept functions with positional-only or variadic positional parameters",
):
CapturedCallable(function)

@pytest.mark.xfail
# Known bug: position-only argument doesn't work properly. Fix while keeping the more important
# test_varkwargs passing due to https://bugs.python.org/issue41745.
# Error raised is TypeError: function got some positional-only arguments passed as keyword arguments: 'a'
def test_positional_only(positional_only_function):
assert positional_only_function(b=2) == 1 + 2

def positional_or_keyword_function(a, b, c):
return a + b + c

@pytest.fixture
def keyword_only_function():
def function(a, *, b):
return a + b

return CapturedCallable(function, 1)
def keyword_only_function(a, *, b, c):
maxschulz-COL marked this conversation as resolved.
Show resolved Hide resolved
return a + b + c


def test_keyword_only(keyword_only_function):
antonymilne marked this conversation as resolved.
Show resolved Hide resolved
assert keyword_only_function(b=2) == 1 + 2
def var_keyword_function(a, **kwargs):
petar-qb marked this conversation as resolved.
Show resolved Hide resolved
return a + kwargs["b"] + kwargs["c"]


@pytest.fixture
def varkwargs_function():
def function(a, b=2, **kwargs):
return a + b + kwargs["c"]

return CapturedCallable(function, 1)


def test_varkwargs(varkwargs_function):
antonymilne marked this conversation as resolved.
Show resolved Hide resolved
varkwargs_function(c=3, d=4) == 1 + 2 + 3


@pytest.fixture
def simple_function():
antonymilne marked this conversation as resolved.
Show resolved Hide resolved
def function(a, b, c, d=4):
return a + b + c + d

return CapturedCallable(function, 1, b=2)
def captured_callable(request):
return CapturedCallable(request.param, 1, b=2)


@pytest.mark.parametrize(
"captured_callable",
[positional_or_keyword_function, keyword_only_function, var_keyword_function],
indirect=True,
)
class TestCall:
def test_call_missing_argument(self, simple_function):
with pytest.raises(TypeError, match="missing 1 required positional argument"):
simple_function()

def test_call_needs_keyword_arguments(self, simple_function):
def test_call_needs_keyword_arguments(self, captured_callable):
with pytest.raises(TypeError, match="takes 1 positional argument but 2 were given"):
simple_function(2)

def test_call_provide_required_argument(self, simple_function):
assert simple_function(c=3) == 1 + 2 + 3 + 4
captured_callable(2)

def test_call_override_existing_arguments(self, simple_function):
assert simple_function(a=5, b=2, c=6) == 5 + 2 + 6 + 4
def test_call_provide_required_argument(self, captured_callable):
assert captured_callable(c=3) == 1 + 2 + 3

def test_call_is_memoryless(self, simple_function):
simple_function(c=3)

with pytest.raises(TypeError, match="missing 1 required positional argument"):
simple_function()

def test_call_unknown_argument(self, simple_function):
with pytest.raises(TypeError, match="got an unexpected keyword argument"):
simple_function(e=1)
def test_call_override_existing_arguments(self, captured_callable):
assert captured_callable(a=5, b=2, c=6) == 5 + 2 + 6


@pytest.mark.parametrize(
"captured_callable",
[positional_or_keyword_function, keyword_only_function, var_keyword_function],
indirect=True,
)
class TestDunderMethods:
def test_getitem_known_args(self, simple_function):
assert simple_function["a"] == 1
assert simple_function["b"] == 2

def test_getitem_unknown_args(self, simple_function):
with pytest.raises(KeyError):
simple_function["c"]
def test_getitem_known_args(self, captured_callable):
assert captured_callable["a"] == 1
assert captured_callable["b"] == 2

def test_getitem_unknown_args(self, captured_callable):
with pytest.raises(KeyError):
simple_function["d"]
captured_callable["c"]

def test_delitem(self, simple_function):
del simple_function["a"]
def test_delitem(self, captured_callable):
del captured_callable["a"]

with pytest.raises(KeyError):
simple_function["a"]
captured_callable["a"]


@pytest.mark.parametrize(
"captured_callable, expectation",
[
(positional_or_keyword_function, pytest.raises(TypeError, match="missing 1 required positional argument: 'c'")),
antonymilne marked this conversation as resolved.
Show resolved Hide resolved
(keyword_only_function, pytest.raises(TypeError, match="missing 1 required keyword-only argument: 'c'")),
(var_keyword_function, pytest.raises(KeyError, match="'c'")),
],
indirect=["captured_callable"],
)
class TestCallMissingArgument:
def test_call_missing_argument(self, captured_callable, expectation):
with expectation:
captured_callable()

def test_call_is_memoryless(self, captured_callable, expectation):
captured_callable(c=3)

with expectation:
captured_callable()


@pytest.mark.parametrize(
"captured_callable, expectation",
[
(positional_or_keyword_function, pytest.raises(TypeError, match="got an unexpected keyword argument")),
(keyword_only_function, pytest.raises(TypeError, match="got an unexpected keyword argument")),
(var_keyword_function, pytest.raises(KeyError, match="'c'")),
],
indirect=["captured_callable"],
)
def test_call_unknown_argument(captured_callable, expectation):
with expectation:
captured_callable(e=1)


def undecorated_function(a, b, c, d=4):
Expand Down
Loading