Skip to content

Commit

Permalink
Add a method to specify ExpressionTuple evaluation function
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf authored and brandonwillard committed Sep 2, 2022
1 parent b5acdc8 commit e4e6fe0
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
18 changes: 13 additions & 5 deletions etuples/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import warnings
from collections import deque
from collections.abc import Generator, Sequence
from typing import Any, Callable
from typing import Callable

etuple_repr = reprlib.Repr()
etuple_repr.maxstring = 100
Expand Down Expand Up @@ -178,8 +178,16 @@ def eval_obj(self):
)
return trampoline_eval(self._eval_step())

def _eval_apply(self, op: Callable, op_args: inspect.BoundArguments) -> Any:
return op(*op_args.args, **op_args.kwargs)
def _eval_apply_fn(self, op: Callable) -> Callable:
"""Return the callable used to evaluate the expression tuple.
The expression tuple's operator can be any `Callable`, i.e. either
a function or an instance of a class that defines `__call__`. In
the latter case, one can evalute the expression tuple using a
method other than `__call__` by overloading this method.
"""
return op

def _eval_step(self):
if len(self._tuple) == 0:
Expand Down Expand Up @@ -210,7 +218,7 @@ def _eval_step(self):
evaled_args.append(i)

try:
op_sig = inspect.signature(op)
op_sig = inspect.signature(self._eval_apply_fn(op))
except ValueError:
# This handles some builtin function types
_evaled_obj = op(*(evaled_args + [kw.value for kw in evaled_kwargs]))
Expand All @@ -220,7 +228,7 @@ def _eval_step(self):
)
op_args.apply_defaults()

_evaled_obj = self._eval_apply(op, op_args)
_evaled_obj = self._eval_apply_fn(op)(*op_args.args, **op_args.kwargs)

if isinstance(_evaled_obj, Generator):
self._evaled_obj = _evaled_obj
Expand Down
16 changes: 16 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,22 @@ def test_ExpressionTuple(capsys):
ExpressionTuple((print, "hi")).eval_obj


def test_eval_apply_fn():
class Add(object):
def __call__(self):
return None

def add(self, x, y):
return x + y

class AddExpressionTuple(ExpressionTuple):
def _eval_apply_fn(self, op):
return op.add

op = Add()
assert AddExpressionTuple((op, 1, 2)).evaled_obj == 3


def test_etuple():
"""Test basic `etuple` functionality."""

Expand Down

0 comments on commit e4e6fe0

Please sign in to comment.