Skip to content

Commit

Permalink
Validate actions implement run XOR other tool APIs (#111)
Browse files Browse the repository at this point in the history
Validate actions implement `run` XOR other tool APIs
  • Loading branch information
braisedpork1964 authored Jan 30, 2024
1 parent 95b68c8 commit 5581fad
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 5 deletions.
5 changes: 4 additions & 1 deletion docs/en/tutorials/action.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,13 @@ Only Google style Python docstrings is currently supported.

A simple tool must have its `run` method implemented, while APIs of toolkits should avoid naming conflicts with this reserved word.

```{tip}
`run` is allowed not to be decorated by `tool_api` for simple tools unless you want to hint the return data.
```

```python
class Bold(BaseAction):

@tool_api
def run(self, text: str):
"""make text bold
Expand Down
5 changes: 4 additions & 1 deletion docs/zh_cn/tutorials/action.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,13 @@ def list_args(a: str, b: int, c: float = 0.0) -> dict:

一个简单工具必须实现 `run` 方法,而工具包则应当避免将各子API名称定义为该保留字段。

```{tip}
对于非工具包的 Action,`run` 允许不被 `tool_api` 装饰,除非你想提示返回信息。
```

```python
class Bold(BaseAction):

@tool_api
def run(self, text: str):
"""make text bold
Expand Down
22 changes: 19 additions & 3 deletions lagent/actions/base_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,25 @@ def __new__(mcs, name, base, attrs):
if api_desc.get('return_data'):
tool_desc['return_data'] = api_desc['return_data']
is_toolkit = False
break
tool_desc.setdefault('api_list', []).append(api_desc)
else:
tool_desc.setdefault('api_list', []).append(api_desc)
if not is_toolkit and 'api_list' in tool_desc:
raise KeyError('`run` and other tool APIs can not be implemented '
'at the same time')
if is_toolkit and 'api_list' not in tool_desc:
is_toolkit = False
if callable(attrs.get('run')):
run_api = tool_api(attrs['run'])
api_desc = run_api.api_description
tool_desc['parameters'] = api_desc['parameters']
tool_desc['required'] = api_desc['required']
if api_desc['description']:
tool_desc['description'] = api_desc['description']
if api_desc.get('return_data'):
tool_desc['return_data'] = api_desc['return_data']
attrs['run'] = run_api
else:
tool_desc['parameters'], tool_desc['required'] = [], []
attrs['_is_toolkit'] = is_toolkit
attrs['__tool_description__'] = tool_desc
return super().__new__(mcs, name, base, attrs)
Expand All @@ -248,7 +265,6 @@ class BaseAction(metaclass=AutoRegister(TOOL_REGISTRY, ToolMeta)):
class Bold(BaseAction):
'''Make text bold'''
@tool_api
def run(self, text: str):
'''
Args:
Expand Down

0 comments on commit 5581fad

Please sign in to comment.