diff --git a/openhands/core/config/sandbox_config.py b/openhands/core/config/sandbox_config.py index 3a0b705dd02d..0ea40f29faab 100644 --- a/openhands/core/config/sandbox_config.py +++ b/openhands/core/config/sandbox_config.py @@ -41,7 +41,7 @@ class SandboxConfig: remote_runtime_api_url: str = 'http://localhost:8000' local_runtime_url: str = 'http://localhost' - keep_runtime_alive: bool = True + keep_runtime_alive: bool = False rm_all_containers: bool = False api_key: str | None = None base_container_image: str = 'nikolaik/python-nodejs:python3.12-nodejs22' # default to nikolaik/python-nodejs:python3.12-nodejs22 for eventstream runtime @@ -60,7 +60,7 @@ class SandboxConfig: runtime_startup_env_vars: dict[str, str] = field(default_factory=dict) browsergym_eval_env: str | None = None platform: str | None = None - close_delay: int = 900 + close_delay: int = 15 remote_runtime_resource_factor: int = 1 enable_gpu: bool = False docker_runtime_kwargs: str | None = None diff --git a/openhands/runtime/builder/remote.py b/openhands/runtime/builder/remote.py index b2e869eca3bf..a728460a374e 100644 --- a/openhands/runtime/builder/remote.py +++ b/openhands/runtime/builder/remote.py @@ -9,6 +9,7 @@ from openhands.core.logger import openhands_logger as logger from openhands.runtime.builder import RuntimeBuilder from openhands.runtime.utils.request import send_request +from openhands.utils.http_session import HttpSession from openhands.utils.shutdown_listener import ( should_continue, sleep_if_should_continue, @@ -18,12 +19,10 @@ class RemoteRuntimeBuilder(RuntimeBuilder): """This class interacts with the remote Runtime API for building and managing container images.""" - def __init__( - self, api_url: str, api_key: str, session: requests.Session | None = None - ): + def __init__(self, api_url: str, api_key: str, session: HttpSession | None = None): self.api_url = api_url self.api_key = api_key - self.session = session or requests.Session() + self.session = session or HttpSession() self.session.headers.update({'X-API-Key': self.api_key}) def build( diff --git a/openhands/runtime/impl/action_execution/action_execution_client.py b/openhands/runtime/impl/action_execution/action_execution_client.py index 24fb8250b30e..4965fc1752af 100644 --- a/openhands/runtime/impl/action_execution/action_execution_client.py +++ b/openhands/runtime/impl/action_execution/action_execution_client.py @@ -35,6 +35,7 @@ from openhands.runtime.base import Runtime from openhands.runtime.plugins import PluginRequirement from openhands.runtime.utils.request import send_request +from openhands.utils.http_session import HttpSession class ActionExecutionClient(Runtime): @@ -55,7 +56,7 @@ def __init__( attach_to_existing: bool = False, headless_mode: bool = True, ): - self.session = requests.Session() + self.session = HttpSession() self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time self._runtime_initialized: bool = False self._vscode_token: str | None = None # initial dummy value diff --git a/openhands/runtime/utils/request.py b/openhands/runtime/utils/request.py index e05a083e7b0d..0117e019a6a8 100644 --- a/openhands/runtime/utils/request.py +++ b/openhands/runtime/utils/request.py @@ -4,6 +4,7 @@ import requests from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential +from openhands.utils.http_session import HttpSession from openhands.utils.tenacity_stop import stop_if_should_exit @@ -34,7 +35,7 @@ def is_retryable_error(exception): wait=wait_exponential(multiplier=1, min=4, max=60), ) def send_request( - session: requests.Session, + session: HttpSession, method: str, url: str, timeout: int = 10, @@ -48,11 +49,11 @@ def send_request( _json = response.json() except (requests.exceptions.JSONDecodeError, json.decoder.JSONDecodeError): _json = None - finally: - response.close() raise RequestHTTPError( e, response=e.response, detail=_json.get('detail') if _json is not None else None, ) from e + finally: + response.close() return response diff --git a/openhands/server/routes/manage_conversations.py b/openhands/server/routes/manage_conversations.py index 5cfc0ba82d52..f622c5bad9cf 100644 --- a/openhands/server/routes/manage_conversations.py +++ b/openhands/server/routes/manage_conversations.py @@ -130,7 +130,7 @@ async def search_conversations( for conversation in conversation_metadata_result_set.results if hasattr(conversation, 'created_at') ) - running_conversations = await session_manager.get_agent_loop_running( + running_conversations = await session_manager.get_running_agent_loops( get_user_id(request), set(conversation_ids) ) result = ConversationInfoResultSet( diff --git a/openhands/server/session/agent_session.py b/openhands/server/session/agent_session.py index 70bf6eeca6bb..285acccbfbe4 100644 --- a/openhands/server/session/agent_session.py +++ b/openhands/server/session/agent_session.py @@ -1,4 +1,5 @@ import asyncio +import time from typing import Callable, Optional from openhands.controller import AgentController @@ -16,10 +17,10 @@ from openhands.runtime.base import Runtime from openhands.security import SecurityAnalyzer, options from openhands.storage.files import FileStore -from openhands.utils.async_utils import call_async_from_sync, call_sync_from_async +from openhands.utils.async_utils import call_sync_from_async from openhands.utils.shutdown_listener import should_continue -WAIT_TIME_BEFORE_CLOSE = 300 +WAIT_TIME_BEFORE_CLOSE = 90 WAIT_TIME_BEFORE_CLOSE_INTERVAL = 5 @@ -36,7 +37,8 @@ class AgentSession: controller: AgentController | None = None runtime: Runtime | None = None security_analyzer: SecurityAnalyzer | None = None - _initializing: bool = False + _starting: bool = False + _started_at: float = 0 _closed: bool = False loop: asyncio.AbstractEventLoop | None = None @@ -88,7 +90,8 @@ async def start( if self._closed: logger.warning('Session closed before starting') return - self._initializing = True + self._starting = True + self._started_at = time.time() self._create_security_analyzer(config.security.security_analyzer) await self._create_runtime( runtime_name=runtime_name, @@ -109,24 +112,19 @@ async def start( self.event_stream.add_event( ChangeAgentStateAction(AgentState.INIT), EventSource.ENVIRONMENT ) - self._initializing = False + self._starting = False - def close(self): + async def close(self): """Closes the Agent session""" if self._closed: return self._closed = True - call_async_from_sync(self._close) - - async def _close(self): - seconds_waited = 0 - while self._initializing and should_continue(): + while self._starting and should_continue(): logger.debug( f'Waiting for initialization to finish before closing session {self.sid}' ) await asyncio.sleep(WAIT_TIME_BEFORE_CLOSE_INTERVAL) - seconds_waited += WAIT_TIME_BEFORE_CLOSE_INTERVAL - if seconds_waited > WAIT_TIME_BEFORE_CLOSE: + if time.time() <= self._started_at + WAIT_TIME_BEFORE_CLOSE: logger.error( f'Waited too long for initialization to finish before closing session {self.sid}' ) @@ -311,3 +309,12 @@ def _maybe_restore_state(self) -> State | None: else: logger.debug('No events found, no state to restore') return restored_state + + def get_state(self) -> AgentState | None: + controller = self.controller + if controller: + return controller.state.agent_state + if time.time() > self._started_at + WAIT_TIME_BEFORE_CLOSE: + # If 5 minutes have elapsed and we still don't have a controller, something has gone wrong + return AgentState.ERROR + return None diff --git a/openhands/server/session/manager.py b/openhands/server/session/manager.py index 67358f61fbe8..3c4d929a72de 100644 --- a/openhands/server/session/manager.py +++ b/openhands/server/session/manager.py @@ -2,6 +2,7 @@ import json import time from dataclasses import dataclass, field +from typing import Generic, Iterable, TypeVar from uuid import uuid4 import socketio @@ -9,26 +10,28 @@ from openhands.core.config import AppConfig from openhands.core.exceptions import AgentRuntimeUnavailableError from openhands.core.logger import openhands_logger as logger +from openhands.core.schema.agent import AgentState from openhands.events.stream import EventStream, session_exists from openhands.server.session.conversation import Conversation from openhands.server.session.session import ROOM_KEY, Session from openhands.server.settings import Settings from openhands.storage.files import FileStore -from openhands.utils.async_utils import call_sync_from_async +from openhands.utils.async_utils import wait_all from openhands.utils.shutdown_listener import should_continue _REDIS_POLL_TIMEOUT = 1.5 _CHECK_ALIVE_INTERVAL = 15 _CLEANUP_INTERVAL = 15 -_CLEANUP_EXCEPTION_WAIT_TIME = 15 +MAX_RUNNING_CONVERSATIONS = 3 +T = TypeVar('T') @dataclass -class _SessionIsRunningCheck: - request_id: str - request_sids: list[str] - running_sids: set[str] = field(default_factory=set) +class _ClusterQuery(Generic[T]): + query_id: str + request_ids: set[str] | None + result: T flag: asyncio.Event = field(default_factory=asyncio.Event) @@ -38,10 +41,10 @@ class SessionManager: config: AppConfig file_store: FileStore _local_agent_loops_by_sid: dict[str, Session] = field(default_factory=dict) - local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict) + _local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict) _last_alive_timestamps: dict[str, float] = field(default_factory=dict) _redis_listen_task: asyncio.Task | None = None - _session_is_running_checks: dict[str, _SessionIsRunningCheck] = field( + _running_sid_queries: dict[str, _ClusterQuery[set[str]]] = field( default_factory=dict ) _active_conversations: dict[str, tuple[Conversation, int]] = field( @@ -52,7 +55,7 @@ class SessionManager: ) _conversations_lock: asyncio.Lock = field(default_factory=asyncio.Lock) _cleanup_task: asyncio.Task | None = None - _has_remote_connections_flags: dict[str, asyncio.Event] = field( + _connection_queries: dict[str, _ClusterQuery[dict[str, str]]] = field( default_factory=dict ) @@ -60,7 +63,7 @@ async def __aenter__(self): redis_client = self._get_redis_client() if redis_client: self._redis_listen_task = asyncio.create_task(self._redis_subscribe()) - self._cleanup_task = asyncio.create_task(self._cleanup_detached_conversations()) + self._cleanup_task = asyncio.create_task(self._cleanup_stale()) return self async def __aexit__(self, exc_type, exc_value, traceback): @@ -82,7 +85,7 @@ async def _redis_subscribe(self): logger.debug('_redis_subscribe') redis_client = self._get_redis_client() pubsub = redis_client.pubsub() - await pubsub.subscribe('oh_event') + await pubsub.subscribe('session_msg') while should_continue(): try: message = await pubsub.get_message( @@ -108,59 +111,71 @@ async def _process_message(self, message: dict): session = self._local_agent_loops_by_sid.get(sid) if session: await session.dispatch(data['data']) - elif message_type == 'is_session_running': + elif message_type == 'running_agent_loops_query': # Another node in the cluster is asking if the current node is running the session given. - request_id = data['request_id'] - sids = [ - sid for sid in data['sids'] if sid in self._local_agent_loops_by_sid - ] + query_id = data['query_id'] + sids = self._get_running_agent_loops_locally( + data.get('user_id'), data.get('filter_to_sids') + ) if sids: await self._get_redis_client().publish( - 'oh_event', + 'session_msg', json.dumps( { - 'request_id': request_id, - 'sids': sids, - 'message_type': 'session_is_running', + 'query_id': query_id, + 'sids': list(sids), + 'message_type': 'running_agent_loops_response', } ), ) - elif message_type == 'session_is_running': - request_id = data['request_id'] + elif message_type == 'running_agent_loops_response': + query_id = data['query_id'] for sid in data['sids']: self._last_alive_timestamps[sid] = time.time() - check = self._session_is_running_checks.get(request_id) - if check: - check.running_sids.update(data['sids']) - if len(check.request_sids) == len(check.running_sids): - check.flag.set() - elif message_type == 'has_remote_connections_query': + running_query = self._running_sid_queries.get(query_id) + if running_query: + running_query.result.update(data['sids']) + if running_query.request_ids is not None and len( + running_query.request_ids + ) == len(running_query.result): + running_query.flag.set() + elif message_type == 'connections_query': # Another node in the cluster is asking if the current node is connected to a session - sid = data['sid'] - required = sid in self.local_connection_id_to_session_id.values() - if required: + query_id = data['query_id'] + connections = self._get_connections_locally( + data.get('user_id'), data.get('filter_to_sids') + ) + if connections: await self._get_redis_client().publish( - 'oh_event', + 'session_msg', json.dumps( - {'sid': sid, 'message_type': 'has_remote_connections_response'} + { + 'query_id': query_id, + 'connections': connections, + 'message_type': 'connections_response', + } ), ) - elif message_type == 'has_remote_connections_response': - sid = data['sid'] - flag = self._has_remote_connections_flags.get(sid) - if flag: - flag.set() + elif message_type == 'connections_response': + query_id = data['query_id'] + connection_query = self._connection_queries.get(query_id) + if connection_query: + connection_query.result.update(**data['connections']) + if connection_query.request_ids is not None and len( + connection_query.request_ids + ) == len(connection_query.result): + connection_query.flag.set() elif message_type == 'close_session': sid = data['sid'] if sid in self._local_agent_loops_by_sid: - await self._on_close_session(sid) + await self._close_session(sid) elif message_type == 'session_closing': # Session closing event - We only get this in the event of graceful shutdown, # which can't be guaranteed - nodes can simply vanish unexpectedly! sid = data['sid'] logger.debug(f'session_closing:{sid}') # Create a list of items to process to avoid modifying dict during iteration - items = list(self.local_connection_id_to_session_id.items()) + items = list(self._local_connection_id_to_session_id.items()) for connection_id, local_sid in items: if sid == local_sid: logger.warning( @@ -208,7 +223,7 @@ async def join_conversation( ): logger.info(f'join_conversation:{sid}:{connection_id}') await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid)) - self.local_connection_id_to_session_id[connection_id] = sid + self._local_connection_id_to_session_id[connection_id] = sid event_stream = await self._get_event_stream(sid) if not event_stream: return await self.maybe_start_agent_loop(sid, settings, user_id) @@ -226,7 +241,7 @@ async def detach_from_conversation(self, conversation: Conversation): self._active_conversations.pop(sid) self._detached_conversations[sid] = (conversation, time.time()) - async def _cleanup_detached_conversations(self): + async def _cleanup_stale(self): while should_continue(): if self._get_redis_client(): # Debug info for HA envs @@ -240,7 +255,7 @@ async def _cleanup_detached_conversations(self): f'Running agent loops: {len(self._local_agent_loops_by_sid)}' ) logger.info( - f'Local connections: {len(self.local_connection_id_to_session_id)}' + f'Local connections: {len(self._local_connection_id_to_session_id)}' ) try: async with self._conversations_lock: @@ -250,97 +265,176 @@ async def _cleanup_detached_conversations(self): await conversation.disconnect() self._detached_conversations.pop(sid, None) + close_threshold = time.time() - self.config.sandbox.close_delay + running_loops = list(self._local_agent_loops_by_sid.items()) + running_loops.sort(key=lambda item: item[1].last_active_ts) + sid_to_close: list[str] = [] + for sid, session in running_loops: + state = session.agent_session.get_state() + if session.last_active_ts < close_threshold and state not in [ + AgentState.RUNNING, + None, + ]: + sid_to_close.append(sid) + + connections = self._get_connections_locally( + filter_to_sids=set(sid_to_close) + ) + connected_sids = {sid for _, sid in connections.items()} + sid_to_close = [ + sid for sid in sid_to_close if sid not in connected_sids + ] + + if sid_to_close: + connections = await self._get_connections_remotely( + filter_to_sids=set(sid_to_close) + ) + connected_sids = {sid for _, sid in connections.items()} + sid_to_close = [ + sid for sid in sid_to_close if sid not in connected_sids + ] + + await wait_all(self._close_session(sid) for sid in sid_to_close) await asyncio.sleep(_CLEANUP_INTERVAL) except asyncio.CancelledError: async with self._conversations_lock: for conversation, _ in self._detached_conversations.values(): await conversation.disconnect() self._detached_conversations.clear() + await wait_all( + self._close_session(sid) for sid in self._local_agent_loops_by_sid + ) return except Exception as e: - logger.warning(f'error_cleaning_detached_conversations: {str(e)}') - await asyncio.sleep(_CLEANUP_EXCEPTION_WAIT_TIME) - - async def get_agent_loop_running(self, user_id, sids: set[str]) -> set[str]: - running_sids = set(sid for sid in sids if sid in self._local_agent_loops_by_sid) - check_cluster_sids = [sid for sid in sids if sid not in running_sids] - running_cluster_sids = await self.get_agent_loop_running_in_cluster( - check_cluster_sids - ) - running_sids.union(running_cluster_sids) - return running_sids + logger.warning(f'error_cleaning_stale: {str(e)}') + await asyncio.sleep(_CLEANUP_INTERVAL) async def is_agent_loop_running(self, sid: str) -> bool: - if await self.is_agent_loop_running_locally(sid): - return True - if await self.is_agent_loop_running_in_cluster(sid): - return True - return False - - async def is_agent_loop_running_locally(self, sid: str) -> bool: - return sid in self._local_agent_loops_by_sid - - async def is_agent_loop_running_in_cluster(self, sid: str) -> bool: - running_sids = await self.get_agent_loop_running_in_cluster([sid]) - return bool(running_sids) - - async def get_agent_loop_running_in_cluster(self, sids: list[str]) -> set[str]: + sids = await self.get_running_agent_loops(filter_to_sids={sid}) + return bool(sids) + + async def get_running_agent_loops( + self, user_id: str | None = None, filter_to_sids: set[str] | None = None + ) -> set[str]: + """Get the running session ids. If a user is supplied, then the results are limited to session ids for that user. If a set of filter_to_sids is supplied, then results are limited to these ids of interest.""" + sids = self._get_running_agent_loops_locally(user_id, filter_to_sids) + remote_sids = await self._get_running_agent_loops_remotely( + user_id, filter_to_sids + ) + return sids.union(remote_sids) + + def _get_running_agent_loops_locally( + self, user_id: str | None = None, filter_to_sids: set[str] | None = None + ) -> set[str]: + items: Iterable[tuple[str, Session]] = self._local_agent_loops_by_sid.items() + if filter_to_sids is not None: + items = (item for item in items if item[0] in filter_to_sids) + if user_id: + items = (item for item in items if item[1].user_id == user_id) + sids = {sid for sid, _ in items} + return sids + + async def _get_running_agent_loops_remotely( + self, + user_id: str | None = None, + filter_to_sids: set[str] | None = None, + ) -> set[str]: """As the rest of the cluster if a session is running. Wait a for a short timeout for a reply""" redis_client = self._get_redis_client() if not redis_client: return set() flag = asyncio.Event() - request_id = str(uuid4()) - check = _SessionIsRunningCheck(request_id=request_id, request_sids=sids) - self._session_is_running_checks[request_id] = check + query_id = str(uuid4()) + query = _ClusterQuery[set[str]]( + query_id=query_id, request_ids=filter_to_sids, result=set() + ) + self._running_sid_queries[query_id] = query try: - logger.debug(f'publish:is_session_running:{sids}') - await redis_client.publish( - 'oh_event', - json.dumps( - { - 'request_id': request_id, - 'sids': sids, - 'message_type': 'is_session_running', - } - ), + logger.debug( + f'publish:_get_running_agent_loops_remotely_query:{user_id}:{filter_to_sids}' ) + data: dict = { + 'query_id': query_id, + 'message_type': 'running_agent_loops_query', + } + if user_id: + data['user_id'] = user_id + if filter_to_sids: + data['filter_to_sids'] = list(filter_to_sids) + await redis_client.publish('session_msg', json.dumps(data)) async with asyncio.timeout(_REDIS_POLL_TIMEOUT): await flag.wait() - return check.running_sids + return query.result except TimeoutError: # Nobody replied in time - return check.running_sids + return query.result finally: - self._session_is_running_checks.pop(request_id, None) + self._running_sid_queries.pop(query_id, None) + + async def get_connections( + self, user_id: str | None = None, filter_to_sids: set[str] | None = None + ) -> dict[str, str]: + connection_ids = self._get_connections_locally(user_id, filter_to_sids) + remote_connection_ids = await self._get_connections_remotely( + user_id, filter_to_sids + ) + connection_ids.update(**remote_connection_ids) + return connection_ids + + def _get_connections_locally( + self, user_id: str | None = None, filter_to_sids: set[str] | None = None + ) -> dict[str, str]: + connections = dict(**self._local_connection_id_to_session_id) + if filter_to_sids is not None: + connections = { + connection_id: sid + for connection_id, sid in connections.items() + if sid in filter_to_sids + } + if user_id: + for connection_id, sid in list(connections.items()): + session = self._local_agent_loops_by_sid.get(sid) + if not session or session.user_id != user_id: + connections.pop(connection_id) + return connections + + async def _get_connections_remotely( + self, user_id: str | None = None, filter_to_sids: set[str] | None = None + ) -> dict[str, str]: + redis_client = self._get_redis_client() + if not redis_client: + return {} - async def _has_remote_connections(self, sid: str) -> bool: - """As the rest of the cluster if they still want this session running. Wait a for a short timeout for a reply""" - # Create a flag for the callback flag = asyncio.Event() - self._has_remote_connections_flags[sid] = flag + query_id = str(uuid4()) + query = _ClusterQuery[dict[str, str]]( + query_id=query_id, request_ids=filter_to_sids, result={} + ) + self._connection_queries[query_id] = query try: - await self._get_redis_client().publish( - 'oh_event', - json.dumps( - { - 'sid': sid, - 'message_type': 'has_remote_connections_query', - } - ), + logger.debug( + f'publish:get_connections_remotely_query:{user_id}:{filter_to_sids}' ) + data: dict = { + 'query_id': query_id, + 'message_type': 'connections_query', + } + if user_id: + data['user_id'] = user_id + if filter_to_sids: + data['filter_to_sids'] = list(filter_to_sids) + await redis_client.publish('session_msg', json.dumps(data)) async with asyncio.timeout(_REDIS_POLL_TIMEOUT): await flag.wait() - result = flag.is_set() - return result + return query.result except TimeoutError: # Nobody replied in time - return False + return query.result finally: - self._has_remote_connections_flags.pop(sid, None) + self._connection_queries.pop(query_id, None) async def maybe_start_agent_loop( self, sid: str, settings: Settings, user_id: str | None @@ -349,8 +443,18 @@ async def maybe_start_agent_loop( session: Session | None = None if not await self.is_agent_loop_running(sid): logger.info(f'start_agent_loop:{sid}') + + response_ids = await self.get_running_agent_loops(user_id) + if len(response_ids) >= MAX_RUNNING_CONVERSATIONS: + logger.info('too_many_sessions_for:{user_id}') + await self.close_session(next(iter(response_ids))) + session = Session( - sid=sid, file_store=self.file_store, config=self.config, sio=self.sio + sid=sid, + file_store=self.file_store, + config=self.config, + sio=self.sio, + user_id=user_id, ) self._local_agent_loops_by_sid[sid] = session asyncio.create_task(session.initialize_agent(settings)) @@ -359,7 +463,6 @@ async def maybe_start_agent_loop( if not event_stream: logger.error(f'No event stream after starting agent loop: {sid}') raise RuntimeError(f'no_event_stream:{sid}') - asyncio.create_task(self._cleanup_session_later(sid)) return event_stream async def _get_event_stream(self, sid: str) -> EventStream | None: @@ -369,7 +472,7 @@ async def _get_event_stream(self, sid: str) -> EventStream | None: logger.info(f'found_local_agent_loop:{sid}') return session.agent_session.event_stream - if await self.is_agent_loop_running_in_cluster(sid): + if await self._get_running_agent_loops_remotely(filter_to_sids={sid}): logger.info(f'found_remote_agent_loop:{sid}') return EventStream(sid, self.file_store) @@ -377,7 +480,7 @@ async def _get_event_stream(self, sid: str) -> EventStream | None: async def send_to_event_stream(self, connection_id: str, data: dict): # If there is a local session running, send to that - sid = self.local_connection_id_to_session_id.get(connection_id) + sid = self._local_connection_id_to_session_id.get(connection_id) if not sid: raise RuntimeError(f'no_connected_session:{connection_id}') @@ -393,11 +496,11 @@ async def send_to_event_stream(self, connection_id: str, data: dict): next_alive_check = last_alive_at + _CHECK_ALIVE_INTERVAL if ( next_alive_check > time.time() - or await self.is_agent_loop_running_in_cluster(sid) + or await self._get_running_agent_loops_remotely(filter_to_sids={sid}) ): # Send the event to the other pod await redis_client.publish( - 'oh_event', + 'session_msg', json.dumps( { 'sid': sid, @@ -411,75 +514,37 @@ async def send_to_event_stream(self, connection_id: str, data: dict): raise RuntimeError(f'no_connected_session:{connection_id}:{sid}') async def disconnect_from_session(self, connection_id: str): - sid = self.local_connection_id_to_session_id.pop(connection_id, None) + sid = self._local_connection_id_to_session_id.pop(connection_id, None) logger.info(f'disconnect_from_session:{connection_id}:{sid}') if not sid: # This can occur if the init action was never run. logger.warning(f'disconnect_from_uninitialized_session:{connection_id}') return - if should_continue(): - asyncio.create_task(self._cleanup_session_later(sid)) - else: - await self._on_close_session(sid) - - async def _cleanup_session_later(self, sid: str): - # Once there have been no connections to a session for a reasonable period, we close it - try: - await asyncio.sleep(self.config.sandbox.close_delay) - finally: - # If the sleep was cancelled, we still want to close these - await self._cleanup_session(sid) - - async def _cleanup_session(self, sid: str) -> bool: - # Get local connections - logger.info(f'_cleanup_session:{sid}') - has_local_connections = next( - (True for v in self.local_connection_id_to_session_id.values() if v == sid), - False, - ) - if has_local_connections: - return False - - # If no local connections, get connections through redis - redis_client = self._get_redis_client() - if redis_client and await self._has_remote_connections(sid): - return False - - # We alert the cluster in case they are interested - if redis_client: - await redis_client.publish( - 'oh_event', - json.dumps({'sid': sid, 'message_type': 'session_closing'}), - ) - - await self._on_close_session(sid) - return True - async def close_session(self, sid: str): session = self._local_agent_loops_by_sid.get(sid) if session: - await self._on_close_session(sid) + await self._close_session(sid) redis_client = self._get_redis_client() if redis_client: await redis_client.publish( - 'oh_event', + 'session_msg', json.dumps({'sid': sid, 'message_type': 'close_session'}), ) - async def _on_close_session(self, sid: str): + async def _close_session(self, sid: str): logger.info(f'_close_session:{sid}') # Clear up local variables connection_ids_to_remove = list( connection_id - for connection_id, conn_sid in self.local_connection_id_to_session_id.items() + for connection_id, conn_sid in self._local_connection_id_to_session_id.items() if sid == conn_sid ) logger.info(f'removing connections: {connection_ids_to_remove}') for connnnection_id in connection_ids_to_remove: - self.local_connection_id_to_session_id.pop(connnnection_id, None) + self._local_connection_id_to_session_id.pop(connnnection_id, None) session = self._local_agent_loops_by_sid.pop(sid, None) if not session: @@ -488,12 +553,17 @@ async def _on_close_session(self, sid: str): logger.info(f'closing_session:{session.sid}') # We alert the cluster in case they are interested - redis_client = self._get_redis_client() - if redis_client: - await redis_client.publish( - 'oh_event', - json.dumps({'sid': session.sid, 'message_type': 'session_closing'}), + try: + redis_client = self._get_redis_client() + if redis_client: + await redis_client.publish( + 'session_msg', + json.dumps({'sid': session.sid, 'message_type': 'session_closing'}), + ) + except Exception: + logger.info( + 'error_publishing_close_session_event', exc_info=True, stack_info=True ) - await call_sync_from_async(session.close) + await session.close() logger.info(f'closed_session:{session.sid}') diff --git a/openhands/server/session/session.py b/openhands/server/session/session.py index 8318ab773129..e77a77101b20 100644 --- a/openhands/server/session/session.py +++ b/openhands/server/session/session.py @@ -62,9 +62,17 @@ def __init__( self.loop = asyncio.get_event_loop() self.user_id = user_id - def close(self): + async def close(self): + if self.sio: + await self.sio.emit( + 'oh_event', + event_to_dict( + AgentStateChangedObservation('', AgentState.STOPPED.value) + ), + to=ROOM_KEY.format(sid=self.sid), + ) self.is_alive = False - self.agent_session.close() + await self.agent_session.close() async def initialize_agent( self, diff --git a/openhands/utils/http_session.py b/openhands/utils/http_session.py new file mode 100644 index 000000000000..4edc4e6546c3 --- /dev/null +++ b/openhands/utils/http_session.py @@ -0,0 +1,24 @@ +from dataclasses import dataclass, field + +import requests + + +@dataclass +class HttpSession: + """ + request.Session is reusable after it has been closed. This behavior makes it + likely to leak file descriptors (Especially when combined with tenacity). + We wrap the session to make it unusable after being closed + """ + + session: requests.Session | None = field(default_factory=requests.Session) + + def __getattr__(self, name): + if self.session is None: + raise ValueError('session_was_closed') + return object.__getattribute__(self.session, name) + + def close(self): + if self.session is not None: + self.session.close() + self.session = None diff --git a/tests/unit/test_manager.py b/tests/unit/test_manager.py index f0ac68ff8361..cd2ddf6ba0a6 100644 --- a/tests/unit/test_manager.py +++ b/tests/unit/test_manager.py @@ -44,28 +44,28 @@ async def test_session_not_running_in_cluster(): async with SessionManager( sio, AppConfig(), InMemoryFileStore() ) as session_manager: - result = await session_manager.is_agent_loop_running_in_cluster( - 'non-existant-session' + result = await session_manager._get_running_agent_loops_remotely( + filter_to_sids={'non-existant-session'} ) - assert result is False + assert result == set() assert sio.manager.redis.publish.await_count == 1 sio.manager.redis.publish.assert_called_once_with( - 'oh_event', - '{"request_id": "' + 'session_msg', + '{"query_id": "' + str(id) - + '", "sids": ["non-existant-session"], "message_type": "is_session_running"}', + + '", "message_type": "running_agent_loops_query", "filter_to_sids": ["non-existant-session"]}', ) @pytest.mark.asyncio -async def test_session_is_running_in_cluster(): +async def test_get_running_agent_loops_remotely(): id = uuid4() sio = get_mock_sio( GetMessageMock( { - 'request_id': str(id), + 'query_id': str(id), 'sids': ['existing-session'], - 'message_type': 'session_is_running', + 'message_type': 'running_agent_loops_response', } ) ) @@ -76,16 +76,16 @@ async def test_session_is_running_in_cluster(): async with SessionManager( sio, AppConfig(), InMemoryFileStore() ) as session_manager: - result = await session_manager.is_agent_loop_running_in_cluster( - 'existing-session' + result = await session_manager._get_running_agent_loops_remotely( + 1, {'existing-session'} ) - assert result is True + assert result == {'existing-session'} assert sio.manager.redis.publish.await_count == 1 sio.manager.redis.publish.assert_called_once_with( - 'oh_event', - '{"request_id": "' + 'session_msg', + '{"query_id": "' + str(id) - + '", "sids": ["existing-session"], "message_type": "is_session_running"}', + + '", "message_type": "running_agent_loops_query", "user_id": 1, "filter_to_sids": ["existing-session"]}', ) @@ -96,8 +96,8 @@ async def test_init_new_local_session(): mock_session = MagicMock() mock_session.return_value = session_instance sio = get_mock_sio() - is_agent_loop_running_in_cluster_mock = AsyncMock() - is_agent_loop_running_in_cluster_mock.return_value = False + get_running_agent_loops_mock = AsyncMock() + get_running_agent_loops_mock.return_value = set() with ( patch('openhands.server.session.manager.Session', mock_session), patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.1), @@ -106,8 +106,8 @@ async def test_init_new_local_session(): AsyncMock(), ), patch( - 'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster', - is_agent_loop_running_in_cluster_mock, + 'openhands.server.session.manager.SessionManager.get_running_agent_loops', + get_running_agent_loops_mock, ), ): async with SessionManager( @@ -130,8 +130,8 @@ async def test_join_local_session(): mock_session = MagicMock() mock_session.return_value = session_instance sio = get_mock_sio() - is_agent_loop_running_in_cluster_mock = AsyncMock() - is_agent_loop_running_in_cluster_mock.return_value = False + get_running_agent_loops_mock = AsyncMock() + get_running_agent_loops_mock.return_value = set() with ( patch('openhands.server.session.manager.Session', mock_session), patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01), @@ -140,8 +140,8 @@ async def test_join_local_session(): AsyncMock(), ), patch( - 'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster', - is_agent_loop_running_in_cluster_mock, + 'openhands.server.session.manager.SessionManager.get_running_agent_loops', + get_running_agent_loops_mock, ), ): async with SessionManager( @@ -167,8 +167,8 @@ async def test_join_cluster_session(): mock_session = MagicMock() mock_session.return_value = session_instance sio = get_mock_sio() - is_agent_loop_running_in_cluster_mock = AsyncMock() - is_agent_loop_running_in_cluster_mock.return_value = True + get_running_agent_loops_mock = AsyncMock() + get_running_agent_loops_mock.return_value = {'new-session-id'} with ( patch('openhands.server.session.manager.Session', mock_session), patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01), @@ -177,8 +177,8 @@ async def test_join_cluster_session(): AsyncMock(), ), patch( - 'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster', - is_agent_loop_running_in_cluster_mock, + 'openhands.server.session.manager.SessionManager._get_running_agent_loops_remotely', + get_running_agent_loops_mock, ), ): async with SessionManager( @@ -198,8 +198,8 @@ async def test_add_to_local_event_stream(): mock_session = MagicMock() mock_session.return_value = session_instance sio = get_mock_sio() - is_agent_loop_running_in_cluster_mock = AsyncMock() - is_agent_loop_running_in_cluster_mock.return_value = False + get_running_agent_loops_mock = AsyncMock() + get_running_agent_loops_mock.return_value = set() with ( patch('openhands.server.session.manager.Session', mock_session), patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01), @@ -208,8 +208,8 @@ async def test_add_to_local_event_stream(): AsyncMock(), ), patch( - 'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster', - is_agent_loop_running_in_cluster_mock, + 'openhands.server.session.manager.SessionManager.get_running_agent_loops', + get_running_agent_loops_mock, ), ): async with SessionManager( @@ -234,8 +234,8 @@ async def test_add_to_cluster_event_stream(): mock_session = MagicMock() mock_session.return_value = session_instance sio = get_mock_sio() - is_agent_loop_running_in_cluster_mock = AsyncMock() - is_agent_loop_running_in_cluster_mock.return_value = True + get_running_agent_loops_mock = AsyncMock() + get_running_agent_loops_mock.return_value = {'new-session-id'} with ( patch('openhands.server.session.manager.Session', mock_session), patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01), @@ -244,8 +244,8 @@ async def test_add_to_cluster_event_stream(): AsyncMock(), ), patch( - 'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster', - is_agent_loop_running_in_cluster_mock, + 'openhands.server.session.manager.SessionManager._get_running_agent_loops_remotely', + get_running_agent_loops_mock, ), ): async with SessionManager( @@ -259,7 +259,7 @@ async def test_add_to_cluster_event_stream(): ) assert sio.manager.redis.publish.await_count == 1 sio.manager.redis.publish.assert_called_once_with( - 'oh_event', + 'session_msg', '{"sid": "new-session-id", "message_type": "event", "data": {"event_type": "some_event"}}', ) @@ -277,7 +277,7 @@ async def test_cleanup_session_connections(): async with SessionManager( sio, AppConfig(), InMemoryFileStore() ) as session_manager: - session_manager.local_connection_id_to_session_id.update( + session_manager._local_connection_id_to_session_id.update( { 'conn1': 'session1', 'conn2': 'session1', @@ -286,9 +286,9 @@ async def test_cleanup_session_connections(): } ) - await session_manager._on_close_session('session1') + await session_manager._close_session('session1') - remaining_connections = session_manager.local_connection_id_to_session_id + remaining_connections = session_manager._local_connection_id_to_session_id assert 'conn1' not in remaining_connections assert 'conn2' not in remaining_connections assert 'conn3' in remaining_connections