diff --git a/etuples/core.py b/etuples/core.py index 926195c..05b1880 100644 --- a/etuples/core.py +++ b/etuples/core.py @@ -3,7 +3,7 @@ import warnings from collections import deque from collections.abc import Generator, Sequence -from typing import Any, Callable +from typing import Callable etuple_repr = reprlib.Repr() etuple_repr.maxstring = 100 @@ -178,8 +178,16 @@ def eval_obj(self): ) return trampoline_eval(self._eval_step()) - def _eval_apply(self, op: Callable, op_args: inspect.BoundArguments) -> Any: - return op(*op_args.args, **op_args.kwargs) + def _eval_apply_fn(self, op: Callable) -> Callable: + """Return the callable used to evaluate the expression tuple. + + The expression tuple's operator can be any `Callable`, i.e. either + a function or an instance of a class that defines `__call__`. In + the latter case, one can evalute the expression tuple using a + method other than `__call__` by overloading this method. + + """ + return op def _eval_step(self): if len(self._tuple) == 0: @@ -210,7 +218,7 @@ def _eval_step(self): evaled_args.append(i) try: - op_sig = inspect.signature(op) + op_sig = inspect.signature(self._eval_apply_fn(op)) except ValueError: # This handles some builtin function types _evaled_obj = op(*(evaled_args + [kw.value for kw in evaled_kwargs])) @@ -220,7 +228,7 @@ def _eval_step(self): ) op_args.apply_defaults() - _evaled_obj = self._eval_apply(op, op_args) + _evaled_obj = self._eval_apply_fn(op)(*op_args.args, **op_args.kwargs) if isinstance(_evaled_obj, Generator): self._evaled_obj = _evaled_obj diff --git a/tests/test_core.py b/tests/test_core.py index 959a569..7e2e23b 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -68,6 +68,22 @@ def test_ExpressionTuple(capsys): ExpressionTuple((print, "hi")).eval_obj +def test_eval_apply_fn(): + class Add(object): + def __call__(self): + return None + + def add(self, x, y): + return x + y + + class AddExpressionTuple(ExpressionTuple): + def _eval_apply_fn(self, op): + return op.add + + op = Add() + assert AddExpressionTuple((op, 1, 2)).evaled_obj == 3 + + def test_etuple(): """Test basic `etuple` functionality."""