Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds helpful error message for action failure #13

Merged
merged 1 commit into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions burr/core/application.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
import dataclasses
import logging
import pprint
from typing import (
Any,
AsyncGenerator,
Expand Down Expand Up @@ -88,6 +89,40 @@ def _run_reducer(reducer: Reducer, state: State, result: dict, name: str) -> Sta
return state.merge(new_state.update(**{PRIOR_STEP: name}))


def _create_dict_string(kwargs: dict) -> str:
"""This is a utility function to create a string representation of a dict.
This is the state that was passed into the function usually. This is useful for debugging,
as it can be printed out to see what the state was.

:param kwargs: The inputs to the function that errored.
:return: The string representation of the inputs, truncated appropriately.
"""
pp = pprint.PrettyPrinter(width=80)
inputs = {}
for k, v in kwargs.items():
item_repr = repr(v)
if len(item_repr) > 50:
item_repr = item_repr[:50] + "..."
else:
item_repr = v
inputs[k] = item_repr
input_string = pp.pformat(inputs)
if len(input_string) > 1000:
input_string = input_string[:1000] + "..."
return input_string


def _format_error_message(action: Action, input_state: State) -> str:
"""Formats the error string, given that we're inside an action"""
message = f"> Action: {action.name} encountered an error!"
padding = " " * (80 - len(message) - 1)
message += padding + "<"
input_string = _create_dict_string(input_state.get_all())
message += "\n> State (at time of action):\n" + input_string
border = "*" * 80
logger.exception("\n" + border + "\n" + message + "\n" + border)


class Application:
def __init__(
self,
Expand Down Expand Up @@ -132,6 +167,7 @@ def step(self) -> Optional[Tuple[Action, dict, State]]:
self._set_state(new_state)
except Exception as e:
exc = e
logger.exception(_format_error_message(next_action, self._state))
raise e
finally:
self._adapter_set.call_all_lifecycle_hooks_sync(
Expand Down Expand Up @@ -165,6 +201,7 @@ async def astep(self) -> Optional[Tuple[Action, dict, State]]:
new_state = _run_reducer(next_action, self._state, result, next_action.name)
except Exception as e:
exc = e
logger.exception(_format_error_message(next_action, self._state))
raise e
finally:
await self._adapter_set.call_all_lifecycle_hooks_sync_and_async(
Expand Down
50 changes: 50 additions & 0 deletions tests/core/test_application.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import logging
from typing import Awaitable, Callable

import pytest
Expand Down Expand Up @@ -90,6 +91,25 @@ async def _counter_update_async(state: State) -> dict:
)


class BrokenStepException(Exception):
pass


base_broken_action = PassedInAction(
reads=[],
writes=[],
fn=lambda x: exec("raise(BrokenStepException(x))"),
update_fn=lambda result, state: state,
)

base_broken_action_async = PassedInActionAsync(
reads=[],
writes=[],
fn=lambda x: exec("raise(BrokenStepException(x))"),
update_fn=lambda result, state: state,
)


def test_run_function():
"""Tests that we can run a function"""
action = base_counter_action
Expand Down Expand Up @@ -128,6 +148,21 @@ def test_app_step():
assert result == {"counter": 1}


def test_app_step_broken(caplog):
"""Tests that we can run a step in an app"""
broken_action = base_broken_action.with_name("broken_action_unique_name")
app = Application(
actions=[broken_action],
transitions=[Transition(broken_action, broken_action, default)],
state=State({}),
initial_step="broken_action_unique_name",
)
with caplog.at_level(logging.ERROR): # it should say the name, that's the only contract for now
with pytest.raises(BrokenStepException):
app.step()
assert "broken_action_unique_name" in caplog.text


def test_app_step_done():
"""Tests that when we cannot run a step, we return None"""
counter_action = base_counter_action.with_name("counter")
Expand All @@ -152,6 +187,21 @@ async def test_app_astep():
assert result == {"counter": 1}


async def test_app_astep_broken(caplog):
"""Tests that we can run a step in an app"""
broken_action = base_broken_action_async.with_name("broken_action_unique_name")
app = Application(
actions=[broken_action],
transitions=[Transition(broken_action, broken_action, default)],
state=State({}),
initial_step="broken_action_unique_name",
)
with caplog.at_level(logging.ERROR): # it should say the name, that's the only contract for now
with pytest.raises(BrokenStepException):
await app.astep()
assert "broken_action_unique_name" in caplog.text


async def test_app_astep_done():
"""Tests that when we cannot run a step, we return None"""
counter_action = base_counter_action_async.with_name("counter_async")
Expand Down
Loading