Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: AgentLegoToolkit #164

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
8 changes: 8 additions & 0 deletions examples/internlm2_agent_web_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import streamlit as st

from lagent.actions import ActionExecutor, ArxivSearch, IPythonInterpreter
from lagent.actions.agentlego_wrapper import AgentLegoToolkit
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from lagent.actions.agentlego_wrapper import AgentLegoToolkit
# from lagent.actions.agentlego_wrapper import AgentLegoToolkit

from lagent.agents.internlm2_agent import INTERPRETER_CN, META_CN, PLUGIN_CN, Internlm2Agent, Internlm2Protocol
from lagent.llms.lmdepoly_wrapper import LMDeployClient
from lagent.llms.meta_template import INTERNLM2_META as META
Expand All @@ -23,6 +24,13 @@ def init_state(self):

action_list = [
ArxivSearch(),
AgentLegoToolkit(
type='ImageDescription',
url='http://127.0.0.1:16180/openapi.json'),
AgentLegoToolkit(
type='Calculator', url='http://127.0.0.1:16181/openapi.json'),
AgentLegoToolkit(
type='PluginMarket', url='http://127.0.0.1:16182/openapi.json')
]
st.session_state['plugin_map'] = {
action.name: action
Expand Down
51 changes: 51 additions & 0 deletions lagent/actions/agentlego_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Optional

# from agentlego.parsers import DefaultParser
from agentlego.tools.remote import RemoteTool

from lagent import BaseAction
from lagent.actions.parser import JsonParser


class AgentLegoToolkit(BaseAction):

def __init__(self,
type: str,
url: Optional[str] = None,
text: Optional[str] = None,
spec_dict: Optional[dict] = None,
parser=JsonParser,
enable: bool = True):

if url is not None:
spec = dict(url=url)
elif text is not None:
spec = dict(text=text)
else:
assert spec_dict is not None
spec = dict(spec_dict=spec_dict)
if url is not None and not url.endswith('.json'):
api_list = [RemoteTool.from_url(url).to_lagent()]
else:
api_list = [
api.to_lagent() for api in RemoteTool.from_openapi(**spec)
]
api_desc = []
for api in api_list:
api_desc.append(api.description)
if len(api_list) > 1:
tool_description = dict(name=type, api_list=api_desc)
self.add_method(api_list)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.add_method(api_list)
self.add_method(api_list)
Suggested change
self.add_method(api_list)
for func in api_list:
setattr(self, func.name, func.run)

else:
tool_description = api_desc[0]
setattr(self, 'run', api_list[0].run)
super().__init__(
description=tool_description, parser=parser, enable=enable)

@property
def is_toolkit(self):
return 'api_list' in self.description

def add_method(self, funcs):
for func in funcs:
setattr(self, func.name, func.run)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove