Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(feat) Add trajectory replay for headless mode #6215

Merged
merged 9 commits into from
Jan 18, 2025
5 changes: 5 additions & 0 deletions config.template.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ workspace_base = "./workspace"
# If it's a folder, the session id will be used as the file name
#save_trajectory_path="./trajectories"

# Path to replay a trajectory, must be a file path
# If provided, trajectory will be loaded and replayed before the
# agent responds to any user instruction
#replay_trajectory_path = ""

# File store path
#file_store_path = "/tmp/file_store"

Expand Down
5 changes: 5 additions & 0 deletions docs/modules/usage/configuration-options.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ The core configuration options are defined in the `[core]` section of the `confi
- Default: `"./trajectories"`
- Description: Path to store trajectories (can be a folder or a file). If it's a folder, the trajectories will be saved in a file named with the session id name and .json extension, in that folder.

- `replay_trajectory_path`
- Type: `str`
- Default: `""`
- Description: Path to load a trajectory and replay. If given, must be a path to the trajectory file in JSON format. The actions in the trajectory file would be replayed first before any user instruction is executed.

### File Store
- `file_store_path`
- Type: `str`
Expand Down
87 changes: 53 additions & 34 deletions openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)

from openhands.controller.agent import Agent
from openhands.controller.replay import ReplayManager
from openhands.controller.state.state import State, TrafficControlState
from openhands.controller.stuck import StuckDetector
from openhands.core.config import AgentConfig, LLMConfig
Expand Down Expand Up @@ -90,6 +91,7 @@ def __init__(
is_delegate: bool = False,
headless_mode: bool = True,
status_callback: Callable | None = None,
replay_events: list[Event] | None = None,
):
"""Initializes a new instance of the AgentController class.

Expand All @@ -108,6 +110,7 @@ def __init__(
is_delegate: Whether this controller is a delegate.
headless_mode: Whether the agent is run in headless mode.
status_callback: Optional callback function to handle status updates.
replay_events: A list of logs to replay.
"""
self.id = sid
self.agent = agent
Expand Down Expand Up @@ -139,6 +142,9 @@ def __init__(
self._stuck_detector = StuckDetector(self.state)
self.status_callback = status_callback

# replay-related
self._replay_manager = ReplayManager(replay_events)

async def close(self) -> None:
"""Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream.

Expand Down Expand Up @@ -234,6 +240,11 @@ async def _step_with_exception_handling(self):
await self._react_to_exception(reported)

def should_step(self, event: Event) -> bool:
"""
Whether the agent should take a step based on an event. In general,
the agent should take a step if it receives a message from the user,
or observes something in the environment (after acting).
"""
# it might be the delegate's day in the sun
if self.delegate is not None:
return False
Expand Down Expand Up @@ -641,42 +652,50 @@ async def _step(self) -> None:

self.update_state_before_step()
action: Action = NullAction()
try:
action = self.agent.step(self.state)
if action is None:
raise LLMNoActionError('No action was returned')
except (
LLMMalformedActionError,
LLMNoActionError,
LLMResponseError,
FunctionCallValidationError,
FunctionCallNotExistsError,
) as e:
self.event_stream.add_event(
ErrorObservation(
content=str(e),
),
EventSource.AGENT,
)
return
except (ContextWindowExceededError, BadRequestError) as e:
# FIXME: this is a hack until a litellm fix is confirmed
# Check if this is a nested context window error
error_str = str(e).lower()
if (
'contextwindowexceedederror' in error_str
or 'prompt is too long' in error_str
or isinstance(e, ContextWindowExceededError)
):
# When context window is exceeded, keep roughly half of agent interactions
self.state.history = self._apply_conversation_window(self.state.history)

# Save the ID of the first event in our truncated history for future reloading
if self.state.history:
self.state.start_id = self.state.history[0].id
# Don't add error event - let the agent retry with reduced context
if self._replay_manager.should_replay():
# in replay mode, we don't let the agent to proceed
# instead, we replay the action from the replay trajectory
action = self._replay_manager.step()
else:
try:
action = self.agent.step(self.state)
if action is None:
raise LLMNoActionError('No action was returned')
except (
LLMMalformedActionError,
LLMNoActionError,
LLMResponseError,
FunctionCallValidationError,
FunctionCallNotExistsError,
) as e:
self.event_stream.add_event(
ErrorObservation(
content=str(e),
),
EventSource.AGENT,
)
return
raise
except (ContextWindowExceededError, BadRequestError) as e:
# FIXME: this is a hack until a litellm fix is confirmed
# Check if this is a nested context window error
error_str = str(e).lower()
if (
'contextwindowexceedederror' in error_str
or 'prompt is too long' in error_str
or isinstance(e, ContextWindowExceededError)
):
# When context window is exceeded, keep roughly half of agent interactions
self.state.history = self._apply_conversation_window(
self.state.history
)

# Save the ID of the first event in our truncated history for future reloading
if self.state.history:
self.state.start_id = self.state.history[0].id
# Don't add error event - let the agent retry with reduced context
return
raise

if action.runnable:
if self.state.confirmation_mode and (
Expand Down
52 changes: 52 additions & 0 deletions openhands/controller/replay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from openhands.core.logger import openhands_logger as logger
from openhands.events.action.action import Action
from openhands.events.event import Event, EventSource


class ReplayManager:
"""ReplayManager manages the lifecycle of a replay session of a given trajectory.

Replay manager keeps track of a list of events, replays actions, and ignore
messages and observations. It could lead to unexpected or even errorneous
results if any action is non-deterministic, or if the initial state before
the replay session is different from the initial state of the trajectory.
"""

def __init__(self, replay_events: list[Event] | None):
if replay_events:
logger.info(f'Replay logs loaded, events length = {len(replay_events)}')
self.replay_events = replay_events
self.replay_mode = bool(replay_events)
self.replay_index = 0

def _replayable(self) -> bool:
return (
self.replay_events is not None
and self.replay_index < len(self.replay_events)
and isinstance(self.replay_events[self.replay_index], Action)
and self.replay_events[self.replay_index].source != EventSource.USER
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see! This has to work just fine... I do wonder though, what could break it?

  • delegation? because in that case, the controller itself handles AgentDelegateAction and creates a MessageAction for the new guy, and puts in the stream; the trajectory must have saved the old one too. I don't think it needs to be in scope of this PR, though.
  • MessageActions with source user, which happened after the initial task?

Copy link
Collaborator Author

@li-boxuan li-boxuan Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I bet delegation won't work; I am not even sure if trajectory properly records delegation.

MessageActions with source user, which happened after the initial task?

yeah that is just not possible with headless mode, but it would be interesting when we enable this in GUI mode as well... I'd love to have that functionality working, so that people can just upload trajectory to replay some recorded events first and then start working with agents.

Oh actually headless mode does allow interactive inputs, that would be interesting to test out as the next step.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I bet delegation won't work; I am not even sure if trajectory properly records delegation.

I did think of it here! Though I didn't test this use case after the last changes (and for a long time really).

Copy link
Collaborator Author

@li-boxuan li-boxuan Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MessageActions with source user, which happened after the initial task?

Tested in headless mode and it just worked!

Steps 0-5 were from replay

image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it continues normally. But it can't replay the new one, together with that user message? 😅

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I mean, at this end of the run in your image, the event stream has all events, those first 5 steps plus 3 other steps. The agent history has them too. If we save trajectory now, traj2.json should contain all of them, including the user message in the middle. Can we replay this traj2.json?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💭 no, because "wait_for_response": true from agent message would trigger a AWAITING_USER_INPUT state.

A workaround is to manually fix the trajectory and change AWAITING_USER_INPUT to false. And that works!

A hack in the code is to somehow not AWAITING_USER_INPUT if there's a next action to replay. (Not implemented)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I wouldn't hack this. Maybe a next iteration of this feature, when we accept user messages, could take care of it because the user message should (?) change the agent state.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it needs some design in the next iteration. I am thinking that the agent state space should be a subset under replay mode. AWAITING_USER_INPUT is not a valid state during replay; it's only valid after a replay (or equivalently, without replay).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense too!

)

def should_replay(self) -> bool:
"""
Whether the controller is in trajectory replay mode, and the replay
hasn't finished. Note: after the replay is finished, the user and
the agent could continue to message/act.

This method also moves "replay_index" to the next action, if applicable.
"""
if not self.replay_mode:
return False

assert self.replay_events is not None
while self.replay_index < len(self.replay_events) and not self._replayable():
self.replay_index += 1

return self._replayable()

def step(self) -> Action:
assert self.replay_events is not None
event = self.replay_events[self.replay_index]
assert isinstance(event, Action)
enyst marked this conversation as resolved.
Show resolved Hide resolved
self.replay_index += 1
return event
2 changes: 2 additions & 0 deletions openhands/core/config/app_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class AppConfig:
file_store: Type of file store to use.
file_store_path: Path to the file store.
save_trajectory_path: Either a folder path to store trajectories with auto-generated filenames, or a designated trajectory file path.
replay_trajectory_path: Path to load trajectory and replay. If provided, trajectory would be replayed first before user's instruction.
workspace_base: Base path for the workspace. Defaults to `./workspace` as absolute path.
workspace_mount_path: Path to mount the workspace. Defaults to `workspace_base`.
workspace_mount_path_in_sandbox: Path to mount the workspace in sandbox. Defaults to `/workspace`.
Expand Down Expand Up @@ -55,6 +56,7 @@ class AppConfig:
file_store: str = 'local'
file_store_path: str = '/tmp/openhands_file_store'
save_trajectory_path: str | None = None
replay_trajectory_path: str | None = None
workspace_base: str | None = None
workspace_mount_path: str | None = None
workspace_mount_path_in_sandbox: str = '/workspace'
Expand Down
65 changes: 60 additions & 5 deletions openhands/core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import os
import sys
from pathlib import Path
from typing import Callable, Protocol

import openhands.agenthub # noqa F401 (we import this to get the agents registered)
Expand All @@ -22,10 +23,11 @@
generate_sid,
)
from openhands.events import EventSource, EventStreamSubscriber
from openhands.events.action import MessageAction
from openhands.events.action import MessageAction, NullAction
from openhands.events.action.action import Action
from openhands.events.event import Event
from openhands.events.observation import AgentStateChangedObservation
from openhands.events.serialization import event_from_dict
from openhands.events.serialization.event import event_to_trajectory
from openhands.runtime.base import Runtime

Expand Down Expand Up @@ -101,7 +103,17 @@ async def run_controller(
if agent is None:
agent = create_agent(runtime, config)

controller, initial_state = create_controller(agent, runtime, config)
replay_events: list[Event] | None = None
if config.replay_trajectory_path:
logger.info('Trajectory replay is enabled')
assert isinstance(initial_user_action, NullAction)
replay_events, initial_user_action = load_replay_log(
xingyaoww marked this conversation as resolved.
Show resolved Hide resolved
config.replay_trajectory_path
)

controller, initial_state = create_controller(
agent, runtime, config, replay_events=replay_events
)

assert isinstance(
initial_user_action, Action
Expand Down Expand Up @@ -194,21 +206,64 @@ def auto_continue_response(
return message


def load_replay_log(trajectory_path: str) -> tuple[list[Event] | None, Action]:
"""
Load trajectory from given path, serialize it to a list of events, and return
two things:
1) A list of events except the first action
2) First action (user message, a.k.a. initial task)
"""
try:
path = Path(trajectory_path).resolve()

if not path.exists():
raise ValueError(f'Trajectory file not found: {path}')

if not path.is_file():
raise ValueError(f'Trajectory path is a directory, not a file: {path}')

with open(path, 'r', encoding='utf-8') as file:
data = json.load(file)
if not isinstance(data, list):
raise ValueError(
f'Expected a list in {path}, got {type(data).__name__}'
)
events = []
for item in data:
event = event_from_dict(item)
# cannot add an event with _id to event stream
event._id = None # type: ignore[attr-defined]
events.append(event)
assert isinstance(events[0], MessageAction)
return events[1:], events[0]
except json.JSONDecodeError as e:
raise ValueError(f'Invalid JSON format in {trajectory_path}: {e}')


if __name__ == '__main__':
args = parse_arguments()

config = setup_config_from_args(args)

# Determine the task
task_str = ''
if args.file:
task_str = read_task_from_file(args.file)
elif args.task:
task_str = args.task
elif not sys.stdin.isatty():
task_str = read_task_from_stdin()

initial_user_action: Action = NullAction()
if config.replay_trajectory_path:
if task_str:
raise ValueError(
'User-specified task is not supported under trajectory replay mode'
)
xingyaoww marked this conversation as resolved.
Show resolved Hide resolved
elif task_str:
initial_user_action = MessageAction(content=task_str)
else:
raise ValueError('No task provided. Please specify a task through -t, -f.')
initial_user_action: MessageAction = MessageAction(content=task_str)

config = setup_config_from_args(args)

# Set session name
session_name = args.name
Expand Down
8 changes: 7 additions & 1 deletion openhands/core/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from openhands.core.logger import openhands_logger as logger
from openhands.events import EventStream
from openhands.events.event import Event
from openhands.llm.llm import LLM
from openhands.runtime import get_runtime_cls
from openhands.runtime.base import Runtime
Expand Down Expand Up @@ -78,7 +79,11 @@ def create_agent(runtime: Runtime, config: AppConfig) -> Agent:


def create_controller(
agent: Agent, runtime: Runtime, config: AppConfig, headless_mode: bool = True
agent: Agent,
runtime: Runtime,
config: AppConfig,
headless_mode: bool = True,
replay_events: list[Event] | None = None,
) -> Tuple[AgentController, State | None]:
event_stream = runtime.event_stream
initial_state = None
Expand All @@ -101,6 +106,7 @@ def create_controller(
initial_state=initial_state,
headless_mode=headless_mode,
confirmation_mode=config.security.confirmation_mode,
replay_events=replay_events,
)
return (controller, initial_state)

Expand Down
4 changes: 3 additions & 1 deletion openhands/events/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class FileReadSource(str, Enum):

@dataclass
class Event:
INVALID_ID = -1

@property
def message(self) -> str | None:
if hasattr(self, '_message'):
Expand All @@ -34,7 +36,7 @@ def message(self) -> str | None:
def id(self) -> int:
if hasattr(self, '_id'):
return self._id # type: ignore[attr-defined]
return -1
return Event.INVALID_ID
li-boxuan marked this conversation as resolved.
Show resolved Hide resolved

@property
def timestamp(self):
Expand Down
2 changes: 1 addition & 1 deletion openhands/events/observation/browse.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class BrowserOutputObservation(Observation):

url: str
trigger_by_action: str
screenshot: str = field(repr=False) # don't show in repr
screenshot: str = field(repr=False, default='') # don't show in repr
error: bool = False
observation: str = ObservationType.BROWSE
# do not include in the memory
Expand Down
Loading