Skip to content

Commit

Permalink
Adds single step action
Browse files Browse the repository at this point in the history
This makes the function first-class, and runsa s ingle-step
function/reducer. THe tricky thing here is with inheritance -- we have a
`run` and an `update` but we don't call it as it subclasses `Action`. We
will be cleaning this up shortly (insomuch as we can) and thinking
through the abstraction soon.

This is a solution for #11
  • Loading branch information
elijahbenizzy committed Feb 14, 2024
1 parent 3bff5e6 commit 784d6dd
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 27 deletions.
63 changes: 46 additions & 17 deletions burr/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ def name(self) -> str:
across your application."""
return self._name

@property
def single_step(self) -> bool:
return False

def __repr__(self):
read_repr = ", ".join(self.reads) if self.reads else "{}"
write_repr = ", ".join(self.writes) if self.writes else "{}"
Expand Down Expand Up @@ -218,7 +222,45 @@ def writes(self) -> list[str]:
return []


class FunctionBasedAction(Action):
class SingleStepAction(Action, abc.ABC):
"""Internal representation of a "single-step" action. While most actions will have
a run and an update, this is a convenience class for actions that return them both at the same time.
Note this is not user-facing, as the internal API is meant to change. This is largely special-cased
for the function-based action, which users will not be extending.
Currently this keeps a cache of the state created, which is not ideal. This is a temporary
measure to make the API work, and will be removed in the future.
"""

def __init__(self):
super(SingleStepAction, self).__init__()
self._state_created = None

@property
def single_step(self) -> bool:
return True

@abc.abstractmethod
def run_and_update(self, state: State) -> Tuple[dict, State]:
"""Performs a run/update at the same time.
:param state:
:return:
"""
pass

def run(self, state: State) -> dict:
result, new_state = self.run_and_update(state)
self._state_created = new_state
return result

def update(self, result: dict, state: State) -> State:
if self._state_created is None:
raise ValueError("SingleStepAction.run must be called before SingleStepAction.update")
return self._state_created


class FunctionBasedAction(SingleStepAction):
ACTION_FUNCTION = "action_function"

def __init__(
Expand All @@ -239,7 +281,6 @@ def __init__(
self._fn = fn
self._reads = reads
self._writes = writes
self._state_created = None
self._bound_params = bound_params if bound_params is not None else {}

@property
Expand All @@ -250,25 +291,10 @@ def fn(self) -> Callable:
def reads(self) -> list[str]:
return self._reads

def run(self, state: State) -> dict:
result, new_state = self._fn(state, **self._bound_params)
self._state_created = new_state
return result

@property
def writes(self) -> list[str]:
return self._writes

def update(self, result: dict, state: State) -> State:
if self._state_created is None:
raise ValueError(
"FunctionBasedAction.run must be called before FunctionBasedAction.update"
)
# TODO -- validate that all the keys are contained -- fix up subset to handle this
# TODO -- validate that we've (a) written only to the write ones (by diffing the read ones),
# and (b) written to no more than the write ones
return self._state_created.subset(*self._writes)

def with_params(self, **kwargs: Any) -> "FunctionBasedAction":
"""Binds parameters to the function.
Note that there is no reason to call this by the user. This *could*
Expand All @@ -282,6 +308,9 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedAction":
new_action._bound_params = {**self._bound_params, **kwargs}
return new_action

def run_and_update(self, state: State) -> Tuple[dict, State]:
return self._fn(state, **self._bound_params)


def _validate_action_function(fn: Callable):
"""Validates that an action has the signature: (state: State) -> Tuple[dict, State]
Expand Down
49 changes: 44 additions & 5 deletions burr/core/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,15 @@
Union,
)

from burr.core.action import Action, Condition, Function, Reducer, create_action, default
from burr.core.action import (
Action,
Condition,
Function,
Reducer,
SingleStepAction,
create_action,
default,
)
from burr.core.state import State
from burr.lifecycle.base import LifecycleAdapter
from burr.lifecycle.internal import LifecycleAdapterSet
Expand Down Expand Up @@ -77,6 +85,7 @@ def _run_reducer(reducer: Reducer, state: State, result: dict, name: str) -> Sta
:param result:
:return:
"""

state_to_use = state.subset(*reducer.writes)
new_state = reducer.update(result, state_to_use).subset(*reducer.writes)
keys_in_new_state = set(new_state.keys())
Expand Down Expand Up @@ -123,6 +132,30 @@ def _format_error_message(action: Action, input_state: State) -> str:
logger.exception("\n" + border + "\n" + message + "\n" + border)


def _run_single_step_action(action: SingleStepAction, state: State) -> Tuple[dict, State]:
"""Runs a single step action. This API is internal-facing and a bit in flux, but
it corresponds to the SingleStepAction class.
:param action: Action to run
:param state: State to run with
:return: The result of running the action, and the new state
"""
state_to_use = state.subset(
*action.reads, *action.writes
) # TODO -- specify some as required and some as not
result, new_state = action.run_and_update(state_to_use)
return result, state.merge(new_state.subset(*action.writes)) # we just want the writes action


async def _arun_single_step_action(action: SingleStepAction, state: State) -> Tuple[dict, State]:
"""Runs a single step action in async. See the synchronous version for more details."""
state_to_use = state.subset(
*action.reads, *action.writes
) # TODO -- specify some as required and some as not
result, new_state = await action.run_and_update(state_to_use)
return result, state.merge(new_state.subset(*action.writes)) # we just want the writes action


class Application:
def __init__(
self,
Expand Down Expand Up @@ -162,8 +195,11 @@ def step(self) -> Optional[Tuple[Action, dict, State]]:
result = None
new_state = self._state
try:
result = _run_function(next_action, self._state)
new_state = _run_reducer(next_action, self._state, result, next_action.name)
if next_action.single_step:
result, new_state = _run_single_step_action(next_action, self._state)
else:
result = _run_function(next_action, self._state)
new_state = _run_reducer(next_action, self._state, result, next_action.name)
self._set_state(new_state)
except Exception as e:
exc = e
Expand Down Expand Up @@ -197,8 +233,11 @@ async def astep(self) -> Optional[Tuple[Action, dict, State]]:
# which this is supposed to be its OK).
# this delegatees hooks to the synchronous version, so we'll call all of them as well
return self.step()
result = await _arun_function(next_action, self._state)
new_state = _run_reducer(next_action, self._state, result, next_action.name)
if next_action.single_step:
result, new_state = await _arun_single_step_action(next_action, self._state)
else:
result = await _arun_function(next_action, self._state)
new_state = _run_reducer(next_action, self._state, result, next_action.name)
except Exception as e:
exc = e
logger.exception(_format_error_message(next_action, self._state))
Expand Down
66 changes: 61 additions & 5 deletions tests/core/test_application.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import asyncio
import logging
from typing import Awaitable, Callable
from typing import Awaitable, Callable, Tuple

import pytest

from burr.core import State
from burr.core.action import Action, Condition, Result, default
from burr.core.action import Action, Condition, Result, SingleStepAction, default
from burr.core.application import (
Application,
ApplicationBuilder,
Transition,
_arun_function,
_arun_single_step_action,
_assert_set,
_run_function,
_run_single_step_action,
_validate_actions,
_validate_start,
_validate_transitions,
Expand Down Expand Up @@ -110,30 +112,84 @@ class BrokenStepException(Exception):
)


def test_run_function():
def test__run_function():
"""Tests that we can run a function"""
action = base_counter_action
state = State({})
result = _run_function(action, state)
assert result == {"counter": 1}


def test_run_function_cant_run_async():
def test__run_function_cant_run_async():
"""Tests that we can't run an async function"""
action = base_counter_action_async
state = State({})
with pytest.raises(ValueError, match="async"):
_run_function(action, state)


async def test_a_run_function():
async def test__a_run_function():
"""Tests that we can run an async function"""
action = base_counter_action_async
state = State({})
result = await _arun_function(action, state)
assert result == {"counter": 1}


class SingleStepCounter(SingleStepAction):
def run_and_update(self, state: State) -> Tuple[dict, State]:
result = {"count": state["count"] + 1}
return result, state.update(**result).append(tracker=result["count"])

@property
def reads(self) -> list[str]:
return ["count"]

@property
def writes(self) -> list[str]:
return ["count", "tracker"]


class SingleStepCounterAsync(SingleStepCounter):
async def run_and_update(self, state: State) -> Tuple[dict, State]:
await asyncio.sleep(0.0001) # just so we can make this *truly* async
return super(SingleStepCounterAsync, self).run_and_update(state)

@property
def reads(self) -> list[str]:
return ["count"]

@property
def writes(self) -> list[str]:
return ["count", "tracker"]


base_single_step_counter = SingleStepCounter()
base_single_step_counter_async = SingleStepCounterAsync()


def test__run_single_step_action():
action = base_single_step_counter.with_name("counter")
state = State({"count": 0, "tracker": []})
result, state = _run_single_step_action(action, state)
assert result == {"count": 1}
assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]}
result, state = _run_single_step_action(action, state)
assert result == {"count": 2}
assert state.subset("count", "tracker").get_all() == {"count": 2, "tracker": [1, 2]}


async def test__arun_single_step_action():
action = base_single_step_counter_async.with_name("counter")
state = State({"count": 0, "tracker": []})
result, state = await _arun_single_step_action(action, state)
assert result == {"count": 1}
assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]}
result, state = await _arun_single_step_action(action, state)
assert result == {"count": 2}
assert state.subset("count", "tracker").get_all() == {"count": 2, "tracker": [1, 2]}


def test_app_step():
"""Tests that we can run a step in an app"""
counter_action = base_counter_action.with_name("counter")
Expand Down

0 comments on commit 784d6dd

Please sign in to comment.