diff --git a/lagent/agents/__init__.py b/lagent/agents/__init__.py index d54f8bf5..3e673c19 100644 --- a/lagent/agents/__init__.py +++ b/lagent/agents/__init__.py @@ -1,8 +1,9 @@ -from .agent import Agent, AsyncAgent +from .agent import Agent, AgentDict, AgentList, AsyncAgent from .react import AsyncReAct, ReAct from .stream import AgentForInternLM, AsyncAgentForInternLM, AsyncMathCoder, MathCoder __all__ = [ - 'Agent', 'AsyncAgent', 'AgentForInternLM', 'AsyncAgentForInternLM', - 'MathCoder', 'AsyncMathCoder', 'ReAct', 'AsyncReAct' + 'Agent', 'AgentDict', 'AgentList', 'AsyncAgent', 'AgentForInternLM', + 'AsyncAgentForInternLM', 'MathCoder', 'AsyncMathCoder', 'ReAct', + 'AsyncReAct' ] diff --git a/lagent/agents/agent.py b/lagent/agents/agent.py index d97bd253..59b34f38 100644 --- a/lagent/agents/agent.py +++ b/lagent/agents/agent.py @@ -1,7 +1,8 @@ import copy import warnings -from collections import OrderedDict -from typing import Any, Callable, Dict, List, Optional, Union +from collections import OrderedDict, UserDict, UserList, abc +from functools import wraps +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union from lagent.agents.aggregator import DefaultAggregator from lagent.hooks import Hook, RemovableHandle @@ -46,7 +47,6 @@ def __init__( ): self.name = name or self.__class__.__name__ self.llm: BaseLLM = create_object(llm) - self.memory: MemoryManager = MemoryManager(memory) if memory else None self.output_format: StrParser = create_object(output_format) self.template = template @@ -78,7 +78,6 @@ def __call__( result = hook.before_agent(self, message, session_id) if result: message = result - self.update_memory(message, session_id=session_id) response_message = self.forward( *message, session_id=session_id, **kwargs) @@ -99,7 +98,6 @@ def forward(self, *message: AgentMessage, session_id=0, **kwargs) -> Union[AgentMessage, str]: - formatted_messages = self.aggregator.aggregate( self.memory.get(session_id), self.name, @@ -129,10 +127,11 @@ def state_dict(self, session_id=0): while stack: prefix, node = stack.pop() key = prefix + 'memory' - if session_id not in node.memory.memory_map: - raise KeyError(f'No session id {session_id} in {key}') - state_dict[key] = ( - node.memory.get(session_id).save() if node.memory else None) + if node.memory is not None: + if session_id not in node.memory.memory_map: + warnings.warn(f'No session id {session_id} in {key}') + memory = node.memory.get(session_id) + state_dict[key] = memory and memory.save() or [] if hasattr(node, '_agents'): for name, value in reversed(node._agents.items()): stack.append((prefix + name + '.', value)) @@ -149,10 +148,17 @@ def load_state_dict(self, state_dict: Dict, session_id=0): for key in _state_dict: obj = self for attr in key.split('.')[:-1]: - obj = getattr(obj, attr) - if session_id not in obj.memory.memory_map: - obj.memory.create_instance(session_id) - obj.memory.memory_map[session_id].load(state_dict[key]) + if isinstance(obj, AgentList): + assert attr.isdigit() + obj = obj[int(attr)] + elif isinstance(obj, AgentDict): + obj = obj[attr] + else: + obj = getattr(obj, attr) + if obj.memory is not None: + if session_id not in obj.memory.memory_map: + obj.memory.create_instance(session_id) + obj.memory.memory_map[session_id].load(state_dict[key] or []) def register_hook(self, hook: Callable): handle = RemovableHandle(self._hooks) @@ -179,7 +185,6 @@ async def __call__(self, result = hook.before_agent(self, message, session_id) if result: message = result - self.update_memory(message, session_id=session_id) response_message = await self.forward( *message, session_id=session_id, **kwargs) @@ -200,7 +205,6 @@ async def forward(self, *message: AgentMessage, session_id=0, **kwargs) -> Union[AgentMessage, str]: - formatted_messages = self.aggregator.aggregate( self.memory.get(session_id), self.name, @@ -218,3 +222,71 @@ async def forward(self, formatted=formatted_messages, ) return llm_response + + +class AgentContainerMixin: + + def __init_subclass__(cls): + super().__init_subclass__() + + def wrap_api(func): + + @wraps(func) + def wrapped_func(self, *args, **kwargs): + data = self.data.copy() if hasattr(self, 'data') else None + + def _backup(d): + if d is None: + self.data.clear() + else: + self.data = d + + ret = func(self, *args, **kwargs) + agents = OrderedDict() + for k, item in (self.data.items() if isinstance( + self.data, abc.Mapping) else enumerate(self.data)): + if isinstance(self.data, + abc.Mapping) and not isinstance(k, str): + _backup(data) + raise KeyError( + f'agent name should be a string, got {type(k)}') + if isinstance(k, str) and '.' in k: + _backup(data) + raise KeyError( + f'agent name can\'t contain ".", got {k}') + if not isinstance(item, (Agent, AsyncAgent)): + _backup(data) + raise TypeError( + f'{type(item)} is not an Agent or AsyncAgent subclass' + ) + agents[str(k)] = item + self._agents = agents + return ret + + return wrapped_func + + for method in [ + 'append', 'sort', 'reverse', 'pop', 'clear', 'update', + 'insert', 'extend', 'remove', '__init__', '__setitem__', + '__delitem__', '__add__', '__iadd__', '__radd__', '__mul__', + '__imul__', '__rmul__' + ]: + if hasattr(cls, method): + setattr(cls, method, wrap_api(getattr(cls, method))) + + +class AgentList(Agent, UserList, AgentContainerMixin): + + def __init__(self, + agents: Optional[Iterable[Union[Agent, AsyncAgent]]] = None): + Agent.__init__(self, memory=None) + UserList.__init__(self, agents) + + +class AgentDict(Agent, UserDict, AgentContainerMixin): + + def __init__(self, + agents: Optional[Mapping[str, Union[Agent, + AsyncAgent]]] = None): + Agent.__init__(self, memory=None) + UserDict.__init__(self, agents) diff --git a/lagent/memory/base_memory.py b/lagent/memory/base_memory.py index adb665bf..c60d9780 100644 --- a/lagent/memory/base_memory.py +++ b/lagent/memory/base_memory.py @@ -53,7 +53,7 @@ def load( else: raise TypeError(f'{type(memories)} is not supported') - def save(self) -> Union[Dict, List]: + def save(self) -> List[dict]: memory = [] for m in self.memory: memory.append(m.model_dump())