Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: AgentLegoToolkit #164

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
5 changes: 2 additions & 3 deletions examples/internlm2_agent_web_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import streamlit as st

from lagent.actions import ActionExecutor, ArxivSearch, IPythonInterpreter
# from lagent.actions.agentlego_wrapper import AgentLegoToolkit
from lagent.agents.internlm2_agent import INTERPRETER_CN, META_CN, PLUGIN_CN, Internlm2Agent, Internlm2Protocol
from lagent.llms.lmdepoly_wrapper import LMDeployClient
from lagent.llms.meta_template import INTERNLM2_META as META
Expand All @@ -21,9 +22,7 @@ def init_state(self):
st.session_state['assistant'] = []
st.session_state['user'] = []

action_list = [
ArxivSearch(),
]
action_list = [ArxivSearch()]
st.session_state['plugin_map'] = {
action.name: action
for action in action_list
Expand Down
48 changes: 48 additions & 0 deletions lagent/actions/agentlego_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Optional

# from agentlego.parsers import DefaultParser
from agentlego.tools.remote import RemoteTool

from lagent import BaseAction
from lagent.actions.parser import JsonParser


class AgentLegoToolkit(BaseAction):

def __init__(self,
name: str,
url: Optional[str] = None,
text: Optional[str] = None,
spec_dict: Optional[dict] = None,
parser=JsonParser,
enable: bool = True):

if url is not None:
spec = dict(url=url)
elif text is not None:
spec = dict(text=text)
else:
assert spec_dict is not None
spec = dict(spec_dict=spec_dict)
if url is not None and not url.endswith('.json'):
api_list = [RemoteTool.from_url(url).to_lagent()]
else:
api_list = [
api.to_lagent() for api in RemoteTool.from_openapi(**spec)
]
api_desc = []
for api in api_list:
api_desc.append(api.description)
if len(api_list) > 1:
tool_description = dict(name=name, api_list=api_desc)
for func in api_list:
setattr(self, func.name, func.run)
else:
tool_description = api_desc[0]
setattr(self, 'run', api_list[0].run)
super().__init__(
description=tool_description, parser=parser, enable=enable)

@property
def is_toolkit(self):
return 'api_list' in self.description
1 change: 1 addition & 0 deletions requirements/optional.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
agentlego
google-search-results
lmdeploy>=0.2.3
pillow
Expand Down