Skip to content

Commit

Permalink
[Feat] Add agent containers (#264)
Browse files Browse the repository at this point in the history
add `AgentList` and `AgentDict`
  • Loading branch information
braisedpork1964 authored Oct 29, 2024
1 parent 8341529 commit cd34d8d
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 19 deletions.
7 changes: 4 additions & 3 deletions lagent/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .agent import Agent, AsyncAgent
from .agent import Agent, AgentDict, AgentList, AsyncAgent
from .react import AsyncReAct, ReAct
from .stream import AgentForInternLM, AsyncAgentForInternLM, AsyncMathCoder, MathCoder

__all__ = [
'Agent', 'AsyncAgent', 'AgentForInternLM', 'AsyncAgentForInternLM',
'MathCoder', 'AsyncMathCoder', 'ReAct', 'AsyncReAct'
'Agent', 'AgentDict', 'AgentList', 'AsyncAgent', 'AgentForInternLM',
'AsyncAgentForInternLM', 'MathCoder', 'AsyncMathCoder', 'ReAct',
'AsyncReAct'
]
102 changes: 87 additions & 15 deletions lagent/agents/agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import copy
import warnings
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Union
from collections import OrderedDict, UserDict, UserList, abc
from functools import wraps
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union

from lagent.agents.aggregator import DefaultAggregator
from lagent.hooks import Hook, RemovableHandle
Expand Down Expand Up @@ -46,7 +47,6 @@ def __init__(
):
self.name = name or self.__class__.__name__
self.llm: BaseLLM = create_object(llm)

self.memory: MemoryManager = MemoryManager(memory) if memory else None
self.output_format: StrParser = create_object(output_format)
self.template = template
Expand Down Expand Up @@ -78,7 +78,6 @@ def __call__(
result = hook.before_agent(self, message, session_id)
if result:
message = result

self.update_memory(message, session_id=session_id)
response_message = self.forward(
*message, session_id=session_id, **kwargs)
Expand All @@ -99,7 +98,6 @@ def forward(self,
*message: AgentMessage,
session_id=0,
**kwargs) -> Union[AgentMessage, str]:

formatted_messages = self.aggregator.aggregate(
self.memory.get(session_id),
self.name,
Expand Down Expand Up @@ -129,10 +127,11 @@ def state_dict(self, session_id=0):
while stack:
prefix, node = stack.pop()
key = prefix + 'memory'
if session_id not in node.memory.memory_map:
raise KeyError(f'No session id {session_id} in {key}')
state_dict[key] = (
node.memory.get(session_id).save() if node.memory else None)
if node.memory is not None:
if session_id not in node.memory.memory_map:
warnings.warn(f'No session id {session_id} in {key}')
memory = node.memory.get(session_id)
state_dict[key] = memory and memory.save() or []
if hasattr(node, '_agents'):
for name, value in reversed(node._agents.items()):
stack.append((prefix + name + '.', value))
Expand All @@ -149,10 +148,17 @@ def load_state_dict(self, state_dict: Dict, session_id=0):
for key in _state_dict:
obj = self
for attr in key.split('.')[:-1]:
obj = getattr(obj, attr)
if session_id not in obj.memory.memory_map:
obj.memory.create_instance(session_id)
obj.memory.memory_map[session_id].load(state_dict[key])
if isinstance(obj, AgentList):
assert attr.isdigit()
obj = obj[int(attr)]
elif isinstance(obj, AgentDict):
obj = obj[attr]
else:
obj = getattr(obj, attr)
if obj.memory is not None:
if session_id not in obj.memory.memory_map:
obj.memory.create_instance(session_id)
obj.memory.memory_map[session_id].load(state_dict[key] or [])

def register_hook(self, hook: Callable):
handle = RemovableHandle(self._hooks)
Expand All @@ -179,7 +185,6 @@ async def __call__(self,
result = hook.before_agent(self, message, session_id)
if result:
message = result

self.update_memory(message, session_id=session_id)
response_message = await self.forward(
*message, session_id=session_id, **kwargs)
Expand All @@ -200,7 +205,6 @@ async def forward(self,
*message: AgentMessage,
session_id=0,
**kwargs) -> Union[AgentMessage, str]:

formatted_messages = self.aggregator.aggregate(
self.memory.get(session_id),
self.name,
Expand All @@ -218,3 +222,71 @@ async def forward(self,
formatted=formatted_messages,
)
return llm_response


class AgentContainerMixin:

def __init_subclass__(cls):
super().__init_subclass__()

def wrap_api(func):

@wraps(func)
def wrapped_func(self, *args, **kwargs):
data = self.data.copy() if hasattr(self, 'data') else None

def _backup(d):
if d is None:
self.data.clear()
else:
self.data = d

ret = func(self, *args, **kwargs)
agents = OrderedDict()
for k, item in (self.data.items() if isinstance(
self.data, abc.Mapping) else enumerate(self.data)):
if isinstance(self.data,
abc.Mapping) and not isinstance(k, str):
_backup(data)
raise KeyError(
f'agent name should be a string, got {type(k)}')
if isinstance(k, str) and '.' in k:
_backup(data)
raise KeyError(
f'agent name can\'t contain ".", got {k}')
if not isinstance(item, (Agent, AsyncAgent)):
_backup(data)
raise TypeError(
f'{type(item)} is not an Agent or AsyncAgent subclass'
)
agents[str(k)] = item
self._agents = agents
return ret

return wrapped_func

for method in [
'append', 'sort', 'reverse', 'pop', 'clear', 'update',
'insert', 'extend', 'remove', '__init__', '__setitem__',
'__delitem__', '__add__', '__iadd__', '__radd__', '__mul__',
'__imul__', '__rmul__'
]:
if hasattr(cls, method):
setattr(cls, method, wrap_api(getattr(cls, method)))


class AgentList(Agent, UserList, AgentContainerMixin):

def __init__(self,
agents: Optional[Iterable[Union[Agent, AsyncAgent]]] = None):
Agent.__init__(self, memory=None)
UserList.__init__(self, agents)


class AgentDict(Agent, UserDict, AgentContainerMixin):

def __init__(self,
agents: Optional[Mapping[str, Union[Agent,
AsyncAgent]]] = None):
Agent.__init__(self, memory=None)
UserDict.__init__(self, agents)
2 changes: 1 addition & 1 deletion lagent/memory/base_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def load(
else:
raise TypeError(f'{type(memories)} is not supported')

def save(self) -> Union[Dict, List]:
def save(self) -> List[dict]:
memory = []
for m in self.memory:
memory.append(m.model_dump())
Expand Down

0 comments on commit cd34d8d

Please sign in to comment.