From 5581fad8cefa688442555c3e91deff26b06ffd80 Mon Sep 17 00:00:00 2001 From: BraisedPork <46232992+braisedpork1964@users.noreply.github.com> Date: Tue, 30 Jan 2024 15:12:25 +0800 Subject: [PATCH] Validate actions implement `run` XOR other tool APIs (#111) Validate actions implement `run` XOR other tool APIs --- docs/en/tutorials/action.md | 5 ++++- docs/zh_cn/tutorials/action.md | 5 ++++- lagent/actions/base_action.py | 22 +++++++++++++++++++--- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/docs/en/tutorials/action.md b/docs/en/tutorials/action.md index 36c07358..31fbd857 100644 --- a/docs/en/tutorials/action.md +++ b/docs/en/tutorials/action.md @@ -209,10 +209,13 @@ Only Google style Python docstrings is currently supported. A simple tool must have its `run` method implemented, while APIs of toolkits should avoid naming conflicts with this reserved word. +```{tip} +`run` is allowed not to be decorated by `tool_api` for simple tools unless you want to hint the return data. +``` + ```python class Bold(BaseAction): - @tool_api def run(self, text: str): """make text bold diff --git a/docs/zh_cn/tutorials/action.md b/docs/zh_cn/tutorials/action.md index d816648c..284a7522 100644 --- a/docs/zh_cn/tutorials/action.md +++ b/docs/zh_cn/tutorials/action.md @@ -206,10 +206,13 @@ def list_args(a: str, b: int, c: float = 0.0) -> dict: 一个简单工具必须实现 `run` 方法,而工具包则应当避免将各子API名称定义为该保留字段。 +```{tip} +对于非工具包的 Action,`run` 允许不被 `tool_api` 装饰,除非你想提示返回信息。 +``` + ```python class Bold(BaseAction): - @tool_api def run(self, text: str): """make text bold diff --git a/lagent/actions/base_action.py b/lagent/actions/base_action.py index c7a93ccf..d3aa263c 100644 --- a/lagent/actions/base_action.py +++ b/lagent/actions/base_action.py @@ -221,8 +221,25 @@ def __new__(mcs, name, base, attrs): if api_desc.get('return_data'): tool_desc['return_data'] = api_desc['return_data'] is_toolkit = False - break - tool_desc.setdefault('api_list', []).append(api_desc) + else: + tool_desc.setdefault('api_list', []).append(api_desc) + if not is_toolkit and 'api_list' in tool_desc: + raise KeyError('`run` and other tool APIs can not be implemented ' + 'at the same time') + if is_toolkit and 'api_list' not in tool_desc: + is_toolkit = False + if callable(attrs.get('run')): + run_api = tool_api(attrs['run']) + api_desc = run_api.api_description + tool_desc['parameters'] = api_desc['parameters'] + tool_desc['required'] = api_desc['required'] + if api_desc['description']: + tool_desc['description'] = api_desc['description'] + if api_desc.get('return_data'): + tool_desc['return_data'] = api_desc['return_data'] + attrs['run'] = run_api + else: + tool_desc['parameters'], tool_desc['required'] = [], [] attrs['_is_toolkit'] = is_toolkit attrs['__tool_description__'] = tool_desc return super().__new__(mcs, name, base, attrs) @@ -248,7 +265,6 @@ class BaseAction(metaclass=AutoRegister(TOOL_REGISTRY, ToolMeta)): class Bold(BaseAction): '''Make text bold''' - @tool_api def run(self, text: str): ''' Args: