Skip to content

Commit

Permalink
Fix CapturedCallable to work better with kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
antonymilne committed Oct 23, 2023
1 parent f391d06 commit 4a8c2ce
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 99 deletions.
65 changes: 46 additions & 19 deletions vizro-core/src/vizro/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import functools
import inspect
from copy import deepcopy
from typing import Any, Dict, List, Literal, Protocol, Union, runtime_checkable

from pydantic import Field, StrictBool
Expand Down Expand Up @@ -41,36 +40,68 @@ def __init__(self, function, /, *args, **kwargs):
Partially binds *args and **kwargs to the function call.
"""
# It is difficult to get positional-only and variadic positional arguments working at the same time as
# variadic keyword arguments. Ideally we would do the __call__ as
# self.__function(*bound_arguments.args, **bound_arguments.kwargs) as in the
# Python documentation. This would handle positional-only and variadic positional arguments better but makes
# it more difficult to handle variadic keyword arguments due to https://bugs.python.org/issue41745.
# Hence we abandon bound_arguments.args and bound_arguments.kwargs in favor of just using
# self.__function(**bound_arguments.arguments).
parameters = inspect.signature(function).parameters
invalid_params = {
param.name
for param in parameters.values()
if param.kind in [inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.VAR_POSITIONAL]
}

if invalid_params:
raise ValueError(
f"Invalid parameter {', '.join(invalid_params)}. CapturedCallable does not accept functions with "
f"positional-only or variadic positional parameters (*args)."
)

self.__function = function
self.__bound_arguments = inspect.signature(function).bind_partial(*args, **kwargs)
self.__arguments = inspect.signature(function).bind_partial(*args, **kwargs).arguments

# A function can only ever have one variadic keyword parameter. {""} is just here so that var_keyword_param
# is always unpacking a one element set.
(var_keyword_param,) = {
param.name for param in parameters.values() if param.kind == inspect.Parameter.VAR_KEYWORD
} or {""}

# Since we do __call__ as self.__function(**bound_arguments.arguments), we need to restructure the arguments
# a bit to put the kwargs in the right place.
# For a function with parameter **kwargs this converts self.__arguments = {"kwargs": {"a": 1}} into
# self.__arguments = {"a": 1}.
if var_keyword_param in self.__arguments:
self.__arguments.update(self.__arguments[var_keyword_param])
del self.__arguments[var_keyword_param]

def __call__(self, **kwargs):
"""Run the `function` with the initial arguments overridden by **kwargs.
Note *args are not possible here, but you can still override positional arguments using argument name.
"""
if not kwargs:
return self.__function(*self.__bound_arguments.args, **self.__bound_arguments.kwargs)

bound_arguments = deepcopy(self.__bound_arguments)
bound_arguments.arguments.update(kwargs)
# This looks like it should be self.__function(*bound_arguments.args, **bound_arguments.kwargs) as in the
# Python documentation, but that leads to problems due to https://bugs.python.org/issue41745.
return self.__function(**bound_arguments.arguments)
return self.__function(**{**self.__arguments, **kwargs})

def __getitem__(self, arg_name: str):
"""Gets the value of a bound argument."""
return self.__bound_arguments.arguments[arg_name]
return self.__arguments[arg_name]

def __delitem__(self, arg_name: str):
"""Deletes a bound argument."""
del self.__bound_arguments.arguments[arg_name]
del self.__arguments[arg_name]

@property
def _arguments(self):
# TODO: This is used twice: in _get_parametrized_config and in vm.Action and should be removed when those
# references are removed.
return self.__bound_arguments.arguments
return self.__arguments

# TODO-actions: Find a way how to compare CapturedCallable and function
@property
def _function(self):
return self.__function

@classmethod
def __get_validators__(cls):
Expand Down Expand Up @@ -137,11 +168,6 @@ def _parse_json(
else:
raise ValueError(f"_target_={function_name} must be wrapped in the @capture decorator.")

# TODO-actions: Find a way how to compare CapturedCallable and function
@property
def _function(self):
return self.__function


class capture:
"""Captures a function call to create a [`CapturedCallable`][vizro.models.types.CapturedCallable].
Expand Down Expand Up @@ -278,7 +304,8 @@ class OptionsDictType(TypedDict):
NavigationPagesType = Annotated[
Union[List[str], Dict[str, List[str]]],
Field(
None, description="List of Page IDs or dict mapping of Page IDs and titles (for hierarchical sub-navigation)"
None,
description="List of Page IDs or dict mapping of Page IDs and titles (for hierarchical sub-navigation)",
),
]
"""Permissible value types for page attribute. Values are displayed as default."""
158 changes: 78 additions & 80 deletions vizro-core/tests/unit/vizro/models/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,113 +8,111 @@
from vizro.models.types import CapturedCallable, capture


@pytest.fixture
def varargs_function():
def function(*args, b=2):
return args[0] + b

return CapturedCallable(function, 1)

def positional_only_function(a, /):
pass

@pytest.mark.xfail
# Known bug: *args doesn't work properly. Fix while keeping the more important test_varkwargs
# passing due to https://bugs.python.org/issue41745.
# Error raised is IndexError: tuple index out of range
def test_varargs(varargs_function):
assert varargs_function(b=2) == 1 + 2

def var_positional_function(*args):
pass

@pytest.fixture
def positional_only_function():
def function(a, /, b):
return a + b

return CapturedCallable(function, 1)

@pytest.mark.parametrize("function", [positional_only_function, var_positional_function])
def test_invalid_parameter_kind(function):
with pytest.raises(
ValueError,
match="CapturedCallable does not accept functions with positional-only or variadic positional parameters",
):
CapturedCallable(function)

@pytest.mark.xfail
# Known bug: position-only argument doesn't work properly. Fix while keeping the more important
# test_varkwargs passing due to https://bugs.python.org/issue41745.
# Error raised is TypeError: function got some positional-only arguments passed as keyword arguments: 'a'
def test_positional_only(positional_only_function):
assert positional_only_function(b=2) == 1 + 2

def positional_or_keyword_function(a, b, c):
return a + b + c

@pytest.fixture
def keyword_only_function():
def function(a, *, b):
return a + b

return CapturedCallable(function, 1)
def keyword_only_function(a, *, b, c):
return a + b + c


def test_keyword_only(keyword_only_function):
assert keyword_only_function(b=2) == 1 + 2
def var_keyword_function(a, **kwargs):
return a + kwargs["b"] + kwargs["c"]


@pytest.fixture
def varkwargs_function():
def function(a, b=2, **kwargs):
return a + b + kwargs["c"]

return CapturedCallable(function, 1)


def test_varkwargs(varkwargs_function):
varkwargs_function(c=3, d=4) == 1 + 2 + 3


@pytest.fixture
def simple_function():
def function(a, b, c, d=4):
return a + b + c + d

return CapturedCallable(function, 1, b=2)
def captured_callable(request):
return CapturedCallable(request.param, 1, b=2)


@pytest.mark.parametrize(
"captured_callable",
[positional_or_keyword_function, keyword_only_function, var_keyword_function],
indirect=True,
)
class TestCall:
def test_call_missing_argument(self, simple_function):
with pytest.raises(TypeError, match="missing 1 required positional argument"):
simple_function()

def test_call_needs_keyword_arguments(self, simple_function):
def test_call_needs_keyword_arguments(self, captured_callable):
with pytest.raises(TypeError, match="takes 1 positional argument but 2 were given"):
simple_function(2)

def test_call_provide_required_argument(self, simple_function):
assert simple_function(c=3) == 1 + 2 + 3 + 4
captured_callable(2)

def test_call_override_existing_arguments(self, simple_function):
assert simple_function(a=5, b=2, c=6) == 5 + 2 + 6 + 4
def test_call_provide_required_argument(self, captured_callable):
assert captured_callable(c=3) == 1 + 2 + 3

def test_call_is_memoryless(self, simple_function):
simple_function(c=3)

with pytest.raises(TypeError, match="missing 1 required positional argument"):
simple_function()

def test_call_unknown_argument(self, simple_function):
with pytest.raises(TypeError, match="got an unexpected keyword argument"):
simple_function(e=1)
def test_call_override_existing_arguments(self, captured_callable):
assert captured_callable(a=5, b=2, c=6) == 5 + 2 + 6


@pytest.mark.parametrize(
"captured_callable",
[positional_or_keyword_function, keyword_only_function, var_keyword_function],
indirect=True,
)
class TestDunderMethods:
def test_getitem_known_args(self, simple_function):
assert simple_function["a"] == 1
assert simple_function["b"] == 2

def test_getitem_unknown_args(self, simple_function):
with pytest.raises(KeyError):
simple_function["c"]
def test_getitem_known_args(self, captured_callable):
assert captured_callable["a"] == 1
assert captured_callable["b"] == 2

def test_getitem_unknown_args(self, captured_callable):
with pytest.raises(KeyError):
simple_function["d"]
captured_callable["c"]

def test_delitem(self, simple_function):
del simple_function["a"]
def test_delitem(self, captured_callable):
del captured_callable["a"]

with pytest.raises(KeyError):
simple_function["a"]
captured_callable["a"]


@pytest.mark.parametrize(
"captured_callable, expectation",
[
(positional_or_keyword_function, pytest.raises(TypeError, match="missing 1 required positional argument: 'c'")),
(keyword_only_function, pytest.raises(TypeError, match="missing 1 required keyword-only argument: 'c'")),
(var_keyword_function, pytest.raises(KeyError, match="'c'")),
],
indirect=["captured_callable"],
)
class TestCallMissingArgument:
def test_call_missing_argument(self, captured_callable, expectation):
with expectation:
captured_callable()

def test_call_is_memoryless(self, captured_callable, expectation):
captured_callable(c=3)

with expectation:
captured_callable()


@pytest.mark.parametrize(
"captured_callable, expectation",
[
(positional_or_keyword_function, pytest.raises(TypeError, match="got an unexpected keyword argument")),
(keyword_only_function, pytest.raises(TypeError, match="got an unexpected keyword argument")),
(var_keyword_function, pytest.raises(KeyError, match="'c'")),
],
indirect=["captured_callable"],
)
def test_call_unknown_argument(captured_callable, expectation):
with expectation:
captured_callable(e=1)


def undecorated_function(a, b, c, d=4):
Expand Down

0 comments on commit 4a8c2ce

Please sign in to comment.