From e46e59cc6aa5a5ed759d7d57293fb7540bcbcbdb Mon Sep 17 00:00:00 2001 From: liujiangning30 <147385819+liujiangning30@users.noreply.github.com> Date: Thu, 18 Jan 2024 16:22:16 +0800 Subject: [PATCH 01/20] Feature: redesign BaseModel (#80) * redesign BaseModel * update docstring * update baseModel --- lagent/llms/base_llm.py | 117 ++++++++++++++++++++++++++++++++-------- 1 file changed, 95 insertions(+), 22 deletions(-) diff --git a/lagent/llms/base_llm.py b/lagent/llms/base_llm.py index 34729ce8..05919787 100644 --- a/lagent/llms/base_llm.py +++ b/lagent/llms/base_llm.py @@ -1,4 +1,5 @@ from abc import abstractclassmethod +from copy import copy from typing import Dict, List, Optional, Tuple, Union @@ -21,7 +22,7 @@ def __init__(self, meta_template: Optional[List[Dict]] = None): 'role in meta prompt must be unique!' self.roles[item['role']] = item.copy() - def parse_template(self, dialog) -> str: + def __call__(self, dialog) -> str: """Parse a prompt template, and wrap it with meta template if applicable. @@ -114,7 +115,14 @@ def __init__(self, max_seq_len: int = 2048, tokenizer_only: bool = False, template_parser: 'LMTemplateParser' = LMTemplateParser, - meta_template: Optional[List[Dict]] = None): + meta_template: Optional[List[Dict]] = None, + *, + max_out_len: int = 512, + top_p: float = 0.8, + top_k: float = None, + temperature: float = 0.8, + repetition_penalty: float = 1.0, + stop_words: Union[List[str], str] = None): self.path = path self.max_seq_len = max_seq_len self.tokenizer_only = tokenizer_only @@ -124,41 +132,106 @@ def __init__(self, if meta_template and 'eos_token_id' in meta_template: self.eos_token_id = meta_template['eos_token_id'] + self.completion_params = dict( + max_out_len=max_out_len, + top_p=top_p, + top_k=top_k, + temperature=temperature, + repetition_penalty=repetition_penalty, + stop_words=stop_words) + @abstractclassmethod - def generate(self, inputs: List[str], max_out_len: int) -> List[str]: - """Generate results given a list of inputs. + def completion( + self, + inputs: Union[str, List[str]], + **completion_params + ) -> str: + """Generate results given a str (or list of) inputs. Args: - inputs (List[str]): A list of strings. - max_out_len (int): The maximum length of the output. + inputs (Union[str, List[str]]): + completion_params (dict): The input params for completion. Returns: - List[str]: A list of generated strings. + Union[str, List[str]]: A (list of) generated strings. + + eg. + batched = True + if isinstance(inputs, str): + inputs = [inputs] + batched = False + response = [''] + if batched: + return response + return response[0] """ - def parse_template(self, dialog) -> str: - """Parse a prompt template, and wrap it with meta template if - applicable. + def stream_completion( + self, + inputs: str, + **completion_params + ) -> List[str]: + """Generate results as streaming given a str inputs. Args: - dialog (List[str or PromptList]): A prompt - template (potentially before being wrapped by meta template). - mode (str): Parsing mode. Choices are 'ppl' and 'gen'. + inputs (str): + completion_params (dict): The input params for completion. Returns: - str: The final string. + str: A generated string. """ - return self.template_parser.parse_template(dialog) + raise NotImplementedError - def generate_from_template(self, templates, max_out_len: int, **kwargs): + def chat( + self, + inputs: Union[List[dict], List[List[dict]]], + **completion_params + ): """Generate completion from a list of templates. Args: - templates (List[PromptType]): A list of templates. - max_out_len (int): The maximum length of the output. + inputs (Union[List[dict], List[List[dict]]]): + completion_params (dict): The input params for completion. + Returns: + """ + if isinstance(inputs[0], list): + inputs = list() + for msg in inputs: + inputs.append(self.template_parser(msg)) + else: + inputs = self.template_parser(inputs) + return self.completion(inputs, **completion_params) + + def stream_chat( + self, + inputs: List[dict], + **completion_params + ): + """Generate results as streaming given a list of templates. + + Args: + inputs (Union[List[dict]): + completion_params (dict): The input params for completion. + Returns: + """ + raise NotImplementedError + + def tokenize( + self, + prompts: Union[str, List[str], List[dict], List[List[dict]]] + ): + """Tokenize the input prompts. + + Args: + prompts(str | List[str]): user's prompt, or a batch prompts + + Returns: + Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token + ids, ids' length and requested output length """ - inputs = self.parse_template(templates) - return self.generate(inputs, max_out_len=max_out_len, **kwargs) + raise NotImplementedError - def to(self, device): - self.model.to(device) + def update_completion_params(self, **kwargs): + completion_params = copy(self.completion_params) + completion_params.update(kwargs) + return completion_params From 30f77daa071c71dd6c352901d775e89fa9bf1c8a Mon Sep 17 00:00:00 2001 From: BraisedPork <46232992+braisedpork1964@users.noreply.github.com> Date: Fri, 19 Jan 2024 14:56:11 +0800 Subject: [PATCH 02/20] [Refactor] improve `Action` and `ActionExecutor` (#83) * [Fix]: fix turbomind (#81) fix turbomind * add parsers * skip ActionReturn in postprocessing * check existence of API name * add exception catch in action executing * validate input arguments * modify returned structure of `get_actions_info` * adapt tools to the new protocol * remove `LLMQA` action --------- Co-authored-by: RangiLyu Co-authored-by: wangzy --- lagent/actions/__init__.py | 3 +- lagent/actions/action_executor.py | 42 ++++---- lagent/actions/base_action.py | 152 +++++++++++++++++++++------ lagent/actions/builtin_actions.py | 5 +- lagent/actions/google_search.py | 66 ++++++------ lagent/actions/llm_qa.py | 56 ---------- lagent/actions/parser.py | 127 ++++++++++++++++++++++ lagent/actions/python_interpreter.py | 40 +++---- 8 files changed, 322 insertions(+), 169 deletions(-) delete mode 100644 lagent/actions/llm_qa.py create mode 100644 lagent/actions/parser.py diff --git a/lagent/actions/__init__.py b/lagent/actions/__init__.py index e3d4928b..d62bede1 100644 --- a/lagent/actions/__init__.py +++ b/lagent/actions/__init__.py @@ -2,10 +2,9 @@ from .base_action import BaseAction from .builtin_actions import FinishAction, InvalidAction, NoAction from .google_search import GoogleSearch -from .llm_qa import LLMQA from .python_interpreter import PythonInterpreter __all__ = [ 'BaseAction', 'ActionExecutor', 'InvalidAction', 'NoAction', - 'FinishAction', 'GoogleSearch', 'PythonInterpreter', 'LLMQA' + 'FinishAction', 'GoogleSearch', 'PythonInterpreter' ] diff --git a/lagent/actions/action_executor.py b/lagent/actions/action_executor.py index c3186e64..f1870a73 100644 --- a/lagent/actions/action_executor.py +++ b/lagent/actions/action_executor.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Union +from typing import Dict, List, Union from lagent.schema import ActionReturn, ActionValidCode from .base_action import BaseAction @@ -39,14 +39,20 @@ def __init__(self, self.no_action = no_action self.finish_action = finish_action - def get_actions_info(self, only_enable: bool = True) -> Dict: - if only_enable: - return { - k: v.description - for k, v in self.actions.items() if v.enable - } - else: - return {k: v.description for k, v in self.actions.items()} + def get_actions_info(self) -> List[Dict]: + actions = [] + for action_name, action in self.actions.items(): + if not action.enable: + continue + if action.is_toolkit: + for api in action.description['api_list']: + api_desc = api.copy() + api_desc['name'] = f"{action_name}.{api_desc['name']}" + actions.append(api_desc) + else: + action_desc = action.description.copy() + actions.append(action_desc) + return actions def is_valid(self, name: str): return name in self.actions and self.actions[name].enable @@ -66,19 +72,17 @@ def del_action(self, name: str): if name in self.actions: del self.actions[name] - def __call__(self, name: str, command: Any) -> ActionReturn: - if isinstance(command, str): - args, kwargs = (command, ), {} - else: - args, kwargs = (), command - if not self.is_valid(name): + def __call__(self, name: str, command: str) -> ActionReturn: + action_name, api_name = ( + name.split('.') if '.' in name else (name, 'run')) + if not self.is_valid(action_name): if name == self.no_action.name: - action_return = self.no_action.run(*args, **kwargs) + action_return = self.no_action(command) elif name == self.finish_action.name: - action_return = self.finish_action.run(*args, **kwargs) + action_return = self.finish_action(command) else: - action_return = self.invalid_action(*args, **kwargs) + action_return = self.invalid_action(command) else: - action_return = self.actions[name].run(*args, **kwargs) + action_return = self.actions[action_name](command, api_name) action_return.valid = ActionValidCode.OPEN return action_return diff --git a/lagent/actions/base_action.py b/lagent/actions/base_action.py index 34cb933e..c890960c 100644 --- a/lagent/actions/base_action.py +++ b/lagent/actions/base_action.py @@ -1,57 +1,139 @@ -from typing import Optional +from typing import Optional, Type -from lagent.schema import ActionReturn +from lagent.actions.parser import BaseParser, JsonParser, ParseError +from lagent.schema import ActionReturn, ActionStatusCode class BaseAction: """Base class for all actions. Args: - description (str, optional): The description of the action. Defaults to - None. - name (str, optional): The name of the action. If None, the name will - be class name. Defaults to None. - enable (bool, optional): Whether the action is enabled. Defaults to - True. - disable_description (str, optional): The description of the action when - it is disabled. Defaults to None. + description (:class:`Optional[dict]`): The description of the action. + Defaults to ``None``. + parser (:class:`Type[BaseParser]`): The parser class to process the + action's inputs and outputs. Defaults to :class:`JsonParser`. + enable (:class:`bool`): Whether the action is enabled. Defaults to + ``True``. + + Examples: + + * simple tool + + .. code-block:: python + + class Bold(BaseAction): + def run(self, text): + return '**' + text + '**' + + desc = dict( + name='bold', + description='make text bold', + parameters=[dict(name='text', type='STRING', description='input text')], + required=['text'], + ) + action = Bold(desc) + + * toolkit with multiple APIs + + .. code-block:: python + + class Calculator(BaseAction): + def add(self, a, b): + return a + b + + def sub(self, a, b): + return a - b + + desc = dict( + name='calculate', + description='perform arithmetic operations', + api_list=[ + dict( + name='add', + descrition='addition operation', + parameters=[ + dict(name='a', type='NUMBER', description='augend'), + dict(name='b', type='NUMBER', description='addend'), + ], + required=['a', 'b'], + ), + dict( + name='sub', + description='subtraction operation', + parameters=[ + dict(name='a', type='NUMBER', description='minuend'), + dict(name='b', type='NUMBER', description='subtrahend'), + ], + required=['a', 'b'], + ) + ] + ) + action = Calculator(desc) """ def __init__(self, - description: Optional[str] = None, - name: Optional[str] = None, - enable: bool = True, - disable_description: Optional[str] = None) -> None: - if name is None: - name = self.__class__.__name__ - self._name = name - self._description = description - self._disable_description = disable_description + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + enable: bool = True): + self._description = description.copy() if description else {} + self._name = self._description.get('name', self.__class__.__name__) self._enable = enable + self._is_toolkit = 'api_list' in self._description + self._parser = parser(self) - def __call__(self, *args, **kwargs) -> ActionReturn: - raise NotImplementedError - - def __repr__(self): - return f'{self.name}:{self.description}' + def __call__(self, inputs: str, name='run') -> ActionReturn: + fallback_args = {'inputs': inputs, 'name': name} + if not hasattr(self, name): + return ActionReturn( + fallback_args, + type=self.name, + errmsg=f'invalid API: {name}', + state=ActionStatusCode.API_ERROR) + try: + inputs = self._parser.parse_inputs(inputs, name) + except ParseError as exc: + return ActionReturn( + fallback_args, + type=self.name, + errmsg=exc.err_msg, + state=ActionStatusCode.ARGS_ERROR) + try: + outputs = getattr(self, name)(**inputs) + except Exception as exc: + return ActionReturn( + inputs, + type=self.name, + errmsg=str(exc), + state=ActionStatusCode.API_ERROR) + if isinstance(outputs, ActionReturn): + action_return = outputs + if not action_return.args: + action_return.args = inputs + else: + result = self._parser.parse_outputs(outputs) + action_return = ActionReturn(inputs, type=self.name, result=result) + return action_return - def __str__(self): - return self.__repr__() + def run(self): + return NotImplementedError - def run(self, *args, **kwargs) -> ActionReturn: - return self.__call__(*args, **kwargs) + @property + def name(self): + return self._name @property def enable(self): return self._enable @property - def name(self): - return self._name + def description(self): + return self._description @property - def description(self): - if self.enable: - return self._description - else: - return self._disable_description + def is_toolkit(self): + return self._is_toolkit + + def __repr__(self): + return f'{self.description}' + + __str__ = __repr__ diff --git a/lagent/actions/builtin_actions.py b/lagent/actions/builtin_actions.py index 20f5c2a7..76aea6c1 100644 --- a/lagent/actions/builtin_actions.py +++ b/lagent/actions/builtin_actions.py @@ -35,7 +35,7 @@ def __call__(self, err_msg: Optional[str] = None): action_return = ActionReturn( url=None, args=dict(text=err_msg), - errmsg=err_msg if err_msg else self._err_msg, + errmsg=err_msg or self._err_msg, type=self.name, valid=ActionValidCode.INVALID, state=ActionStatusCode.API_ERROR) @@ -52,7 +52,6 @@ class NoAction(BaseAction): """ def __init__(self, err_msg: str = 'Please follow the format', **kwargs): - super().__init__(enable=False, **kwargs) self._err_msg = err_msg @@ -71,7 +70,7 @@ def __call__(self, err_msg: Optional[str] = None): url=None, args=dict(text=err_msg), type=self.name, - errmsg=err_msg if err_msg else self._err_msg, + errmsg=err_msg or self._err_msg, valid=ActionValidCode.INVALID, state=ActionStatusCode.API_ERROR) return action_return diff --git a/lagent/actions/google_search.py b/lagent/actions/google_search.py index a907d4cf..0ed08082 100644 --- a/lagent/actions/google_search.py +++ b/lagent/actions/google_search.py @@ -1,15 +1,25 @@ import os -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Type, Union import requests from lagent.schema import ActionReturn, ActionStatusCode from .base_action import BaseAction - -DEFAULT_DESCRIPTION = """一个可以从谷歌搜索结果的API。 -当你需要对于一个特定问题找到简短明了的回答时,可以使用它。 -输入应该是一个搜索查询。 -""" +from .parser import BaseParser, JsonParser + +DEFAULT_DESCRIPTION = dict( + name='GoogleSearch', + description='一个可以从谷歌搜索结果的API。当你需要对于一个特定问题找到简短明了的回答时,可以使用它。输入应该是一个搜索查询。', + parameters=[ + dict(name='query', type='STRING', description='the search content'), + dict( + name='k', + type='NUMBER', + description= + 'select first k results in the search results as response'), + ], + required=['query'], +) class GoogleSearch(BaseAction): @@ -28,15 +38,12 @@ class GoogleSearch(BaseAction): timeout (int): Upper bound of waiting time for a serper request. search_type (str): Serper API support ['search', 'images', 'news', 'places'] types of search, currently we only support 'search'. - k (int): select first k results in the search results as response. - description (str): The description of the action. Defaults to - None. - name (str, optional): The name of the action. If None, the name will - be class name. Defaults to None. + description (dict): The description of the action. Defaults to + :py:data:`~DEFAULT_DESCRIPTION`. + parser (Type[BaseParser]): The parser class to process the + action's inputs and outputs. Defaults to :class:`JsonParser`. enable (bool, optional): Whether the action is enabled. Defaults to True. - disable_description (str, optional): The description of the action when - it is disabled. Defaults to None. """ result_key_for_type = { 'news': 'news', @@ -49,36 +56,23 @@ def __init__(self, api_key: Optional[str] = None, timeout: int = 5, search_type: str = 'search', - k: int = 10, - description: str = DEFAULT_DESCRIPTION, - name: Optional[str] = None, - enable: bool = True, - disable_description: Optional[str] = None) -> None: - super().__init__(description, name, enable, disable_description) - + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + enable: bool = True): + super().__init__(description or DEFAULT_DESCRIPTION, parser, enable) api_key = os.environ.get('SERPER_API_KEY', api_key) if api_key is None: raise ValueError( 'Please set Serper API key either in the environment ' - ' as SERPER_API_KEY or pass it as `api_key` parameter.') + 'as SERPER_API_KEY or pass it as `api_key` parameter.') self.api_key = api_key self.timeout = timeout self.search_type = search_type - self.k = k - - def __call__(self, query: str) -> ActionReturn: - """Return the search response. - - Args: - query (str): The search content. - - Returns: - ActionReturn: The action return. - """ + def run(self, query: str, k: int = 10) -> ActionReturn: + """Return the search response.""" tool_return = ActionReturn(url=None, args=None, type=self.name) - status_code, response = self._search( - query, search_type=self.search_type, k=self.k) + status_code, response = self._search(query, k=k) # convert search results to ToolReturn format if status_code == -1: tool_return.errmsg = response @@ -139,7 +133,7 @@ def _parse_results(self, results: dict) -> Union[str, List[str]]: def _search(self, search_term: str, - search_type: str = 'search', + search_type: Optional[str] = None, **kwargs) -> Tuple[int, Union[dict, str]]: """HTTP requests to Serper API. @@ -166,7 +160,7 @@ def _search(self, } try: response = requests.post( - f'https://google.serper.dev/{search_type}', + f'https://google.serper.dev/{search_type or self.search_type}', headers=headers, params=params, timeout=self.timeout) diff --git a/lagent/actions/llm_qa.py b/lagent/actions/llm_qa.py deleted file mode 100644 index adaffa29..00000000 --- a/lagent/actions/llm_qa.py +++ /dev/null @@ -1,56 +0,0 @@ -from typing import Optional, Union - -from lagent.llms.base_api import BaseAPIModel -from lagent.llms.base_llm import BaseModel -from lagent.schema import ActionReturn, ActionStatusCode -from .base_action import BaseAction - -DEFAULT_DESCRIPTION = """一个像你一样的大语言预训练模型,当你需要获得一些常识或简单世界知识时可以问它。 -当你很有把握自己直接解决问题时可以优先使用它。输入应该是一个询问语句, 且每个问题尽可能简单。 -""" - - -class LLMQA(BaseAction): - """An LLM Wrapper as BaseAction type. - - Args: - llm (BaseModel or BaseAPIModel): a LLM service which can chat. - description (str): The description of the action. Defaults to - None. - name (str, optional): The name of the action. If None, the name will - be class name. Defaults to None. - enable (bool, optional): Whether the action is enabled. Defaults to - True. - disable_description (str, optional): The description of the action when - it is disabled. Defaults to None. - """ - - def __init__(self, - llm: Union[BaseModel, BaseAPIModel], - description: str = DEFAULT_DESCRIPTION, - name: Optional[str] = None, - enable: bool = True, - disable_description: Optional[str] = None) -> None: - super().__init__(description, name, enable, disable_description) - - self._llm = llm - - def __call__(self, query: str) -> ActionReturn: - """Return the QA response. - - Args: - query (str): The query content. - - Returns: - ActionReturn: The action return. - """ - - tool_return = ActionReturn(url=None, args=None) - try: - response = self._llm.generate_from_template(query, 512) - tool_return.result = dict(text=str(response)) - tool_return.state = ActionStatusCode.SUCCESS - except Exception as e: - tool_return.result = dict(text=str(e)) - tool_return.state = ActionStatusCode.API_ERROR - return tool_return diff --git a/lagent/actions/parser.py b/lagent/actions/parser.py new file mode 100644 index 00000000..e9cdd4a4 --- /dev/null +++ b/lagent/actions/parser.py @@ -0,0 +1,127 @@ +import json +from ast import literal_eval +from typing import Any + + +class ParseError(Exception): + """Parsing exception class""" + + def __init__(self, err_msg: str): + self.err_msg = err_msg + + +class BaseParser: + """Base parser to process inputs and outputs of actions. + + Args: + action (:class:`BaseAction`): action to validate + + Attributes: + PARAMETER_DESCRIPTION (:class:`str`): declare the input format which + LLMs should follow when generating arguments for decided tools. + """ + + PARAMETER_DESCRIPTION: str = '' + + def __init__(self, action): + self.action = action + self._api2param = {} + self._api2required = {} + # perform basic argument validation + if action.description: + for api in action.description.get('api_list', + [action.description]): + name = (f'{action.name}.{api["name"]}' + if self.action.is_toolkit else api['name']) + required_parameters = set(api['required']) + all_parameters = {j['name'] for j in api['parameters']} + if not required_parameters.issubset(all_parameters): + raise ValueError( + f'unknown parameters for function "{name}": ' + f'{required_parameters - all_parameters}') + if self.PARAMETER_DESCRIPTION: + api['parameter_description'] = self.PARAMETER_DESCRIPTION + api_name = api['name'] if self.action.is_toolkit else 'run' + self._api2param[api_name] = api['parameters'] + self._api2required[api_name] = api['required'] + + def parse_inputs(self, inputs: str, name: str = 'run') -> dict: + """parse inputs LLMs generate for the action + + Args: + inputs (:class:`str`): input string extracted from responses + + Returns: + :class:`dict`: processed input + """ + inputs = {self._api2param[name][0]['name']: inputs} + return inputs + + def parse_outputs(self, outputs: Any) -> dict: + """parser outputs returned by the action + + Args: + outputs (:class:`Any`): raw output of the action + + Returns: + :class:`dict`: processed output + """ + if isinstance(outputs, dict): + outputs = json.dumps(outputs, ensure_ascii=False) + elif not isinstance(outputs, str): + outputs = str(outputs) + return {'text': outputs} + + +class JsonParser(BaseParser): + """Json parser to convert input string into a dictionary. + + Args: + action (:class:`BaseAction`): action to validate + """ + + PARAMETER_DESCRIPTION = '如果调用该工具,你必须使用 dict(key=value) 格式传参,其中key为参数名称' + + def parse_inputs(self, inputs: str, name: str = 'run') -> dict: + try: + inputs = json.loads(inputs) + except json.JSONDecodeError as exc: + raise ParseError(f'invalid json format: {inputs}') from exc + input_keys = set(inputs) + all_keys = {param['name'] for param in self._api2param[name]} + if not input_keys.issubset(all_keys): + raise ParseError(f'unknown arguments: {input_keys - all_keys}') + required_keys = set(self._api2required[name]) + if not input_keys.issuperset(required_keys): + raise ParseError( + f'missing required arguments: {required_keys - input_keys}') + return inputs + + +class TupleParser(BaseParser): + """Tuple parser to convert input string into a tuple. + + Args: + action (:class:`BaseAction`): action to validate + """ + + PARAMETER_DESCRIPTION = '如果调用该工具,你必须使用 (arg1, arg2, arg3) 格式传参,且参数是有序的' + + def parse_inputs(self, inputs: str, name: str = 'run') -> dict: + try: + inputs = literal_eval(inputs) + except Exception as exc: + raise ParseError(f'invalid tuple format: {inputs}') from exc + if len(inputs) < len(self._api2required[name]): + raise ParseError( + f'API takes {len(self._api2required[name])} required positional ' + f'arguments but {len(inputs)} were given') + if len(inputs) > len(self._api2param[name]): + raise ParseError( + f'API takes {len(self._api2param[name])} positional arguments ' + f'but {len(inputs)} were given') + inputs = { + self._api2param[name][i]['name']: item + for i, item in enumerate(inputs) + } + return inputs diff --git a/lagent/actions/python_interpreter.py b/lagent/actions/python_interpreter.py index e10102ba..823d91a5 100644 --- a/lagent/actions/python_interpreter.py +++ b/lagent/actions/python_interpreter.py @@ -1,11 +1,12 @@ import copy import io from contextlib import redirect_stdout -from typing import Any, Optional +from typing import Any, Optional, Type from func_timeout import FunctionTimedOut, func_set_timeout from lagent.actions.base_action import BaseAction +from lagent.actions.parser import BaseParser, JsonParser from lagent.schema import ActionReturn, ActionStatusCode @@ -29,7 +30,9 @@ def eval_code(self, expr: str) -> Any: return eval(expr, self._global_vars) -DEFAULT_DESCRIPTION = """用来执行Python代码。代码必须是一个函数, +DEFAULT_DESCRIPTION = dict( + name='PythonInterpreter', + description="""用来执行Python代码。代码必须是一个函数, 函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下: ```python # import 依赖包 @@ -44,46 +47,47 @@ def solution(): # 最后结果 final_answer = func(mid_variable) return final_answer -```""" +```""", + parameters=[ + dict(name='command', type='STRING', description='Python code snippet') + ], + required=['command'], +) class PythonInterpreter(BaseAction): """A Python executor that can execute Python scripts. Args: - description (str): The description of the action. Defaults to - DEFAULT_DESCRIPTION. answer_symbol (str, Optional): the answer symbol from LLM answer_expr (str, Optional): the answer function name of the Python script. Default to 'solution()'. answer_from_stdout (boolean): whether the execution results is from stdout. - name (str, optional): The name of the action. If None, the name will - be class nameDefaults to None. + timeout (int): Upper bound of waiting time for Python script execution. + description (dict): The description of the action. Defaults to + :py:data:`~DEFAULT_DESCRIPTION`. + parser (Type[BaseParser]): The parser class to process the + action's inputs and outputs. Defaults to :class:`JsonParser`. enable (bool, optional): Whether the action is enabled. Defaults to True. - disable_description (str, optional): The description of the action when - it is disabled. Defaults to None. - timeout (int): Upper bound of waiting time for Python script execution. """ def __init__(self, - description: str = DEFAULT_DESCRIPTION, answer_symbol: Optional[str] = None, answer_expr: Optional[str] = 'solution()', answer_from_stdout: bool = False, - name: Optional[str] = None, - enable: bool = True, - disable_description: Optional[str] = None, - timeout: int = 20) -> None: - super().__init__(description, name, enable, disable_description) - + timeout: int = 20, + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + enable: bool = True) -> None: + super().__init__(description or DEFAULT_DESCRIPTION, parser, enable) self.answer_symbol = answer_symbol self.answer_expr = answer_expr self.answer_from_stdout = answer_from_stdout self.timeout = timeout - def __call__(self, command: str) -> ActionReturn: + def run(self, command: str) -> ActionReturn: self.runtime = GenericRuntime() try: tool_return = func_set_timeout(self.timeout)(self._call)(command) From c42c884900eea7bd52a0332455586c8d04a76eb2 Mon Sep 17 00:00:00 2001 From: BraisedPork <46232992+braisedpork1964@users.noreply.github.com> Date: Mon, 22 Jan 2024 15:18:25 +0800 Subject: [PATCH 03/20] [Feature] add tools (#89) * add new tools * update PPT * chores * update action module init --------- Co-authored-by: wangzy --- lagent/actions/__init__.py | 11 +- lagent/actions/arxiv_search.py | 70 ++++ lagent/actions/base_action.py | 2 + lagent/actions/bing_map.py | 185 ++++++++++ lagent/actions/google_scholar_search.py | 459 ++++++++++++++++++++++++ lagent/actions/google_search.py | 2 +- lagent/actions/parser.py | 4 +- lagent/actions/ppt.py | 211 +++++++++++ lagent/actions/python_interpreter.py | 4 +- lagent/agents/react.py | 2 +- lagent/schema.py | 4 +- 11 files changed, 944 insertions(+), 10 deletions(-) create mode 100644 lagent/actions/arxiv_search.py create mode 100644 lagent/actions/bing_map.py create mode 100644 lagent/actions/google_scholar_search.py create mode 100644 lagent/actions/ppt.py diff --git a/lagent/actions/__init__.py b/lagent/actions/__init__.py index d62bede1..7146cfbf 100644 --- a/lagent/actions/__init__.py +++ b/lagent/actions/__init__.py @@ -1,10 +1,17 @@ from .action_executor import ActionExecutor +from .arxiv_search import ArxivSearch from .base_action import BaseAction +from .bing_map import BINGMap from .builtin_actions import FinishAction, InvalidAction, NoAction +from .google_scholar_search import GoogleScholar from .google_search import GoogleSearch +from .parser import BaseParser, JsonParser, TupleParser +from .ppt import PPT from .python_interpreter import PythonInterpreter __all__ = [ - 'BaseAction', 'ActionExecutor', 'InvalidAction', 'NoAction', - 'FinishAction', 'GoogleSearch', 'PythonInterpreter' + 'BaseAction', 'ActionExecutor', 'InvalidAction', 'FinishAction', + 'NoAction', 'BINGMap', 'ArxivSearch', 'FinishAction', 'GoogleSearch', + 'GoogleScholar', 'PythonInterpreter', 'PPT', 'BaseParser', 'JsonParser', + 'TupleParser' ] diff --git a/lagent/actions/arxiv_search.py b/lagent/actions/arxiv_search.py new file mode 100644 index 00000000..c2d5c9d6 --- /dev/null +++ b/lagent/actions/arxiv_search.py @@ -0,0 +1,70 @@ +from typing import Optional, Type + +import arxiv + +from lagent.actions.base_action import BaseAction +from lagent.actions.parser import BaseParser, JsonParser +from lagent.schema import ActionReturn, ActionStatusCode + +DEFAULT_DESCRIPTION = dict( + name='ArxivSearch', + description='Search information from Arxiv.org ' + 'Useful for when you need to answer questions about Physics, Mathematics, ' + 'Computer Science, Quantitative Biology, Quantitative Finance, Statistics, ' + 'Electrical Engineering, and Economics ' + 'from scientific articles on arxiv.org', + api_list=[ + dict( + name='get_arxiv_article_information', + description= + 'Run Arxiv search and get the article meta information.', + parameters=[ + dict( + name='query', + type='STRING', + description='the content of search query') + ], + required=['query'], + return_data=[ + dict( + name='content', + description='a list of 3 arxiv search papers'), + ], + ) + ], +) + + +class ArxivSearch(BaseAction): + """ArxivSearch action""" + + def __init__(self, + top_k_results: int = 3, + max_query_len: int = 300, + doc_content_chars_max: int = 1500, + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + enable: bool = True) -> None: + super().__init__(description or DEFAULT_DESCRIPTION, parser, enable) + self.top_k_results = top_k_results + self.max_query_len = max_query_len + self.doc_content_chars_max = doc_content_chars_max + + def get_arxiv_article_information(self, query: str): + try: + results = arxiv.Search( # type: ignore + query[:self.max_query_len], + max_results=self.top_k_results).results() + except Exception as exc: + return ActionReturn( + errmsg=f'Arxiv exception: {exc}', + state=ActionStatusCode.HTTP_ERROR) + docs = [ + f'Published: {result.updated.date()}\nTitle: {result.title}\n' + f'Authors: {", ".join(a.name for a in result.authors)}\n' + f'Summary: {result.summary[:self.doc_content_chars_max]}' + for result in results + ] + if docs: + return {'content': '\n\n'.join(docs)} + return {'content': 'No good Arxiv Result was found'} diff --git a/lagent/actions/base_action.py b/lagent/actions/base_action.py index c890960c..7f9bb1d6 100644 --- a/lagent/actions/base_action.py +++ b/lagent/actions/base_action.py @@ -109,6 +109,8 @@ def __call__(self, inputs: str, name='run') -> ActionReturn: action_return = outputs if not action_return.args: action_return.args = inputs + if not action_return.type: + action_return.type = self.name else: result = self._parser.parse_outputs(outputs) action_return = ActionReturn(inputs, type=self.name, result=result) diff --git a/lagent/actions/bing_map.py b/lagent/actions/bing_map.py new file mode 100644 index 00000000..7ebaff2a --- /dev/null +++ b/lagent/actions/bing_map.py @@ -0,0 +1,185 @@ +import json +import os +from typing import Optional, Type + +import requests + +from lagent.actions.base_action import BaseAction +from lagent.actions.parser import BaseParser, JsonParser + +DEFAULT_DESCRIPTION = dict( + name='BINGMap', + description='Plugin for looking up map information', + api_list=[ + dict( + name='get_distance', + description='Get the distance between two locations in km.', + parameters=[ + dict( + name='start', + type='STRING', + description='The start location.'), + dict( + name='end', type='STRING', description='The end location.') + ], + required=['start', 'end'], + return_data=[ + dict(name='distance', description='the distance in km.') + ]), + dict( + name='get_route', + description='Get the route between two locations in km.', + parameters=[ + dict( + name='start', + type='STRING', + description='The start location.'), + dict( + name='end', type='STRING', description='The end location.') + ], + required=['start', 'end'], + return_data=[ + dict( + name='route', description='the route, a list of actions.') + ]), + dict( + name='get_coordinates', + description='Get the coordinates of a location.', + parameters=[ + dict( + name='location', + type='STRING', + description='the location need to get coordinates.') + ], + required=['location'], + return_data=[ + dict( + name='latitude', + description='the latitude of the location.'), + dict( + name='longitude', + description='the longitude of the location.') + ]), + dict( + name='search_nearby', + description= + 'Search for places nearby a location, within a given radius, and return the results into a list. You can use either the places name or the latitude and longitude.', + parameters=[ + dict( + name='search_term', + type='STRING', + description='the place name'), + dict( + name='places', + type='STRING', + description='the name of the location.'), + dict( + name='latitude', + type='FLOAT', + description='the latitude of the location.'), + dict( + name='longitude', + type='FLOAT', + description='the longitude of the location.'), + dict( + name='radius', + type='NUMBER', + description='radius in meters.') + ], + required=['search_term'], + return_data=[ + dict( + name='places', + description= + 'the list of places, each place is a dict with name and address, at most 5 places.' + ) + ]), + ]) + + +class BINGMap(BaseAction): + """BING Map plugin for looking up map information""" + + def __init__(self, + key: Optional[str] = None, + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + enable: bool = True) -> None: + super().__init__(description or DEFAULT_DESCRIPTION, parser, enable) + key = os.environ.get('BING_MAP_KEY') + if key is None: + raise ValueError( + 'Please set BING Map API key either in the environment ' + 'as BING_MAP_KEY or pass it as `key` parameter.') + self.key = key + self.base_url = 'http://dev.virtualearth.net/REST/V1/' + + def get_distance(self, start: str, end: str) -> dict: + # Request URL + url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key + # GET request + r = requests.get(url) + # TODO check request status? + data = json.loads(r.text) + # Extract route information + route = data['resourceSets'][0]['resources'][0] + # Extract distance in miles + distance = route['travelDistance'] + return dict(distance=distance) + + def get_route(self, start: str, end: str) -> dict: + # Request URL + url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key + # GET request + r = requests.get(url) + data = json.loads(r.text) + # Extract route information + route = data['resourceSets'][0]['resources'][0] + itinerary = route['routeLegs'][0]['itineraryItems'] + # Extract route text information + route_text = [] + for item in itinerary: + if 'instruction' in item: + route_text.append(item['instruction']['text']) + return dict(route=route_text) + + def get_coordinates(self, location: str) -> dict: + url = self.base_url + 'Locations' + params = {'query': location, 'key': self.key} + response = requests.get(url, params=params) + json_data = response.json() + coordinates = json_data['resourceSets'][0]['resources'][0]['point'][ + 'coordinates'] + return dict(latitude=coordinates[0], longitude=coordinates[1]) + + def search_nearby(self, + search_term: str, + places: str = 'unknown', + latitude: float = 0.0, + longitude: float = 0.0, + radius: int = 5000) -> dict: # radius in meters + url = self.base_url + 'LocalSearch' + if places != 'unknown': + pos = self.get_coordinates(**{'location': places}) + latitude, longitude = pos[1]['latitude'], pos[1]['longitude'] + # Build the request query string + params = { + 'query': search_term, + 'userLocation': f'{latitude},{longitude}', + 'radius': radius, + 'key': self.key + } + # Make the request + response = requests.get(url, params=params) + # Parse the response + response_data = json.loads(response.content) + # Get the results + results = response_data['resourceSets'][0]['resources'] + addresses = [] + for result in results: + name = result['name'] + address = result['Address']['formattedAddress'] + addresses.append(dict(name=name, address=address)) + if len(addresses) == 5: + break + return dict(place=addresses) diff --git a/lagent/actions/google_scholar_search.py b/lagent/actions/google_scholar_search.py new file mode 100644 index 00000000..941870e6 --- /dev/null +++ b/lagent/actions/google_scholar_search.py @@ -0,0 +1,459 @@ +import os +from typing import Optional, Type + +from serpapi import GoogleSearch + +from lagent.actions.base_action import BaseAction +from lagent.schema import ActionReturn, ActionStatusCode +from .parser import BaseParser, JsonParser + +DEFAULT_DESCRIPTION = dict( + name='GoogleScholar', + description='Plugin for google scholar search', + api_list=[{ + 'name': + 'search_google_scholar', + 'description': + 'Search for scholarly articles based on a query according to the google scholar', + 'parameters': [ + { + 'name': 'query', + 'description': 'The query to search for.', + 'type': 'STRING' + }, + { + 'name': 'cites', + 'description': + 'The unique ID of an article for triggering "Cited By" searches', + 'type': 'STRING' + }, + { + 'name': 'as_ylo', + 'description': + 'The starting year for results (e.g., if as_ylo=2018, results before this year will be omitted)', + 'type': 'NUMBER' + }, + { + 'name': 'as_yhi', + 'description': + 'The ending year for results (e.g., if as_yhi=2018, results after this year will be omitted)', + 'type': 'NUMBER' + }, + { + 'name': 'scisbd', + 'description': + 'Defines articles added in the last year, sorted by date. It can be set to 1 to include only abstracts, or 2 to include everything', + 'type': 'NUMBER' + }, + { + 'name': 'cluster', + 'description': + 'The unique ID of an article for triggering "All Versions" searches', + 'type': 'STRING' + }, + { + 'name': 'hl', + 'description': + 'The language to use for the Google Scholar search', + 'type': 'STRING' + }, + { + 'name': 'lr', + 'description': + 'One or multiple languages to limit the search to', + 'type': 'STRING' + }, + { + 'name': 'start', + 'description': + 'The result offset for pagination (0 is the first page of results, 10 is the 2nd page, etc.)', + 'type': 'NUMBER' + }, + { + 'name': 'num', + 'description': + 'The maximum number of results to return, limited to 20', + 'type': 'NUMBER' + }, + { + 'name': 'as_sdt', + 'description': + 'Can be used either as a search type or a filter', + 'type': 'STRING' + }, + { + 'name': 'safe', + 'description': 'The level of filtering for adult content', + 'type': 'STRING' + }, + { + 'name': 'filter', + 'description': + "Defines if the filters for 'Similar Results' and 'Omitted Results' are on or off", + 'type': 'STRING' + }, + { + 'name': 'as_vis', + 'description': 'Defines whether to include citations or not', + 'type': 'STRING' + }, + ], + 'required': ['query'], + 'return_data': [{ + 'name': + 'title', + 'description': + 'a list of the titles of the three selected papers' + }, { + 'name': + 'cited_by', + 'description': + 'a list of the citation numbers of the three selected papers' + }, { + 'name': + 'organic_id', + 'description': + 'a list of the organic results\' ids of the three selected papers' + }, { + 'name': 'snippets', + 'description': 'snippets of the papers' + }, { + 'name': + 'pub_info', + 'description': + 'publication information of selected papers' + }] + }, { + 'name': + 'get_author_information', + 'description': + 'Search for an author\'s information by author\'s id provided by get_author_id.', + 'parameters': [{ + 'name': 'author_id', + 'description': 'Required. The ID of an author.', + 'type': 'STRING' + }, { + 'name': 'hl', + 'description': + "The language to use for the Google Scholar Author search. Default is 'en'.", + 'type': 'STRING' + }, { + 'name': 'view_op', + 'description': 'Used for viewing specific parts of a page.', + 'type': 'STRING' + }, { + 'name': 'sort', + 'description': 'Used for sorting and refining articles.', + 'type': 'STRING' + }, { + 'name': 'citation_id', + 'description': 'Used for retrieving individual article citation.', + 'type': 'STRING' + }, { + 'name': 'start', + 'description': 'Defines the result offset. Default is 0.', + 'type': 'NUMBER' + }, { + 'name': 'num', + 'description': + 'Defines the number of results to return. Default is 20.', + 'type': 'NUMBER' + }, { + 'name': 'no_cache', + 'description': + 'Forces SerpApi to fetch the results even if a cached version is already present. Default is False.', + 'type': 'BOOLEAN' + }, { + 'name': 'async_req', + 'description': + 'Defines the way you want to submit your search to SerpApi. Default is False.', + 'type': 'BOOLEAN' + }, { + 'name': 'output', + 'description': + "Defines the final output you want. Default is 'json'.", + 'type': 'STRING' + }], + 'required': ['author_id'], + 'return_data': [{ + 'name': 'name', + 'description': "author's name" + }, { + 'name': 'affliation', + 'description': 'the affliation of the author' + }, { + 'name': 'articles', + 'description': 'at most 3 articles by the author' + }, { + 'name': 'website', + 'description': "the author's homepage url" + }] + }, { + 'name': + 'get_citation_format', + 'description': + 'Function to get MLA citation format by an identification of organic_result\'s id provided by search_google_scholar.', + 'parameters': [{ + 'name': 'q', + 'description': + 'ID of an individual Google Scholar organic search result.', + 'type': 'STRING' + }, { + 'name': 'no_cache', + 'description': + 'If set to True, will force SerpApi to fetch the Google Scholar Cite results even if a cached version is already present. Defaults to None.', + 'type': 'BOOLEAN' + }, { + 'name': 'async_', + 'description': + 'If set to True, will submit search to SerpApi and retrieve results later. Defaults to None.', + 'type': 'BOOLEAN' + }, { + 'name': 'output', + 'description': + "Final output format. Set to 'json' to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.", + 'type': 'STRING' + }], + 'required': ['q'], + 'return_data': [{ + 'name': 'authors', + 'description': 'the authors of the article' + }, { + 'name': 'citation', + 'description': 'the citation format of the article' + }] + }, { + 'name': + 'get_author_id', + 'description': + 'The getAuthorId function is used to get the author\'s id by his or her name.', + 'parameters': [{ + 'name': 'mauthors', + 'description': 'Defines the author you want to search for.', + 'type': 'STRING' + }, { + 'name': 'hl', + 'description': + "Defines the language to use for the Google Scholar Profiles search.It's a two-letter language code. (e.g., 'en' for English, 'es' for Spanish, or 'fr' for French). Defaults to 'en'.", + 'type': 'STRING' + }, { + 'name': 'after_author', + 'description': + 'Defines the next page token. It is used for retrieving the next page results. The parameter has the precedence over before_author parameter. Defaults to None.', + 'type': 'STRING' + }, { + 'name': 'before_author', + 'description': + 'Defines the previous page token. It is used for retrieving the previous page results. Defaults to None.', + 'type': 'STRING' + }, { + 'name': 'no_cache', + 'description': + 'Will force SerpApi to fetch the Google Scholar Profiles results even if a cached version is already present. Defaults to False.', + 'type': 'BOOLEAN' + }, { + 'name': '_async', + 'description': + 'Defines the way you want to submit your search to SerpApi. Defaults to False.', + 'type': 'BOOLEAN' + }, { + 'name': 'output', + 'description': + "Defines the final output you want. It can be set to 'json' (default) to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.", + 'type': 'STRING' + }], + 'required': ['mauthors'], + 'return_data': [{ + 'name': 'author_id', + 'description': 'the author_id of the author' + }], + }]) + + +class GoogleScholar(BaseAction): + """Wrapper around the Serper.dev Google Search API. + + To use, you should pass your serper API key to the constructor. + + Code is modified from lang-chain GoogleSerperAPIWrapper + (https://github.com/langchain-ai/langchain/blob/ba5f + baba704a2d729a4b8f568ed70d7c53e799bb/libs/langchain/ + langchain/utilities/google_serper.py) + + Args: + api_key (str): API KEY to use serper google search API, + You can create a free API key at https://serper.dev. + description (dict): The description of the action. Defaults to + :py:data:`~DEFAULT_DESCRIPTION`. + parser (Type[BaseParser]): The parser class to process the + action's inputs and outputs. Defaults to :class:`JsonParser`. + enable (bool, optional): Whether the action is enabled. Defaults to + True. + """ + + def __init__(self, + api_key: Optional[str] = None, + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + enable: bool = True): + super().__init__(description or DEFAULT_DESCRIPTION, parser, enable) + api_key = os.environ.get('SERPER_API_KEY', api_key) + if api_key is None: + raise ValueError( + 'Please set Serper API key either in the environment ' + 'as SERPER_API_KEY or pass it as `api_key` parameter.') + self.api_key = api_key + + def search_google_scholar( + self, + query: str, + cites: Optional[str] = None, + as_ylo: Optional[int] = None, + as_yhi: Optional[int] = None, + scisbd: Optional[int] = None, + cluster: Optional[str] = None, + hl: Optional[str] = None, + lr: Optional[str] = None, + start: Optional[int] = None, + num: Optional[int] = None, + as_sdt: Optional[str] = None, + safe: Optional[str] = None, + filter: Optional[str] = None, + as_vis: Optional[str] = None, + ): + params = { + 'q': query, + 'engine': 'google_scholar', + 'api_key': self.api_key, + 'cites': cites, + 'as_ylo': as_ylo, + 'as_yhi': as_yhi, + 'scisbd': scisbd, + 'cluster': cluster, + 'hl': hl, + 'lr': lr, + 'start': start, + 'num': num, + 'as_sdt': as_sdt, + 'safe': safe, + 'filter': filter, + 'as_vis': as_vis + } + search = GoogleSearch(params) + try: + r = search.get_dict() + results = r['organic_results'] + title = [] + snippets = [] + cited_by = [] + organic_id = [] + pub_info = [] + for item in results[:3]: + title.append(item['title']) + pub_info.append(item['publication_info']['summary']) + citation = item['inline_links'].get('cited_by', {'total': ''}) + cited_by.append(citation['total']) + snippets.append(item['snippet']) + organic_id.append(item['result_id']) + return dict( + title=title, + cited_by=cited_by, + organic_id=organic_id, + snippets=snippets) + except Exception as e: + return ActionReturn( + errmsg=str(e), state=ActionStatusCode.HTTP_ERROR) + + def get_author_information(self, + author_id: str, + hl: Optional[str] = None, + view_op: Optional[str] = None, + sort: Optional[str] = None, + citation_id: Optional[str] = None, + start: Optional[int] = None, + num: Optional[int] = None, + no_cache: Optional[bool] = None, + async_req: Optional[bool] = None, + output: Optional[str] = None): + params = { + 'engine': 'google_scholar_author', + 'author_id': author_id, + 'api_key': self.api_key, + 'hl': hl, + 'view_op': view_op, + 'sort': sort, + 'citation_id': citation_id, + 'start': start, + 'num': num, + 'no_cache': no_cache, + 'async': async_req, + 'output': output + } + try: + search = GoogleSearch(params) + results = search.get_dict() + author = results['author'] + articles = results.get('articles', []) + return dict( + name=author['name'], + affiliations=author.get('affiliations', ''), + website=author.get('website', ''), + articles=[ + dict(title=article['title'], authors=article['authors']) + for article in articles[:3] + ]) + except Exception as e: + return ActionReturn( + errmsg=str(e), state=ActionStatusCode.HTTP_ERROR) + + def get_citation_format(self, + q: str, + no_cache: Optional[bool] = None, + async_: Optional[bool] = None, + output: Optional[str] = 'json'): + params = { + 'q': q, + 'engine': 'google_scholar_cite', + 'api_key': self.api_key, + 'no_cache': no_cache, + 'async': async_, + 'output': output + } + try: + search = GoogleSearch(params) + results = search.get_dict() + citation = results['citations'] + citation_info = citation[0]['snippet'] + return citation_info + except Exception as e: + return ActionReturn( + errmsg=str(e), state=ActionStatusCode.HTTP_ERROR) + + def get_author_id(self, + mauthors: str, + hl: Optional[str] = 'en', + after_author: Optional[str] = None, + before_author: Optional[str] = None, + no_cache: Optional[bool] = False, + _async: Optional[bool] = False, + output: Optional[str] = 'json'): + params = { + 'mauthors': mauthors, + 'engine': 'google_scholar_profiles', + 'api_key': self.api_key, + 'hl': hl, + 'after_author': after_author, + 'before_author': before_author, + 'no_cache': no_cache, + 'async': _async, + 'output': output + } + try: + search = GoogleSearch(params) + results = search.get_dict() + profile = results['profiles'] + author_info = dict(author_id=profile[0]['author_id']) + return author_info + except Exception as e: + return ActionReturn( + errmsg=str(e), state=ActionStatusCode.HTTP_ERROR) diff --git a/lagent/actions/google_search.py b/lagent/actions/google_search.py index 0ed08082..e26d829a 100644 --- a/lagent/actions/google_search.py +++ b/lagent/actions/google_search.py @@ -71,7 +71,7 @@ def __init__(self, def run(self, query: str, k: int = 10) -> ActionReturn: """Return the search response.""" - tool_return = ActionReturn(url=None, args=None, type=self.name) + tool_return = ActionReturn(type=self.name) status_code, response = self._search(query, k=k) # convert search results to ToolReturn format if status_code == -1: diff --git a/lagent/actions/parser.py b/lagent/actions/parser.py index e9cdd4a4..d31bc69e 100644 --- a/lagent/actions/parser.py +++ b/lagent/actions/parser.py @@ -80,7 +80,7 @@ class JsonParser(BaseParser): action (:class:`BaseAction`): action to validate """ - PARAMETER_DESCRIPTION = '如果调用该工具,你必须使用 dict(key=value) 格式传参,其中key为参数名称' + PARAMETER_DESCRIPTION = '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称' def parse_inputs(self, inputs: str, name: str = 'run') -> dict: try: @@ -105,7 +105,7 @@ class TupleParser(BaseParser): action (:class:`BaseAction`): action to validate """ - PARAMETER_DESCRIPTION = '如果调用该工具,你必须使用 (arg1, arg2, arg3) 格式传参,且参数是有序的' + PARAMETER_DESCRIPTION = '如果调用该工具,你必须使用Tuple格式 (arg1, arg2, arg3) 传参,且参数是有序的' def parse_inputs(self, inputs: str, name: str = 'run') -> dict: try: diff --git a/lagent/actions/ppt.py b/lagent/actions/ppt.py new file mode 100644 index 00000000..26715f34 --- /dev/null +++ b/lagent/actions/ppt.py @@ -0,0 +1,211 @@ +from typing import Dict, Optional, Type + +from pptx import Presentation + +from lagent.actions.base_action import BaseAction +from lagent.actions.parser import BaseParser, JsonParser + +DEFAULT_DESCRIPTION = dict( + name='PPT', + description= + 'This tool allows you to create ppt slides with text, paragraph, images, with good looking styles', + api_list=[ + dict( + name='create_file', + description='Create a pptx file with specific themes', + parameters=[ + dict( + name='theme', type='STRING', description='the theme used'), + dict( + name='abs_location', + type='STRING', + description='the ppt file\'s absolute location') + ], + required=['theme', 'abs_location'], + return_data=[ + dict(name='status', description='the result of the execution') + ]), + dict( + name='get_image', + description= + 'Get an image given comma separated keywords, return the image path.', + parameters=[ + dict( + name='keywords', + type='STRING', + description= + 'the comma separated keywords to describe the image') + ], + required=['keywords'], + return_data=[ + dict(name='status', description='the result of the execution') + ]), + dict( + name='add_first_page', + description='Add the first page of ppt.', + parameters=[ + dict( + name='title', + type='STRING', + description='the title of ppt'), + dict( + name='subtitle', + type='STRING', + description='the subtitle of ppt') + ], + required=['title', 'subtitle'], + return_data=[ + dict(name='status', description='the result of the execution') + ]), + dict( + name='add_text_page', + description='Add text page of ppt', + parameters=[ + dict( + name='title', + type='STRING', + description='the title of the page'), + dict( + name='bullet_items', + type='STRING', + description= + 'bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them.' + ) + ], + required=['title', 'bullet_items'], + return_data=[ + dict(name='status', description='the result of the execution') + ]), + dict( + name='add_text_image_page', + description= + 'Add a text page with one image. Image should be a path', + parameters=[ + dict( + name='title', + type='STRING', + description='the title of the page'), + dict( + name='bullet_items', + type='STRING', + description= + 'bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them.' + ), + dict( + name='image', + type='STRING', + description='the path of the image') + ], + required=['title', 'bullet_items', 'image'], + return_data=[ + dict(name='status', description='the result of the execution') + ]), + dict( + name='submit_file', + description= + 'When all steps done, YOU MUST use submit_file() to submit your work.', + parameters=[], + required=[], + return_data=[ + dict(name='status', description='the result of the execution') + ]) + ]) + +THEME_MAPPING = { + 'Default': { + 'template': None, + 'title': 'Title Slide', + 'single': 'Title and Content', + 'two': 'Tow content', + } +} + + +class PPT(BaseAction): + """Plugin to create ppt slides with text, paragraph, images in good looking styles""" + + def __init__(self, + theme_mapping: Optional[Dict[str, dict]] = None, + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + enable: bool = True): + super().__init__(description or DEFAULT_DESCRIPTION, parser, enable) + self.theme_mapping = theme_mapping or THEME_MAPPING + self.pointer = None + self.location = None + + def create_file(self, theme: str, abs_location: str) -> dict: + self.location = abs_location + try: + self.pointer = Presentation(self.theme_mapping[theme]['template']) + self.pointer.slide_master.name = theme + # print('created') + except Exception as e: + print(e) + return dict(status='created a ppt file.') + + def add_first_page(self, title: str, subtitle: str) -> dict: + layout_name = self.theme_mapping[ + self.pointer.slide_master.name]['title'] + layout = next(i for i in self.pointer.slide_master.slide_layouts + if i.name == layout_name) + slide = self.pointer.slides.add_slide(layout) + ph_title, ph_subtitle = slide.placeholders + ph_title.text = title + if subtitle: + ph_subtitle.text = subtitle + return dict(status='added page') + + def add_text_page(self, title: str, bullet_items: str) -> dict: + layout_name = self.theme_mapping[ + self.pointer.slide_master.name]['single'] + layout = next(i for i in self.pointer.slide_master.slide_layouts + if i.name == layout_name) + slide = self.pointer.slides.add_slide(layout) + ph_title, ph_body = slide.placeholders + ph_title.text = title + ph = ph_body + tf = ph.text_frame + for i, item in enumerate(bullet_items.split('[SPAN]')): + if i == 0: + p = tf.paragraphs[0] + else: + p = tf.add_paragraph() + p.text = item.strip() + p.level = 0 + return dict(status='added page') + + def add_text_image_page(self, title: str, bullet_items: str, + image: str) -> dict: + layout_name = self.theme_mapping[self.pointer.slide_master.name]['two'] + layout = next(i for i in self.pointer.slide_master.slide_layouts + if i.name == layout_name) + slide = self.pointer.slides.add_slide(layout) + ph_title, ph_body1, ph_body2 = slide.placeholders + ph_title.text = title + ph = ph_body2 + image_pil = image.to_pil() + left = ph.left + width = ph.width + height = int(width / image_pil.width * image_pil.height) + top = (ph.top + (ph.top + ph.height)) // 2 - height // 2 + slide.shapes.add_picture(image.to_path(), left, top, width, height) + + ph = ph_body1 + tf = ph.text_frame + for i, item in enumerate(bullet_items.split('[SPAN]')): + if i == 0: + p = tf.paragraphs[0] + else: + p = tf.add_paragraph() + p.text = item.strip() + p.level = 0 + + return dict(status='added page') + + def submit_file(self) -> dict: + # file_path = os.path.join(self.CACHE_DIR, f'{self._return_timestamp()}.pptx') + # self.pointer.save(file_path) + # retreival_url = upload_file(file_path) + self.pointer.save(self.location) + return dict(status=f'submitted. view ppt at {self.location}') diff --git a/lagent/actions/python_interpreter.py b/lagent/actions/python_interpreter.py index 823d91a5..9720ef72 100644 --- a/lagent/actions/python_interpreter.py +++ b/lagent/actions/python_interpreter.py @@ -92,13 +92,13 @@ def run(self, command: str) -> ActionReturn: try: tool_return = func_set_timeout(self.timeout)(self._call)(command) except FunctionTimedOut as e: - tool_return = ActionReturn(url=None, args=None, type=self.name) + tool_return = ActionReturn(type=self.name) tool_return.errmsg = repr(e) tool_return.state = ActionStatusCode.API_ERROR return tool_return def _call(self, command: str) -> ActionReturn: - tool_return = ActionReturn(url=None, args=None, type=self.name) + tool_return = ActionReturn(type=self.name) try: if '```python' in command: command = command.split('```python')[1].split('```')[0] diff --git a/lagent/agents/react.py b/lagent/agents/react.py index 033ffa62..dfb5397f 100644 --- a/lagent/agents/react.py +++ b/lagent/agents/react.py @@ -43,7 +43,7 @@ The response after utilizing tools should using the following format: ``` {response}the results after call the tool. -`` +``` If you already know the answer, or you do not need to use tools, please using the following format to reply: ``` diff --git a/lagent/schema.py b/lagent/schema.py index 925df6c5..91c7001b 100644 --- a/lagent/schema.py +++ b/lagent/schema.py @@ -1,6 +1,6 @@ from dataclasses import asdict, dataclass, field from enum import Enum -from typing import Dict, List, Optional, Union +from typing import List, Optional, Union from lagent.utils import is_module_exist @@ -33,7 +33,7 @@ class ActionValidCode(int, Enum): @dataclass class ActionReturn: - args: Dict + args: Optional[dict] = None url: Optional[str] = None type: Optional[str] = None result: Optional[str] = None From fdaacb87a4339915199f75c05864ce9934effe38 Mon Sep 17 00:00:00 2001 From: liujiangning30 <147385819+liujiangning30@users.noreply.github.com> Date: Mon, 22 Jan 2024 15:32:57 +0800 Subject: [PATCH 04/20] rename func 'completion' to 'generate' (#90) --- lagent/llms/base_llm.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/lagent/llms/base_llm.py b/lagent/llms/base_llm.py index 05919787..000b8918 100644 --- a/lagent/llms/base_llm.py +++ b/lagent/llms/base_llm.py @@ -132,7 +132,7 @@ def __init__(self, if meta_template and 'eos_token_id' in meta_template: self.eos_token_id = meta_template['eos_token_id'] - self.completion_params = dict( + self.gen_params = dict( max_out_len=max_out_len, top_p=top_p, top_k=top_k, @@ -141,16 +141,16 @@ def __init__(self, stop_words=stop_words) @abstractclassmethod - def completion( + def generate( self, inputs: Union[str, List[str]], - **completion_params + **gen_params ) -> str: """Generate results given a str (or list of) inputs. Args: inputs (Union[str, List[str]]): - completion_params (dict): The input params for completion. + gen_params (dict): The input params for generation. Returns: Union[str, List[str]]: A (list of) generated strings. @@ -166,16 +166,16 @@ def completion( return response[0] """ - def stream_completion( + def stream_generate( self, inputs: str, - **completion_params + **gen_params ) -> List[str]: """Generate results as streaming given a str inputs. Args: inputs (str): - completion_params (dict): The input params for completion. + gen_params (dict): The input params for generation. Returns: str: A generated string. @@ -185,13 +185,13 @@ def stream_completion( def chat( self, inputs: Union[List[dict], List[List[dict]]], - **completion_params + **gen_params ): """Generate completion from a list of templates. Args: inputs (Union[List[dict], List[List[dict]]]): - completion_params (dict): The input params for completion. + gen_params (dict): The input params for generation. Returns: """ if isinstance(inputs[0], list): @@ -200,18 +200,18 @@ def chat( inputs.append(self.template_parser(msg)) else: inputs = self.template_parser(inputs) - return self.completion(inputs, **completion_params) + return self.generate(inputs, **gen_params) def stream_chat( self, inputs: List[dict], - **completion_params + **gen_params ): """Generate results as streaming given a list of templates. Args: inputs (Union[List[dict]): - completion_params (dict): The input params for completion. + gen_params (dict): The input params for generation. Returns: """ raise NotImplementedError @@ -231,7 +231,7 @@ def tokenize( """ raise NotImplementedError - def update_completion_params(self, **kwargs): - completion_params = copy(self.completion_params) - completion_params.update(kwargs) - return completion_params + def update_gen_params(self, **kwargs): + gen_params = copy(self.gen_params) + gen_params.update(kwargs) + return gen_params From adefd97a273ba44543c3b8e814821aa326f330bc Mon Sep 17 00:00:00 2001 From: BraisedPork <46232992+braisedpork1964@users.noreply.github.com> Date: Tue, 23 Jan 2024 18:40:20 +0800 Subject: [PATCH 05/20] [Feature] support batch inference in API models (#91) * implement `chat` * update agent interfaces * redundancy reduction --------- Co-authored-by: wangzy --- lagent/agents/autogpt.py | 16 +++---- lagent/agents/base_agent.py | 10 +--- lagent/agents/react.py | 32 +++++++------ lagent/agents/rewoo.py | 33 +++++++------ lagent/llms/base_api.py | 57 +++++++--------------- lagent/llms/base_llm.py | 32 +++---------- lagent/llms/openai.py | 96 ++++++++++++++++--------------------- 7 files changed, 108 insertions(+), 168 deletions(-) diff --git a/lagent/agents/autogpt.py b/lagent/agents/autogpt.py index 519715d8..f0eedf06 100644 --- a/lagent/agents/autogpt.py +++ b/lagent/agents/autogpt.py @@ -1,6 +1,5 @@ # flake8: noqa import ast -import copy import platform from typing import Dict, List, Optional, Tuple, Union @@ -261,18 +260,17 @@ def __init__(self, super().__init__( llm=llm, action_executor=action_executor, protocol=protocol) - def chat(self, goal: str) -> AgentReturn: - self._inner_history = [] + def chat(self, goal: str, **kwargs) -> AgentReturn: + inner_history = [] agent_return = AgentReturn() default_response = 'Sorry that I cannot answer your question.' for _ in range(self.max_turn): prompt = self._protocol.format( goal=goal, - inner_history=self._inner_history, + inner_history=inner_history, action_executor=self._action_executor) - response = self._llm.generate_from_template(prompt, 512) - self._inner_history.append( - dict(role='assistant', content=response)) + response = self._llm.chat(prompt, **kwargs) + inner_history.append(dict(role='assistant', content=response)) action, action_input = self._protocol.parse( response, self._action_executor) action_return: ActionReturn = self._action_executor( @@ -281,10 +279,10 @@ def chat(self, goal: str) -> AgentReturn: if action_return.type == self._action_executor.finish_action.name: agent_return.response = action_return.result['text'] return agent_return - self._inner_history.append( + inner_history.append( dict( role='system', content=self._protocol.format_response(action_return))) - agent_return.inner_steps = copy.deepcopy(self._inner_history) + agent_return.inner_steps = inner_history agent_return.response = default_response return agent_return diff --git a/lagent/agents/base_agent.py b/lagent/agents/base_agent.py index d3118602..6ffdea1c 100644 --- a/lagent/agents/base_agent.py +++ b/lagent/agents/base_agent.py @@ -1,5 +1,3 @@ -from typing import List - from lagent.actions import ActionExecutor from lagent.actions.base_action import BaseAction from lagent.llms.base_llm import BaseModel @@ -19,8 +17,6 @@ class BaseAgent: def __init__(self, llm: BaseModel, action_executor: ActionExecutor, protocol: object) -> None: - - self._session_history = [] self._llm = llm self._action_executor = action_executor self._protocol = protocol @@ -41,9 +37,5 @@ def del_action(self, name: str) -> None: """ self._action_executor.del_action(name) - def chat(self, message: str) -> AgentReturn: + def chat(self, message: str, **kwargs) -> AgentReturn: raise NotImplementedError - - @property - def session_history(self) -> List: - return self._session_history diff --git a/lagent/agents/react.py b/lagent/agents/react.py index dfb5397f..3e09d924 100644 --- a/lagent/agents/react.py +++ b/lagent/agents/react.py @@ -1,4 +1,3 @@ -import copy from typing import Dict, List, Tuple, Union from lagent.actions import ActionExecutor @@ -210,20 +209,27 @@ def __init__(self, super().__init__( llm=llm, action_executor=action_executor, protocol=protocol) - def chat(self, message: str) -> AgentReturn: - self._inner_history = [] - self._inner_history.append(dict(role='user', content=message)) + def chat(self, message: Union[str, dict, List[dict]], + **kwargs) -> AgentReturn: + if isinstance(message, str): + inner_history = [dict(role='user', content=message)] + elif isinstance(message, dict): + inner_history = [message] + elif isinstance(message, list): + inner_history = message[:] + else: + raise TypeError(f'unsupported type: {type(message)}') + offset = len(inner_history) agent_return = AgentReturn() default_response = 'Sorry that I cannot answer your question.' for turn in range(self.max_turn): prompt = self._protocol.format( - chat_history=self.session_history, - inner_step=self._inner_history, + chat_history=[], + inner_step=inner_history, action_executor=self._action_executor, force_stop=(turn == self.max_turn - 1)) - response = self._llm.generate_from_template(prompt, 512) - self._inner_history.append( - dict(role='assistant', content=response)) + response = self._llm.chat(prompt, **kwargs) + inner_history.append(dict(role='assistant', content=response)) thought, action, action_input = self._protocol.parse( response, self._action_executor) action_return: ActionReturn = self._action_executor( @@ -233,15 +239,11 @@ def chat(self, message: str) -> AgentReturn: if action_return.type == self._action_executor.finish_action.name: agent_return.response = action_return.result['text'] break - self._inner_history.append( + inner_history.append( dict( role='system', content=self._protocol.format_response(action_return))) else: agent_return.response = default_response - agent_return.inner_steps = copy.deepcopy(self._inner_history) - # only append the user and final response - self._session_history.append(dict(role='user', content=message)) - self._session_history.append( - dict(role='assistant', content=agent_return.response)) + agent_return.inner_steps = inner_history[offset:] return agent_return diff --git a/lagent/agents/rewoo.py b/lagent/agents/rewoo.py index 92f5f859..a9bd3163 100644 --- a/lagent/agents/rewoo.py +++ b/lagent/agents/rewoo.py @@ -1,4 +1,3 @@ -import copy import re import warnings from typing import Dict, List, Optional, Tuple, Union @@ -227,9 +226,17 @@ def __init__(self, self.max_turn = max_turn - def chat(self, message: str) -> AgentReturn: - self._inner_history = [] - self._inner_history.append(dict(role='user', content=message)) + def chat(self, message: Union[str, dict, List[dict]], + **kwargs) -> AgentReturn: + if isinstance(message, str): + inner_history = [dict(role='user', content=message)] + elif isinstance(message, dict): + inner_history = [message] + elif isinstance(message, list): + inner_history = message[:] + else: + raise TypeError(f'unsupported type: {type(message)}') + offset = len(inner_history) agent_return = AgentReturn() # planner @@ -237,13 +244,12 @@ def chat(self, message: str) -> AgentReturn: reformat_request = '' while turn_id < self.max_turn: planner_prompt = self._protocol.format_planner( - chat_history=self.session_history, - inner_step=self._inner_history, + chat_history=[], + inner_step=inner_history, action_executor=self._action_executor, reformat_request=reformat_request) - response = self._llm.generate_from_template(planner_prompt, 512) - self._inner_history.append( - dict(role='assistant', content=response)) + response = self._llm.chat(planner_prompt, **kwargs) + inner_history.append(dict(role='assistant', content=response)) try: thoughts, actions, actions_input = self._protocol.parse_worker( response) @@ -274,11 +280,10 @@ def chat(self, message: str) -> AgentReturn: solver_prompt, worker_log = self._protocol.format_solver( message, thoughts, action_responses) - self._inner_history.append(dict(role='system', content=worker_log)) + inner_history.append(dict(role='system', content=worker_log)) - final_response = self._llm.generate_from_template(solver_prompt, 512) - self._inner_history.append( - dict(role='assistant', content=final_response)) - agent_return.inner_steps = copy.deepcopy(self._inner_history) + final_response = self._llm.chat(solver_prompt, **kwargs) + inner_history.append(dict(role='assistant', content=final_response)) + agent_return.inner_steps = inner_history[offset:] agent_return.response = final_response return agent_return diff --git a/lagent/llms/base_api.py b/lagent/llms/base_api.py index 898bcfd0..d276d4c0 100644 --- a/lagent/llms/base_api.py +++ b/lagent/llms/base_api.py @@ -155,7 +155,14 @@ def __init__(self, retry: int = 2, max_seq_len: int = 2048, template_parser: 'APITemplateParser' = APITemplateParser, - meta_template: Optional[Dict] = None): + meta_template: Optional[Dict] = None, + *, + max_out_len: int = 512, + top_p: float = 0.8, + top_k: float = None, + temperature: float = 0.8, + repetition_penalty: float = 1.0, + stop_words: Union[List[str], str] = None): self.model_type = model_type self.max_seq_len = max_seq_len self.meta_template = meta_template @@ -165,53 +172,21 @@ def __init__(self, if template_parser: self.template_parser = template_parser(meta_template) - @abstractclassmethod - def generate(self, inputs, max_out_len: int) -> List[str]: - """Generate results given a list of inputs. + self.gen_params = dict( + max_out_len=max_out_len, + top_p=top_p, + top_k=top_k, + temperature=temperature, + repetition_penalty=repetition_penalty, + stop_words=stop_words) - Args: - inputs (List[str or list]): A list of strings or PromptDicts. - The PromptDict should be organized in OpenCompass' - API format. - max_out_len (int): The maximum length of the output. - - Returns: - List[str]: A list of generated strings. - """ - - def get_token_len(self, prompt: str) -> int: - """Get lengths of the tokenized string. Only English and Chinese - characters are counted for now. Users are encouraged to override this - method if more accurate length is needed. - - Args: - prompt (str): Input string. - - Returns: - int: Length of the input tokens - """ - - english_parts = re.findall(r'[A-Za-z0-9]+', prompt) - chinese_parts = re.findall(r'[\u4e00-\u9FFF]+', prompt) - - # Count English words - english_count = sum(len(part.split()) for part in english_parts) - - # Count Chinese words - chinese_count = sum(len(part) for part in chinese_parts) - - return english_count + chinese_count - - def wait(self): + def _wait(self): """Wait till the next query can be sent. Applicable in both single-thread and multi-thread environments. """ return self.token_bucket.get_token() - def to(self, device): - pass - class TokenBucket: """A token bucket for rate limiting. diff --git a/lagent/llms/base_llm.py b/lagent/llms/base_llm.py index 000b8918..05db2ffa 100644 --- a/lagent/llms/base_llm.py +++ b/lagent/llms/base_llm.py @@ -140,12 +140,7 @@ def __init__(self, repetition_penalty=repetition_penalty, stop_words=stop_words) - @abstractclassmethod - def generate( - self, - inputs: Union[str, List[str]], - **gen_params - ) -> str: + def generate(self, inputs: Union[str, List[str]], **gen_params) -> str: """Generate results given a str (or list of) inputs. Args: @@ -165,12 +160,9 @@ def generate( return response return response[0] """ + raise NotImplementedError - def stream_generate( - self, - inputs: str, - **gen_params - ) -> List[str]: + def stream_generate(self, inputs: str, **gen_params) -> List[str]: """Generate results as streaming given a str inputs. Args: @@ -182,11 +174,7 @@ def stream_generate( """ raise NotImplementedError - def chat( - self, - inputs: Union[List[dict], List[List[dict]]], - **gen_params - ): + def chat(self, inputs: Union[List[dict], List[List[dict]]], **gen_params): """Generate completion from a list of templates. Args: @@ -202,11 +190,7 @@ def chat( inputs = self.template_parser(inputs) return self.generate(inputs, **gen_params) - def stream_chat( - self, - inputs: List[dict], - **gen_params - ): + def stream_chat(self, inputs: List[dict], **gen_params): """Generate results as streaming given a list of templates. Args: @@ -216,10 +200,8 @@ def stream_chat( """ raise NotImplementedError - def tokenize( - self, - prompts: Union[str, List[str], List[dict], List[List[dict]]] - ): + def tokenize(self, prompts: Union[str, List[str], List[dict], + List[List[dict]]]): """Tokenize the input prompts. Args: diff --git a/lagent/llms/openai.py b/lagent/llms/openai.py index 0ec59669..344c79b3 100644 --- a/lagent/llms/openai.py +++ b/lagent/llms/openai.py @@ -1,7 +1,7 @@ import json import os import time -# from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, wait from logging import getLogger from threading import Lock from typing import Dict, List, Optional, Union @@ -38,9 +38,8 @@ class GPTAPI(BaseAPIModel): wrapping of any meta instructions. openai_api_base (str): The base url of OpenAI's API. Defaults to 'https://api.openai.com/v1/chat/completions'. - temperature (float, optional): What sampling temperature to use. - If not None, will override the temperature in the `generate()` - call. Defaults to None. + gen_params: Default generation configuration which could be overrided + on the fly of generation. """ is_api: bool = True @@ -58,16 +57,15 @@ def __init__(self, dict(role='assistant', api_role='assistant') ], openai_api_base: str = OPENAI_API_BASE, - temperature: Optional[float] = None): - + **gen_params): super().__init__( model_type=model_type, max_seq_len=max_seq_len, meta_template=meta_template, query_per_second=query_per_second, - retry=retry) + retry=retry, + **gen_params) self.logger = getLogger(__name__) - self.temperature = temperature if isinstance(key, str): self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key] @@ -97,67 +95,56 @@ def __init__(self, context_window = 8192 self.context_window = context_window - def generate( + def chat( self, - inputs: Union[List, str], - max_out_len: int = 512, - temperature: float = 0.7, + inputs: Union[List[dict], List[List[dict]]], + **gen_params, ) -> List[str]: - """Generate results given a list of inputs. + """Generate responses given the contexts. Args: - inputs (List[str or List]): A list of strings or PromptDicts. - The PromptDict should be organized in OpenCompass' - API format. - max_out_len (int): The maximum length of the output. - temperature (float): What sampling temperature to use, - between 0 and 2. Higher values like 0.8 will make the output - more random, while lower values like 0.2 will make it more - focused and deterministic. Defaults to 0.7. + inputs (Union[List[dict], List[List[dict]]]): a list of messages + or list of lists of messages + gen_params: additional generation configuration Returns: List[str]: A list of generated strings. """ - if self.temperature is not None: - temperature = self.temperature - return self._generate(inputs, max_out_len, temperature) - - def _generate(self, input: str or List, max_out_len: int, - temperature: float) -> str: - """Generate results given a list of inputs. + assert isinstance(inputs, list) + if isinstance(inputs[0], dict): + inputs = [inputs] + gen_params = {**self.gen_params, **gen_params} + with ThreadPoolExecutor(max_workers=20) as executor: + tasks = [ + executor.submit(self._chat, messages, **gen_params) + for messages in inputs + ] + wait(tasks) + return [task.result() for task in tasks] + + def _chat(self, messages: List[dict], **gen_params) -> str: + """Generate completion from a list of templates. Args: - inputs (str or List): A string or PromptDict. - The PromptDict should be organized in OpenCompass' - API format. - max_out_len (int): The maximum length of the output. - temperature (float): What sampling temperature to use, - between 0 and 2. Higher values like 0.8 will make the output - more random, while lower values like 0.2 will make it more - focused and deterministic. + messages (List[dict]): a list of prompt dictionaries + gen_params: additional generation configuration Returns: str: The generated string. """ - assert isinstance(input, (str, list, dict)) - - if isinstance(input, str): - messages = [{'role': 'user', 'content': input}] - elif isinstance(input, dict): - messages = [input] - else: - messages = input + assert isinstance(messages, list) + gen_params = gen_params.copy() # Hold out 100 tokens due to potential errors in tiktoken calculation max_out_len = min( - max_out_len, - self.context_window - self.get_token_len(str(input)) - 100) + gen_params.pop('max_out_len'), + self.context_window - len(self.tokenize(str(input))) - 100) if max_out_len <= 0: return '' max_num_retries = 0 while max_num_retries < self.retry: - self.wait() + self._wait() with Lock(): if len(self.invalid_keys) == len(self.keys): @@ -192,8 +179,9 @@ def _generate(self, input: str or List, max_out_len: int, messages=messages, max_tokens=max_out_len, n=1, - stop=None, - temperature=temperature, + stop=gen_params.pop('stop_words'), + frequency_penalty=gen_params.pop('repetition_penalty'), + **gen_params, ) raw_response = requests.post( self.url, headers=header, data=json.dumps(data)) @@ -225,18 +213,16 @@ def _generate(self, input: str or List, max_out_len: int, f'{max_num_retries} times. Check the logs for ' 'details.') - def get_token_len(self, prompt: str) -> int: - """Get lengths of the tokenized string. Only English and Chinese - characters are counted for now. Users are encouraged to override this - method if more accurate length is needed. + def tokenize(self, prompt: str) -> list: + """Tokenize the input prompt. Args: prompt (str): Input string. Returns: - int: Length of the input tokens + list: token ids """ import tiktoken self.tiktoken = tiktoken enc = self.tiktoken.encoding_for_model(self.model_type) - return len(enc.encode(prompt)) + return enc.encode(prompt) From 73de598d42545d79eb3b7182c7aff4cf2b97dd11 Mon Sep 17 00:00:00 2001 From: liujiangning30 <147385819+liujiangning30@users.noreply.github.com> Date: Tue, 23 Jan 2024 20:39:21 +0800 Subject: [PATCH 06/20] Feature: lmdeploy_wrapper implemented BaseMode (#86) * [Fix]: fix turbomind (#81) fix turbomind * Feature: lmdeploy_wrapper implemented BaseMode * remove comments of 'llms.__init__' * update of 'llms.__init__' * update lmdepoly_wrapper with 'gen_params' * add property 'state_map' in __init_ and use APIClient to stream infer_ * func 'generate' in LMDeployClient with 'APIClient' * fix bug of TritonClient * add docstr for LMDeployPipeline & LMDeployServer * class LMDeployClient inherits class LMDeployServer * LMDeployClient with BaseModel.__init__ and use field 'max_tokens' control model output * add TODO * move 'import mmengine' to func '_update_gen_params' --------- Co-authored-by: RangiLyu --- lagent/llms/__init__.py | 4 - lagent/llms/base_llm.py | 29 ++- lagent/llms/lmdeploy.py | 143 ------------ lagent/llms/lmdepoly_wrapper.py | 388 ++++++++++++++++++++++++++++++++ lagent/schema.py | 35 ++- lagent/utils/util.py | 30 +++ 6 files changed, 454 insertions(+), 175 deletions(-) delete mode 100644 lagent/llms/lmdeploy.py create mode 100644 lagent/llms/lmdepoly_wrapper.py create mode 100644 lagent/utils/util.py diff --git a/lagent/llms/__init__.py b/lagent/llms/__init__.py index 9809dfcd..202910c0 100644 --- a/lagent/llms/__init__.py +++ b/lagent/llms/__init__.py @@ -8,7 +8,3 @@ if is_module_exist('transformers'): from .huggingface import HFTransformer, HFTransformerCasualLM # noqa: F401 __all__.extend(['HFTransformer', 'HFTransformerCasualLM']) - -if is_module_exist('lmdeploy'): - from .lmdeploy import TritonClient, TurboMind # noqa: F401 - __all__.extend(['TritonClient', 'TurboMind']) diff --git a/lagent/llms/base_llm.py b/lagent/llms/base_llm.py index 05db2ffa..35bcad5a 100644 --- a/lagent/llms/base_llm.py +++ b/lagent/llms/base_llm.py @@ -117,7 +117,7 @@ def __init__(self, template_parser: 'LMTemplateParser' = LMTemplateParser, meta_template: Optional[List[Dict]] = None, *, - max_out_len: int = 512, + max_tokens: int = 512, top_p: float = 0.8, top_k: float = None, temperature: float = 0.8, @@ -133,14 +133,19 @@ def __init__(self, self.eos_token_id = meta_template['eos_token_id'] self.gen_params = dict( - max_out_len=max_out_len, + max_tokens=max_tokens, top_p=top_p, top_k=top_k, temperature=temperature, repetition_penalty=repetition_penalty, stop_words=stop_words) - def generate(self, inputs: Union[str, List[str]], **gen_params) -> str: + @abstractclassmethod + def generate( + self, + inputs: Union[str, List[str]], + **gen_params + ) -> str: """Generate results given a str (or list of) inputs. Args: @@ -162,7 +167,11 @@ def generate(self, inputs: Union[str, List[str]], **gen_params) -> str: """ raise NotImplementedError - def stream_generate(self, inputs: str, **gen_params) -> List[str]: + def stream_generate( + self, + inputs: str, + **gen_params + ) -> List[str]: """Generate results as streaming given a str inputs. Args: @@ -174,7 +183,11 @@ def stream_generate(self, inputs: str, **gen_params) -> List[str]: """ raise NotImplementedError - def chat(self, inputs: Union[List[dict], List[List[dict]]], **gen_params): + def chat( + self, + inputs: Union[List[dict], List[List[dict]]], + **gen_params + ): """Generate completion from a list of templates. Args: @@ -190,7 +203,11 @@ def chat(self, inputs: Union[List[dict], List[List[dict]]], **gen_params): inputs = self.template_parser(inputs) return self.generate(inputs, **gen_params) - def stream_chat(self, inputs: List[dict], **gen_params): + def stream_chat( + self, + inputs: List[dict], + **gen_params + ): """Generate results as streaming given a list of templates. Args: diff --git a/lagent/llms/lmdeploy.py b/lagent/llms/lmdeploy.py deleted file mode 100644 index 27275d1c..00000000 --- a/lagent/llms/lmdeploy.py +++ /dev/null @@ -1,143 +0,0 @@ -import dataclasses -import os.path as osp -import random - -import lmdeploy.turbomind.chat as tm_chat -from lmdeploy import turbomind as tm -from lmdeploy.serve.turbomind.chatbot import Chatbot, Session, get_logger -from lmdeploy.tokenizer import Tokenizer - -from .base_llm import BaseModel - - -class TritonClient(Chatbot, BaseModel): - - def __init__(self, meta_template=None, **kwargs): - """TritonClient is a wrapper of TritonClient for LLM. - - Args: - model_name (str): the name of the model - max_out_len (int): the expected generated token numbers - log_level (str): log level - """ - BaseModel.__init__(self, meta_template=meta_template, path=None) - Chatbot.__init__(self, **kwargs) - - def generate(self, - prompt: str, - session_id: int = 2967, - request_id: str = '', - max_out_len: int = None, - sequence_start: bool = True, - sequence_end: bool = True, - *args, - **kwargs): - """Start a new round conversation of a session. Return the chat - completions in non-stream mode. - - Args: - session_id (int): the identical id of a session - prompt (str): user's prompt in this round conversation - request_id (str): the identical id of this round conversation - max_out_len (int): the expected generated token numbers - sequence_start (bool): start flag of a session - sequence_end (bool): end flag of a session - - Returns: - tuple(Status, str, int): status, text/chat completion, - generated token number - """ - assert isinstance(session_id, int), \ - f'INT session id is required, but got {type(session_id)}' - - logger = get_logger(log_level=self.log_level) - logger.info(f'session {session_id}, request_id {request_id}, ' - f'max_out_len {max_out_len}') - - if self._session is None: - sequence_start = True - self._session = Session(session_id=session_id) - elif self._session.status == 0: - logger.error(f'session {session_id} has been ended. Please set ' - f'`sequence_start` be True if you want to restart it') - return '' - - self._session.status = 1 - self._session.request_id = request_id - self._session.response = '' - - status, res, _ = None, '', 0 - for status, res, _ in self._stream_infer(self._session, prompt, - max_out_len, sequence_start, - sequence_end): - if status.value < 0: - break - if status.value == 0: - self._session.histories = \ - self._session.histories + self._session.prompt + \ - self._session.response - return res - else: - return '' - - def generate_from_template(self, templates, max_out_len: int, **kwargs): - """Generate completion from a list of templates. - - Args: - templates (List[PromptType]): A list of templates. - max_out_len (int): The maximum length of the output. - """ - inputs = self.parse_template(templates) - response = self.generate(inputs, max_out_len=max_out_len, **kwargs) - # The return of tuibomind contains , here we hard code removes it. - response = response.replace( - self.template_parser.roles['assistant']['end'].strip(), - '').strip() - return response - - -class TurboMind(BaseModel): - - def __init__(self, - path: str, - max_seq_len: int = 8192, - tokenizer_only: bool = False, - meta_template=None, - tp=1, - **kwargs): - super().__init__( - path=path, - max_seq_len=max_seq_len, - tokenizer_only=tokenizer_only, - meta_template=meta_template) - tokenizer_model_path = osp.join(path, 'triton_models', 'tokenizer') - self.tokenizer = Tokenizer(tokenizer_model_path) - self.tm_model = tm.TurboMind( - path, eos_id=self.tokenizer.eos_token_id, tp=tp) - self.generator = self.tm_model.create_instance() - - model_name = self.tm_model.model_name - self.model = tm_chat.MODELS.get(model_name)( - capability='completion', **kwargs) - self._session_id = 0 - - def generate(self, prompt, **kwargs): - seed = random.getrandbits(64) - input_ids = self.tokenizer.encode(prompt) - gen_param = tm_chat.get_gen_param( - 'completion', self.model.sampling_param, step=0, nth_round=1) - response_size = 0 - self._session_id = (self._session_id + 1) % 100000 - for outputs in self.generator.stream_infer( - session_id=self._session_id, - input_ids=[input_ids], - stream_output=False, - **dataclasses.asdict(gen_param), - ignore_eos=False, - random_seed=seed): - res, tokens = outputs[0] - # decode res - response = self.tokenizer.decode( - res.tolist(), offset=response_size) - response = tm_chat.valid_str(response) - return response diff --git a/lagent/llms/lmdepoly_wrapper.py b/lagent/llms/lmdepoly_wrapper.py new file mode 100644 index 00000000..dc93e907 --- /dev/null +++ b/lagent/llms/lmdepoly_wrapper.py @@ -0,0 +1,388 @@ +import json +from typing import List, Optional, Union + + +import requests + +from lagent.llms.base_llm import BaseModel +from lagent.schema import AgentStatusCode +from lagent.utils.util import filter_suffix + + +class TritonClient(BaseModel): + """TritonClient is a wrapper of TritonClient for LLM. + + Args: + tritonserver_addr (str): the address in format "ip:port" of + triton inference server + model_name (str): the name of the model + session_len (int): the context size + max_tokens (int): the expected generated token numbers + """ + + def __init__(self, + tritonserver_addr: str, + model_name: str, + session_len: int = 32768, + log_level: str = 'WARNING', + **kwargs): + super().__init__(path=None, **kwargs) + from lmdeploy.serve.turbomind.chatbot import Chatbot, StatusCode + self.state_map = { + StatusCode.TRITON_STREAM_END: AgentStatusCode.END, + StatusCode.TRITON_SERVER_ERR: AgentStatusCode.SERVER_ERR, + StatusCode.TRITON_SESSION_CLOSED: AgentStatusCode.SESSION_CLOSED, + StatusCode.TRITON_STREAM_ING: AgentStatusCode.STREAM_ING, + StatusCode.TRITON_SESSION_OUT_OF_LIMIT: + AgentStatusCode.SESSION_OUT_OF_LIMIT, + StatusCode.TRITON_SESSION_INVALID_ARG: + AgentStatusCode.SESSION_INVALID_ARG, + StatusCode.TRITON_SESSION_READY: AgentStatusCode.SESSION_READY + } + self.chatbot = Chatbot( + tritonserver_addr=tritonserver_addr, + model_name=model_name, + session_len=session_len, + log_level=log_level, + **kwargs) + + def generate(self, + inputs: Union[str, List[str]], + session_id: int = 2967, + request_id: str = '', + max_tokens: int = 512, + sequence_start: bool = True, + sequence_end: bool = True, + **kwargs): + """Start a new round conversation of a session. Return the chat + completions in non-stream mode. + + Args: + inputs (str, List[str]): user's prompt(s) in this round + session_id (int): the identical id of a session + request_id (str): the identical id of this round conversation + max_tokens (int): the expected generated token numbers + sequence_start (bool): start flag of a session + sequence_end (bool): end flag of a session + + Returns: + (a list of/batched) text/chat completion + """ + from lmdeploy.serve.turbomind.chatbot import Session, get_logger + if isinstance(inputs, str): + inputs = [inputs] + prompt = inputs + + assert isinstance(session_id, int), \ + f'INT session id is required, but got {type(session_id)}' + + logger = get_logger(log_level=self.chatbot.log_level) + logger.info(f'session {session_id}, request_id {request_id}, ' + f'max_out_len {max_tokens}') + + if self.chatbot._session is None: + sequence_start = True + self.chatbot._session = Session(session_id=session_id) + elif self.chatbot._session.status == 0: + logger.error(f'session {session_id} has been ended. Please set ' + f'`sequence_start` be True if you want to restart it') + return '' + + self.chatbot._session.status = 1 + self.chatbot._session.request_id = request_id + self.chatbot._session.response = '' + + self.chatbot.cfg = self._update_gen_params( + max_tokens=max_tokens, **kwargs) + + status, res, _ = None, '', 0 + for status, res, _ in self.chatbot._stream_infer( + self.chatbot._session, prompt, max_tokens, sequence_start, + sequence_end): + if status.value < 0: + break + if status.value == 0: + self.chatbot._session.histories = \ + self.chatbot._session.histories + self.chatbot._session.prompt + \ + self.chatbot._session.response + # remove stop_words + res = filter_suffix(res, self.gen_params.get('stop_words')) + return res + else: + return '' + + def stream_chat(self, + inputs: List[dict], + session_id: int = 2967, + request_id: str = '', + max_tokens: int = 512, + sequence_start: bool = True, + sequence_end: bool = True, + **kwargs): + """Start a new round conversation of a session. Return the chat + completions in non-stream mode. + + Args: + session_id (int): the identical id of a session + inputs (List[dict]): user's inputs in this round conversation + request_id (str): the identical id of this round conversation + max_tokens (int): the expected generated token numbers + sequence_start (bool): start flag of a session + sequence_end (bool): end flag of a session + + Returns: + tuple(Status, str, int): status, text/chat completion, + generated token number + """ + from lmdeploy.serve.turbomind.chatbot import (Session, StatusCode, + get_logger) + assert isinstance(session_id, int), \ + f'INT session id is required, but got {type(session_id)}' + + logger = get_logger(log_level=self.chatbot.log_level) + logger.info(f'session {session_id}, request_id {request_id}, ' + f'max_out_len {max_tokens}') + + if self.chatbot._session is None: + sequence_start = True + self.chatbot._session = Session(session_id=session_id) + elif self.chatbot._session.status == 0: + logger.error(f'session {session_id} has been ended. Please set ' + f'`sequence_start` be True if you want to restart it') + return '' + + self.chatbot._session.status = 1 + self.chatbot._session.request_id = request_id + self.chatbot._session.response = '' + + self.chatbot.cfg = self._update_gen_params( + max_tokens=max_tokens, **kwargs) + prompt = self.template_parser(inputs) + + status, res, _ = None, '', 0 + for status, res, _ in self.chatbot._stream_infer( + self.chatbot._session, prompt, max_tokens, sequence_start, + sequence_end): + if status == StatusCode.TRITON_STREAM_END: # remove stop_words + res = filter_suffix(res, self.gen_params.get('stop_words')) + if status.value < 0: + break + else: + yield self.state_map.get(status), res, _ + if status.value == 0: + self.chatbot._session.histories = \ + self.chatbot._session.histories + self.chatbot._session.prompt + \ + self.chatbot._session.response + yield self.state_map.get(status), res, _ + else: + return '' + + def _update_gen_params(self, **kwargs): + import mmengine + new_gen_params = self.update_gen_params(**kwargs) + self.gen_params['stop_words'] = new_gen_params.pop('stop_words') + stop_words = self.chatbot._stop_words(self.gen_params.get('stop_words')) + cfg = mmengine.Config( + dict( + session_len=self.chatbot.model.session_len, + stop_words=stop_words, + bad_words=self.chatbot.cfg.bad_words, + **new_gen_params)) + return cfg + + +class LMDeployPipeline(BaseModel): + """ + + Args: + path (str): The path to the model. + It could be one of the following options: + - i) A local directory path of a turbomind model which is + converted by `lmdeploy convert` command or download from + ii) and iii). + - ii) The model_id of a lmdeploy-quantized model hosted + inside a model repo on huggingface.co, such as + "InternLM/internlm-chat-20b-4bit", + "lmdeploy/llama2-chat-70b-4bit", etc. + - iii) The model_id of a model hosted inside a model repo + on huggingface.co, such as "internlm/internlm-chat-7b", + "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" + and so on. + model_name (str): needed when model_path is a pytorch model on + huggingface.co, such as "internlm/internlm-chat-7b", + "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on. + tp (int): + pipeline_cfg (dict): + """ + + def __init__(self, + path: str, + model_name: Optional[str] = None, + tp: int = 1, + pipeline_cfg=dict(), + **kwargs): + + super().__init__(path=path, **kwargs) + from lmdeploy import pipeline + self.model = pipeline( + model_path=self.path, model_name=model_name, tp=tp, **pipeline_cfg) + + def generate(self, + inputs: Union[str, List[str]], + do_preprocess=None, + **kwargs): + batched = True + if isinstance(inputs, str): + inputs = [inputs] + batched = False + prompt = inputs + gen_params = self.update_gen_params(**kwargs) + response = self.model.batch_infer( + prompt, do_preprocess=do_preprocess, **gen_params) + response = [resp.text for resp in response] + # remove stop_words + response = filter_suffix(response, self.gen_params.get('stop_words')) + if batched: + return response + return response[0] + + +class LMDeployServer(BaseModel): + """ + + Args: + path (str): The path to the model. + It could be one of the following options: + - i) A local directory path of a turbomind model which is + converted by `lmdeploy convert` command or download from + ii) and iii). + - ii) The model_id of a lmdeploy-quantized model hosted + inside a model repo on huggingface.co, such as + "InternLM/internlm-chat-20b-4bit", + "lmdeploy/llama2-chat-70b-4bit", etc. + - iii) The model_id of a model hosted inside a model repo + on huggingface.co, such as "internlm/internlm-chat-7b", + "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" + and so on. + model_name (str): needed when model_path is a pytorch model on + huggingface.co, such as "internlm/internlm-chat-7b", + "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on. + server_name (str): host ip for serving + server_port (int): server port + tp (int): + log_level (str): set log level whose value among [CRITICAL, ERROR, WARNING, INFO, DEBUG] + """ + + def __init__( + self, + path: str, + model_name: Optional[str] = None, + server_name: str = '0.0.0.0', + server_port: int = 23333, + tp: int = 1, + log_level: str = 'WARNING', + serve_cfg=dict(), + **kwargs + ): + super().__init__(path=path, **kwargs) + # TODO get_logger issue in multi processing + import lmdeploy + self.client = lmdeploy.serve( + model_path=self.path, + model_name=model_name, + server_name=server_name, + server_port=server_port, + tp=tp, + log_level=log_level, + **serve_cfg) + + def generate( + self, + inputs: Union[str, List[str]], + session_id: int = 2967, + sequence_start: bool = True, + sequence_end: bool = True, + ignore_eos: bool = False, + timeout: int = 30, + **kwargs) -> List[str]: + batched = True + if isinstance(inputs, str): + inputs = [inputs] + batched = False + + gen_params = self.update_gen_params(**kwargs) + + resp = [''] * len(inputs) + for text in self.client.completions_v1( + self.path, + inputs, + session_id=session_id, + sequence_start=sequence_start, + sequence_end=sequence_end, + stream=False, + ignore_eos=ignore_eos, + timeout=timeout, + **gen_params + ): + resp = [ + resp[i] + item['text'] + for i, item in enumerate(text['choices']) + ] + # remove stop_words + resp = filter_suffix(resp, self.gen_params.get('stop_words')) + if not batched: + return resp[0] + return resp + + def stream_chat(self, + inputs: List[dict], + session_id=0, + sequence_start: bool = True, + sequence_end: bool = True, + stream: bool = True, + ignore_eos: bool = False, + timeout: int = 30, + **kwargs): + + gen_params = self.update_gen_params(**kwargs) + + resp = '' + finished = False + stop_words = self.gen_params.get('stop_words') + for text in self.client.completions_v1( + self.path, + inputs, + session_id=session_id, + sequence_start=sequence_start, + sequence_end=sequence_end, + stream=stream, + ignore_eos=ignore_eos, + timeout=timeout, + **gen_params): + resp += text['choices'][0]['text'] + if not resp: + continue + # remove stop_words + for sw in stop_words: + if sw in resp: + resp = filter_suffix(resp, stop_words) + finished = True + break + yield AgentStatusCode.STREAM_ING, resp, None + if finished: + break + yield AgentStatusCode.END, resp, None + + +class LMDeployClient(LMDeployServer): + """ + + Args: + path (str): The path to the model. + url (str): + """ + + def __init__(self, path: str, url: str, **kwargs): + BaseModel.__init__(self, path=path, **kwargs) + from lmdeploy.serve.openai.api_client import APIClient + self.client = APIClient(url) diff --git a/lagent/schema.py b/lagent/schema.py index 91c7001b..e6752e16 100644 --- a/lagent/schema.py +++ b/lagent/schema.py @@ -43,37 +43,28 @@ class ActionReturn: valid: Optional[ActionValidCode] = ActionValidCode.OPEN -class AgentStatusCode(Enum): - END = 0 # end of streaming +# 需要集成int,如此asdict可以把AgentStatusCode 转换成 int +class AgentStatusCode(int, Enum): + END = 0 # end of streaming 返回本次history STREAM_ING = 1 # response is in streaming SERVER_ERR = -1 # triton server's error SESSION_CLOSED = -2 # session has been closed SESSION_OUT_OF_LIMIT = -3 # request length out of limit - CMD = 2 # return command + PLUGIN_START = 3 # start tool + PLUGIN_END = 4 # finish tool + PLUGIN_RETURN = 5 # finish tool + + CODING = 6 # start python + CODE_END = 7 # end python + CODE_RETURN = 8 # python return SESSION_INVALID_ARG = -4 # invalid argument - SESSION_READY = 3 # session is ready for inference + SESSION_READY = 2 # session is ready for inference @dataclass class AgentReturn: + state: Union[AgentStatusCode, int] = AgentStatusCode.END actions: List[ActionReturn] = field(default_factory=list) response: str = '' inner_steps: List = field(default_factory=list) - errmsg: Optional[str] = None - - -if is_module_exist('lmdeploy'): - from lmdeploy.serve.turbomind.chatbot import StatusCode - STATE_MAP = { - StatusCode.TRITON_STREAM_END: AgentStatusCode.END, - StatusCode.TRITON_SERVER_ERR: AgentStatusCode.SERVER_ERR, - StatusCode.TRITON_SESSION_CLOSED: AgentStatusCode.SESSION_CLOSED, - StatusCode.TRITON_STREAM_ING: AgentStatusCode.STREAM_ING, - StatusCode.TRITON_SESSION_OUT_OF_LIMIT: - AgentStatusCode.SESSION_OUT_OF_LIMIT, - StatusCode.TRITON_SESSION_INVALID_ARG: - AgentStatusCode.SESSION_INVALID_ARG, - StatusCode.TRITON_SESSION_READY: AgentStatusCode.SESSION_READY - } -else: - STATE_MAP = {} + errmsg: Optional[str] = None \ No newline at end of file diff --git a/lagent/utils/util.py b/lagent/utils/util.py new file mode 100644 index 00000000..dc0e1790 --- /dev/null +++ b/lagent/utils/util.py @@ -0,0 +1,30 @@ +from typing import List, Optional, Union + + +def filter_suffix(response: Union[str, List[str]], suffixes: Optional[List[str]] = None) -> str: + """Filter response with suffixes. + + Args: + response (Union[str, List[str]]): generated responses by LLMs. + suffixes (str): a list of suffixes to be deleted. + + Return: + str: a clean response. + """ + if suffixes is None: + return response + batched = True + if isinstance(response, str): + response = [response] + batched = False + processed = [] + for resp in response: + for item in suffixes: + # if response.endswith(item): + # response = response[:len(response) - len(item)] + if item in resp: + resp = resp.split(item)[0] + processed.append(resp) + if not batched: + return processed[0] + return processed From 54e5b615f4153df55020482ddf5fe4e7da0d968d Mon Sep 17 00:00:00 2001 From: BraisedPork <46232992+braisedpork1964@users.noreply.github.com> Date: Wed, 24 Jan 2024 16:44:06 +0800 Subject: [PATCH 07/20] Fix APITemplateParser object is not callable (#95) fix APITemplateParser object is not callable Co-authored-by: wangzy --- lagent/llms/base_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lagent/llms/base_api.py b/lagent/llms/base_api.py index d276d4c0..15e33dcf 100644 --- a/lagent/llms/base_api.py +++ b/lagent/llms/base_api.py @@ -27,7 +27,7 @@ def __init__(self, meta_template: Optional[Dict] = None): 'role in meta prompt must be unique!' self.roles[item['role']] = item.copy() - def parse_template(self, dialog: List[Union[str, List]]): + def __call__(self, dialog: List[Union[str, List]]): """Parse the intermidate prompt template, and wrap it with meta template if applicable. When the meta template is set and the input is a list, the return value will be a list containing the full @@ -161,7 +161,7 @@ def __init__(self, top_p: float = 0.8, top_k: float = None, temperature: float = 0.8, - repetition_penalty: float = 1.0, + repetition_penalty: float = 0.0, stop_words: Union[List[str], str] = None): self.model_type = model_type self.max_seq_len = max_seq_len From a53bad2077e059e023eeaf05f8f5ed7e68481eb9 Mon Sep 17 00:00:00 2001 From: liukuikun <24622904+Harold-lkk@users.noreply.github.com> Date: Wed, 24 Jan 2024 20:15:45 +0800 Subject: [PATCH 08/20] [Feat] support StreamAgent (#82) * [Feat] support StreamAgent * update `StreamAgent` * truncate inner history --------- Co-authored-by: wangzy --- lagent/agents/stream_agent.py | 287 ++++++++++++++++++++++++++++++++++ 1 file changed, 287 insertions(+) create mode 100644 lagent/agents/stream_agent.py diff --git a/lagent/agents/stream_agent.py b/lagent/agents/stream_agent.py new file mode 100644 index 00000000..b2c1258a --- /dev/null +++ b/lagent/agents/stream_agent.py @@ -0,0 +1,287 @@ +import json +import logging +from copy import deepcopy +from typing import Dict, List, Union + +from ilagent.schema import AgentReturn, AgentStatusCode + +from lagent import BaseAgent +from lagent.actions import ActionExecutor +from lagent.llms import BaseAPIModel, BaseModel +from lagent.schema import ActionReturn, ActionStatusCode + +API_PREFIX = ( + "This is the subfunction for tool '{tool_name}', you can use this tool. " + 'The description of this function is: \n{description}') + +INTERPRETER_CN = ('你现在可以使用一个支持 Python 代码执行的 Jupyter 笔记本环境。只需向 python 发' + '送代码,即可在这个有状态环境中进行运行。这个功能适用于数据分析或处理(如数据操作和' + '图形制作),复杂计算(如数学和物理问题),编程示例(用于理解编程概念或语言特性),文' + '本处理和分析(包括文本分析和自然语言处理),机器学习和数据科学(模型训练和数据可视化' + '展示),以及文件操作和数据导入(处理CSV、JSON等格式文件)。') + +PLUGIN_CN = ('你可以使用如下工具:' + '\n{prompt}\n' + '如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! ' + '同时注意你可以使用的工具,不要随意捏造!') + + +class StreamProtocol: + + def __init__( + self, + meta_prompt=None, + interpreter_prompt=INTERPRETER_CN, + plugin_prompt=PLUGIN_CN, + few_shot=None, + language=dict( + begin='', + end='', + belong='assistant', + ), + tool=dict( + begin='{start_token}{name}\n', + start_token='[UNUSED_TOKEN_144]', + name_map=dict( + plugin='[UNUSED_TOKEN_141]', interpreter='[UNUSED_TOKEN_142]'), + belong='assistant', + end='[UNUSED_TOKEN_143]\n', + ), + execute: dict = dict( + role='execute', begin='', end='', fallback_role='environment'), + ) -> None: + self.meta_prompt = meta_prompt + self.interpreter_prompt = interpreter_prompt + self.plugin_prompt = plugin_prompt + self.roles_cfg = dict(tool=tool, language=language) + self.language = language + self.execute = execute + self.tool = tool + self.few_shot = few_shot + + def format_sub_role(self, messages: List[Dict]) -> List[Dict]: + + def format_interpreter(message): + if isinstance(message['content'], dict): + # assert message['content']['name'] == 'IPythonInterpreter' + return dict( + role=message['role'], + name=message['name'], + content=message['content']['parameters']['command']) + else: + return message + + def format_plugin(message): + if isinstance(message['content'], dict): + return dict( + role=message['role'], + name=message['name'], + content=json.dumps(message['content'])) + else: + return message + + new_message = list() + for message in messages: + if message['role'] in [ + 'assistant', 'user', 'system', 'environment' + ]: + new_message.append(message) + continue + role_cfg = self.roles_cfg[message['role']] + begin = role_cfg['begin'] + if message['role'] == 'tool': + if message['name'] == 'interpreter': + message = format_interpreter(message) + elif message['name'] == 'plugin': + message = format_plugin(message) + else: + raise NotImplementedError + begin = role_cfg['begin'].format( + start_token=role_cfg.get('start_token', ''), + name=role_cfg.get('name_map', {}).get(message['name'], '')) + new_content = begin + message['content'] + role_cfg['end'] + if role_cfg.get('fallback_role'): + new_message.append( + dict(role=role_cfg['fallback_role'], content=new_content)) + elif role_cfg.get('belong'): + if new_message[-1]['role'] != role_cfg.get('belong'): + new_message.append( + dict(role=role_cfg.get('belong'), content=new_content)) + else: + new_message[-1]['content'] += new_content + else: + new_message.append( + dict(role=message['role'], content=new_content)) + + return new_message + + def format(self, + inner_step: List[Dict], + plugin_executor: ActionExecutor = None, + interpreter_executor: ActionExecutor = None, + **kwargs) -> list: + formatted = [] + if self.meta_prompt: + formatted.append(dict(role='system', content=self.meta_prompt)) + if interpreter_executor and self.interpreter_prompt: + interpreter_info = list( + interpreter_executor.get_actions_info().items())[0] + interpreter_prompt = self.interpreter_prompt.format( + code_prompt=interpreter_info[1]) + formatted.append( + dict( + role='system', + content=interpreter_prompt, + name='interpreter')) + if plugin_executor and plugin_executor.actions and self.plugin_prompt: + plugin_descriptions = [] + for api_name, api_info in plugin_executor.get_actions_info().items( + ): + if isinstance(api_info, dict): + plugin = deepcopy(api_info) + tool_name = api_name.split('.')[0] + plugin['name'] = api_name + plugin['description'] = API_PREFIX.format( + tool_name=tool_name, description=plugin['description']) + else: + plugin = dict(name=api_name, description=api_info) + plugin_descriptions.append(plugin) + plugin_prompt = self.plugin_prompt.format( + prompt=json.dumps( + plugin_descriptions, ensure_ascii=False, indent=4)) + formatted.append( + dict(role='system', content=plugin_prompt, name='plugin')) + if self.few_shot: + for few_shot in self.few_shot: + formatted += self.format_sub_role(few_shot) + formatted += self.format_sub_role(inner_step) + return formatted + + def parse(self, message, plugin_executor: ActionExecutor, + interpreter_executor: ActionExecutor): + if self.language['begin']: + message = message.split(self.language['begin'])[-1] + if self.tool['name_map']['plugin'] in message: + message, action = message.split( + f"{self.tool['start_token']}{self.tool['name_map']['plugin']}") + action = action.split(self.tool['end'].strip())[0] + action = json.loads(action) + return 'plugin', message, action + if self.tool['name_map']['interpreter'] in message: + message, code = message.split( + f"{self.tool['start_token']}" + f"{self.tool['name_map']['interpreter']}") + code = code.split(self.tool['end'].strip())[0].strip() + return 'interpreter', message, dict( + name=interpreter_executor.action_names()[0], + parameters=dict(command=code)) + return None, message, None + + def format_response(self, action_return, name) -> str: + if action_return.state == ActionStatusCode.SUCCESS: + if isinstance(action_return.result, list): + response = [] + for item in action_return.result: + if item['type'] == 'text': + response.append(item['content']) + else: + response.append(f"[{item['type']}]({item['content']})") + response = '\n'.join(response) + elif isinstance(action_return.result, dict): + response = action_return.result['text'] + if 'image' in action_return.result: + response += '\n'.join([ + f'[image]({im})' + for im in action_return.result['image'] + ]) + if 'audio' in action_return.result: + response += '\n'.join([ + f'[audio]({im})' + for im in action_return.result['audio'] + ]) + elif isinstance(action_return.result, str): + response = action_return.result + else: + raise NotImplementedError + else: + response = action_return.errmsg + content = self.execute['begin'] + response + self.execute['end'] + if self.execute.get('fallback_role'): + return dict( + role=self.execute['fallback_role'], content=content, name=name) + elif self.execute.get('belong'): + return dict( + role=self.execute['belong'], content=content, name=name) + else: + return dict(role=self.execute['role'], content=response, name=name) + + +class StreamAgent(BaseAgent): + + def __init__(self, + llm: Union[BaseModel, BaseAPIModel], + plugin_executor: ActionExecutor = None, + interpreter_executor: ActionExecutor = None, + protocol=StreamProtocol(), + max_turn: int = 3) -> None: + self.max_turn = max_turn + self._interpreter_executor = interpreter_executor + super().__init__( + llm=llm, action_executor=plugin_executor, protocol=protocol) + + def chat(self, message: Union[str, Dict], **kwargs) -> AgentReturn: + if isinstance(message, str): + message = dict(role='user', content=message) + if isinstance(message, dict): + message = [message] + inner_history = message[:] + offset = len(inner_history) + agent_return = AgentReturn() + for _ in range(self.max_turn): + # list of dict + prompt = self._protocol.format( + inner_step=inner_history, + plugin_executor=self._action_executor, + interpreter_executor=self._interpreter_executor, + ) + response = self._llm.chat(prompt, **kwargs) + name, language, action = self._protocol.parse( + message=response, + plugin_executor=self._action_executor, + interpreter_executor=self._interpreter_executor, + ) + if name: + if name == 'plugin': + if self._action_executor: + executor = self._action_executor + else: + logging.info(msg='No plugin is instantiated!') + continue + elif name == 'interpreter': + if self._interpreter_executor: + executor = self._interpreter_executor + else: + logging.info(msg='No interpreter is instantiated!') + continue + else: + logging.info( + msg=(f"Invalid name '{name}'. Currently only 'plugin' " + "and 'interpreter' are supported.")) + continue + action_return: ActionReturn = executor(action['name'], + action['parameters']) + action_return.thought = language + agent_return.actions.append(action_return) + inner_history.append(dict(role='language', content=language)) + if not name or action_return.type == executor.finish_action.name: + agent_return.response = language + agent_return.state = AgentStatusCode.END + break + else: + inner_history.append( + dict(role='tool', content=action, name=name)) + inner_history.append( + self._protocol.format_response(action_return, name=name)) + + agent_return.inner_steps = inner_history[offset:] + return agent_return From bdd6a9b4df6ba0f737ef7e4b115dd310df74a42e Mon Sep 17 00:00:00 2001 From: liujiangning30 <147385819+liujiangning30@users.noreply.github.com> Date: Thu, 25 Jan 2024 11:00:00 +0800 Subject: [PATCH 09/20] [Feat] hf llm implemented BaseMode (#92) * Feature: huggingface implemented BaseMode * hf llm implemented BaseMode * fix bug of hf llm * inject attention_mask during inference * remove unnecessary --- lagent/llms/base_llm.py | 2 - lagent/llms/huggingface.py | 275 ++++++++++++++++++++++++++----------- 2 files changed, 196 insertions(+), 81 deletions(-) diff --git a/lagent/llms/base_llm.py b/lagent/llms/base_llm.py index 35bcad5a..a729ad35 100644 --- a/lagent/llms/base_llm.py +++ b/lagent/llms/base_llm.py @@ -112,7 +112,6 @@ class BaseModel: def __init__(self, path: str, - max_seq_len: int = 2048, tokenizer_only: bool = False, template_parser: 'LMTemplateParser' = LMTemplateParser, meta_template: Optional[List[Dict]] = None, @@ -124,7 +123,6 @@ def __init__(self, repetition_penalty: float = 1.0, stop_words: Union[List[str], str] = None): self.path = path - self.max_seq_len = max_seq_len self.tokenizer_only = tokenizer_only # meta template self.template_parser = template_parser(meta_template) diff --git a/lagent/llms/huggingface.py b/lagent/llms/huggingface.py index 8f991015..5b2252ad 100644 --- a/lagent/llms/huggingface.py +++ b/lagent/llms/huggingface.py @@ -1,15 +1,19 @@ -from typing import Dict, List, Optional - -import torch +import copy +import warnings +import logging +from typing import Dict, List, Optional, Union +from dataclasses import asdict from .base_llm import BaseModel +logger = logging.getLogger(__name__) + class HFTransformer(BaseModel): """Model wrapper around HuggingFace general models. - Adapted from OpenCompass (https://github.com/InternLM/opencompass - /blob/main/opencompass/models/huggingface.py) + Adapted from Internlm (https://github.com/InternLM/InternLM/blob/main/ + chat/web_demo.py) Args: path (str): The name or path to HuggingFace's model. @@ -25,52 +29,40 @@ class HFTransformer(BaseModel): meta_template (Dict, optional): The model's meta prompt template if needed, in case the requirement of injecting or wrapping of any meta instructions. - extract_pred_after_decode (bool): Whether to extract the prediction - string from the decoded output string, instead of extract the - prediction tokens before decoding. Defaults to False. - batch_padding (bool): If False, inference with be performed in for-loop - without batch padding. - - Note: - About ``extract_pred_after_decode``: Commonly, we should extract the - the prediction tokens before decoding. But for some tokenizers using - ``sentencepiece``, like LLaMA, this behavior may change the number of - whitespaces, which is harmful for Python programming tasks. """ - def __init__( - self, - path: str, - max_seq_len: int = 2048, - tokenizer_path: Optional[str] = None, - tokenizer_kwargs: dict = dict(), - tokenizer_only: bool = False, - model_kwargs: dict = dict(device_map='auto'), - meta_template: Optional[Dict] = [ - dict(role='system', begin='<|System|>:', end='\n'), - dict(role='user', begin='<|User|>:', end='\n'), - dict( - role='assistant', - begin='<|Bot|>:', - end='\n', - generate=True) - ], # default meta template for InternLM-7b - extract_pred_after_decode: bool = False, - batch_padding: bool = False): + def __init__(self, + path: str, + tokenizer_path: Optional[str] = None, + tokenizer_kwargs: dict = dict(), + tokenizer_only: bool = False, + model_kwargs: dict = dict(device_map='auto'), + meta_template: Optional[Dict] = None, + **kwargs): super().__init__( path=path, - max_seq_len=max_seq_len, tokenizer_only=tokenizer_only, - meta_template=meta_template) + meta_template=meta_template, + **kwargs) + self._load_tokenizer( path=path, tokenizer_path=tokenizer_path, tokenizer_kwargs=tokenizer_kwargs) - self.batch_padding = batch_padding - self.extract_pred_after_decode = extract_pred_after_decode if not tokenizer_only: self._load_model(path=path, model_kwargs=model_kwargs) + from transformers.generation.utils import (LogitsProcessorList, + StoppingCriteriaList) + self.logits_processor = LogitsProcessorList() + self.stopping_criteria = StoppingCriteriaList() + self.prefix_allowed_tokens_fn = None + + stop_words_id = [] + for sw in self.gen_params.get('stop_words', []): + stop_words_id.append(self.tokenizer(sw)['input_ids'][1]) + self.additional_eos_token_id = stop_words_id + def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], tokenizer_kwargs: dict): from transformers import AutoTokenizer @@ -82,57 +74,182 @@ def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], self.tokenizer.pad_token = self.tokenizer.eos_token def _load_model(self, path: str, model_kwargs: dict): + import torch from transformers import AutoModel model_kwargs.setdefault('torch_dtype', torch.float16) self.model = AutoModel.from_pretrained( path, trust_remote_code=True, **model_kwargs) self.model.eval() - def generate(self, inputs: List[str], max_out_len: int, - **kwargs) -> List[str]: - if isinstance(inputs, str): - inputs = [inputs] - if self.extract_pred_after_decode: - prompt_lens = [len(input_) for input_ in inputs] - - input_ids = self.tokenizer( - inputs, truncation=True, - max_length=self.max_seq_len - max_out_len)['input_ids'] - input_ids = torch.tensor(input_ids, device=self.model.device) - outputs = self.model.generate( - input_ids=input_ids, max_new_tokens=max_out_len, **kwargs) - - if not self.extract_pred_after_decode: - outputs = outputs[:, input_ids.shape[1]:] - - decodeds = self.tokenizer.batch_decode( - outputs, skip_special_tokens=True) - if self.extract_pred_after_decode: - decodeds = [ - token[len_:] for token, len_ in zip(decodeds, prompt_lens) - ] - - return decodeds[0] - - def generate_from_template(self, templates, max_out_len: int, **kwargs): - """Generate completion from a list of templates. - - Args: - templates (List[PromptType]): A list of templates. - max_out_len (int): The maximum length of the output. - """ - inputs = self.parse_template(templates) - response = self.generate(inputs, max_out_len=max_out_len, **kwargs) - end_token = self.template_parser.meta_template[0]['end'].strip() - # return response.replace( - # self.template_parser.roles['assistant']['end'].strip(), - # '').strip() - return response.split(end_token.strip())[0] + def tokenize(self, inputs: str): + assert isinstance(inputs, str) + inputs = self.tokenizer( + inputs, return_tensors='pt', return_length=True) + return inputs['input_ids'].tolist() + + def generate( + self, + inputs: List[str], + do_sample=True, + **kwargs, + ): + for chunk in self.stream_generate(inputs, do_sample, **kwargs): + response = chunk + return response + + def stream_generate( + self, + inputs: List[str], + do_sample=True, + **kwargs, + ): + import torch + from torch import nn + with torch.no_grad(): + batched = True + if isinstance(inputs, str): + inputs = [inputs] + batched = False + # import pdb; pdb.set_trace() + inputs = self.tokenizer( + inputs, padding=True, return_tensors='pt', return_length=True) + input_length = inputs['length'] + for k, v in inputs.items(): + inputs[k] = v.cuda() + input_ids = inputs['input_ids'] + attention_mask = inputs['attention_mask'] + batch_size, input_ids_seq_length = input_ids.shape[ + 0], input_ids.shape[-1] # noqa: F841 # pylint: disable=W0612 + generation_config = self.model.generation_config + generation_config = copy.deepcopy(generation_config) + new_gen_params = self.update_gen_params(**kwargs) + generation_config.update(**new_gen_params) + generation_config.update(**kwargs) + model_kwargs = generation_config.to_dict() + model_kwargs['attention_mask'] = attention_mask + _, eos_token_id = ( # noqa: F841 # pylint: disable=W0612 + generation_config.bos_token_id, + generation_config.eos_token_id, + ) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + if self.additional_eos_token_id is not None: + eos_token_id.extend(self.additional_eos_token_id) + eos_token_id_tensor = torch.tensor(eos_token_id).to( + input_ids.device) if eos_token_id is not None else None + has_default_max_length = ( + kwargs.get('max_length') is None + and generation_config.max_length is not None) + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + 'This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we' + ' recommend using `max_new_tokens` to control the maximum length of the generation.', + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + generation_config.max_length = ( + generation_config.max_new_tokens + input_ids_seq_length) + if not has_default_max_length: + logger.warn( # pylint: disable=W4902 + f'Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=' + f'{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. ' + 'Please refer to the documentation for more information. ' + '(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)', + UserWarning, + ) + + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = 'input_ids' + logger.warning( + f'Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to' + f' {generation_config.max_length}. This can lead to unexpected behavior. You should consider' + ' increasing `max_new_tokens`.') + + # 2. Set generation parameters if not already defined + logits_processor = self.logits_processor + stopping_criteria = self.stopping_criteria + + logits_processor = self.model._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=self.prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + stopping_criteria = self.model._get_stopping_criteria( + generation_config=generation_config, + stopping_criteria=stopping_criteria) + logits_warper = self.model._get_logits_warper(generation_config) + + unfinished_sequences = input_ids.new(batch_size).fill_(1) + scores = None + while True: + model_inputs = self.model.prepare_inputs_for_generation( + input_ids, **model_kwargs) + # forward pass to get next token + outputs = self.model( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, + next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if do_sample: + next_tokens = torch.multinomial( + probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], + dim=-1) + model_kwargs = self.model._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=False) + unfinished_sequences = unfinished_sequences.mul( + next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne( + eos_token_id_tensor.unsqueeze(1)).prod(dim=0)) + # output_token_ids = input_ids.cpu()[:, input_length:].tolist() + output_token_ids = input_ids.cpu().tolist() + for i in range(len(output_token_ids)): + output_token_ids[i] = output_token_ids[i][:][ + input_length[i]:] + # Find the first occurrence of an EOS token in the sequence + first_eos_idx = next( + (idx + for idx, token_id in enumerate(output_token_ids[i]) + if token_id in eos_token_id), None) + # If an EOS token is found, only the previous part of it is retained + if first_eos_idx is not None: + output_token_ids[i] = output_token_ids[ + i][:first_eos_idx] + + response = self.tokenizer.batch_decode(output_token_ids) + # print(response) + if not batched: + yield response[0] + else: + yield response + # stop when each sentence is finished, or if we exceed the maximum length + if (unfinished_sequences.max() == 0 + or stopping_criteria(input_ids, scores)): + break class HFTransformerCasualLM(HFTransformer): def _load_model(self, path: str, model_kwargs: dict): + import torch from transformers import AutoModelForCausalLM model_kwargs.setdefault('torch_dtype', torch.float16) self.model = AutoModelForCausalLM.from_pretrained( From 799c6a32c46fa2aa2f7efb860f4822fd125327d8 Mon Sep 17 00:00:00 2001 From: BraisedPork <46232992+braisedpork1964@users.noreply.github.com> Date: Thu, 25 Jan 2024 11:00:54 +0800 Subject: [PATCH 10/20] [Feature] support building tool descriptions automatically (#96) * redundancy reduction * add `tool_api` to annotate a tool method * improve json parsing * enhance parsers * update README.md --------- Co-authored-by: wangzy --- README.md | 8 +- lagent/actions/__init__.py | 46 ++- lagent/actions/arxiv_search.py | 51 ++-- lagent/actions/base_action.py | 320 +++++++++++++++++---- lagent/actions/bing_map.py | 138 ++++----- lagent/actions/builtin_actions.py | 28 +- lagent/actions/google_scholar_search.py | 357 ++++++------------------ lagent/actions/google_search.py | 26 +- lagent/actions/parser.py | 33 ++- lagent/actions/ppt.py | 157 ++++------- lagent/actions/python_interpreter.py | 49 ++-- lagent/llms/base_llm.py | 25 +- requirements/runtime.txt | 2 + 13 files changed, 590 insertions(+), 650 deletions(-) diff --git a/README.md b/README.md index 6f679853..b38ba072 100644 --- a/README.md +++ b/README.md @@ -83,7 +83,7 @@ Below is an example of running ReWOO with GPT-3.5 ```python # Import necessary modules and classes from the "lagent" library. from lagent.agents import ReWOO -from lagent.actions import ActionExecutor, GoogleSearch, LLMQA +from lagent.actions import ActionExecutor, GoogleSearch from lagent.llms import GPTAPI # Initialize the Language Model (llm) and provide your API key. @@ -92,14 +92,11 @@ llm = GPTAPI(model_type='gpt-3.5-turbo', key=['Your OPENAI_API_KEY']) # Initialize the Google Search tool and provide your API key. search_tool = GoogleSearch(api_key='Your SERPER_API_KEY') -# Initialize the LLMQA tool using the Language Model (llm). -llmqa_tool = LLMQA(llm) - # Create a chatbot by configuring the ReWOO agent. chatbot = ReWOO( llm=llm, # Provide the Language Model instance. action_executor=ActionExecutor( - actions=[search_tool, llmqa_tool] # Specify the actions the chatbot can perform. + actions=[search_tool] # Specify the actions the chatbot can perform. ), ) @@ -154,6 +151,7 @@ response = chatbot.chat( print(response.response) # Output the response generated by the chatbot. >>> $-\\frac{1}{3}+\\frac{{\\sqrt{3}}}{3}i$ ``` + ### All Thanks To Our Contributors: diff --git a/lagent/actions/__init__.py b/lagent/actions/__init__.py index 7146cfbf..4737418f 100644 --- a/lagent/actions/__init__.py +++ b/lagent/actions/__init__.py @@ -1,6 +1,8 @@ +from typing import Type + from .action_executor import ActionExecutor from .arxiv_search import ArxivSearch -from .base_action import BaseAction +from .base_action import TOOL_REGISTRY, BaseAction, tool_api from .bing_map import BINGMap from .builtin_actions import FinishAction, InvalidAction, NoAction from .google_scholar_search import GoogleScholar @@ -13,5 +15,45 @@ 'BaseAction', 'ActionExecutor', 'InvalidAction', 'FinishAction', 'NoAction', 'BINGMap', 'ArxivSearch', 'FinishAction', 'GoogleSearch', 'GoogleScholar', 'PythonInterpreter', 'PPT', 'BaseParser', 'JsonParser', - 'TupleParser' + 'TupleParser', 'tool_api', 'list_tools', 'get_tool_cls', 'get_tool' ] + + +def list_tools(with_class: bool = False): + """List available tools + + Args: + with_class (bool): whether to return the action class along + with its name. Defaults to ``False``. + + Returns: + list: all action names + """ + return list(TOOL_REGISTRY.items()) if with_class else list( + TOOL_REGISTRY.keys()) + + +def get_tool_cls(specifier: str) -> Type[BaseAction]: + """Get the action class + + Args: + specifier (:class:`str`): tool name + + Returns: + Type[BaseAction]: action class + """ + return TOOL_REGISTRY.get_class(specifier) + + +def get_tool(specifier: str, *args, **kwargs) -> BaseAction: + """Instantiate an action + + Args: + specifier (str): tool name + args: positional arguments passed to the action's ``__init__`` method + kwargs: keyword arguments passed to the action's ``__init__`` method + + Returns: + :class:`BaseAction`: action object + """ + return TOOL_REGISTRY.get(specifier, *args, **kwargs) diff --git a/lagent/actions/arxiv_search.py b/lagent/actions/arxiv_search.py index c2d5c9d6..0d833ccc 100644 --- a/lagent/actions/arxiv_search.py +++ b/lagent/actions/arxiv_search.py @@ -2,41 +2,17 @@ import arxiv -from lagent.actions.base_action import BaseAction +from lagent.actions.base_action import BaseAction, tool_api from lagent.actions.parser import BaseParser, JsonParser from lagent.schema import ActionReturn, ActionStatusCode -DEFAULT_DESCRIPTION = dict( - name='ArxivSearch', - description='Search information from Arxiv.org ' - 'Useful for when you need to answer questions about Physics, Mathematics, ' - 'Computer Science, Quantitative Biology, Quantitative Finance, Statistics, ' - 'Electrical Engineering, and Economics ' - 'from scientific articles on arxiv.org', - api_list=[ - dict( - name='get_arxiv_article_information', - description= - 'Run Arxiv search and get the article meta information.', - parameters=[ - dict( - name='query', - type='STRING', - description='the content of search query') - ], - required=['query'], - return_data=[ - dict( - name='content', - description='a list of 3 arxiv search papers'), - ], - ) - ], -) - class ArxivSearch(BaseAction): - """ArxivSearch action""" + """Search information from Arxiv.org. \ +Useful for when you need to answer questions about Physics, Mathematics, \ +Computer Science, Quantitative Biology, Quantitative Finance, Statistics, \ +Electrical Engineering, and Economics from scientific articles on arxiv.org. + """ def __init__(self, top_k_results: int = 3, @@ -44,13 +20,22 @@ def __init__(self, doc_content_chars_max: int = 1500, description: Optional[dict] = None, parser: Type[BaseParser] = JsonParser, - enable: bool = True) -> None: - super().__init__(description or DEFAULT_DESCRIPTION, parser, enable) + enable: bool = True): + super().__init__(description, parser, enable) self.top_k_results = top_k_results self.max_query_len = max_query_len self.doc_content_chars_max = doc_content_chars_max - def get_arxiv_article_information(self, query: str): + @tool_api(return_dict=True) + def get_arxiv_article_information(self, query: str) -> dict: + """Run Arxiv search and get the article meta information. + + Args: + query (:class:`str`): the content of search query + + Returns: + content (:class:`str`): a list of 3 arxiv search papers + """ try: results = arxiv.Search( # type: ignore query[:self.max_query_len], diff --git a/lagent/actions/base_action.py b/lagent/actions/base_action.py index 7f9bb1d6..9561c7a0 100644 --- a/lagent/actions/base_action.py +++ b/lagent/actions/base_action.py @@ -1,10 +1,231 @@ -from typing import Optional, Type +import inspect +import logging +import re +from abc import ABCMeta +from copy import deepcopy +from functools import wraps +from typing import Annotated, Callable, Optional, Type, get_args, get_origin -from lagent.actions.parser import BaseParser, JsonParser, ParseError -from lagent.schema import ActionReturn, ActionStatusCode +from class_registry import AutoRegister, ClassRegistry +from griffe import Docstring +from griffe.enumerations import DocstringSectionKind +from ..schema import ActionReturn, ActionStatusCode +from .parser import BaseParser, JsonParser, ParseError -class BaseAction: +logging.getLogger('griffe').setLevel(logging.ERROR) + +TOOL_REGISTRY = ClassRegistry('__tool_name__', unique=True) + + +def tool_api(func: Optional[Callable] = None, + *, + return_dict: bool = False, + returns_named_value: bool = False, + **kwargs): + """Turn functions into tools. It will parse typehints as well as docstrings + to build the tool description and attach it to functions via an attribute + ``api_description``. + + Examples: + + .. code-block:: python + + # typehints has higher priority than docstrings + from typing import Annotated + + @tool_api + def add(a: Annotated[int, 'augend'], b: Annotated[int, 'addend'] = 1): + '''Add operation + + Args: + x (int): a + y (int): b + ''' + return a + b + + print(add.api_description) + + Args: + func (Optional[Callable]): function to decorate. Defaults to ``None``. + return_dict (bool): suggest if returned data is a single dictionary. + When enabled, the returns sections in docstrings should indicate the + key-value infomation of the dictionary rather than hint a standard + tuple return. Defaults to ``False``. + + .. code-block:: python + + # set `return_dict` True will force `returns_named_value` to be enabled + @tool_api(return_dict=True) + def foo(a, b): + '''A simple function + + Args: + a (int): a + b (int): b + + Returns: + x: the value of input a + y: the value of input b + ''' + return {'x': a, 'y': b} + + print(foo.api_description) + + returns_named_value (bool): whether to parse ``thing: Description`` in + returns sections as a name and description, rather than a type and + description. When true, type must be wrapped in parentheses: + ``(int): Description.``. When false, parentheses are optional but + the items cannot be named: ``int: Description``. Defaults to ``False``. + + Returns: + Callable: wrapped function or partial decorator + """ + if return_dict: + returns_named_value = True + + def _detect_type(string): + field_type = 'STRING' + if 'list' in string: + field_type = 'Array' + elif 'str' not in string: + if 'float' in string: + field_type = 'FLOAT' + elif 'int' in string: + field_type = 'NUMBER' + elif 'bool' in string: + field_type = 'BOOLEAN' + return field_type + + def _parse_tool(function): + # remove rst syntax + docs = Docstring( + re.sub(':(.+?):`(.+?)`', '\\2', function.__doc__ or '')).parse( + 'google', returns_named_value=returns_named_value, **kwargs) + desc = dict( + name=function.__name__, + description=docs[0].value + if docs[0].kind is DocstringSectionKind.text else '', + parameters=[], + required=[], + ) + args_doc, returns_doc = {}, [] + for doc in docs: + if doc.kind is DocstringSectionKind.parameters: + for d in doc.value: + d = d.as_dict() + d['description'] = d['description'] + d['type'] = _detect_type(d.pop('annotation').lower()) + args_doc[d['name']] = d + if doc.kind is DocstringSectionKind.returns: + for d in doc.value: + d = d.as_dict() + d['description'] = d['description'] + if not d['name']: + d.pop('name') + if not d['annotation']: + d.pop('annotation') + else: + d['type'] = _detect_type(d.pop('annotation').lower()) + returns_doc.append(d) + + sig = inspect.signature(function) + for name, param in sig.parameters.items(): + if name == 'self': + continue + parameter = dict( + name=param.name, + type='STRING', + description=args_doc.get(param.name, + {}).get('description', '')) + annotation = param.annotation + if annotation is inspect.Signature.empty: + parameter['type'] = args_doc.get(param.name, + {}).get('type', 'STRING') + else: + if get_origin(annotation) is Annotated: + annotation, info = get_args(annotation) + if info: + parameter['description'] = info + while get_origin(annotation): + annotation = get_args(annotation) + parameter['type'] = _detect_type(str(annotation)) + desc['parameters'].append(parameter) + if param.default is inspect.Signature.empty: + desc['required'].append(param.name) + + return_data, return_annotation = [], sig.return_annotation + if return_dict: + return_data = returns_doc + elif return_annotation is not inspect.Signature.empty: + if return_annotation is tuple: + return_data = returns_doc + elif get_origin(return_annotation) is tuple: + return_annotation = get_args(return_annotation) + if not return_annotation: + return_data = returns_doc + elif len(return_annotation) >= 2: + for i, item in enumerate(return_annotation): + info = returns_doc[i]['description'] if i < len( + returns_doc) else '' + if get_origin(item) is Annotated: + item, info = get_args(item) + return_data.append({ + 'description': info, + 'type': _detect_type(str(item)) + }) + if return_data: + desc['return_data'] = return_data + return desc + + if callable(func): + + @wraps(func) + def wrapper(self, *args, **kwargs): + return func(self, *args, **kwargs) + + wrapper.api_description = _parse_tool(func) + return wrapper + + def decorate(func): + + @wraps(func) + def wrapper(self, *args, **kwargs): + return func(self, *args, **kwargs) + + wrapper.api_description = _parse_tool(func) + return wrapper + + return decorate + + +class ToolMeta(ABCMeta): + """Metaclass of tools""" + + def __new__(mcs, name, base, attrs): + is_toolkit, tool_desc = True, dict( + name=attrs.setdefault('__tool_name__', name), + description=Docstring(attrs.get('__doc__', + '')).parse('google')[0].value) + for key, value in attrs.items(): + if callable(value) and hasattr(value, 'api_description'): + api_desc = getattr(value, 'api_description') + if key == 'run': + tool_desc['parameters'] = api_desc['parameters'] + tool_desc['required'] = api_desc['required'] + if api_desc['description']: + tool_desc['description'] = api_desc['description'] + if api_desc.get('return_data'): + tool_desc['return_data'] = api_desc['return_data'] + is_toolkit = False + break + tool_desc.setdefault('api_list', []).append(api_desc) + attrs['_is_toolkit'] = is_toolkit + attrs['__tool_description__'] = tool_desc + return super().__new__(mcs, name, base, attrs) + + +class BaseAction(metaclass=AutoRegister(TOOL_REGISTRY, ToolMeta)): """Base class for all actions. Args: @@ -14,7 +235,7 @@ class BaseAction: action's inputs and outputs. Defaults to :class:`JsonParser`. enable (:class:`bool`): Whether the action is enabled. Defaults to ``True``. - + Examples: * simple tool @@ -22,64 +243,65 @@ class BaseAction: .. code-block:: python class Bold(BaseAction): - def run(self, text): + '''Make text bold''' + + @tool_api + def run(self, text: str): + ''' + Args: + text (str): input text + + Returns: + str: bold text + ''' return '**' + text + '**' - desc = dict( - name='bold', - description='make text bold', - parameters=[dict(name='text', type='STRING', description='input text')], - required=['text'], - ) - action = Bold(desc) + action = Bold() * toolkit with multiple APIs .. code-block:: python class Calculator(BaseAction): + '''Calculator''' + + @tool_api def add(self, a, b): - return a + b + '''Add operation + Args: + a (int): augend + b (int): addend + + Returns: + int: sum + ''' + return a + b + + @tool_api def sub(self, a, b): + '''Subtraction operation + + Args: + a (int): minuend + b (int): subtrahend + + Returns: + int: difference + ''' return a - b - desc = dict( - name='calculate', - description='perform arithmetic operations', - api_list=[ - dict( - name='add', - descrition='addition operation', - parameters=[ - dict(name='a', type='NUMBER', description='augend'), - dict(name='b', type='NUMBER', description='addend'), - ], - required=['a', 'b'], - ), - dict( - name='sub', - description='subtraction operation', - parameters=[ - dict(name='a', type='NUMBER', description='minuend'), - dict(name='b', type='NUMBER', description='subtrahend'), - ], - required=['a', 'b'], - ) - ] - ) - action = Calculator(desc) + action = Calculator() """ def __init__(self, description: Optional[dict] = None, parser: Type[BaseParser] = JsonParser, enable: bool = True): - self._description = description.copy() if description else {} - self._name = self._description.get('name', self.__class__.__name__) - self._enable = enable - self._is_toolkit = 'api_list' in self._description + self._description = deepcopy(description or self.__tool_description__) + self._name = self._description['name'] self._parser = parser(self) + self._enable = enable def __call__(self, inputs: str, name='run') -> ActionReturn: fallback_args = {'inputs': inputs, 'name': name} @@ -116,9 +338,6 @@ def __call__(self, inputs: str, name='run') -> ActionReturn: action_return = ActionReturn(inputs, type=self.name, result=result) return action_return - def run(self): - return NotImplementedError - @property def name(self): return self._name @@ -127,14 +346,15 @@ def name(self): def enable(self): return self._enable - @property - def description(self): - return self._description - @property def is_toolkit(self): return self._is_toolkit + @property + def description(self) -> dict: + """Description of the tool""" + return self._description + def __repr__(self): return f'{self.description}' diff --git a/lagent/actions/bing_map.py b/lagent/actions/bing_map.py index 7ebaff2a..6906cb62 100644 --- a/lagent/actions/bing_map.py +++ b/lagent/actions/bing_map.py @@ -4,98 +4,9 @@ import requests -from lagent.actions.base_action import BaseAction +from lagent.actions.base_action import BaseAction, tool_api from lagent.actions.parser import BaseParser, JsonParser -DEFAULT_DESCRIPTION = dict( - name='BINGMap', - description='Plugin for looking up map information', - api_list=[ - dict( - name='get_distance', - description='Get the distance between two locations in km.', - parameters=[ - dict( - name='start', - type='STRING', - description='The start location.'), - dict( - name='end', type='STRING', description='The end location.') - ], - required=['start', 'end'], - return_data=[ - dict(name='distance', description='the distance in km.') - ]), - dict( - name='get_route', - description='Get the route between two locations in km.', - parameters=[ - dict( - name='start', - type='STRING', - description='The start location.'), - dict( - name='end', type='STRING', description='The end location.') - ], - required=['start', 'end'], - return_data=[ - dict( - name='route', description='the route, a list of actions.') - ]), - dict( - name='get_coordinates', - description='Get the coordinates of a location.', - parameters=[ - dict( - name='location', - type='STRING', - description='the location need to get coordinates.') - ], - required=['location'], - return_data=[ - dict( - name='latitude', - description='the latitude of the location.'), - dict( - name='longitude', - description='the longitude of the location.') - ]), - dict( - name='search_nearby', - description= - 'Search for places nearby a location, within a given radius, and return the results into a list. You can use either the places name or the latitude and longitude.', - parameters=[ - dict( - name='search_term', - type='STRING', - description='the place name'), - dict( - name='places', - type='STRING', - description='the name of the location.'), - dict( - name='latitude', - type='FLOAT', - description='the latitude of the location.'), - dict( - name='longitude', - type='FLOAT', - description='the longitude of the location.'), - dict( - name='radius', - type='NUMBER', - description='radius in meters.') - ], - required=['search_term'], - return_data=[ - dict( - name='places', - description= - 'the list of places, each place is a dict with name and address, at most 5 places.' - ) - ]), - ]) - class BINGMap(BaseAction): """BING Map plugin for looking up map information""" @@ -105,7 +16,7 @@ def __init__(self, description: Optional[dict] = None, parser: Type[BaseParser] = JsonParser, enable: bool = True) -> None: - super().__init__(description or DEFAULT_DESCRIPTION, parser, enable) + super().__init__(description, parser, enable) key = os.environ.get('BING_MAP_KEY') if key is None: raise ValueError( @@ -114,7 +25,17 @@ def __init__(self, self.key = key self.base_url = 'http://dev.virtualearth.net/REST/V1/' + @tool_api(return_dict=True) def get_distance(self, start: str, end: str) -> dict: + """Get the distance between two locations in km. + + Args: + start (:class:`str`): The start location + end (:class:`str`): The end location + + Returns: + distance (:class:`str`): the distance in km. + """ # Request URL url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key # GET request @@ -127,7 +48,17 @@ def get_distance(self, start: str, end: str) -> dict: distance = route['travelDistance'] return dict(distance=distance) + @tool_api(return_dict=True) def get_route(self, start: str, end: str) -> dict: + """Get the route between two locations in km. + + Args: + start (:class:`str`): The start location + end (:class:`str`): The end location + + Returns: + route (:class:`list`): the route, a list of actions. + """ # Request URL url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key # GET request @@ -143,7 +74,17 @@ def get_route(self, start: str, end: str) -> dict: route_text.append(item['instruction']['text']) return dict(route=route_text) + @tool_api(return_dict=True) def get_coordinates(self, location: str) -> dict: + """Get the coordinates of a location. + + Args: + location (:class:`str`): the location need to get coordinates. + + Returns: + latitude (:class:`float`): the latitude of the location. + longitude (:class:`float`): the longitude of the location. + """ url = self.base_url + 'Locations' params = {'query': location, 'key': self.key} response = requests.get(url, params=params) @@ -152,12 +93,27 @@ def get_coordinates(self, location: str) -> dict: 'coordinates'] return dict(latitude=coordinates[0], longitude=coordinates[1]) + @tool_api(return_dict=True) def search_nearby(self, search_term: str, places: str = 'unknown', latitude: float = 0.0, longitude: float = 0.0, radius: int = 5000) -> dict: # radius in meters + """Search for places nearby a location, within a given radius, and \ +return the results into a list. You can use either the places name or the \ +latitude and longitude. + + Args: + search_term (:class:`str`): the place name. + places (:class:`str`): the name of the location. Defaults to ``'unknown'``. + latitude (:class:`float`): the latitude of the location. Defaults to ``0.0``. + longitude (:class:`float`): the longitude of the location. Defaults to ``0.0``. + radius (:class:`int`): radius in meters. Defaults to ``5000``. + + Returns: + places (:class:`list`): the list of places, each place is a dict with name and address, at most 5 places. + """ url = self.base_url + 'LocalSearch' if places != 'unknown': pos = self.get_coordinates(**{'location': places}) diff --git a/lagent/actions/builtin_actions.py b/lagent/actions/builtin_actions.py index 76aea6c1..805f99f1 100644 --- a/lagent/actions/builtin_actions.py +++ b/lagent/actions/builtin_actions.py @@ -1,6 +1,7 @@ from typing import Optional -from lagent.actions.base_action import BaseAction +from lagent.actions.base_action import BaseAction, tool_api +from lagent.actions.parser import BaseParser from lagent.schema import ActionReturn, ActionStatusCode, ActionValidCode @@ -19,12 +20,13 @@ class InvalidAction(BaseAction): def __init__(self, err_msg: str = 'The action is invalid, please check the action name.', - **kwargs) -> None: - - super().__init__(enable=False, **kwargs) + description: Optional[dict] = None, + parser=BaseParser) -> None: + super().__init__(description, parser, enable=False) self._err_msg = err_msg - def __call__(self, err_msg: Optional[str] = None): + @tool_api + def run(self, err_msg: Optional[str] = None) -> ActionReturn: """Return the error message. Args: @@ -51,11 +53,15 @@ class NoAction(BaseAction): 'Please follow the format'. """ - def __init__(self, err_msg: str = 'Please follow the format', **kwargs): - super().__init__(enable=False, **kwargs) + def __init__(self, + err_msg: str = 'Please follow the format', + description: Optional[dict] = None, + parser=BaseParser): + super().__init__(description, parser, enable=False) self._err_msg = err_msg - def __call__(self, err_msg: Optional[str] = None): + @tool_api + def run(self, err_msg: Optional[str] = None) -> ActionReturn: """Return the error message. Args: @@ -80,7 +86,11 @@ class FinishAction(BaseAction): """This is a finish action class, which is used to return the final result.""" - def __call__(self, response: str) -> ActionReturn: + def __init__(self, description: Optional[dict] = None, parser=BaseParser): + super().__init__(description, parser, enable=True) + + @tool_api + def run(self, response: str) -> ActionReturn: """Return the final result. Args: diff --git a/lagent/actions/google_scholar_search.py b/lagent/actions/google_scholar_search.py index 941870e6..9209a0e6 100644 --- a/lagent/actions/google_scholar_search.py +++ b/lagent/actions/google_scholar_search.py @@ -3,282 +3,13 @@ from serpapi import GoogleSearch -from lagent.actions.base_action import BaseAction +from lagent.actions.base_action import BaseAction, tool_api from lagent.schema import ActionReturn, ActionStatusCode from .parser import BaseParser, JsonParser -DEFAULT_DESCRIPTION = dict( - name='GoogleScholar', - description='Plugin for google scholar search', - api_list=[{ - 'name': - 'search_google_scholar', - 'description': - 'Search for scholarly articles based on a query according to the google scholar', - 'parameters': [ - { - 'name': 'query', - 'description': 'The query to search for.', - 'type': 'STRING' - }, - { - 'name': 'cites', - 'description': - 'The unique ID of an article for triggering "Cited By" searches', - 'type': 'STRING' - }, - { - 'name': 'as_ylo', - 'description': - 'The starting year for results (e.g., if as_ylo=2018, results before this year will be omitted)', - 'type': 'NUMBER' - }, - { - 'name': 'as_yhi', - 'description': - 'The ending year for results (e.g., if as_yhi=2018, results after this year will be omitted)', - 'type': 'NUMBER' - }, - { - 'name': 'scisbd', - 'description': - 'Defines articles added in the last year, sorted by date. It can be set to 1 to include only abstracts, or 2 to include everything', - 'type': 'NUMBER' - }, - { - 'name': 'cluster', - 'description': - 'The unique ID of an article for triggering "All Versions" searches', - 'type': 'STRING' - }, - { - 'name': 'hl', - 'description': - 'The language to use for the Google Scholar search', - 'type': 'STRING' - }, - { - 'name': 'lr', - 'description': - 'One or multiple languages to limit the search to', - 'type': 'STRING' - }, - { - 'name': 'start', - 'description': - 'The result offset for pagination (0 is the first page of results, 10 is the 2nd page, etc.)', - 'type': 'NUMBER' - }, - { - 'name': 'num', - 'description': - 'The maximum number of results to return, limited to 20', - 'type': 'NUMBER' - }, - { - 'name': 'as_sdt', - 'description': - 'Can be used either as a search type or a filter', - 'type': 'STRING' - }, - { - 'name': 'safe', - 'description': 'The level of filtering for adult content', - 'type': 'STRING' - }, - { - 'name': 'filter', - 'description': - "Defines if the filters for 'Similar Results' and 'Omitted Results' are on or off", - 'type': 'STRING' - }, - { - 'name': 'as_vis', - 'description': 'Defines whether to include citations or not', - 'type': 'STRING' - }, - ], - 'required': ['query'], - 'return_data': [{ - 'name': - 'title', - 'description': - 'a list of the titles of the three selected papers' - }, { - 'name': - 'cited_by', - 'description': - 'a list of the citation numbers of the three selected papers' - }, { - 'name': - 'organic_id', - 'description': - 'a list of the organic results\' ids of the three selected papers' - }, { - 'name': 'snippets', - 'description': 'snippets of the papers' - }, { - 'name': - 'pub_info', - 'description': - 'publication information of selected papers' - }] - }, { - 'name': - 'get_author_information', - 'description': - 'Search for an author\'s information by author\'s id provided by get_author_id.', - 'parameters': [{ - 'name': 'author_id', - 'description': 'Required. The ID of an author.', - 'type': 'STRING' - }, { - 'name': 'hl', - 'description': - "The language to use for the Google Scholar Author search. Default is 'en'.", - 'type': 'STRING' - }, { - 'name': 'view_op', - 'description': 'Used for viewing specific parts of a page.', - 'type': 'STRING' - }, { - 'name': 'sort', - 'description': 'Used for sorting and refining articles.', - 'type': 'STRING' - }, { - 'name': 'citation_id', - 'description': 'Used for retrieving individual article citation.', - 'type': 'STRING' - }, { - 'name': 'start', - 'description': 'Defines the result offset. Default is 0.', - 'type': 'NUMBER' - }, { - 'name': 'num', - 'description': - 'Defines the number of results to return. Default is 20.', - 'type': 'NUMBER' - }, { - 'name': 'no_cache', - 'description': - 'Forces SerpApi to fetch the results even if a cached version is already present. Default is False.', - 'type': 'BOOLEAN' - }, { - 'name': 'async_req', - 'description': - 'Defines the way you want to submit your search to SerpApi. Default is False.', - 'type': 'BOOLEAN' - }, { - 'name': 'output', - 'description': - "Defines the final output you want. Default is 'json'.", - 'type': 'STRING' - }], - 'required': ['author_id'], - 'return_data': [{ - 'name': 'name', - 'description': "author's name" - }, { - 'name': 'affliation', - 'description': 'the affliation of the author' - }, { - 'name': 'articles', - 'description': 'at most 3 articles by the author' - }, { - 'name': 'website', - 'description': "the author's homepage url" - }] - }, { - 'name': - 'get_citation_format', - 'description': - 'Function to get MLA citation format by an identification of organic_result\'s id provided by search_google_scholar.', - 'parameters': [{ - 'name': 'q', - 'description': - 'ID of an individual Google Scholar organic search result.', - 'type': 'STRING' - }, { - 'name': 'no_cache', - 'description': - 'If set to True, will force SerpApi to fetch the Google Scholar Cite results even if a cached version is already present. Defaults to None.', - 'type': 'BOOLEAN' - }, { - 'name': 'async_', - 'description': - 'If set to True, will submit search to SerpApi and retrieve results later. Defaults to None.', - 'type': 'BOOLEAN' - }, { - 'name': 'output', - 'description': - "Final output format. Set to 'json' to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.", - 'type': 'STRING' - }], - 'required': ['q'], - 'return_data': [{ - 'name': 'authors', - 'description': 'the authors of the article' - }, { - 'name': 'citation', - 'description': 'the citation format of the article' - }] - }, { - 'name': - 'get_author_id', - 'description': - 'The getAuthorId function is used to get the author\'s id by his or her name.', - 'parameters': [{ - 'name': 'mauthors', - 'description': 'Defines the author you want to search for.', - 'type': 'STRING' - }, { - 'name': 'hl', - 'description': - "Defines the language to use for the Google Scholar Profiles search.It's a two-letter language code. (e.g., 'en' for English, 'es' for Spanish, or 'fr' for French). Defaults to 'en'.", - 'type': 'STRING' - }, { - 'name': 'after_author', - 'description': - 'Defines the next page token. It is used for retrieving the next page results. The parameter has the precedence over before_author parameter. Defaults to None.', - 'type': 'STRING' - }, { - 'name': 'before_author', - 'description': - 'Defines the previous page token. It is used for retrieving the previous page results. Defaults to None.', - 'type': 'STRING' - }, { - 'name': 'no_cache', - 'description': - 'Will force SerpApi to fetch the Google Scholar Profiles results even if a cached version is already present. Defaults to False.', - 'type': 'BOOLEAN' - }, { - 'name': '_async', - 'description': - 'Defines the way you want to submit your search to SerpApi. Defaults to False.', - 'type': 'BOOLEAN' - }, { - 'name': 'output', - 'description': - "Defines the final output you want. It can be set to 'json' (default) to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.", - 'type': 'STRING' - }], - 'required': ['mauthors'], - 'return_data': [{ - 'name': 'author_id', - 'description': 'the author_id of the author' - }], - }]) - class GoogleScholar(BaseAction): - """Wrapper around the Serper.dev Google Search API. - - To use, you should pass your serper API key to the constructor. - - Code is modified from lang-chain GoogleSerperAPIWrapper - (https://github.com/langchain-ai/langchain/blob/ba5f - baba704a2d729a4b8f568ed70d7c53e799bb/libs/langchain/ - langchain/utilities/google_serper.py) + """Plugin for google scholar search Args: api_key (str): API KEY to use serper google search API, @@ -296,7 +27,7 @@ def __init__(self, description: Optional[dict] = None, parser: Type[BaseParser] = JsonParser, enable: bool = True): - super().__init__(description or DEFAULT_DESCRIPTION, parser, enable) + super().__init__(description, parser, enable) api_key = os.environ.get('SERPER_API_KEY', api_key) if api_key is None: raise ValueError( @@ -304,6 +35,7 @@ def __init__(self, 'as SERPER_API_KEY or pass it as `api_key` parameter.') self.api_key = api_key + @tool_api(return_dict=True) def search_google_scholar( self, query: str, @@ -320,7 +52,31 @@ def search_google_scholar( safe: Optional[str] = None, filter: Optional[str] = None, as_vis: Optional[str] = None, - ): + ) -> dict: + """Search for scholarly articles based on a query according to the google scholar + + Args: + query (str): The query to search for. + cites (Optional[str]): The unique ID of an article for triggering "Cited By" searches. + as_ylo (Optional[int]): The starting year for results (e.g., if as_ylo=2018, results before this year will be omitted). + as_yhi (Optional[int]): The ending year for results (e.g., if as_yhi=2018, results after this year will be omitted). + scisbd (Optional[int]): Defines articles added in the last year, sorted by date. It can be set to 1 to include only abstracts, or 2 to include everything. + cluster (Optional[str]): The unique ID of an article for triggering "All Versions" searches. + hl (Optional[str]): The language to use for the Google Scholar search. + lr (Optional[str]): One or multiple languages to limit the search to. + start (Optional[int]): The result offset for pagination (0 is the first page of results, 10 is the 2nd page, etc.) + num (Optional[int]): The maximum number of results to return, limited to 20. + as_sdt (Optional[str]): Can be used either as a search type or a filter. + safe (Optional[str]): The level of filtering for adult content. + filter (Optional[str]): Defines if the filters for 'Similar Results' and 'Omitted Results' are on or off. + as_vis (Optional[str]): Defines whether to include citations or not. + + Returns: + title: a list of the titles of the three selected papers + cited_by: a list of the citation numbers of the three selected papers + organic_id: a list of the organic results' ids of the three selected papers + pub_info: publication information of selected papers + """ params = { 'q': query, 'engine': 'google_scholar', @@ -364,6 +120,7 @@ def search_google_scholar( return ActionReturn( errmsg=str(e), state=ActionStatusCode.HTTP_ERROR) + @tool_api(return_dict=True) def get_author_information(self, author_id: str, hl: Optional[str] = None, @@ -374,7 +131,27 @@ def get_author_information(self, num: Optional[int] = None, no_cache: Optional[bool] = None, async_req: Optional[bool] = None, - output: Optional[str] = None): + output: Optional[str] = None) -> dict: + """Search for an author's information by author's id provided by get_author_id. + + Args: + author_id (str): Required. The ID of an author. + hl (Optional[str]): The language to use for the Google Scholar Author search. Default is 'en'. + view_op (Optional[str]): Used for viewing specific parts of a page. + sort (Optional[str]): Used for sorting and refining articles. + citation_id (Optional[str]): Used for retrieving individual article citation. + start (Optional[int]): Defines the result offset. Default is 0. + num (Optional[int]): Defines the number of results to return. Default is 20. + no_cache (Optional[bool]): Forces SerpApi to fetch the results even if a cached version is already present. Default is False. + async_req (Optional[bool]): Defines the way you want to submit your search to SerpApi. Default is False. + output (Optional[str]): Defines the final output you want. Default is 'json'. + + Returns: + name: author's name + affliation: the affliation of the author + articles: at most 3 articles by the author + website: the author's homepage url + """ params = { 'engine': 'google_scholar_author', 'author_id': author_id, @@ -406,11 +183,24 @@ def get_author_information(self, return ActionReturn( errmsg=str(e), state=ActionStatusCode.HTTP_ERROR) + @tool_api(return_dict=True) def get_citation_format(self, q: str, no_cache: Optional[bool] = None, async_: Optional[bool] = None, - output: Optional[str] = 'json'): + output: Optional[str] = 'json') -> dict: + """Function to get MLA citation format by an identification of organic_result's id provided by search_google_scholar. + + Args: + q (str): ID of an individual Google Scholar organic search result. + no_cache (Optional[bool]): If set to True, will force SerpApi to fetch the Google Scholar Cite results even if a cached version is already present. Defaults to None. + async_ (Optional[bool]): If set to True, will submit search to SerpApi and retrieve results later. Defaults to None. + output (Optional[str]): Final output format. Set to 'json' to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'. + + Returns: + authors: the authors of the article + citation: the citation format of the article + """ params = { 'q': q, 'engine': 'google_scholar_cite', @@ -429,6 +219,7 @@ def get_citation_format(self, return ActionReturn( errmsg=str(e), state=ActionStatusCode.HTTP_ERROR) + @tool_api(return_dict=True) def get_author_id(self, mauthors: str, hl: Optional[str] = 'en', @@ -436,7 +227,21 @@ def get_author_id(self, before_author: Optional[str] = None, no_cache: Optional[bool] = False, _async: Optional[bool] = False, - output: Optional[str] = 'json'): + output: Optional[str] = 'json') -> dict: + """The getAuthorId function is used to get the author's id by his or her name. + + Args: + mauthors (str): Defines the author you want to search for. + hl (Optional[str]): Defines the language to use for the Google Scholar Profiles search. It's a two-letter language code. (e.g., 'en' for English, 'es' for Spanish, or 'fr' for French). Defaults to 'en'. + after_author (Optional[str]): Defines the next page token. It is used for retrieving the next page results. The parameter has the precedence over before_author parameter. Defaults to None. + before_author (Optional[str]): Defines the previous page token. It is used for retrieving the previous page results. Defaults to None. + no_cache (Optional[bool]): Will force SerpApi to fetch the Google Scholar Profiles results even if a cached version is already present. Defaults to False. + _async (Optional[bool]): Defines the way you want to submit your search to SerpApi. Defaults to False. + output (Optional[str]): Defines the final output you want. It can be set to 'json' (default) to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'. + + Returns: + author_id: the author_id of the author + """ params = { 'mauthors': mauthors, 'engine': 'google_scholar_profiles', diff --git a/lagent/actions/google_search.py b/lagent/actions/google_search.py index e26d829a..3f5b116a 100644 --- a/lagent/actions/google_search.py +++ b/lagent/actions/google_search.py @@ -4,23 +4,9 @@ import requests from lagent.schema import ActionReturn, ActionStatusCode -from .base_action import BaseAction +from .base_action import BaseAction, tool_api from .parser import BaseParser, JsonParser -DEFAULT_DESCRIPTION = dict( - name='GoogleSearch', - description='一个可以从谷歌搜索结果的API。当你需要对于一个特定问题找到简短明了的回答时,可以使用它。输入应该是一个搜索查询。', - parameters=[ - dict(name='query', type='STRING', description='the search content'), - dict( - name='k', - type='NUMBER', - description= - 'select first k results in the search results as response'), - ], - required=['query'], -) - class GoogleSearch(BaseAction): """Wrapper around the Serper.dev Google Search API. @@ -59,7 +45,7 @@ def __init__(self, description: Optional[dict] = None, parser: Type[BaseParser] = JsonParser, enable: bool = True): - super().__init__(description or DEFAULT_DESCRIPTION, parser, enable) + super().__init__(description, parser, enable) api_key = os.environ.get('SERPER_API_KEY', api_key) if api_key is None: raise ValueError( @@ -69,8 +55,14 @@ def __init__(self, self.timeout = timeout self.search_type = search_type + @tool_api def run(self, query: str, k: int = 10) -> ActionReturn: - """Return the search response.""" + """一个可以从谷歌搜索结果的API。当你需要对于一个特定问题找到简短明了的回答时,可以使用它。输入应该是一个搜索查询。 + + Args: + query (str): the search content + k (int): select first k results in the search results as response + """ tool_return = ActionReturn(type=self.name) status_code, response = self._search(query, k=k) # convert search results to ToolReturn format diff --git a/lagent/actions/parser.py b/lagent/actions/parser.py index d31bc69e..9dbf6104 100644 --- a/lagent/actions/parser.py +++ b/lagent/actions/parser.py @@ -1,6 +1,7 @@ import json +import re from ast import literal_eval -from typing import Any +from typing import Any, Union class ParseError(Exception): @@ -82,11 +83,18 @@ class JsonParser(BaseParser): PARAMETER_DESCRIPTION = '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称' - def parse_inputs(self, inputs: str, name: str = 'run') -> dict: - try: - inputs = json.loads(inputs) - except json.JSONDecodeError as exc: - raise ParseError(f'invalid json format: {inputs}') from exc + def parse_inputs(self, + inputs: Union[str, dict], + name: str = 'run') -> dict: + if not isinstance(inputs, dict): + try: + match = re.search(r'^\s*(```json\n)?(.*)\n```\s*$', inputs, + re.S) + if match: + inputs = match.group(2).strip() + inputs = json.loads(inputs) + except json.JSONDecodeError as exc: + raise ParseError(f'invalid json format: {inputs}') from exc input_keys = set(inputs) all_keys = {param['name'] for param in self._api2param[name]} if not input_keys.issubset(all_keys): @@ -107,11 +115,14 @@ class TupleParser(BaseParser): PARAMETER_DESCRIPTION = '如果调用该工具,你必须使用Tuple格式 (arg1, arg2, arg3) 传参,且参数是有序的' - def parse_inputs(self, inputs: str, name: str = 'run') -> dict: - try: - inputs = literal_eval(inputs) - except Exception as exc: - raise ParseError(f'invalid tuple format: {inputs}') from exc + def parse_inputs(self, + inputs: Union[str, tuple], + name: str = 'run') -> dict: + if not isinstance(inputs, tuple): + try: + inputs = literal_eval(inputs) + except Exception as exc: + raise ParseError(f'invalid tuple format: {inputs}') from exc if len(inputs) < len(self._api2required[name]): raise ParseError( f'API takes {len(self._api2required[name])} required positional ' diff --git a/lagent/actions/ppt.py b/lagent/actions/ppt.py index 26715f34..9cd021bc 100644 --- a/lagent/actions/ppt.py +++ b/lagent/actions/ppt.py @@ -2,115 +2,9 @@ from pptx import Presentation -from lagent.actions.base_action import BaseAction +from lagent.actions.base_action import BaseAction, tool_api from lagent.actions.parser import BaseParser, JsonParser -DEFAULT_DESCRIPTION = dict( - name='PPT', - description= - 'This tool allows you to create ppt slides with text, paragraph, images, with good looking styles', - api_list=[ - dict( - name='create_file', - description='Create a pptx file with specific themes', - parameters=[ - dict( - name='theme', type='STRING', description='the theme used'), - dict( - name='abs_location', - type='STRING', - description='the ppt file\'s absolute location') - ], - required=['theme', 'abs_location'], - return_data=[ - dict(name='status', description='the result of the execution') - ]), - dict( - name='get_image', - description= - 'Get an image given comma separated keywords, return the image path.', - parameters=[ - dict( - name='keywords', - type='STRING', - description= - 'the comma separated keywords to describe the image') - ], - required=['keywords'], - return_data=[ - dict(name='status', description='the result of the execution') - ]), - dict( - name='add_first_page', - description='Add the first page of ppt.', - parameters=[ - dict( - name='title', - type='STRING', - description='the title of ppt'), - dict( - name='subtitle', - type='STRING', - description='the subtitle of ppt') - ], - required=['title', 'subtitle'], - return_data=[ - dict(name='status', description='the result of the execution') - ]), - dict( - name='add_text_page', - description='Add text page of ppt', - parameters=[ - dict( - name='title', - type='STRING', - description='the title of the page'), - dict( - name='bullet_items', - type='STRING', - description= - 'bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them.' - ) - ], - required=['title', 'bullet_items'], - return_data=[ - dict(name='status', description='the result of the execution') - ]), - dict( - name='add_text_image_page', - description= - 'Add a text page with one image. Image should be a path', - parameters=[ - dict( - name='title', - type='STRING', - description='the title of the page'), - dict( - name='bullet_items', - type='STRING', - description= - 'bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them.' - ), - dict( - name='image', - type='STRING', - description='the path of the image') - ], - required=['title', 'bullet_items', 'image'], - return_data=[ - dict(name='status', description='the result of the execution') - ]), - dict( - name='submit_file', - description= - 'When all steps done, YOU MUST use submit_file() to submit your work.', - parameters=[], - required=[], - return_data=[ - dict(name='status', description='the result of the execution') - ]) - ]) - THEME_MAPPING = { 'Default': { 'template': None, @@ -129,12 +23,22 @@ def __init__(self, description: Optional[dict] = None, parser: Type[BaseParser] = JsonParser, enable: bool = True): - super().__init__(description or DEFAULT_DESCRIPTION, parser, enable) + super().__init__(description, parser, enable) self.theme_mapping = theme_mapping or THEME_MAPPING self.pointer = None self.location = None + @tool_api(return_dict=True) def create_file(self, theme: str, abs_location: str) -> dict: + """Create a pptx file with specific themes + + Args: + theme (:class:`str`): the theme used + abs_location (:class:`str`): the ppt file's absolute location + + Returns: + status: the result of the execution + """ self.location = abs_location try: self.pointer = Presentation(self.theme_mapping[theme]['template']) @@ -144,7 +48,17 @@ def create_file(self, theme: str, abs_location: str) -> dict: print(e) return dict(status='created a ppt file.') + @tool_api(return_dict=True) def add_first_page(self, title: str, subtitle: str) -> dict: + """Add the first page of ppt. + + Args: + title (:class:`str`): the title of ppt + subtitle (:class:`str`): the subtitle of ppt + + Returns: + status: the result of the execution + """ layout_name = self.theme_mapping[ self.pointer.slide_master.name]['title'] layout = next(i for i in self.pointer.slide_master.slide_layouts @@ -156,7 +70,17 @@ def add_first_page(self, title: str, subtitle: str) -> dict: ph_subtitle.text = subtitle return dict(status='added page') + @tool_api(return_dict=True) def add_text_page(self, title: str, bullet_items: str) -> dict: + """Add text page of ppt + + Args: + title (:class:`str`): the title of the page + bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them. + + Returns: + status: the result of the execution + """ layout_name = self.theme_mapping[ self.pointer.slide_master.name]['single'] layout = next(i for i in self.pointer.slide_master.slide_layouts @@ -175,8 +99,19 @@ def add_text_page(self, title: str, bullet_items: str) -> dict: p.level = 0 return dict(status='added page') + @tool_api(return_dict=True) def add_text_image_page(self, title: str, bullet_items: str, image: str) -> dict: + """Add a text page with one image. Image should be a path + + Args: + title (:class:`str`): the title of the page + bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them. + image (:class:`str`): the path of the image + + Returns: + status: the result of the execution + """ layout_name = self.theme_mapping[self.pointer.slide_master.name]['two'] layout = next(i for i in self.pointer.slide_master.slide_layouts if i.name == layout_name) @@ -203,7 +138,13 @@ def add_text_image_page(self, title: str, bullet_items: str, return dict(status='added page') + @tool_api(return_dict=True) def submit_file(self) -> dict: + """When all steps done, YOU MUST use submit_file() to submit your work. + + Returns: + status: the result of the execution + """ # file_path = os.path.join(self.CACHE_DIR, f'{self._return_timestamp()}.pptx') # self.pointer.save(file_path) # retreival_url = upload_file(file_path) diff --git a/lagent/actions/python_interpreter.py b/lagent/actions/python_interpreter.py index 9720ef72..88424a45 100644 --- a/lagent/actions/python_interpreter.py +++ b/lagent/actions/python_interpreter.py @@ -5,7 +5,7 @@ from func_timeout import FunctionTimedOut, func_set_timeout -from lagent.actions.base_action import BaseAction +from lagent.actions.base_action import BaseAction, tool_api from lagent.actions.parser import BaseParser, JsonParser from lagent.schema import ActionReturn, ActionStatusCode @@ -30,31 +30,6 @@ def eval_code(self, expr: str) -> Any: return eval(expr, self._global_vars) -DEFAULT_DESCRIPTION = dict( - name='PythonInterpreter', - description="""用来执行Python代码。代码必须是一个函数, -函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下: -```python -# import 依赖包 -import xxx -def solution(): - # 初始化一些变量 - variable_names_with_real_meaning = xxx - # 步骤一 - mid_variable = func(variable_names_with_real_meaning) - # 步骤 x - mid_variable = func(mid_variable) - # 最后结果 - final_answer = func(mid_variable) - return final_answer -```""", - parameters=[ - dict(name='command', type='STRING', description='Python code snippet') - ], - required=['command'], -) - - class PythonInterpreter(BaseAction): """A Python executor that can execute Python scripts. @@ -81,13 +56,33 @@ def __init__(self, description: Optional[dict] = None, parser: Type[BaseParser] = JsonParser, enable: bool = True) -> None: - super().__init__(description or DEFAULT_DESCRIPTION, parser, enable) + super().__init__(description, parser, enable) self.answer_symbol = answer_symbol self.answer_expr = answer_expr self.answer_from_stdout = answer_from_stdout self.timeout = timeout + @tool_api def run(self, command: str) -> ActionReturn: + """用来执行Python代码。代码必须是一个函数,函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下: + ```python + # import 依赖包 + import xxx + def solution(): + # 初始化一些变量 + variable_names_with_real_meaning = xxx + # 步骤一 + mid_variable = func(variable_names_with_real_meaning) + # 步骤 x + mid_variable = func(mid_variable) + # 最后结果 + final_answer = func(mid_variable) + return final_answer + ``` + + Args: + command (:class:`str`): Python code snippet + """ self.runtime = GenericRuntime() try: tool_return = func_set_timeout(self.timeout)(self._call)(command) diff --git a/lagent/llms/base_llm.py b/lagent/llms/base_llm.py index a729ad35..e350027f 100644 --- a/lagent/llms/base_llm.py +++ b/lagent/llms/base_llm.py @@ -138,12 +138,7 @@ def __init__(self, repetition_penalty=repetition_penalty, stop_words=stop_words) - @abstractclassmethod - def generate( - self, - inputs: Union[str, List[str]], - **gen_params - ) -> str: + def generate(self, inputs: Union[str, List[str]], **gen_params) -> str: """Generate results given a str (or list of) inputs. Args: @@ -165,11 +160,7 @@ def generate( """ raise NotImplementedError - def stream_generate( - self, - inputs: str, - **gen_params - ) -> List[str]: + def stream_generate(self, inputs: str, **gen_params) -> List[str]: """Generate results as streaming given a str inputs. Args: @@ -181,11 +172,7 @@ def stream_generate( """ raise NotImplementedError - def chat( - self, - inputs: Union[List[dict], List[List[dict]]], - **gen_params - ): + def chat(self, inputs: Union[List[dict], List[List[dict]]], **gen_params): """Generate completion from a list of templates. Args: @@ -201,11 +188,7 @@ def chat( inputs = self.template_parser(inputs) return self.generate(inputs, **gen_params) - def stream_chat( - self, - inputs: List[dict], - **gen_params - ): + def stream_chat(self, inputs: List[dict], **gen_params): """Generate results as streaming given a list of templates. Args: diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 2e601beb..e58b6b15 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -3,3 +3,5 @@ func_timeout jsonschema requests tiktoken +griffe +phx-class-registry From 1b5fd996d3037b84b78b39a8c0525015d66ee10c Mon Sep 17 00:00:00 2001 From: BraisedPork <46232992+braisedpork1964@users.noreply.github.com> Date: Thu, 25 Jan 2024 15:43:28 +0800 Subject: [PATCH 11/20] Enhance tool annotation (#98) improve `tool_api` Co-authored-by: wangzy --- lagent/actions/arxiv_search.py | 5 +- lagent/actions/base_action.py | 69 +++++++++++++------------ lagent/actions/bing_map.py | 22 ++++---- lagent/actions/google_scholar_search.py | 34 ++++++------ lagent/actions/ppt.py | 25 +++++---- 5 files changed, 86 insertions(+), 69 deletions(-) diff --git a/lagent/actions/arxiv_search.py b/lagent/actions/arxiv_search.py index 0d833ccc..0b3e6332 100644 --- a/lagent/actions/arxiv_search.py +++ b/lagent/actions/arxiv_search.py @@ -26,7 +26,7 @@ def __init__(self, self.max_query_len = max_query_len self.doc_content_chars_max = doc_content_chars_max - @tool_api(return_dict=True) + @tool_api(explode_return=True) def get_arxiv_article_information(self, query: str) -> dict: """Run Arxiv search and get the article meta information. @@ -34,7 +34,8 @@ def get_arxiv_article_information(self, query: str) -> dict: query (:class:`str`): the content of search query Returns: - content (:class:`str`): a list of 3 arxiv search papers + :class:`dict`: article information + * content (str): a list of 3 arxiv search papers """ try: results = arxiv.Search( # type: ignore diff --git a/lagent/actions/base_action.py b/lagent/actions/base_action.py index 9561c7a0..9c7e0dca 100644 --- a/lagent/actions/base_action.py +++ b/lagent/actions/base_action.py @@ -20,7 +20,7 @@ def tool_api(func: Optional[Callable] = None, *, - return_dict: bool = False, + explode_return: bool = False, returns_named_value: bool = False, **kwargs): """Turn functions into tools. It will parse typehints as well as docstrings @@ -48,15 +48,13 @@ def add(a: Annotated[int, 'augend'], b: Annotated[int, 'addend'] = 1): Args: func (Optional[Callable]): function to decorate. Defaults to ``None``. - return_dict (bool): suggest if returned data is a single dictionary. - When enabled, the returns sections in docstrings should indicate the - key-value infomation of the dictionary rather than hint a standard - tuple return. Defaults to ``False``. + explode_return (bool): whether to flatten the dictionary or tuple return + as the ``return_data`` field. When enabled, it is recommended to + annotate the member in docstrings. Defaults to ``False``. .. code-block:: python - # set `return_dict` True will force `returns_named_value` to be enabled - @tool_api(return_dict=True) + @tool_api(explode_return=True) def foo(a, b): '''A simple function @@ -65,8 +63,9 @@ def foo(a, b): b (int): b Returns: - x: the value of input a - y: the value of input b + dict: information of inputs + * x: value of a + * y: value of b ''' return {'x': a, 'y': b} @@ -75,14 +74,16 @@ def foo(a, b): returns_named_value (bool): whether to parse ``thing: Description`` in returns sections as a name and description, rather than a type and description. When true, type must be wrapped in parentheses: - ``(int): Description.``. When false, parentheses are optional but + ``(int): Description``. When false, parentheses are optional but the items cannot be named: ``int: Description``. Defaults to ``False``. + + Important: + ``return_data`` field will be added to ``api_description`` only + when ``explode_return`` or ``returns_named_value`` is enabled. Returns: Callable: wrapped function or partial decorator """ - if return_dict: - returns_named_value = True def _detect_type(string): field_type = 'STRING' @@ -97,6 +98,25 @@ def _detect_type(string): field_type = 'BOOLEAN' return field_type + def _explode(desc): + kvs = [] + desc = '\nArgs:\n' + '\n'.join([ + ' ' + item.lstrip(' -+*#.') + for item in desc.split('\n')[1:] if item.strip() + ]) + docs = Docstring(desc).parse('google') + if not docs: + return kvs + if docs[0].kind is DocstringSectionKind.parameters: + for d in docs[0].value: + d = d.as_dict() + if not d['annotation']: + d.pop('annotation') + else: + d['type'] = _detect_type(d.pop('annotation').lower()) + kvs.append(d) + return kvs + def _parse_tool(function): # remove rst syntax docs = Docstring( @@ -114,13 +134,11 @@ def _parse_tool(function): if doc.kind is DocstringSectionKind.parameters: for d in doc.value: d = d.as_dict() - d['description'] = d['description'] d['type'] = _detect_type(d.pop('annotation').lower()) args_doc[d['name']] = d if doc.kind is DocstringSectionKind.returns: for d in doc.value: d = d.as_dict() - d['description'] = d['description'] if not d['name']: d.pop('name') if not d['annotation']: @@ -154,26 +172,11 @@ def _parse_tool(function): if param.default is inspect.Signature.empty: desc['required'].append(param.name) - return_data, return_annotation = [], sig.return_annotation - if return_dict: + return_data = [] + if explode_return: + return_data = _explode(returns_doc[0]['description']) + elif returns_named_value: return_data = returns_doc - elif return_annotation is not inspect.Signature.empty: - if return_annotation is tuple: - return_data = returns_doc - elif get_origin(return_annotation) is tuple: - return_annotation = get_args(return_annotation) - if not return_annotation: - return_data = returns_doc - elif len(return_annotation) >= 2: - for i, item in enumerate(return_annotation): - info = returns_doc[i]['description'] if i < len( - returns_doc) else '' - if get_origin(item) is Annotated: - item, info = get_args(item) - return_data.append({ - 'description': info, - 'type': _detect_type(str(item)) - }) if return_data: desc['return_data'] = return_data return desc diff --git a/lagent/actions/bing_map.py b/lagent/actions/bing_map.py index 6906cb62..01efe698 100644 --- a/lagent/actions/bing_map.py +++ b/lagent/actions/bing_map.py @@ -25,7 +25,7 @@ def __init__(self, self.key = key self.base_url = 'http://dev.virtualearth.net/REST/V1/' - @tool_api(return_dict=True) + @tool_api(explode_return=True) def get_distance(self, start: str, end: str) -> dict: """Get the distance between two locations in km. @@ -34,7 +34,8 @@ def get_distance(self, start: str, end: str) -> dict: end (:class:`str`): The end location Returns: - distance (:class:`str`): the distance in km. + :class:`dict`: distance information + * distance (str): the distance in km. """ # Request URL url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key @@ -48,7 +49,7 @@ def get_distance(self, start: str, end: str) -> dict: distance = route['travelDistance'] return dict(distance=distance) - @tool_api(return_dict=True) + @tool_api(explode_return=True) def get_route(self, start: str, end: str) -> dict: """Get the route between two locations in km. @@ -57,7 +58,8 @@ def get_route(self, start: str, end: str) -> dict: end (:class:`str`): The end location Returns: - route (:class:`list`): the route, a list of actions. + :class:`dict`: route information + * route (list): the route, a list of actions. """ # Request URL url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key @@ -74,7 +76,7 @@ def get_route(self, start: str, end: str) -> dict: route_text.append(item['instruction']['text']) return dict(route=route_text) - @tool_api(return_dict=True) + @tool_api(explode_return=True) def get_coordinates(self, location: str) -> dict: """Get the coordinates of a location. @@ -82,8 +84,9 @@ def get_coordinates(self, location: str) -> dict: location (:class:`str`): the location need to get coordinates. Returns: - latitude (:class:`float`): the latitude of the location. - longitude (:class:`float`): the longitude of the location. + :class:`dict`: coordinates information + * latitude (float): the latitude of the location. + * longitude (float): the longitude of the location. """ url = self.base_url + 'Locations' params = {'query': location, 'key': self.key} @@ -93,7 +96,7 @@ def get_coordinates(self, location: str) -> dict: 'coordinates'] return dict(latitude=coordinates[0], longitude=coordinates[1]) - @tool_api(return_dict=True) + @tool_api(explode_return=True) def search_nearby(self, search_term: str, places: str = 'unknown', @@ -112,7 +115,8 @@ def search_nearby(self, radius (:class:`int`): radius in meters. Defaults to ``5000``. Returns: - places (:class:`list`): the list of places, each place is a dict with name and address, at most 5 places. + :class:`dict`: places information + * places (list): the list of places, each place is a dict with name and address, at most 5 places. """ url = self.base_url + 'LocalSearch' if places != 'unknown': diff --git a/lagent/actions/google_scholar_search.py b/lagent/actions/google_scholar_search.py index 9209a0e6..4098f614 100644 --- a/lagent/actions/google_scholar_search.py +++ b/lagent/actions/google_scholar_search.py @@ -35,7 +35,7 @@ def __init__(self, 'as SERPER_API_KEY or pass it as `api_key` parameter.') self.api_key = api_key - @tool_api(return_dict=True) + @tool_api(explode_return=True) def search_google_scholar( self, query: str, @@ -72,10 +72,11 @@ def search_google_scholar( as_vis (Optional[str]): Defines whether to include citations or not. Returns: - title: a list of the titles of the three selected papers - cited_by: a list of the citation numbers of the three selected papers - organic_id: a list of the organic results' ids of the three selected papers - pub_info: publication information of selected papers + :class:`dict`: article information + - title: a list of the titles of the three selected papers + - cited_by: a list of the citation numbers of the three selected papers + - organic_id: a list of the organic results' ids of the three selected papers + - pub_info: publication information of selected papers """ params = { 'q': query, @@ -120,7 +121,7 @@ def search_google_scholar( return ActionReturn( errmsg=str(e), state=ActionStatusCode.HTTP_ERROR) - @tool_api(return_dict=True) + @tool_api(explode_return=True) def get_author_information(self, author_id: str, hl: Optional[str] = None, @@ -147,10 +148,11 @@ def get_author_information(self, output (Optional[str]): Defines the final output you want. Default is 'json'. Returns: - name: author's name - affliation: the affliation of the author - articles: at most 3 articles by the author - website: the author's homepage url + :class:`dict`: author information + * name: author's name + * affliation: the affliation of the author + * articles: at most 3 articles by the author + * website: the author's homepage url """ params = { 'engine': 'google_scholar_author', @@ -183,7 +185,7 @@ def get_author_information(self, return ActionReturn( errmsg=str(e), state=ActionStatusCode.HTTP_ERROR) - @tool_api(return_dict=True) + @tool_api(explode_return=True) def get_citation_format(self, q: str, no_cache: Optional[bool] = None, @@ -198,8 +200,9 @@ def get_citation_format(self, output (Optional[str]): Final output format. Set to 'json' to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'. Returns: - authors: the authors of the article - citation: the citation format of the article + :class:`dict`: citation format + * authors: the authors of the article + * citation: the citation format of the article """ params = { 'q': q, @@ -219,7 +222,7 @@ def get_citation_format(self, return ActionReturn( errmsg=str(e), state=ActionStatusCode.HTTP_ERROR) - @tool_api(return_dict=True) + @tool_api(explode_return=True) def get_author_id(self, mauthors: str, hl: Optional[str] = 'en', @@ -240,7 +243,8 @@ def get_author_id(self, output (Optional[str]): Defines the final output you want. It can be set to 'json' (default) to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'. Returns: - author_id: the author_id of the author + :class:`dict`: author id + * author_id: the author_id of the author """ params = { 'mauthors': mauthors, diff --git a/lagent/actions/ppt.py b/lagent/actions/ppt.py index 9cd021bc..f0e68502 100644 --- a/lagent/actions/ppt.py +++ b/lagent/actions/ppt.py @@ -28,7 +28,7 @@ def __init__(self, self.pointer = None self.location = None - @tool_api(return_dict=True) + @tool_api(explode_return=True) def create_file(self, theme: str, abs_location: str) -> dict: """Create a pptx file with specific themes @@ -37,7 +37,8 @@ def create_file(self, theme: str, abs_location: str) -> dict: abs_location (:class:`str`): the ppt file's absolute location Returns: - status: the result of the execution + :class:`dict`: operation status + * status: the result of the execution """ self.location = abs_location try: @@ -48,7 +49,7 @@ def create_file(self, theme: str, abs_location: str) -> dict: print(e) return dict(status='created a ppt file.') - @tool_api(return_dict=True) + @tool_api(explode_return=True) def add_first_page(self, title: str, subtitle: str) -> dict: """Add the first page of ppt. @@ -57,7 +58,8 @@ def add_first_page(self, title: str, subtitle: str) -> dict: subtitle (:class:`str`): the subtitle of ppt Returns: - status: the result of the execution + :class:`dict`: operation status + * status: the result of the execution """ layout_name = self.theme_mapping[ self.pointer.slide_master.name]['title'] @@ -70,7 +72,7 @@ def add_first_page(self, title: str, subtitle: str) -> dict: ph_subtitle.text = subtitle return dict(status='added page') - @tool_api(return_dict=True) + @tool_api(explode_return=True) def add_text_page(self, title: str, bullet_items: str) -> dict: """Add text page of ppt @@ -79,7 +81,8 @@ def add_text_page(self, title: str, bullet_items: str) -> dict: bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them. Returns: - status: the result of the execution + :class:`dict`: operation status + * status: the result of the execution """ layout_name = self.theme_mapping[ self.pointer.slide_master.name]['single'] @@ -99,7 +102,7 @@ def add_text_page(self, title: str, bullet_items: str) -> dict: p.level = 0 return dict(status='added page') - @tool_api(return_dict=True) + @tool_api(explode_return=True) def add_text_image_page(self, title: str, bullet_items: str, image: str) -> dict: """Add a text page with one image. Image should be a path @@ -110,7 +113,8 @@ def add_text_image_page(self, title: str, bullet_items: str, image (:class:`str`): the path of the image Returns: - status: the result of the execution + :class:`dict`: operation status + * status: the result of the execution """ layout_name = self.theme_mapping[self.pointer.slide_master.name]['two'] layout = next(i for i in self.pointer.slide_master.slide_layouts @@ -138,12 +142,13 @@ def add_text_image_page(self, title: str, bullet_items: str, return dict(status='added page') - @tool_api(return_dict=True) + @tool_api(explode_return=True) def submit_file(self) -> dict: """When all steps done, YOU MUST use submit_file() to submit your work. Returns: - status: the result of the execution + :class:`dict`: operation status + * status: the result of the execution """ # file_path = os.path.join(self.CACHE_DIR, f'{self._return_timestamp()}.pptx') # self.pointer.save(file_path) From 016ee62cc631d3dc0b4d8f8e070fc95680cc76a3 Mon Sep 17 00:00:00 2001 From: BraisedPork <46232992+braisedpork1964@users.noreply.github.com> Date: Thu, 25 Jan 2024 15:54:39 +0800 Subject: [PATCH 12/20] [Docs] initialize the documentation (#99) init the docs Co-authored-by: wangzy --- docs/en/_templates/autoapi/index.rst | 14 +++ docs/en/_templates/autoapi/python/module.rst | 112 ++++++++++++++++++ docs/en/conf.py | 111 +++++++---------- docs/en/get_started/overview.md | 4 +- docs/en/index.rst | 12 ++ docs/zh_cn/_templates/autoapi/index.rst | 14 +++ .../_templates/autoapi/python/module.rst | 112 ++++++++++++++++++ docs/zh_cn/conf.py | 112 +++++++----------- docs/zh_cn/get_started/overview.md | 23 ++++ docs/zh_cn/index.rst | 12 ++ requirements/docs.txt | 12 +- 11 files changed, 397 insertions(+), 141 deletions(-) create mode 100644 docs/en/_templates/autoapi/index.rst create mode 100644 docs/en/_templates/autoapi/python/module.rst create mode 100644 docs/zh_cn/_templates/autoapi/index.rst create mode 100644 docs/zh_cn/_templates/autoapi/python/module.rst create mode 100644 docs/zh_cn/get_started/overview.md diff --git a/docs/en/_templates/autoapi/index.rst b/docs/en/_templates/autoapi/index.rst new file mode 100644 index 00000000..e3cca6a7 --- /dev/null +++ b/docs/en/_templates/autoapi/index.rst @@ -0,0 +1,14 @@ +API Reference +============= + +This page contains auto-generated API reference documentation. + +.. toctree:: + :titlesonly: + :maxdepth: 3 + + {% for page in pages %} + {% if page.top_level_object and page.display %} + {{ page.include_path }} + {% endif %} + {% endfor %} diff --git a/docs/en/_templates/autoapi/python/module.rst b/docs/en/_templates/autoapi/python/module.rst new file mode 100644 index 00000000..7cb039f1 --- /dev/null +++ b/docs/en/_templates/autoapi/python/module.rst @@ -0,0 +1,112 @@ +{% if not obj.display %} +:orphan: + +{% endif %} +:py:mod:`{{ obj.name if obj.name.count(".") <= 1 else obj.short_name }}` +=========={{ "=" * (obj.name|length if obj.name.count(".") <= 1 else obj.short_name|length) }} + +.. py:module:: {{ obj.name }} + +{% if obj.docstring %} +.. autoapi-nested-parse:: + + {{ obj.docstring|indent(3) }} + +{% endif %} + +{% block subpackages %} +{% set visible_subpackages = obj.subpackages|selectattr("display")|list %} +{% if visible_subpackages %} +Subpackages +----------- +.. toctree:: + :titlesonly: + :maxdepth: 3 + +{% for subpackage in visible_subpackages %} + {{ subpackage.short_name }}/index.rst +{% endfor %} + + +{% endif %} +{% endblock %} +{% block submodules %} +{% set visible_submodules = obj.submodules|selectattr("display")|list %} +{% if visible_submodules %} +Submodules +---------- +.. toctree:: + :titlesonly: + :maxdepth: 1 + +{% for submodule in visible_submodules %} + {{ submodule.short_name }}/index.rst +{% endfor %} + + +{% endif %} +{% endblock %} +{% block content %} +{% if obj.type is equalto("package") %} +{% set visible_children = obj.children|selectattr("display")|list %} +{% else %} +{% set visible_children = obj.children|selectattr("display")|rejectattr("imported")|list %} +{% endif %} +{% if visible_children %} +{{ obj.type|title }} Contents +{{ "-" * obj.type|length }}--------- + +{% set visible_classes = visible_children|selectattr("type", "equalto", "class")|list %} +{% set visible_functions = visible_children|selectattr("type", "equalto", "function")|list %} +{% set visible_attributes = visible_children|selectattr("type", "equalto", "data")|list %} +{% if "show-module-summary" in autoapi_options and (visible_classes or visible_functions) %} +{% block classes scoped %} +{% if visible_classes %} +Classes +~~~~~~~ + +.. autoapisummary:: + +{% for klass in visible_classes %} + {{ klass.id }} +{% endfor %} + + +{% endif %} +{% endblock %} + +{% block functions scoped %} +{% if visible_functions %} +Functions +~~~~~~~~~ + +.. autoapisummary:: + +{% for function in visible_functions %} + {{ function.id }} +{% endfor %} + + +{% endif %} +{% endblock %} + +{% block attributes scoped %} +{% if visible_attributes %} +Attributes +~~~~~~~~~~ + +.. autoapisummary:: + +{% for attribute in visible_attributes %} + {{ attribute.id }} +{% endfor %} + + +{% endif %} +{% endblock %} +{% endif %} +{% for obj_item in visible_children %} +{{ obj_item.render()|indent(0) }} +{% endfor %} +{% endif %} +{% endblock %} diff --git a/docs/en/conf.py b/docs/en/conf.py index 40378e9f..e9ccfeb4 100644 --- a/docs/en/conf.py +++ b/docs/en/conf.py @@ -11,17 +11,16 @@ # documentation root, use os.path.abspath to make it absolute, like shown here. import os +import re import sys -import pytorch_sphinx_theme - -sys.path.insert(0, os.path.abspath('../../')) +sys.path.insert(0, os.path.abspath('../..')) # -- Project information ----------------------------------------------------- - project = 'Lagent' copyright = '2020-2030, InternLM' author = 'InternLM' +language = 'en' # The full version, including alpha/beta/rc tags version_file = '../../lagent/version.py' @@ -36,97 +35,75 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ + 'sphinx_rtd_theme', + 'myst_nb', + 'autoapi.extension', + 'sphinx_markdown_tables', 'sphinx.ext.autodoc', 'sphinx.ext.napoleon', 'sphinx.ext.viewcode', - 'sphinx_markdown_tables', - 'sphinx_copybutton', - 'myst_parser', - 'sphinx.ext.intersphinx', - 'sphinx.ext.autodoc.typehints', - 'sphinx.ext.autosummary', - 'sphinx.ext.autosectionlabel', - 'sphinx_tabs.tabs', ] -autodoc_typehints = 'description' -autosummary_generate = True # Turn on sphinx.ext.autosummary -# Ignore >>> when copying code -copybutton_prompt_text = r'>>> |\.\.\. ' -copybutton_prompt_is_regexp = True +nb_output_stderr = 'remove-warn' +autodoc_typehints = 'description' -myst_enable_extensions = ['colon_fence'] +# sphinx-autoapi configuration +autoapi_dirs = ['../../lagent'] +autoapi_options = [ + 'members', + 'undoc-members', + 'show-inheritance', + 'show-module-summary', +] +autoapi_ignore = ['*migrations*', '*command.py', '*cli.py'] +autoapi_template_dir = '_templates/autoapi' +autoapi_add_toctree_entry = False # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] -# The suffix(es) of source filenames. -# You can specify multiple suffix as a list of string: -# -source_suffix = { - '.rst': 'restructuredtext', - '.md': 'markdown', -} - -# The master toctree document. -master_doc = 'index' - # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = [] # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -# html_theme = 'sphinx_rtd_theme' -html_theme = 'pytorch_sphinx_theme' -html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] +html_theme = 'sphinx_rtd_theme' html_theme_options = { - 'menu': [ - { - 'name': 'GitHub', - 'url': 'https://github.com/InternLM/lagent' - }, - ], - # Specify the language of shared menu - 'menu_lang': 'en' + 'navigation_depth': 3, + 'titles_only': False, + 'style_nav_header_background': '#4fabab', } - -language = 'en' +html_context = { + 'display_github': True, + 'github_host': 'github.com', + 'github_user': 'InternLM', + 'github_repo': 'lagent', + 'github_version': 'main', + 'conf_py_path': '/docs/en/', +} +html_title = 'Lagent' +html_logo = '../imgs/lagent_logo.png' +html_favicon = '../imgs/lagent_icon.png' master_doc = 'index' # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". +# so a file named 'default.css' will overwrite the builtin 'default.css'. html_static_path = ['_static'] -html_css_files = [ - 'https://cdn.datatables.net/1.13.2/css/dataTables.bootstrap5.min.css', - 'css/readthedocs.css' -] -html_js_files = [ - 'https://cdn.datatables.net/1.13.2/js/jquery.dataTables.min.js', - 'https://cdn.datatables.net/1.13.2/js/dataTables.bootstrap5.min.js', - 'js/collapsed.js', - 'js/table.js', -] - -myst_heading_anchors = 4 - -intersphinx_mapping = { - 'python': ('https://docs.python.org/3', None), - 'numpy': ('https://numpy.org/doc/stable', None), - 'torch': ('https://pytorch.org/docs/stable/', None), -} - -def builder_inited_handler(app): - pass +def custom_skip(app, what, name, obj, skip, options): + if what in ['data', 'function', 'class'] and re.search( + 'logger|wrapper', name): + skip = True + return skip -def setup(app): - app.connect('builder-inited', builder_inited_handler) +def setup(sphinx): + sphinx.connect('autoapi-skip-member', custom_skip) diff --git a/docs/en/get_started/overview.md b/docs/en/get_started/overview.md index d996039f..c22f63b8 100644 --- a/docs/en/get_started/overview.md +++ b/docs/en/get_started/overview.md @@ -18,6 +18,6 @@ Lagent consists of 3 main parts, agents, llms, and actions. Here is a detailed step-by-step guide to learn more about Lagent: -1. For installation instructions, please see [README](../README.md). +1. For installation instructions, please see [README](https://github.com/InternLM/lagent/blob/main/README.md). -2. We provide several examples to build agents with Lagent in [examples](examples/) by simply run `python examples/react_example.py`. +2. We provide several examples to build agents with Lagent in [examples](https://github.com/InternLM/lagent/tree/main/examples) by simply run `python examples/react_example.py`. diff --git a/docs/en/index.rst b/docs/en/index.rst index 8745bf5f..ab272ae7 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -8,6 +8,7 @@ You can switch between English and Chinese in the lower-left corner of the layou :caption: Get Started get_started/overview.md + get_started/action.md .. toctree:: @@ -15,6 +16,17 @@ You can switch between English and Chinese in the lower-left corner of the layou switch_language.md +.. toctree:: + :maxdepth: 1 + :caption: API Reference + + autoapi/lagent/actions/index + autoapi/lagent/agents/index + autoapi/lagent/llms/index + autoapi/lagent/utils/index + autoapi/lagent/schema/index + autoapi/lagent/version/index + Indices and tables ================== diff --git a/docs/zh_cn/_templates/autoapi/index.rst b/docs/zh_cn/_templates/autoapi/index.rst new file mode 100644 index 00000000..e3cca6a7 --- /dev/null +++ b/docs/zh_cn/_templates/autoapi/index.rst @@ -0,0 +1,14 @@ +API Reference +============= + +This page contains auto-generated API reference documentation. + +.. toctree:: + :titlesonly: + :maxdepth: 3 + + {% for page in pages %} + {% if page.top_level_object and page.display %} + {{ page.include_path }} + {% endif %} + {% endfor %} diff --git a/docs/zh_cn/_templates/autoapi/python/module.rst b/docs/zh_cn/_templates/autoapi/python/module.rst new file mode 100644 index 00000000..7cb039f1 --- /dev/null +++ b/docs/zh_cn/_templates/autoapi/python/module.rst @@ -0,0 +1,112 @@ +{% if not obj.display %} +:orphan: + +{% endif %} +:py:mod:`{{ obj.name if obj.name.count(".") <= 1 else obj.short_name }}` +=========={{ "=" * (obj.name|length if obj.name.count(".") <= 1 else obj.short_name|length) }} + +.. py:module:: {{ obj.name }} + +{% if obj.docstring %} +.. autoapi-nested-parse:: + + {{ obj.docstring|indent(3) }} + +{% endif %} + +{% block subpackages %} +{% set visible_subpackages = obj.subpackages|selectattr("display")|list %} +{% if visible_subpackages %} +Subpackages +----------- +.. toctree:: + :titlesonly: + :maxdepth: 3 + +{% for subpackage in visible_subpackages %} + {{ subpackage.short_name }}/index.rst +{% endfor %} + + +{% endif %} +{% endblock %} +{% block submodules %} +{% set visible_submodules = obj.submodules|selectattr("display")|list %} +{% if visible_submodules %} +Submodules +---------- +.. toctree:: + :titlesonly: + :maxdepth: 1 + +{% for submodule in visible_submodules %} + {{ submodule.short_name }}/index.rst +{% endfor %} + + +{% endif %} +{% endblock %} +{% block content %} +{% if obj.type is equalto("package") %} +{% set visible_children = obj.children|selectattr("display")|list %} +{% else %} +{% set visible_children = obj.children|selectattr("display")|rejectattr("imported")|list %} +{% endif %} +{% if visible_children %} +{{ obj.type|title }} Contents +{{ "-" * obj.type|length }}--------- + +{% set visible_classes = visible_children|selectattr("type", "equalto", "class")|list %} +{% set visible_functions = visible_children|selectattr("type", "equalto", "function")|list %} +{% set visible_attributes = visible_children|selectattr("type", "equalto", "data")|list %} +{% if "show-module-summary" in autoapi_options and (visible_classes or visible_functions) %} +{% block classes scoped %} +{% if visible_classes %} +Classes +~~~~~~~ + +.. autoapisummary:: + +{% for klass in visible_classes %} + {{ klass.id }} +{% endfor %} + + +{% endif %} +{% endblock %} + +{% block functions scoped %} +{% if visible_functions %} +Functions +~~~~~~~~~ + +.. autoapisummary:: + +{% for function in visible_functions %} + {{ function.id }} +{% endfor %} + + +{% endif %} +{% endblock %} + +{% block attributes scoped %} +{% if visible_attributes %} +Attributes +~~~~~~~~~~ + +.. autoapisummary:: + +{% for attribute in visible_attributes %} + {{ attribute.id }} +{% endfor %} + + +{% endif %} +{% endblock %} +{% endif %} +{% for obj_item in visible_children %} +{{ obj_item.render()|indent(0) }} +{% endfor %} +{% endif %} +{% endblock %} diff --git a/docs/zh_cn/conf.py b/docs/zh_cn/conf.py index c443908f..cf5d7238 100644 --- a/docs/zh_cn/conf.py +++ b/docs/zh_cn/conf.py @@ -11,18 +11,16 @@ # documentation root, use os.path.abspath to make it absolute, like shown here. import os -import subprocess +import re import sys -import pytorch_sphinx_theme - -sys.path.insert(0, os.path.abspath('../../')) +sys.path.insert(0, os.path.abspath('../..')) # -- Project information ----------------------------------------------------- - project = 'Lagent' copyright = '2020-2030, InternLM' author = 'InternLM' +language = 'zh_CN' # The full version, including alpha/beta/rc tags version_file = '../../lagent/version.py' @@ -37,97 +35,75 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ + 'sphinx_rtd_theme', + 'myst_nb', + 'autoapi.extension', + 'sphinx_markdown_tables', 'sphinx.ext.autodoc', 'sphinx.ext.napoleon', 'sphinx.ext.viewcode', - 'sphinx_markdown_tables', - 'sphinx_copybutton', - 'myst_parser', - 'sphinx.ext.intersphinx', - 'sphinx.ext.autodoc.typehints', - 'sphinx.ext.autosummary', - 'sphinx.ext.autosectionlabel', - 'sphinx_tabs.tabs', ] -autodoc_typehints = 'description' -autosummary_generate = True # Turn on sphinx.ext.autosummary -# Ignore >>> when copying code -copybutton_prompt_text = r'>>> |\.\.\. ' -copybutton_prompt_is_regexp = True +nb_output_stderr = 'remove-warn' +autodoc_typehints = 'description' -myst_enable_extensions = ['colon_fence'] +# sphinx-autoapi configuration +autoapi_dirs = ['../../lagent'] +autoapi_options = [ + 'members', + 'undoc-members', + 'show-inheritance', + 'show-module-summary', +] +autoapi_ignore = ['*migrations*', '*command.py', '*cli.py'] +autoapi_template_dir = '_templates/autoapi' +autoapi_add_toctree_entry = False # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] -# The suffix(es) of source filenames. -# You can specify multiple suffix as a list of string: -# -source_suffix = { - '.rst': 'restructuredtext', - '.md': 'markdown', -} - -# The master toctree document. -master_doc = 'index' - # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = [] # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -# html_theme = 'sphinx_rtd_theme' -html_theme = 'pytorch_sphinx_theme' -html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] +html_theme = 'sphinx_rtd_theme' html_theme_options = { - 'menu': [ - { - 'name': 'GitHub', - 'url': 'https://github.com/InternLM/lagent' - }, - ], - # Specify the language of shared menu - 'menu_lang': 'cn', + 'navigation_depth': 3, + 'titles_only': False, + 'style_nav_header_background': '#4fabab', } - -language = 'zh_CN' +html_context = { + 'display_github': True, + 'github_host': 'github.com', + 'github_user': 'InternLM', + 'github_repo': 'lagent', + 'github_version': 'main', + 'conf_py_path': '/docs/en/', +} +html_title = 'Lagent' +html_logo = '../imgs/lagent_logo.png' +html_favicon = '../imgs/lagent_icon.png' master_doc = 'index' # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". +# so a file named 'default.css' will overwrite the builtin 'default.css'. html_static_path = ['_static'] -html_css_files = [ - 'https://cdn.datatables.net/1.13.2/css/dataTables.bootstrap5.min.css', - 'css/readthedocs.css' -] -html_js_files = [ - 'https://cdn.datatables.net/1.13.2/js/jquery.dataTables.min.js', - 'https://cdn.datatables.net/1.13.2/js/dataTables.bootstrap5.min.js', - 'js/collapsed.js', - 'js/table.js', -] - -myst_heading_anchors = 4 - -# Configuration for intersphinx -intersphinx_mapping = { - 'python': ('https://docs.python.org/3', None), - 'numpy': ('https://numpy.org/doc/stable', None), - 'torch': ('https://pytorch.org/docs/stable/', None), -} -def builder_inited_handler(app): - subprocess.run(['./cp_origin_docs.sh']) +def custom_skip(app, what, name, obj, skip, options): + if what in ['data', 'function', 'class'] and re.search( + 'logger|wrapper', name): + skip = True + return skip -def setup(app): - app.connect('builder-inited', builder_inited_handler) +def setup(sphinx): + sphinx.connect('autoapi-skip-member', custom_skip) diff --git a/docs/zh_cn/get_started/overview.md b/docs/zh_cn/get_started/overview.md new file mode 100644 index 00000000..9250c99b --- /dev/null +++ b/docs/zh_cn/get_started/overview.md @@ -0,0 +1,23 @@ +# 总览 + +本章节将介绍 Lagent 的架构,并提供 Lagent 详细教程的链接。 + +## Lagent 是什么 + +Lagent 是一个开源的 LLM 智能体框架,允许使用者快速将一个大语言模型转换成智能体,并提供一些典型工具来激发大语言模型的潜能。Lagent 框架图如下: + +![image](https://github.com/InternLM/lagent/assets/24351120/e104171e-4baf-43b3-8e6d-90cff1b298b6) + +Lagent 包含三个主要模块:agents,llms 和 actions。 + +- **agents** 实现了多种智能体,如 ReAct,AutoGPT。 +- **llms** 支持多种大语言模型,包括在 HuggingFace 上托管的开源模型(Llama-2, InterLM)及 GPT3.5/4 等闭源模型。 +- **actions** 包含一系列工具,并提供工具执行器来统一管理。 + +## 如何使用 + +以下是帮助您了解关于 Lagent 更多信息的详细教程: + +1. 安装请参考 [README](https://github.com/InternLM/lagent/blob/main/README.md). + +2. 一些构建智能体的实例 [examples](https://github.com/InternLM/lagent/tree/main/examples),直接运行脚本即可,如 `python examples/react_example.py`. diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst index 180c38bb..ac7141dc 100644 --- a/docs/zh_cn/index.rst +++ b/docs/zh_cn/index.rst @@ -14,6 +14,18 @@ switch_language.md +.. toctree:: + :maxdepth: 1 + :caption: API 参考 + + autoapi/lagent/actions/index + autoapi/lagent/agents/index + autoapi/lagent/llms/index + autoapi/lagent/utils/index + autoapi/lagent/schema/index + autoapi/lagent/version/index + + 导引 ================== diff --git a/requirements/docs.txt b/requirements/docs.txt index 16ddccda..c4e6b5df 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -1,9 +1,13 @@ -docutils==0.16.0 +docutils==0.18.1 markdown>=3.4.0 -myst-parser --e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme -sphinx==4.0.2 +myst-nb +# -e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme +# sphinx==4.0.2 +sphinx==6.1.0 sphinx-tabs sphinx_copybutton sphinx_markdown_tables>=0.0.16 +sphinx-rtd-theme==1.3.0 tabulate +astroid<3.0.0 +sphinx-autoapi From 1292ea7dc7b5bb7bf104bbe26b0e828e59f7d19b Mon Sep 17 00:00:00 2001 From: BraisedPork <46232992+braisedpork1964@users.noreply.github.com> Date: Thu, 25 Jan 2024 20:57:34 +0800 Subject: [PATCH 13/20] Modify the structure of `ActionReturn`'s result (#102) * modify the struction of action results * fix docstrings --------- Co-authored-by: wangzy --- lagent/actions/builtin_actions.py | 2 +- lagent/actions/parser.py | 9 +++++---- lagent/agents/autogpt.py | 16 ++++++---------- lagent/agents/react.py | 17 ++++++++--------- lagent/agents/rewoo.py | 4 ++-- lagent/agents/stream_agent.py | 30 +++--------------------------- lagent/schema.py | 15 +++++++++++++-- 7 files changed, 38 insertions(+), 55 deletions(-) diff --git a/lagent/actions/builtin_actions.py b/lagent/actions/builtin_actions.py index 805f99f1..33702132 100644 --- a/lagent/actions/builtin_actions.py +++ b/lagent/actions/builtin_actions.py @@ -102,7 +102,7 @@ def run(self, response: str) -> ActionReturn: action_return = ActionReturn( url=None, args=dict(text=response), - result=dict(text=response), + result=[dict(type='text', content=response)], type=self.name, valid=ActionValidCode.FINISH, state=ActionStatusCode.SUCCESS) diff --git a/lagent/actions/parser.py b/lagent/actions/parser.py index 9dbf6104..75fe06e5 100644 --- a/lagent/actions/parser.py +++ b/lagent/actions/parser.py @@ -1,7 +1,7 @@ import json import re from ast import literal_eval -from typing import Any, Union +from typing import Any, List, Union class ParseError(Exception): @@ -58,20 +58,21 @@ def parse_inputs(self, inputs: str, name: str = 'run') -> dict: inputs = {self._api2param[name][0]['name']: inputs} return inputs - def parse_outputs(self, outputs: Any) -> dict: + def parse_outputs(self, outputs: Any) -> List[dict]: """parser outputs returned by the action Args: outputs (:class:`Any`): raw output of the action Returns: - :class:`dict`: processed output + :class:`List[dict]`: processed output of which each member is a + dictionary with two keys - 'type' and 'content'. """ if isinstance(outputs, dict): outputs = json.dumps(outputs, ensure_ascii=False) elif not isinstance(outputs, str): outputs = str(outputs) - return {'text': outputs} + return [{'type': 'text', 'content': outputs}] class JsonParser(BaseParser): diff --git a/lagent/agents/autogpt.py b/lagent/agents/autogpt.py index f0eedf06..15601d53 100644 --- a/lagent/agents/autogpt.py +++ b/lagent/agents/autogpt.py @@ -219,21 +219,20 @@ def format(self, goal: str, inner_history: List[Dict], dict(role='user', content=self.triggering_prompt)) return formatted_data - def format_response(self, action_return): + def format_response(self, action_return) -> dict: """format the final response at current step. Args: action_return (ActionReturn): return value of the current action. Returns: - str: the final response at current step. + dict: the final response at current step. """ if action_return.state == ActionStatusCode.SUCCESS: - response = action_return.result['text'] - response = f'Command {action_return.type} returned: {response}' + response = f'Command {action_return.type} returned: {response.format_result()}' else: response = action_return.errmsg - return response + return dict(role='system', content=response) class AutoGPT(BaseAgent): @@ -277,12 +276,9 @@ def chat(self, goal: str, **kwargs) -> AgentReturn: action, action_input) agent_return.actions.append(action_return) if action_return.type == self._action_executor.finish_action.name: - agent_return.response = action_return.result['text'] + agent_return.response = action_return.format_result() return agent_return - inner_history.append( - dict( - role='system', - content=self._protocol.format_response(action_return))) + inner_history.append(self._protocol.format_response(action_return)) agent_return.inner_steps = inner_history agent_return.response = default_response return agent_return diff --git a/lagent/agents/react.py b/lagent/agents/react.py index 3e09d924..2a284784 100644 --- a/lagent/agents/react.py +++ b/lagent/agents/react.py @@ -169,20 +169,22 @@ def parse( action_input = arg_match[-1] return thought, action.strip(), action_input.strip().strip('"') - def format_response(self, action_return: ActionReturn) -> str: + def format_response(self, action_return: ActionReturn) -> dict: """format the final response at current step. Args: action_return (ActionReturn): return value of the current action. Returns: - str: the final response at current step. + dict: the final response at current step. """ if action_return.state == ActionStatusCode.SUCCESS: - response = action_return.result['text'] + response = action_return.format_result() else: response = action_return.errmsg - return self.response['begin'] + response + self.response['end'] + return dict( + role='system', + content=self.response['begin'] + response + self.response['end']) class ReAct(BaseAgent): @@ -237,12 +239,9 @@ def chat(self, message: Union[str, dict, List[dict]], action_return.thought = thought agent_return.actions.append(action_return) if action_return.type == self._action_executor.finish_action.name: - agent_return.response = action_return.result['text'] + agent_return.response = action_return.format_result() break - inner_history.append( - dict( - role='system', - content=self._protocol.format_response(action_return))) + inner_history.append(self._protocol.format_response(action_return)) else: agent_return.response = default_response agent_return.inner_steps = inner_history[offset:] diff --git a/lagent/agents/rewoo.py b/lagent/agents/rewoo.py index a9bd3163..6a6c020a 100644 --- a/lagent/agents/rewoo.py +++ b/lagent/agents/rewoo.py @@ -191,7 +191,7 @@ def format_solver( worker_log = '' for thought, action_return in zip(thought_list, action_return_list): if action_return.state == ActionStatusCode.SUCCESS: - action_resp = action_return.result['text'] + action_resp = action_return.format_result() else: action_resp = action_return.errmsg worker_response = self.worker_prompt.format( @@ -273,7 +273,7 @@ def chat(self, message: Union[str, dict, List[dict]], for prev_ptr in prev_ptrs: ptr_num = int(prev_ptr.strip('#E')) - 1 # start from 0 actions_input[action_id] = actions_input[action_id].replace( - prev_ptr, action_responses[ptr_num].result['text']) + prev_ptr, action_responses[ptr_num].format_result()) action_return: ActionReturn = self._action_executor( actions[action_id], actions_input[action_id]) action_responses.append(action_return) diff --git a/lagent/agents/stream_agent.py b/lagent/agents/stream_agent.py index b2c1258a..a7295d22 100644 --- a/lagent/agents/stream_agent.py +++ b/lagent/agents/stream_agent.py @@ -177,32 +177,9 @@ def parse(self, message, plugin_executor: ActionExecutor, parameters=dict(command=code)) return None, message, None - def format_response(self, action_return, name) -> str: + def format_response(self, action_return, name) -> dict: if action_return.state == ActionStatusCode.SUCCESS: - if isinstance(action_return.result, list): - response = [] - for item in action_return.result: - if item['type'] == 'text': - response.append(item['content']) - else: - response.append(f"[{item['type']}]({item['content']})") - response = '\n'.join(response) - elif isinstance(action_return.result, dict): - response = action_return.result['text'] - if 'image' in action_return.result: - response += '\n'.join([ - f'[image]({im})' - for im in action_return.result['image'] - ]) - if 'audio' in action_return.result: - response += '\n'.join([ - f'[audio]({im})' - for im in action_return.result['audio'] - ]) - elif isinstance(action_return.result, str): - response = action_return.result - else: - raise NotImplementedError + response = action_return.format_result() else: response = action_return.errmsg content = self.execute['begin'] + response + self.execute['end'] @@ -212,8 +189,7 @@ def format_response(self, action_return, name) -> str: elif self.execute.get('belong'): return dict( role=self.execute['belong'], content=content, name=name) - else: - return dict(role=self.execute['role'], content=response, name=name) + return dict(role=self.execute['role'], content=response, name=name) class StreamAgent(BaseAgent): diff --git a/lagent/schema.py b/lagent/schema.py index e6752e16..2fc0689d 100644 --- a/lagent/schema.py +++ b/lagent/schema.py @@ -36,12 +36,23 @@ class ActionReturn: args: Optional[dict] = None url: Optional[str] = None type: Optional[str] = None - result: Optional[str] = None + result: Optional[List[dict]] = None errmsg: Optional[str] = None state: Union[ActionStatusCode, int] = ActionStatusCode.SUCCESS thought: Optional[str] = None valid: Optional[ActionValidCode] = ActionValidCode.OPEN + def format_result(self) -> str: + """Concatenate items in result""" + result = [] + for item in self.result or []: + if item['type'] == 'text': + result.append(item['content']) + else: + result.append(f"[{item['type']}]({item['content']})") + result = '\n'.join(result) + return result + # 需要集成int,如此asdict可以把AgentStatusCode 转换成 int class AgentStatusCode(int, Enum): @@ -67,4 +78,4 @@ class AgentReturn: actions: List[ActionReturn] = field(default_factory=list) response: str = '' inner_steps: List = field(default_factory=list) - errmsg: Optional[str] = None \ No newline at end of file + errmsg: Optional[str] = None From 487919476b481ef9b8f1c4e580c42ce86b834f1b Mon Sep 17 00:00:00 2001 From: BraisedPork <46232992+braisedpork1964@users.noreply.github.com> Date: Fri, 26 Jan 2024 11:12:32 +0800 Subject: [PATCH 14/20] Fix .readthedocs.yml (#104) fix rtd config Co-authored-by: wangzy --- .readthedocs.yml | 10 +++++++--- docs/en/conf.py | 3 +-- docs/zh_cn/conf.py | 3 +-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/.readthedocs.yml b/.readthedocs.yml index f0174e05..49078c0b 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -2,7 +2,11 @@ version: 2 formats: all +build: + os: ubuntu-22.04 + tools: + python: "3.10" + python: - version: 3.7 - install: - - requirements: requirements/docs.txt + install: + - requirements: requirements/docs.txt diff --git a/docs/en/conf.py b/docs/en/conf.py index e9ccfeb4..0d92c9f4 100644 --- a/docs/en/conf.py +++ b/docs/en/conf.py @@ -99,8 +99,7 @@ def custom_skip(app, what, name, obj, skip, options): - if what in ['data', 'function', 'class'] and re.search( - 'logger|wrapper', name): + if what in ['data', 'function', 'class'] and re.search('logger', name): skip = True return skip diff --git a/docs/zh_cn/conf.py b/docs/zh_cn/conf.py index cf5d7238..baaf05ef 100644 --- a/docs/zh_cn/conf.py +++ b/docs/zh_cn/conf.py @@ -99,8 +99,7 @@ def custom_skip(app, what, name, obj, skip, options): - if what in ['data', 'function', 'class'] and re.search( - 'logger|wrapper', name): + if what in ['data', 'function', 'class'] and re.search('logger', name): skip = True return skip From 85853da26300e4ac197f5966f520cb936e5c3edd Mon Sep 17 00:00:00 2001 From: BraisedPork <46232992+braisedpork1964@users.noreply.github.com> Date: Fri, 26 Jan 2024 11:17:55 +0800 Subject: [PATCH 15/20] [Feature] support IPython interpreter action (#103) * add ipython interpreter * update requirements * remove `return_list` argument --------- Co-authored-by: wangzy --- lagent/actions/__init__.py | 6 +- lagent/actions/google_scholar_search.py | 3 +- lagent/actions/google_search.py | 6 +- lagent/actions/ipython_interpreter.py | 296 ++++++++++++++++++++++++ lagent/actions/python_interpreter.py | 16 +- requirements/runtime.txt | 4 + 6 files changed, 315 insertions(+), 16 deletions(-) create mode 100644 lagent/actions/ipython_interpreter.py diff --git a/lagent/actions/__init__.py b/lagent/actions/__init__.py index 4737418f..3d05ba63 100644 --- a/lagent/actions/__init__.py +++ b/lagent/actions/__init__.py @@ -7,6 +7,7 @@ from .builtin_actions import FinishAction, InvalidAction, NoAction from .google_scholar_search import GoogleScholar from .google_search import GoogleSearch +from .ipython_interpreter import IPythonInterpreter from .parser import BaseParser, JsonParser, TupleParser from .ppt import PPT from .python_interpreter import PythonInterpreter @@ -14,8 +15,9 @@ __all__ = [ 'BaseAction', 'ActionExecutor', 'InvalidAction', 'FinishAction', 'NoAction', 'BINGMap', 'ArxivSearch', 'FinishAction', 'GoogleSearch', - 'GoogleScholar', 'PythonInterpreter', 'PPT', 'BaseParser', 'JsonParser', - 'TupleParser', 'tool_api', 'list_tools', 'get_tool_cls', 'get_tool' + 'GoogleScholar', 'IPythonInterpreter', 'PythonInterpreter', 'PPT', + 'BaseParser', 'JsonParser', 'TupleParser', 'tool_api', 'list_tools', + 'get_tool_cls', 'get_tool' ] diff --git a/lagent/actions/google_scholar_search.py b/lagent/actions/google_scholar_search.py index 4098f614..5060bba7 100644 --- a/lagent/actions/google_scholar_search.py +++ b/lagent/actions/google_scholar_search.py @@ -14,8 +14,7 @@ class GoogleScholar(BaseAction): Args: api_key (str): API KEY to use serper google search API, You can create a free API key at https://serper.dev. - description (dict): The description of the action. Defaults to - :py:data:`~DEFAULT_DESCRIPTION`. + description (dict): The description of the action. Defaults to ``None``. parser (Type[BaseParser]): The parser class to process the action's inputs and outputs. Defaults to :class:`JsonParser`. enable (bool, optional): Whether the action is enabled. Defaults to diff --git a/lagent/actions/google_search.py b/lagent/actions/google_search.py index 3f5b116a..25ac7437 100644 --- a/lagent/actions/google_search.py +++ b/lagent/actions/google_search.py @@ -24,12 +24,10 @@ class GoogleSearch(BaseAction): timeout (int): Upper bound of waiting time for a serper request. search_type (str): Serper API support ['search', 'images', 'news', 'places'] types of search, currently we only support 'search'. - description (dict): The description of the action. Defaults to - :py:data:`~DEFAULT_DESCRIPTION`. + description (dict): The description of the action. Defaults to ``None``. parser (Type[BaseParser]): The parser class to process the action's inputs and outputs. Defaults to :class:`JsonParser`. - enable (bool, optional): Whether the action is enabled. Defaults to - True. + enable (bool): Whether the action is enabled. Defaults to ``True``. """ result_key_for_type = { 'news': 'news', diff --git a/lagent/actions/ipython_interpreter.py b/lagent/actions/ipython_interpreter.py new file mode 100644 index 00000000..2d49e641 --- /dev/null +++ b/lagent/actions/ipython_interpreter.py @@ -0,0 +1,296 @@ +import base64 +import io +import logging +import os +import queue +import re +import signal +import sys +import traceback +import uuid +from typing import Optional, Tuple, Type + +import json5 +import PIL.Image +from jupyter_client import KernelManager + +from lagent.actions.base_action import BaseAction, tool_api +from lagent.actions.parser import BaseParser, JsonParser +from lagent.schema import ActionReturn, ActionStatusCode + +START_CODE = """ +def input(*args, **kwargs): + raise NotImplementedError('Python input() function is disabled.') + +get_ipython().system = lambda *args: print('Assume we have this package, ! is disabled!') +{} +""" # noqa + + +class TimeoutError(Exception): + pass + + +class IPythonInterpreter(BaseAction): + """A IPython executor that can execute Python scripts in a jupyter manner. + + Args: + timeout (int): Upper bound of waiting time for Python script execution. + Defaults to 20. + user_data_dir (str, optional): Specified the user data directory for files + loading. If set to `ENV`, use `USER_DATA_DIR` environment variable. + Defaults to `ENV`. + work_dir (str, optional): Specify which directory to save output images to. + Defaults to ``'./work_dir/tmp_dir'``. + description (dict): The description of the action. Defaults to ``None``. + parser (Type[BaseParser]): The parser class to process the + action's inputs and outputs. Defaults to :class:`JsonParser`. + enable (bool, optional): Whether the action is enabled. Defaults to ``True``. + """ + + _KERNEL_CLIENTS = {} + + def __init__(self, + timeout: int = 20, + user_data_dir: str = 'ENV', + work_dir='./work_dir/tmp_dir', + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + enable: bool = True): + super().__init__(description, parser, enable) + + self.timeout = timeout + if user_data_dir == 'ENV': + user_data_dir = os.environ.get('USER_DATA_DIR', '') + + if user_data_dir: + user_data_dir = os.path.dirname(user_data_dir) + user_data_dir = f"import os\nos.chdir('{user_data_dir}')" + self.user_data_dir = user_data_dir + self._initialized = False + self.work_dir = work_dir + if not os.path.exists(self.work_dir): + os.makedirs(self.work_dir, exist_ok=True) + + @staticmethod + def start_kernel(): + # start the kernel and manager + km = KernelManager() + km.start_kernel() + kc = km.client() + return km, kc + + def initialize(self): + if self._initialized: + return + pid = os.getpid() + if pid not in self._KERNEL_CLIENTS: + self._KERNEL_CLIENTS[pid] = self.start_kernel() + self.kernel_manager, self.kernel_client = self._KERNEL_CLIENTS[pid] + self._initialized = True + self._call(START_CODE.format(self.user_data_dir), None) + + def reset(self): + if not self._initialized: + self.initialize() + else: + code = "get_ipython().run_line_magic('reset', '-f')\n" + \ + START_CODE.format(self.user_data_dir) + self._call(code, None) + + def _call(self, + command: str, + timeout: Optional[int] = None) -> Tuple[str, bool]: + self.initialize() + command = extract_code(command) + + # check previous remaining result + while True: + try: + msg = self.kernel_client.get_iopub_msg(timeout=5) + msg_type = msg['msg_type'] + if msg_type == 'status': + if msg['content'].get('execution_state') == 'idle': + break + except queue.Empty: + # assume no result + break + + self.kernel_client.execute(command) + + def _inner_call(): + result = '' + images = [] + succeed = True + image_idx = 0 + + while True: + text = '' + image = '' + finished = False + msg_type = 'error' + try: + msg = self.kernel_client.get_iopub_msg(timeout=20) + msg_type = msg['msg_type'] + if msg_type == 'status': + if msg['content'].get('execution_state') == 'idle': + finished = True + elif msg_type == 'execute_result': + text = msg['content']['data'].get('text/plain', '') + if 'image/png' in msg['content']['data']: + image_b64 = msg['content']['data']['image/png'] + image_url = publish_image_to_local( + image_b64, self.work_dir) + image_idx += 1 + image = '![fig-%03d](%s)' % (image_idx, image_url) + + elif msg_type == 'display_data': + if 'image/png' in msg['content']['data']: + image_b64 = msg['content']['data']['image/png'] + image_url = publish_image_to_local( + image_b64, self.work_dir) + image_idx += 1 + image = '![fig-%03d](%s)' % (image_idx, image_url) + + else: + text = msg['content']['data'].get('text/plain', '') + elif msg_type == 'stream': + msg_type = msg['content']['name'] # stdout, stderr + text = msg['content']['text'] + elif msg_type == 'error': + succeed = False + text = escape_ansi('\n'.join( + msg['content']['traceback'])) + if 'M6_CODE_INTERPRETER_TIMEOUT' in text: + text = f'Timeout. No response after {timeout} seconds.' # noqa + except queue.Empty: + # stop current task in case break next input. + self.kernel_manager.interrupt_kernel() + succeed = False + text = f'Timeout. No response after {timeout} seconds.' + finished = True + except Exception: + succeed = False + msg = ''.join(traceback.format_exception(*sys.exc_info())) + # text = 'The code interpreter encountered an unexpected error.' # noqa + text = msg + logging.warning(msg) + finished = True + if text: + # result += f'\n\n{msg_type}:\n\n```\n{text}\n```' + result += f'{text}' + + if image: + images.append(image_url) + if finished: + return succeed, dict(text=result, image=images) + + try: + if timeout: + + def handler(signum, frame): + raise TimeoutError() + + signal.signal(signal.SIGALRM, handler) + signal.alarm(timeout) + succeed, result = _inner_call() + except TimeoutError: + succeed = False + text = 'The code interpreter encountered an unexpected error.' + result = f'\n\nerror:\n\n```\n{text}\n```' + finally: + if timeout: + signal.alarm(0) + + # result = result.strip('\n') + return succeed, result + + @tool_api + def run(self, command: str, timeout: Optional[int] = None) -> ActionReturn: + """When you send a message containing Python code to python, it will be \ +executed in a stateful Jupyter notebook environment. python will respond with \ +the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' \ +can be used to save and persist user files. Internet access for this session is \ +disabled. Do not make external web requests or API calls as they will fail. + + Args: + command (:class:`str`): Python code + timeout (:class:`Optional[int]`): Upper bound of waiting time for Python script execution. + """ + tool_return = ActionReturn(url=None, args=None, type=self.name) + tool_return.args = dict(text=command) + succeed, result = self._call(command, timeout) + if succeed: + text = result['text'] + image = result.get('image', []) + resp = [dict(type="text", content=text)] + if image: + resp.extend([dict(type="image", content=im) for im in image]) + tool_return.result = resp + # tool_return.result = dict( + # text=result['text'], image=result.get('image', [])[0]) + tool_return.state = ActionStatusCode.SUCCESS + else: + tool_return.errmsg = result.get("text", "") if isinstance( + result, dict) else result + tool_return.state = ActionStatusCode.API_ERROR + return tool_return + + +def extract_code(text): + # Match triple backtick blocks first + triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL) + # Match single backtick blocks second + single_match = re.search(r'`([^`]*)`', text, re.DOTALL) + if triple_match: + text = triple_match.group(1) + elif single_match: + text = single_match.group(1) + else: + try: + text = json5.loads(text)['code'] + except Exception: + pass + # If no code blocks found, return original text + return text + + +def escape_ansi(line): + ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]') + return ansi_escape.sub('', line) + + +def publish_image_to_local(image_base64: str, work_dir='./work_dir/tmp_dir'): + image_file = str(uuid.uuid4()) + '.png' + local_image_file = os.path.join(work_dir, image_file) + + png_bytes = base64.b64decode(image_base64) + assert isinstance(png_bytes, bytes) + bytes_io = io.BytesIO(png_bytes) + PIL.Image.open(bytes_io).save(local_image_file, 'png') + + return local_image_file + + +# local test for code interpreter +def get_multiline_input(hint): + print(hint) + print('// Press ENTER to make a new line. Press CTRL-D to end input.') + lines = [] + while True: + try: + line = input() + except EOFError: # CTRL-D + break + lines.append(line) + print('// Input received.') + if lines: + return '\n'.join(lines) + else: + return '' + + +if __name__ == '__main__': + code_interpreter = IPythonInterpreter() + while True: + print(code_interpreter(get_multiline_input('Enter python code:'))) diff --git a/lagent/actions/python_interpreter.py b/lagent/actions/python_interpreter.py index 88424a45..cc2a8278 100644 --- a/lagent/actions/python_interpreter.py +++ b/lagent/actions/python_interpreter.py @@ -34,18 +34,18 @@ class PythonInterpreter(BaseAction): """A Python executor that can execute Python scripts. Args: - answer_symbol (str, Optional): the answer symbol from LLM + answer_symbol (str, Optional): the answer symbol from LLM. Defaults to ``None``. answer_expr (str, Optional): the answer function name of the Python - script. Default to 'solution()'. - answer_from_stdout (boolean): whether the execution results is from - stdout. - timeout (int): Upper bound of waiting time for Python script execution. - description (dict): The description of the action. Defaults to - :py:data:`~DEFAULT_DESCRIPTION`. + script. Defaults to ``'solution()'``. + answer_from_stdout (boolean, Optional): whether the execution results is from + stdout. Defaults to ``False``. + timeout (int, Optional): Upper bound of waiting time for Python script execution. + Defaults to ``20``. + description (dict, Optional): The description of the action. Defaults to ``None``. parser (Type[BaseParser]): The parser class to process the action's inputs and outputs. Defaults to :class:`JsonParser`. enable (bool, optional): Whether the action is enabled. Defaults to - True. + ``True``. """ def __init__(self, diff --git a/requirements/runtime.txt b/requirements/runtime.txt index e58b6b15..44e5293a 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -5,3 +5,7 @@ requests tiktoken griffe phx-class-registry +jupyter +jupyter_client +json5 +pillow From a2a0d91b29b49fed2ede456f8d900e1a8dd0c887 Mon Sep 17 00:00:00 2001 From: BraisedPork <46232992+braisedpork1964@users.noreply.github.com> Date: Fri, 26 Jan 2024 11:19:56 +0800 Subject: [PATCH 16/20] Fix BINGMap key (#105) fix the fallback value Co-authored-by: wangzy --- lagent/actions/bing_map.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lagent/actions/bing_map.py b/lagent/actions/bing_map.py index 01efe698..9e0e1bc8 100644 --- a/lagent/actions/bing_map.py +++ b/lagent/actions/bing_map.py @@ -17,7 +17,7 @@ def __init__(self, parser: Type[BaseParser] = JsonParser, enable: bool = True) -> None: super().__init__(description, parser, enable) - key = os.environ.get('BING_MAP_KEY') + key = os.environ.get('BING_MAP_KEY', key) if key is None: raise ValueError( 'Please set BING Map API key either in the environment ' From 183ef77a84a9b558014ce48a8f1472ef3b04ad5c Mon Sep 17 00:00:00 2001 From: liukuikun <24622904+Harold-lkk@users.noreply.github.com> Date: Mon, 29 Jan 2024 17:38:52 +0800 Subject: [PATCH 17/20] StreamAgent infer demo (#106) * update cfg & fix bug of StreamAgent * fix bug of func 'stream_chat' * streamlit demo with full response * enchance stream chat * fix bug of stream chat * fix and file rename * add exception catch for func 'chat' --------- Co-authored-by: liujiangning --- examples/internlm2_agent_web_demo.py | 333 ++++++++++++++++++ .../{stream_agent.py => internlm2_agent.py} | 174 +++++++-- lagent/llms/lmdepoly_wrapper.py | 85 +++-- 3 files changed, 510 insertions(+), 82 deletions(-) create mode 100644 examples/internlm2_agent_web_demo.py rename lagent/agents/{stream_agent.py => internlm2_agent.py} (57%) diff --git a/examples/internlm2_agent_web_demo.py b/examples/internlm2_agent_web_demo.py new file mode 100644 index 00000000..6713a2c7 --- /dev/null +++ b/examples/internlm2_agent_web_demo.py @@ -0,0 +1,333 @@ +import copy +import hashlib +import json +import os + +import streamlit as st + +from lagent.actions import ActionExecutor, ArxivSearch, GoogleScholar, IPythonInterpreter +from lagent.agents.internlm2_agent import (INTERPRETER_CN, META_INS, PLUGIN_CN, + Internlm2Agent, Interlm2Protocol) +from lagent.llms.lmdepoly_wrapper import LMDeployClient +from lagent.llms.meta_template import INTERNLM2_META as META +from lagent.schema import AgentStatusCode + +# from streamlit.logger import get_logger + + +class SessionState: + + def init_state(self): + """Initialize session state variables.""" + st.session_state['assistant'] = [] + st.session_state['user'] = [] + + action_list = [ + GoogleScholar( + api_key=('a558de7dee10146326ca86fbaa0736b' + 'dd947c9e646cd3f14da5aff177d6b2ff0')), + ArxivSearch(), + ] + st.session_state['plugin_map'] = { + action.name: action + for action in action_list + } + st.session_state['model_map'] = {} + st.session_state['model_selected'] = None + st.session_state['plugin_actions'] = set() + st.session_state['history'] = [] + + def clear_state(self): + """Clear the existing session state.""" + st.session_state['assistant'] = [] + st.session_state['user'] = [] + st.session_state['model_selected'] = None + st.session_state['file'] = set() + if 'chatbot' in st.session_state: + st.session_state['chatbot']._session_history = [] + + +class StreamlitUI: + + def __init__(self, session_state: SessionState): + self.init_streamlit() + self.session_state = session_state + + def init_streamlit(self): + """Initialize Streamlit's UI settings.""" + st.set_page_config( + layout='wide', + page_title='lagent-web', + page_icon='./docs/imgs/lagent_icon.png') + st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow') + st.sidebar.title('模型控制') + st.session_state['file'] = set() + st.session_state['ip'] = None + + def setup_sidebar(self): + """Setup the sidebar for model and plugin selection.""" + model_name = st.sidebar.selectbox('模型选择:', options=['internlm']) + meta_prompt = st.sidebar.text_area('系统提示词', value=META_INS) + da_prompt = st.sidebar.text_area('数据分析提示词', value=INTERPRETER_CN) + plugin_prompt = st.sidebar.text_area('插件提示词', value=PLUGIN_CN) + model_ip = st.sidebar.text_input('模型IP:', value='10.140.0.220:23333') + if model_name != st.session_state[ + 'model_selected'] or st.session_state['ip'] != model_ip: + st.session_state['ip'] = model_ip + model = self.init_model(model_name, model_ip) + self.session_state.clear_state() + st.session_state['model_selected'] = model_name + if 'chatbot' in st.session_state: + del st.session_state['chatbot'] + else: + model = st.session_state['model_map'][model_name] + + plugin_name = st.sidebar.multiselect( + '插件选择', + options=list(st.session_state['plugin_map'].keys()), + default=[], + ) + da_flag = st.sidebar.checkbox( + '数据分析', + value=False, + ) + plugin_action = [ + st.session_state['plugin_map'][name] for name in plugin_name + ] + + if 'chatbot' in st.session_state: + if len(plugin_action) > 0: + st.session_state['chatbot']._action_executor = ActionExecutor( + actions=plugin_action) + else: + st.session_state['chatbot']._action_executor = None + if da_flag: + st.session_state[ + 'chatbot']._interpreter_executor = ActionExecutor( + actions=[IPythonInterpreter()]) + else: + st.session_state['chatbot']._interpreter_executor = None + st.session_state['chatbot']._protocol._meta_template = meta_prompt + st.session_state['chatbot']._protocol.plugin_prompt = plugin_prompt + st.session_state[ + 'chatbot']._protocol.interpreter_prompt = da_prompt + if st.sidebar.button('清空对话', key='clear'): + self.session_state.clear_state() + uploaded_file = st.sidebar.file_uploader('上传文件') + + return model_name, model, plugin_action, uploaded_file, model_ip + + def init_model(self, option, ip=None): + """Initialize the model based on the selected option.""" + model_url = f'http://{ip}' + st.session_state['model_map'][option] = LMDeployClient( + path='internlm2-chat-20b', + url=model_url, + meta_template=META, + top_p=0.8, + top_k=100, + temperature=0, + repetition_penalty=1.0, + stop_words=['<|im_end|>']) + return st.session_state['model_map'][option] + + def initialize_chatbot(self, model, plugin_action): + """Initialize the chatbot with the given model and plugin actions.""" + return Internlm2Agent( + llm=model, + protocol=Interlm2Protocol( + tool=dict( + begin='{start_token}{name}\n', + start_token='<|action_start|>', + name_map=dict( + plugin='<|plugin|>', interpreter='<|interpreter|>'), + belong='assistant', + end='<|action_end|>\n', + ), ), + ) + + def render_user(self, prompt: str): + with st.chat_message('user'): + st.markdown(prompt) + + def render_assistant(self, agent_return): + with st.chat_message('assistant'): + for action in agent_return.actions: + if (action) and (action.type != 'FinishAction'): + self.render_action(action) + st.markdown(agent_return.response) + + def render_plugin_args(self, action): + action_name = action.type + args = action.args + import json + parameter_dict = dict(name=action_name, parameters=args) + parameter_str = '```json\n' + json.dumps( + parameter_dict, indent=4, ensure_ascii=False) + '\n```' + st.markdown(parameter_str) + + def render_interpreter_args(self, action): + st.info(action.type) + st.markdown(action.args['text']) + + def render_action(self, action): + st.markdown(action.thought) + if action.type == 'IPythonInterpreter': + self.render_interpreter_args(action) + elif action.type == 'FinishAction': + pass + else: + self.render_plugin_args(action) + self.render_action_results(action) + + def render_action_results(self, action): + """Render the results of action, including text, images, videos, and + audios.""" + if (isinstance(action.result, dict)): + if 'text' in action.result: + st.markdown('```\n' + action.result['text'] + '\n```') + if 'image' in action.result: + # image_path = action.result['image'] + for image_path in action.result['image']: + image_data = open(image_path, 'rb').read() + st.image(image_data, caption='Generated Image') + if 'video' in action.result: + video_data = action.result['video'] + video_data = open(video_data, 'rb').read() + st.video(video_data) + if 'audio' in action.result: + audio_data = action.result['audio'] + audio_data = open(audio_data, 'rb').read() + st.audio(audio_data) + elif isinstance(action.result, list): + for item in action.result: + if item['type'] == 'text': + st.markdown('```\n' + item['content'] + '\n```') + elif item['type'] == 'image': + image_data = open(item['content'], 'rb').read() + st.image(image_data, caption='Generated Image') + elif item['type'] == 'video': + video_data = open(item['content'], 'rb').read() + st.video(video_data) + elif item['type'] == 'audio': + audio_data = open(item['content'], 'rb').read() + st.audio(audio_data) + if action.errmsg: + st.error(action.errmsg) + + +def main(): + # logger = get_logger(__name__) + # Initialize Streamlit UI and setup sidebar + if 'ui' not in st.session_state: + session_state = SessionState() + session_state.init_state() + st.session_state['ui'] = StreamlitUI(session_state) + + else: + st.set_page_config( + layout='wide', + page_title='lagent-web', + page_icon='./docs/imgs/lagent_icon.png') + st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow') + _, model, plugin_action, uploaded_file, _ = st.session_state[ + 'ui'].setup_sidebar() + + # Initialize chatbot if it is not already initialized + # or if the model has changed + if 'chatbot' not in st.session_state or model != st.session_state[ + 'chatbot']._llm: + st.session_state['chatbot'] = st.session_state[ + 'ui'].initialize_chatbot(model, plugin_action) + st.session_state['session_history'] = [] + + for prompt, agent_return in zip(st.session_state['user'], + st.session_state['assistant']): + st.session_state['ui'].render_user(prompt) + st.session_state['ui'].render_assistant(agent_return) + + if user_input := st.chat_input(''): + with st.container(): + st.session_state['ui'].render_user(user_input) + st.session_state['user'].append(user_input) + # Add file uploader to sidebar + if (uploaded_file + and uploaded_file.name not in st.session_state['file']): + + st.session_state['file'].add(uploaded_file.name) + file_bytes = uploaded_file.read() + file_type = uploaded_file.type + if 'image' in file_type: + st.image(file_bytes, caption='Uploaded Image') + elif 'video' in file_type: + st.video(file_bytes, caption='Uploaded Video') + elif 'audio' in file_type: + st.audio(file_bytes, caption='Uploaded Audio') + # Save the file to a temporary location and get the path + + postfix = uploaded_file.name.split('.')[-1] + # prefix = str(uuid.uuid4()) + prefix = hashlib.md5(file_bytes).hexdigest() + filename = f'{prefix}.{postfix}' + file_path = os.path.join(root_dir, filename) + with open(file_path, 'wb') as tmpfile: + tmpfile.write(file_bytes) + file_size = os.stat(file_path).st_size / 1024 / 1024 + file_size = f'{round(file_size, 2)} MB' + # st.write(f'File saved at: {file_path}') + user_input = [ + dict(role='user', content=user_input), + dict( + role='user', + content=json.dumps(dict(path=file_path, size=file_size)), + name='file') + ] + if isinstance(user_input, str): + user_input = [dict(role='user', content=user_input)] + st.session_state['last_status'] = AgentStatusCode.SESSION_READY + for agent_return in st.session_state['chatbot'].stream_chat( + st.session_state['session_history'] + user_input): + if agent_return.state == AgentStatusCode.PLUGIN_RETURN: + with st.container(): + st.session_state['ui'].render_plugin_args( + agent_return.actions[-1]) + st.session_state['ui'].render_action_results( + agent_return.actions[-1]) + elif agent_return.state == AgentStatusCode.CODE_RETURN: + with st.container(): + st.session_state['ui'].render_action_results( + agent_return.actions[-1]) + elif (agent_return.state == AgentStatusCode.STREAM_ING + or agent_return.state == AgentStatusCode.CODING): + # st.markdown(agent_return.response) + # 清除占位符的当前内容,并显示新内容 + with st.container(): + if agent_return.state != st.session_state['last_status']: + st.session_state['temp'] = '' + placeholder = st.empty() + st.session_state['placeholder'] = placeholder + if isinstance(agent_return.response, dict): + action = f"\n\n {agent_return.response['name']}: \n\n" + action_input = agent_return.response['parameters'] + if agent_return.response['name'] == 'IPythonInterpreter': + action_input = action_input['command'] + response = action + action_input + else: + response = agent_return.response + st.session_state['temp'] = response + st.session_state['placeholder'].markdown( + st.session_state['temp']) + elif agent_return.state == AgentStatusCode.END: + st.session_state['session_history'] += (user_input + agent_return.inner_steps) + agent_return = copy.deepcopy(agent_return) + agent_return.response = st.session_state['temp'] + st.session_state['assistant'].append( + copy.deepcopy(agent_return)) + st.session_state['last_status'] = agent_return.state + + +if __name__ == '__main__': + root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + root_dir = os.path.join(root_dir, 'tmp_dir') + os.makedirs(root_dir, exist_ok=True) + main() diff --git a/lagent/agents/stream_agent.py b/lagent/agents/internlm2_agent.py similarity index 57% rename from lagent/agents/stream_agent.py rename to lagent/agents/internlm2_agent.py index a7295d22..9d071efc 100644 --- a/lagent/agents/stream_agent.py +++ b/lagent/agents/internlm2_agent.py @@ -1,7 +1,7 @@ import json import logging from copy import deepcopy -from typing import Dict, List, Union +from typing import Dict, List, Union, Optional from ilagent.schema import AgentReturn, AgentStatusCode @@ -14,40 +14,47 @@ "This is the subfunction for tool '{tool_name}', you can use this tool. " 'The description of this function is: \n{description}') -INTERPRETER_CN = ('你现在可以使用一个支持 Python 代码执行的 Jupyter 笔记本环境。只需向 python 发' - '送代码,即可在这个有状态环境中进行运行。这个功能适用于数据分析或处理(如数据操作和' - '图形制作),复杂计算(如数学和物理问题),编程示例(用于理解编程概念或语言特性),文' - '本处理和分析(包括文本分析和自然语言处理),机器学习和数据科学(模型训练和数据可视化' - '展示),以及文件操作和数据导入(处理CSV、JSON等格式文件)。') +META_INS = ('You are InternLM, a large language model trained by PJLab. ' + 'Answer as concisely as possible. ' + '当开启工具以及代码时,根据需求选择合适的工具进行调用') -PLUGIN_CN = ('你可以使用如下工具:' - '\n{prompt}\n' - '如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! ' - '同时注意你可以使用的工具,不要随意捏造!') +INTERPRETER_CN = ('你现在可以通过如下格式向 Jupyter Notebook 发送并执行代码:' + '\n<|action_start|><|interpreter|>```python\n\n代码\n\n```\n' + '\n当遇到以下问题时,请使用上述格式调用 Jupyter Notebook 去解决,并根据执行结果做出友好的回复:\n' + '1. 文件操作和数据导入,比如处理CSV、JSON等格式文件\n' + '2. 数据分析或处理,比如数据操作或图像绘制如折线图、柱状图等\n' + '3. 数学相关的问题。当遇到数学问题时,你需要分析题目,并给出代码去解决这个题目') +PLUGIN_CN = ( + '你可以使用如下工具:' + '\n{prompt}\n' + '当你需要使用工具时,你可以使用如下格式:\n' + '<|action_start|><|plugin|>{{"name": "工具名称", "parameters": {{参数}}}}\n' + '如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! ' + '同时注意你可以使用的工具,不要随意捏造!') -class StreamProtocol: + +class Interlm2Protocol: def __init__( self, - meta_prompt=None, - interpreter_prompt=INTERPRETER_CN, - plugin_prompt=PLUGIN_CN, - few_shot=None, - language=dict( + meta_prompt: str=META_INS, + interpreter_prompt: str=INTERPRETER_CN, + plugin_prompt: str=PLUGIN_CN, + few_shot: Optional[List]=None, + language: Dict=dict( begin='', end='', belong='assistant', ), - tool=dict( + tool: Dict=dict( begin='{start_token}{name}\n', - start_token='[UNUSED_TOKEN_144]', - name_map=dict( - plugin='[UNUSED_TOKEN_141]', interpreter='[UNUSED_TOKEN_142]'), + start_token='<|action_start|>', + name_map=dict(plugin='<|plugin|>', interpreter='<|interpreter|>'), belong='assistant', - end='[UNUSED_TOKEN_143]\n', + end='<|action_end|>\n', ), - execute: dict = dict( + execute: Dict = dict( role='execute', begin='', end='', fallback_role='environment'), ) -> None: self.meta_prompt = meta_prompt @@ -124,10 +131,9 @@ def format(self, if self.meta_prompt: formatted.append(dict(role='system', content=self.meta_prompt)) if interpreter_executor and self.interpreter_prompt: - interpreter_info = list( - interpreter_executor.get_actions_info().items())[0] + interpreter_info = interpreter_executor.get_actions_info()[0] interpreter_prompt = self.interpreter_prompt.format( - code_prompt=interpreter_info[1]) + code_prompt=interpreter_info['description']) formatted.append( dict( role='system', @@ -135,16 +141,12 @@ def format(self, name='interpreter')) if plugin_executor and plugin_executor.actions and self.plugin_prompt: plugin_descriptions = [] - for api_name, api_info in plugin_executor.get_actions_info().items( - ): + for api_info in plugin_executor.get_actions_info(): + plugin = deepcopy(api_info) if isinstance(api_info, dict): - plugin = deepcopy(api_info) - tool_name = api_name.split('.')[0] - plugin['name'] = api_name + tool_name = api_info['name'].split('.')[0] plugin['description'] = API_PREFIX.format( tool_name=tool_name, description=plugin['description']) - else: - plugin = dict(name=api_name, description=api_info) plugin_descriptions.append(plugin) plugin_prompt = self.plugin_prompt.format( prompt=json.dumps( @@ -165,7 +167,6 @@ def parse(self, message, plugin_executor: ActionExecutor, message, action = message.split( f"{self.tool['start_token']}{self.tool['name_map']['plugin']}") action = action.split(self.tool['end'].strip())[0] - action = json.loads(action) return 'plugin', message, action if self.tool['name_map']['interpreter'] in message: message, code = message.split( @@ -192,13 +193,13 @@ def format_response(self, action_return, name) -> dict: return dict(role=self.execute['role'], content=response, name=name) -class StreamAgent(BaseAgent): +class Internlm2Agent(BaseAgent): def __init__(self, llm: Union[BaseModel, BaseAPIModel], plugin_executor: ActionExecutor = None, interpreter_executor: ActionExecutor = None, - protocol=StreamProtocol(), + protocol=Interlm2Protocol(), max_turn: int = 3) -> None: self.max_turn = max_turn self._interpreter_executor = interpreter_executor @@ -233,6 +234,12 @@ def chat(self, message: Union[str, Dict], **kwargs) -> AgentReturn: else: logging.info(msg='No plugin is instantiated!') continue + try: + action = json.loads(action) + except Exception as e: + logging.info( + msg=f'Invaild action {e}') + continue elif name == 'interpreter': if self._interpreter_executor: executor = self._interpreter_executor @@ -258,6 +265,99 @@ def chat(self, message: Union[str, Dict], **kwargs) -> AgentReturn: dict(role='tool', content=action, name=name)) inner_history.append( self._protocol.format_response(action_return, name=name)) - + yield agent_return agent_return.inner_steps = inner_history[offset:] - return agent_return + yield agent_return + + def stream_chat(self, message: List[dict], **kwargs) -> AgentReturn: + if isinstance(message, str): + message = dict(role='user', content=message) + if isinstance(message, dict): + message = [message] + inner_history = message[:] + offset = len(inner_history) + agent_return = AgentReturn() + last_agent_state = AgentStatusCode.SESSION_READY + for _ in range(self.max_turn): + # list of dict + prompt = self._protocol.format( + inner_step=inner_history, + plugin_executor=self._action_executor, + interpreter_executor=self._interpreter_executor, + ) + response = '' + for model_state, res, _ in self._llm.stream_chat( + prompt, **kwargs): + response = res + if model_state.value < 0: + agent_return.state = model_state + yield deepcopy(agent_return) + return + else: + name, language, action = self._protocol.parse( + message=response, + plugin_executor=self._action_executor, + interpreter_executor=self._interpreter_executor, + ) + if name: + if model_state == AgentStatusCode.END: + agent_state = last_agent_state + 1 + if name == 'plugin': + if self._action_executor: + executor = self._action_executor + else: + logging.info( + msg='No plugin is instantiated!') + continue + try: + action = json.loads(action) + except Exception as e: + logging.info( + msg=f'Invaild action {e}') + continue + elif name == 'interpreter': + if self._interpreter_executor: + executor = self._interpreter_executor + else: + logging.info( + msg='No interpreter is instantiated!') + continue + agent_return.state = agent_state + agent_return.response = action + else: + agent_state = ( + AgentStatusCode.PLUGIN_START if name + == 'plugin' else AgentStatusCode.CODING) + if agent_state != last_agent_state: + # agent_return.state = agent_state + agent_return.response = language + yield deepcopy(agent_return) + agent_return.state = agent_state + agent_return.response = action + else: + agent_state = AgentStatusCode.STREAM_ING + agent_return.state = agent_state + agent_return.response = language + last_agent_state = agent_state + yield deepcopy(agent_return) + if name: + action_return: ActionReturn = executor(action['name'], + action['parameters']) + action_return.thought = language + agent_return.actions.append(action_return) + inner_history.append(dict(role='language', content=language)) + if not name or action_return.type == executor.finish_action.name: + agent_return.response = language + agent_return.state = AgentStatusCode.END + break + else: + inner_history.append( + dict(role='tool', content=action, name=name)) + inner_history.append( + self._protocol.format_response(action_return, name=name)) + agent_state += 1 + agent_return.state = agent_state + yield agent_return + agent_return.inner_steps = deepcopy(inner_history[offset:]) + agent_return.state = AgentStatusCode.END + yield agent_return diff --git a/lagent/llms/lmdepoly_wrapper.py b/lagent/llms/lmdepoly_wrapper.py index dc93e907..ed4d0d5e 100644 --- a/lagent/llms/lmdepoly_wrapper.py +++ b/lagent/llms/lmdepoly_wrapper.py @@ -1,9 +1,5 @@ -import json from typing import List, Optional, Union - -import requests - from lagent.llms.base_llm import BaseModel from lagent.schema import AgentStatusCode from lagent.utils.util import filter_suffix @@ -102,9 +98,9 @@ def generate(self, if status.value < 0: break if status.value == 0: - self.chatbot._session.histories = \ - self.chatbot._session.histories + self.chatbot._session.prompt + \ - self.chatbot._session.response + self.chatbot._session.histories = ( + self.chatbot._session.histories + + self.chatbot._session.prompt + self.chatbot._session.response) # remove stop_words res = filter_suffix(res, self.gen_params.get('stop_words')) return res @@ -170,9 +166,9 @@ def stream_chat(self, else: yield self.state_map.get(status), res, _ if status.value == 0: - self.chatbot._session.histories = \ - self.chatbot._session.histories + self.chatbot._session.prompt + \ - self.chatbot._session.response + self.chatbot._session.histories = ( + self.chatbot._session.histories + + self.chatbot._session.prompt + self.chatbot._session.response) yield self.state_map.get(status), res, _ else: return '' @@ -181,7 +177,8 @@ def _update_gen_params(self, **kwargs): import mmengine new_gen_params = self.update_gen_params(**kwargs) self.gen_params['stop_words'] = new_gen_params.pop('stop_words') - stop_words = self.chatbot._stop_words(self.gen_params.get('stop_words')) + stop_words = self.chatbot._stop_words( + self.gen_params.get('stop_words')) cfg = mmengine.Config( dict( session_len=self.chatbot.model.session_len, @@ -198,8 +195,8 @@ class LMDeployPipeline(BaseModel): path (str): The path to the model. It could be one of the following options: - i) A local directory path of a turbomind model which is - converted by `lmdeploy convert` command or download from - ii) and iii). + converted by `lmdeploy convert` command or download + from ii) and iii). - ii) The model_id of a lmdeploy-quantized model hosted inside a model repo on huggingface.co, such as "InternLM/internlm-chat-20b-4bit", @@ -270,20 +267,19 @@ class LMDeployServer(BaseModel): server_name (str): host ip for serving server_port (int): server port tp (int): - log_level (str): set log level whose value among [CRITICAL, ERROR, WARNING, INFO, DEBUG] + log_level (str): set log level whose value among + [CRITICAL, ERROR, WARNING, INFO, DEBUG] """ - def __init__( - self, - path: str, - model_name: Optional[str] = None, - server_name: str = '0.0.0.0', - server_port: int = 23333, - tp: int = 1, - log_level: str = 'WARNING', - serve_cfg=dict(), - **kwargs - ): + def __init__(self, + path: str, + model_name: Optional[str] = None, + server_name: str = '0.0.0.0', + server_port: int = 23333, + tp: int = 1, + log_level: str = 'WARNING', + serve_cfg=dict(), + **kwargs): super().__init__(path=path, **kwargs) # TODO get_logger issue in multi processing import lmdeploy @@ -296,15 +292,14 @@ def __init__( log_level=log_level, **serve_cfg) - def generate( - self, - inputs: Union[str, List[str]], - session_id: int = 2967, - sequence_start: bool = True, - sequence_end: bool = True, - ignore_eos: bool = False, - timeout: int = 30, - **kwargs) -> List[str]: + def generate(self, + inputs: Union[str, List[str]], + session_id: int = 2967, + sequence_start: bool = True, + sequence_end: bool = True, + ignore_eos: bool = False, + timeout: int = 30, + **kwargs) -> List[str]: batched = True if isinstance(inputs, str): inputs = [inputs] @@ -314,16 +309,15 @@ def generate( resp = [''] * len(inputs) for text in self.client.completions_v1( - self.path, - inputs, - session_id=session_id, - sequence_start=sequence_start, - sequence_end=sequence_end, - stream=False, - ignore_eos=ignore_eos, - timeout=timeout, - **gen_params - ): + self.path, + inputs, + session_id=session_id, + sequence_start=sequence_start, + sequence_end=sequence_end, + stream=False, + ignore_eos=ignore_eos, + timeout=timeout, + **gen_params): resp = [ resp[i] + item['text'] for i, item in enumerate(text['choices']) @@ -345,13 +339,14 @@ def stream_chat(self, **kwargs): gen_params = self.update_gen_params(**kwargs) + prompt = self.template_parser(inputs) resp = '' finished = False stop_words = self.gen_params.get('stop_words') for text in self.client.completions_v1( self.path, - inputs, + prompt, session_id=session_id, sequence_start=sequence_start, sequence_end=sequence_end, From 39e00f25dff38b0855889914c310d8bb9ccc085c Mon Sep 17 00:00:00 2001 From: BraisedPork <46232992+braisedpork1964@users.noreply.github.com> Date: Mon, 29 Jan 2024 17:43:53 +0800 Subject: [PATCH 18/20] [Docs] Add action tutorials (#107) * add `get_started` chapter * fix docstrings * add action.md * add zh docs --------- Co-authored-by: wangzy --- docs/en/get_started/install.md | 19 ++ docs/en/get_started/overview.md | 2 +- docs/en/get_started/quickstart.md | 89 ++++++ docs/en/index.rst | 8 +- docs/en/tutorials/action.md | 396 +++++++++++++++++++++++++++ docs/zh_cn/get_started/install.md | 19 ++ docs/zh_cn/get_started/quickstart.md | 87 ++++++ docs/zh_cn/index.rst | 8 + docs/zh_cn/tutorials/action.md | 394 ++++++++++++++++++++++++++ lagent/actions/base_action.py | 6 +- 10 files changed, 1023 insertions(+), 5 deletions(-) create mode 100644 docs/en/get_started/install.md create mode 100644 docs/en/get_started/quickstart.md create mode 100644 docs/en/tutorials/action.md create mode 100644 docs/zh_cn/get_started/install.md create mode 100644 docs/zh_cn/get_started/quickstart.md create mode 100644 docs/zh_cn/tutorials/action.md diff --git a/docs/en/get_started/install.md b/docs/en/get_started/install.md new file mode 100644 index 00000000..844bd19e --- /dev/null +++ b/docs/en/get_started/install.md @@ -0,0 +1,19 @@ +# Installation + +## With pip + +Install with pip (Recommended). + +```bash +pip install lagent +``` + +## From source + +Optionally, you could also build Lagent from source in case you want to modify the code: + +```bash +git clone https://github.com/InternLM/lagent.git +cd lagent +pip install -e . +``` diff --git a/docs/en/get_started/overview.md b/docs/en/get_started/overview.md index c22f63b8..370d32c6 100644 --- a/docs/en/get_started/overview.md +++ b/docs/en/get_started/overview.md @@ -1,4 +1,4 @@ -# OVERVIEW +# Overview This chapter introduces you to the framework of Lagent, and provides links to detailed tutorials about Lagent. diff --git a/docs/en/get_started/quickstart.md b/docs/en/get_started/quickstart.md new file mode 100644 index 00000000..e80ae492 --- /dev/null +++ b/docs/en/get_started/quickstart.md @@ -0,0 +1,89 @@ +# Quickstart + +Using Lagent, you can easily build agents with just a few lines of code. + +## Run a ReWOO agent with GPT-3.5 + +Below is an example of running ReWOO with GPT-3.5 + +```python +# Import necessary modules and classes from the "lagent" library. +from lagent.agents import ReWOO +from lagent.actions import ActionExecutor, GoogleSearch +from lagent.llms import GPTAPI + +# Initialize the Language Model (llm) and provide your API key. +llm = GPTAPI(model_type='gpt-3.5-turbo', key=['Your OPENAI_API_KEY']) + +# Initialize the Google Search tool and provide your API key. +search_tool = GoogleSearch(api_key='Your SERPER_API_KEY') + +# Create a chatbot by configuring the ReWOO agent. +chatbot = ReWOO( + llm=llm, # Provide the Language Model instance. + action_executor=ActionExecutor( + actions=[search_tool] # Specify the actions the chatbot can perform. + ), +) + +# Ask the chatbot a question and store the response. +response = chatbot.chat('What profession does Nicholas Ray and Elia Kazan have in common') + +# Print the chatbot's response. +print(response.response) # Output the response generated by the chatbot. +``` + +```python +>>> Film director. +``` + +## Run a ReAct agent with InternLM + +NOTE: If you want to run a HuggingFace model, please run `pip install -e .[all]` first. + +```python +# Import necessary modules and classes from the "lagent" library. +from lagent.agents import ReAct +from lagent.actions import ActionExecutor, GoogleSearch, PythonInterpreter +from lagent.llms import HFTransformer + +from lagent.llms.meta_template import INTERNLM2_META as META + +# Initialize the HFTransformer-based Language Model (llm) and +# provide the model name. +llm = HFTransformer(path='internlm/internlm2-chat-7b', meta_template=META) + +# Initialize the Google Search tool and provide your API key. +search_tool = GoogleSearch(api_key='Your SERPER_API_KEY') + +# Initialize the Python Interpreter tool. +python_interpreter = PythonInterpreter() + +# Create a chatbot by configuring the ReAct agent. +# Specify the actions the chatbot can perform. +chatbot = ReAct( + llm=llm, # Provide the Language Model instance. + action_executor=ActionExecutor( + actions=[search_tool, python_interpreter]), +) +# Ask the chatbot a mathematical question in LaTeX format. +response = chatbot.chat('若$z=-1+\sqrt{3}i$,则$\frac{z}{{z\overline{z}-1}}=\left(\ \ \right)$') + +# Print the chatbot's response. +print(response.response) # Output the response generated by the chatbot. +``` + +```python +>>> $-\\frac{1}{3}+\\frac{{\\sqrt{3}}}{3}i$ +``` + +## Run ReAct Web Demo + +```python +# You need to install streamlit first +# pip install streamlit +streamlit run examples/react_web_demo.py +``` + +Then you can chat through the UI shown as below +![image](https://github.com/InternLM/lagent/assets/24622904/3aebb8b4-07d1-42a2-9da3-46080c556f68) diff --git a/docs/en/index.rst b/docs/en/index.rst index ab272ae7..f74c594d 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -8,8 +8,14 @@ You can switch between English and Chinese in the lower-left corner of the layou :caption: Get Started get_started/overview.md - get_started/action.md + get_started/install.md + get_started/quickstart.md +.. toctree:: + :maxdepth: 2 + :caption: Tutorials + + tutorials/action.md .. toctree:: :caption: Switch Language diff --git a/docs/en/tutorials/action.md b/docs/en/tutorials/action.md new file mode 100644 index 00000000..36c07358 --- /dev/null +++ b/docs/en/tutorials/action.md @@ -0,0 +1,396 @@ +# Action + +Actions, also called **tools**, provide a suite of functions LLM-driven agents can use to interact with the real world and perform complex tasks. + +## Basic Concepts + +### Tool & Toolkit + +There are two categories of tools: + +* tool: provide only one API to call. +* toolkit: implement multiple APIs that undertake different sub-tasks. + +### Tool Description + +In Lagent, the tool description is a dictionary containing the action's core information of usage, observed by LLMs for decision-making. + +For simple tools, the description can be created as follows + +```python +TOOL_DESCRIPTION = { + 'name': 'bold', # name of the tool + 'description': 'a function used to make text bold', # introduce the tool's function + 'parameters': [ # a list of parameters the tool take. + { + 'name': 'text', 'type': 'STRING', 'description': 'input content' + } + ], + 'required': ['text'], # specify names of parameters required +} +``` + +In some situations there may be optional `return_data`, `parameter_description` keys describing the returns and argument passing format respectively. + +```{attention} +`parameter_description` is usually inserted into the tool description automatically by the action's parser. It will be introduced in [Interface Design](#interface-design) . +``` + + +For toolkits, the description is very similar but nest submethods + +```python +TOOL_DESCRIPTION = { + 'name': 'PhraseEmphasis', # name of the toolkit + 'description': 'a toolkit which provides different styles of text emphasis', # introduce the tool's function + 'api_list': [ + { + 'name': 'bold', + 'description': 'make text bold', + 'parameters': [ + { + 'name': 'text', 'type': 'STRING', 'description': 'input content' + } + ], + 'required': ['text'] + }, + { + 'name': 'italic', + 'description': 'make text italic', + 'parameters': [ + { + 'name': 'text', 'type': 'STRING', 'description': 'input content' + } + ], + 'required': ['text'] + } + ] +} +``` + +## Make Functions Tools + +It's not necessary to prepare an extra description for a defined function. In Lagent we provide a decorator `tool_api` which can conveniently turn a function into a tool by automatically parsing the function's typehints and dosctrings to generate the description dictionary and binding it to an attribute `api_description`. + +```python +from lagent import tool_api + +@tool_api +def bold(text: str) -> str: + """make text bold + + Args: + text (str): input text + + Returns: + str: bold text + """ + return '**' + text + '**' + + +bold.api_description +``` + +```python +{'name': 'bold', + 'description': 'make text bold', + 'parameters': [{'name': 'text', + 'type': 'STRING', + 'description': 'input text'}], + 'required': ['text']} +``` + +Once `returns_named_value` is enabled you should declare the name of the return data, which will be processed to form a new field `return_data`: + +```python +@tool_api(returns_named_value=True) +def bold(text: str) -> str: + """make text bold + + Args: + text (str): input text + + Returns: + bold_text (str): bold text + """ + return '**' + text + '**' + +bold.api_description +``` + +```python +{'name': 'bold', + 'description': 'make text bold', + 'parameters': [{'name': 'text', + 'type': 'STRING', + 'description': 'input text'}], + 'required': ['text'], + 'return_data': [{'name': 'bold_text', + 'description': 'bold text', + 'type': 'STRING'}]} +``` + +Sometimes the tool may return a `dict` or `tuple`, and you want to elaborate each member in `return_data` rather than take them as a whole. Set `explode_return=True` and list them in the return part of docstrings. + +```python +@tool_api(explode_return=True) +def list_args(a: str, b: int, c: float = 0.0) -> dict: + """Return arguments in dict format + + Args: + a (str): a + b (int): b + c (float): c + + Returns: + dict: input arguments + - a (str): a + - b (int): b + - c: c + """ + return {'a': a, 'b': b, 'c': c} +``` + +```python +{'name': 'list_args', + 'description': 'Return arguments in dict format', + 'parameters': [{'name': 'a', 'type': 'STRING', 'description': 'a'}, + {'name': 'b', 'type': 'NUMBER', 'description': 'b'}, + {'name': 'c', 'type': 'FLOAT', 'description': 'c'}], + 'required': ['a', 'b'], + 'return_data': [{'name': 'a', 'description': 'a', 'type': 'STRING'}, + {'name': 'b', 'description': 'b', 'type': 'NUMBER'}, + {'name': 'c', 'description': 'c'}]} +``` + +```{warning} +Only Google style Python docstrings is currently supported. +``` + +## Interface Design + +`BaseAction(description=None, parser=JsonParser, enable=True)` is the base class all actions should inherit from. It takes three initialization arguments + +* **description**: a tool description dictionary, used set instance attribute `description`. Mostly you don't need explicitly pass this argument since the meta class of `BaseAction` will search methods decorated by `tool_api` and assemble their `api_description` as a class attribute `__tool_description__`, and if the initial `description` is left null, then `__tool_description__` will be copied as `description`. +* **parser**: `BaseParser` class. It will instantialize a parser used to validate the arguments of APIs in `description`. + + For example, `JsonParser` requires arguments passed in the format of JSON or `dict`. To make LLMs aware of this, It inserts a field `parameter_description` into the `description`. + + ```python + from lagent import BaseAction + + action = BaseAction( + { + 'name': 'bold', + 'description': 'a function used to make text bold', + 'parameters': [ + { + 'name': 'text', 'type': 'STRING', 'description': 'input content' + } + ], + 'required': ['text'] + } + ) + action.description + ``` + + ```python + {'name': 'bold', + 'description': 'a function used to make text bold', + 'parameters': [{'name': 'text', + 'type': 'STRING', + 'description': 'input content'}], + 'required': ['text'], + 'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'} + ``` +* **enable**: specify whether the tool is available. + +### Custom Action + +A simple tool must have its `run` method implemented, while APIs of toolkits should avoid naming conflicts with this reserved word. + +```python +class Bold(BaseAction): + + @tool_api + def run(self, text: str): + """make text bold + + Args: + text (str): input text + + Returns: + str: bold text + """ + return '**' + text + '**' + +class PhraseEmphasis(BaseAction): + """a toolkit which provides different styles of text emphasis""" + + @tool_api + def bold(self, text): + """make text bold + + Args: + text (str): input text + + Returns: + str: bold text + """ + return '**' + text + '**' + + @tool_api + def italic(self, text): + """make text italic + + Args: + text (str): input text + + Returns: + str: italic text + """ + return '*' + text + '*' + +# Inspect the default description +# Bold.__tool_description__, PhraseEmphasis.__tool_description__ +``` + +### Auto-registration + +Any subclass of `BaseAction` will be registered automatically. You can use `list_tools()` and `get_tool()` to view all tools and initialize by name. + +```python +from lagent import list_tools, get_tool + +list_tools() +``` + +```python +['BaseAction', + 'InvalidAction', + 'NoAction', + 'FinishAction', + 'ArxivSearch', + 'BINGMap', + 'GoogleScholar', + 'GoogleSearch', + 'IPythonInterpreter', + 'PPT', + 'PythonInterpreter', + 'Bold', + 'PhraseEmphasis'] +``` +Create a `PhraseEmphasis` object + +```python +action = get_tool('PhraseEmphasis') +action.description +``` + +```python +{'name': 'PhraseEmphasis', + 'description': 'a toolkit which provides different styles of text emphasis', + 'api_list': [{'name': 'bold', + 'description': 'make text bold', + 'parameters': [{'name': 'text', + 'type': 'STRING', + 'description': 'input text'}], + 'required': ['text'], + 'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}, + {'name': 'italic', + 'description': 'make text italic', + 'parameters': [{'name': 'text', + 'type': 'STRING', + 'description': 'input text'}], + 'required': ['text'], + 'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}]} +``` + + +## Tool Calling + +### Run a Tool + +`__call__` method of `Action` takes two arguments + +* `inputs`: It depends on the action's parser. Often a string in specific formats generated by LLMs. + + `JsonParser`: Allow passing arguements in the format of JSON string or Python `dict`. + + `TupleParser`: Allow passing arguments in the format of tuple string format or Python `tuple`. +* `name`: Which API to call. Default is `run`. + +It returns an `ActionReturn` object which encapsulates calling details + +* `args`: Dictionary of action inputs. +* `type`: Action name. +* `result`: List of dicts. Each contains two keys: 'type' and 'content'. when errors occur, it is `None`. +* `errmsg`: Error message. Default is `None`. + +Below is an example + +```python +from lagent import IPythonInterpreter, TupleParser + +action1 = IPythonInterpreter() +ret = action1('{"command": "import math;math.sqrt(100)"}') +print(ret.result) +ret = action1({'command': 'import math;math.sqrt(100)'}) +print(ret.result) + +action2 = IPythonInterpreter(parser=TupleParser) +ret = action2('("import math;math.sqrt(100)", )') +print(ret.result) +ret = action2(('import math;math.sqrt(100)',)) +print(ret.result) +``` + +```python +[{'type': 'text', 'content': '10.0'}] +[{'type': 'text', 'content': '10.0'}] +[{'type': 'text', 'content': '10.0'}] +[{'type': 'text', 'content': '10.0'}] +``` + +### Dynamic Invocation + +Lagent provides an `ActionExecutor` to manage multiple tools. It will flatten `api_list` of toolkits and rename each `{tool_name}.{api_name}`. + +```python +from lagent import ActionExecutor, ArxivSearch, IPythonInterpreter + +executor = ActionExecutor(actions=[ArxivSearch(), IPythonInterpreter()]) +executor.get_actions_info() # This information is fed to LLMs as the tool meta prompt +``` + +```python +[{'name': 'ArxivSearch.get_arxiv_article_information', + 'description': 'Run Arxiv search and get the article meta information.', + 'parameters': [{'name': 'query', + 'type': 'STRING', + 'description': 'the content of search query'}], + 'required': ['query'], + 'return_data': [{'name': 'content', + 'description': 'a list of 3 arxiv search papers', + 'type': 'STRING'}], + 'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}, + {'name': 'IPythonInterpreter', + 'description': "When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.", + 'parameters': [{'name': 'command', + 'type': 'STRING', + 'description': 'Python code'}, + {'name': 'timeout', + 'type': 'NUMBER', + 'description': 'Upper bound of waiting time for Python script execution.'}], + 'required': ['command'], + 'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}] +``` + +Trigger an action through the executor + +```python +ret = executor('IPythonInterpreter', '{"command": "import math;math.sqrt(100)"}') +ret.result +``` + +```python +[{'type': 'text', 'content': '10.0'}] +``` diff --git a/docs/zh_cn/get_started/install.md b/docs/zh_cn/get_started/install.md new file mode 100644 index 00000000..2e844f90 --- /dev/null +++ b/docs/zh_cn/get_started/install.md @@ -0,0 +1,19 @@ +# 安装方式 + +## pip安装 + +推荐使用 pip 安装 + +```bash +pip install lagent +``` + +## 源码安装 + +如需修改部分功能,可以从源码构建 Lagent + +```bash +git clone https://github.com/InternLM/lagent.git +cd lagent +pip install -e . +``` diff --git a/docs/zh_cn/get_started/quickstart.md b/docs/zh_cn/get_started/quickstart.md new file mode 100644 index 00000000..ae51f57b --- /dev/null +++ b/docs/zh_cn/get_started/quickstart.md @@ -0,0 +1,87 @@ +# 快速上手 + +借助 Lagent 仅需几行代码就能构建大语言模型智能体。 + +## GPT-3.5 驱动的 ReWOO 智能体 + +下面是使用 GPT-3.5 运行 ReWOO 的示例 + +```python +# 从 Lagent 导入必要的模块和类 +from lagent.agents import ReWOO +from lagent.actions import ActionExecutor, GoogleSearch +from lagent.llms import GPTAPI + +# 初始化 LLM,你可能需要提供 API 密钥 +llm = GPTAPI(model_type='gpt-3.5-turbo', key=['Your OPENAI_API_KEY']) + +# 初始化 Goolge 搜索工具,你可能需要提供 API 密钥 +search_tool = GoogleSearch(api_key='Your SERPER_API_KEY') + +# 配置 ReWOO 智能体,创建聊天机器人 +chatbot = ReWOO( + llm=llm, # 大语言模型实例 + action_executor=ActionExecutor( + actions=[search_tool] # 指定智能体可以调用的工具 + ), +) + +# 询问问题并获取回复 +response = chatbot.chat('What profession does Nicholas Ray and Elia Kazan have in common') + +# 打印回复 +print(response.response) +``` + +```python +>>> Film director. +``` + +## InterLM 驱动的 ReAct 智能体 + +注意,如果你想使用 HuggingFace 模型,请先运行 `pip install -e .[all]` + +```python +# 从 Lagent 导入必要的模块和类 +from lagent.agents import ReAct +from lagent.actions import ActionExecutor, GoogleSearch, PythonInterpreter +from lagent.llms import HFTransformer + +from lagent.llms.meta_template import INTERNLM2_META as META + +# 初始化 HFTransformer 模型 +llm = HFTransformer(path='internlm/internlm2-chat-7b', meta_template=META) + +# 初始化 Goolge 搜索工具,你可能需要提供 API 密钥 +search_tool = GoogleSearch(api_key='Your SERPER_API_KEY') + +# 初始化 Python 代码解释其 +python_interpreter = PythonInterpreter() + +# 配置 ReAct 智能体,创建聊天机器人 +chatbot = ReAct( + llm=llm, # 大语言模型实例 + action_executor=ActionExecutor( + actions=[search_tool, python_interpreter]), # 指定智能体可以调用的工具 +) +# 询问LaTeX格式的数学问题 +response = chatbot.chat('若$z=-1+\sqrt{3}i$,则$\frac{z}{{z\overline{z}-1}}=\left(\ \ \right)$') + +# 打印回复 +print(response.response) +``` + +```python +>>> $-\\frac{1}{3}+\\frac{{\\sqrt{3}}}{3}i$ +``` + +## 启动 ReAct 网页 App + +```python +# 你需要先安装 streamlit +# pip install streamlit +streamlit run examples/react_web_demo.py +``` + +然后你可以通过下图所示UI界面进行对话 +![image](https://github.com/InternLM/lagent/assets/24622904/3aebb8b4-07d1-42a2-9da3-46080c556f68) diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst index ac7141dc..3089e209 100644 --- a/docs/zh_cn/index.rst +++ b/docs/zh_cn/index.rst @@ -8,6 +8,14 @@ :caption: 新手入门 get_started/overview.md + get_started/install.md + get_started/quickstart.md + +.. toctree:: + :maxdepth: 2 + :caption: 教程 + + tutorials/action.md .. toctree:: :caption: 切换语言 diff --git a/docs/zh_cn/tutorials/action.md b/docs/zh_cn/tutorials/action.md new file mode 100644 index 00000000..d816648c --- /dev/null +++ b/docs/zh_cn/tutorials/action.md @@ -0,0 +1,394 @@ +# 动作 + +动作,也被称为工具,提供了一套LLM驱动的智能体用来与真实世界交互并执行复杂任务的函数。 + +## 基本概念 + +### 工具 & 工具包 + +有两种类型的工具: + +* 简单工具: 只提供一个API接口供调用。 +* 工具包: 实现多个API接口,承担不同的子任务。 + +### 工具描述 + +在Lagent中,工具描述是一个刻画工具调用方式的字典,能够被LLM观察并用于决策。 + +对于简单工具,描述可按如下格式声明: + +```python +TOOL_DESCRIPTION = { + 'name': 'bold', # 工具名称 + 'description': 'a function used to make text bold', # 介绍工具的功能 + 'parameters': [ # 这个工具所需要的参数列表 + { + 'name': 'text', 'type': 'STRING', 'description': 'input content' + } + ], + 'required': ['text'], # 指定必需的参数名 +} +``` +在某些情况下,可能还包含 `return_data`,`parameter_description` 字段,分别描述返回内容及参数传递格式。 + +```{attention} +`parameter_description` 通常被动作的解析器自动插入到工具描述中,这部分将在[接口设计](#id6)中进行介绍。 +``` + +对于工具包,描述非常相似,但嵌套了子方法 + +```python +TOOL_DESCRIPTION = { + 'name': 'PhraseEmphasis', # 工具包的名字 + 'description': 'a toolkit which provides different styles of text emphasis', # 介绍工具包的功能 + 'api_list': [ + { + 'name': 'bold', + 'description': 'make text bold', + 'parameters': [ + { + 'name': 'text', 'type': 'STRING', 'description': 'input content' + } + ], + 'required': ['text'] + }, + { + 'name': 'italic', + 'description': 'make text italic', + 'parameters': [ + { + 'name': 'text', 'type': 'STRING', 'description': 'input content' + } + ], + 'required': ['text'] + } + ] +} +``` + +## 将函数转换为工具 + +对于已定义好的函数,无需人工添加额外的描述。在 Lagent 中,我们提供了一个修饰器 `tool_api`,它可以通过自动解析函数的类型提示和文档字符串来生成描述字典,并将其绑定到属性 `api_description`。 + +```python +from lagent import tool_api + +@tool_api +def bold(text: str) -> str: + """make text bold + + Args: + text (str): input text + + Returns: + str: bold text + """ + return '**' + text + '**' + + +bold.api_description +``` + +```python +{'name': 'bold', + 'description': 'make text bold', + 'parameters': [{'name': 'text', + 'type': 'STRING', + 'description': 'input text'}], + 'required': ['text']} +``` + +一旦启用 `returns_named_value`,您应当声明返回值的名称,这将被处理成一个新的字段 `return_data`: + +```python +@tool_api(returns_named_value=True) +def bold(text: str) -> str: + """make text bold + + Args: + text (str): input text + + Returns: + bold_text (str): bold text + """ + return '**' + text + '**' + +bold.api_description +``` + +```python +{'name': 'bold', + 'description': 'make text bold', + 'parameters': [{'name': 'text', + 'type': 'STRING', + 'description': 'input text'}], + 'required': ['text'], + 'return_data': [{'name': 'bold_text', + 'description': 'bold text', + 'type': 'STRING'}]} +``` + +有时工具可能返回一个 `dict` 或 `tuple`,如果你想在 `return_data` 中详细说明每个成员的含义而不是把它们当作一个整体,设置 `explode_return=True` 并在文档字符串的 Returns 部分中罗列它们。 + +```python +@tool_api(explode_return=True) +def list_args(a: str, b: int, c: float = 0.0) -> dict: + """Return arguments in dict format + + Args: + a (str): a + b (int): b + c (float): c + + Returns: + dict: input arguments + - a (str): a + - b (int): b + - c: c + """ + return {'a': a, 'b': b, 'c': c} +``` + +```python +{'name': 'list_args', + 'description': 'Return arguments in dict format', + 'parameters': [{'name': 'a', 'type': 'STRING', 'description': 'a'}, + {'name': 'b', 'type': 'NUMBER', 'description': 'b'}, + {'name': 'c', 'type': 'FLOAT', 'description': 'c'}], + 'required': ['a', 'b'], + 'return_data': [{'name': 'a', 'description': 'a', 'type': 'STRING'}, + {'name': 'b', 'description': 'b', 'type': 'NUMBER'}, + {'name': 'c', 'description': 'c'}]} +``` + +```{warning} +目前仅支持 Google 格式的 Python 文档字符串。 +``` + +## 接口设计 + +`BaseAction(description=None, parser=JsonParser, enable=True)` 是所有动作应该继承的基类,它接收三个初始化参数: + +* **description**:一个工具描述的字典,用于设置实例属性 `description`。通常不需要显式地传递这个参数,因为 `BaseAction` 的元类将查找被 `tool_api` 装饰的方法,并组装它们的 `api_description` 构造一个类属性 `__tool_description__`,如果实例化时 `description` 为空,那么该实例属性将置为 `__tool_description__`。 +* **parser**:`BaseParser` 类,用于实例化一个动作解析器校验 `description` 所描述的工具的参数。例如,`JsonParser` 会要求模型在调用工具时传入一个 JSON 格式字符串或者 Python 字典,为了让 LLM 感知到该指令,它会在 `description` 中插入一个 `parameter_description` 字段。 + + ```python + from lagent import BaseAction + + action = BaseAction( + { + 'name': 'bold', + 'description': 'a function used to make text bold', + 'parameters': [ + { + 'name': 'text', 'type': 'STRING', 'description': 'input content' + } + ], + 'required': ['text'] + } + ) + action.description + ``` + + ```python + {'name': 'bold', + 'description': 'a function used to make text bold', + 'parameters': [{'name': 'text', + 'type': 'STRING', + 'description': 'input content'}], + 'required': ['text'], + 'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'} + ``` + +* **enable**: 指明该动作是否生效。 + +### 自定义动作 + +一个简单工具必须实现 `run` 方法,而工具包则应当避免将各子API名称定义为该保留字段。 + +```python +class Bold(BaseAction): + + @tool_api + def run(self, text: str): + """make text bold + + Args: + text (str): input text + + Returns: + str: bold text + """ + return '**' + text + '**' + +class PhraseEmphasis(BaseAction): + """a toolkit which provides different styles of text emphasis""" + + @tool_api + def bold(self, text): + """make text bold + + Args: + text (str): input text + + Returns: + str: bold text + """ + return '**' + text + '**' + + @tool_api + def italic(self, text): + """make text italic + + Args: + text (str): input text + + Returns: + str: italic text + """ + return '*' + text + '*' + +# 查看默认工具描述 +# Bold.__tool_description__, PhraseEmphasis.__tool_description__ +``` + +### 自动注册 + +任何 `BaseAction` 的子类都会自动被注册。你可以使用 `list_tools()` 和 `get_tool()` 来查看所有工具类并通过工具名进行初始化。 + +```python +from lagent import list_tools, get_tool + +list_tools() +``` + +```python +['BaseAction', + 'InvalidAction', + 'NoAction', + 'FinishAction', + 'ArxivSearch', + 'BINGMap', + 'GoogleScholar', + 'GoogleSearch', + 'IPythonInterpreter', + 'PPT', + 'PythonInterpreter', + 'Bold', + 'PhraseEmphasis'] +``` + +创建一个 `PhraseEmphasis` 对象。 + +```python +action = get_tool('PhraseEmphasis') +action.description +``` + +```python +{'name': 'PhraseEmphasis', + 'description': 'a toolkit which provides different styles of text emphasis', + 'api_list': [{'name': 'bold', + 'description': 'make text bold', + 'parameters': [{'name': 'text', + 'type': 'STRING', + 'description': 'input text'}], + 'required': ['text'], + 'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}, + {'name': 'italic', + 'description': 'make text italic', + 'parameters': [{'name': 'text', + 'type': 'STRING', + 'description': 'input text'}], + 'required': ['text'], + 'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}]} +``` + + +## 工具调用 + +### 执行工具 + +`Action` 的 `__call__` 方法需要传入两个参数 + +* `inputs`: 其类型与动作绑定的 `BaseParser` 相关,通常是由大语言模型生成的字符串。 + + `JsonParser`: 允许传入 JSON 格式字符串或 Python 字典。 + + `TupleParser`: 允许传入字面量为元组的字符串或 Python 元组。 +* `name`: 调用哪个 API,默认为 `run`。 + +工具会返回一个封装了调用细节的 `ActionReturn` 对象。 + +* `args`: 一个字典,表示该动作的入参。 +* `type`: 动作名称。 +* `result`: 以字典为成员的列表,每个字典包含两个键——'type' 和 'content',发生异常时该字段为 `None`。 +* `errmsg`: 错误信息,默认为 `None`。 + +以下是一个例子: + +```python +from lagent import IPythonInterpreter, TupleParser + +action1 = IPythonInterpreter() +ret = action1('{"command": "import math;math.sqrt(100)"}') +print(ret.result) +ret = action1({'command': 'import math;math.sqrt(100)'}) +print(ret.result) + +action2 = IPythonInterpreter(parser=TupleParser) +ret = action2('("import math;math.sqrt(100)", )') +print(ret.result) +ret = action2(('import math;math.sqrt(100)',)) +print(ret.result) +``` + +```python +[{'type': 'text', 'content': '10.0'}] +[{'type': 'text', 'content': '10.0'}] +[{'type': 'text', 'content': '10.0'}] +[{'type': 'text', 'content': '10.0'}] +``` + +### 动态触发 + +Lagent 提供 `ActionExecutor` 接口管理多个工具,它会将工具包的 `api_list` 平展并将各 API 更名为 `{tool_name}.{api_name}`。 + +```python +from lagent import ActionExecutor, ArxivSearch, IPythonInterpreter + +executor = ActionExecutor(actions=[ArxivSearch(), IPythonInterpreter()]) +executor.get_actions_info() # 该结果会作为LLM系统提示词的一部分 +``` + +```python +[{'name': 'ArxivSearch.get_arxiv_article_information', + 'description': 'Run Arxiv search and get the article meta information.', + 'parameters': [{'name': 'query', + 'type': 'STRING', + 'description': 'the content of search query'}], + 'required': ['query'], + 'return_data': [{'name': 'content', + 'description': 'a list of 3 arxiv search papers', + 'type': 'STRING'}], + 'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}, + {'name': 'IPythonInterpreter', + 'description': "When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.", + 'parameters': [{'name': 'command', + 'type': 'STRING', + 'description': 'Python code'}, + {'name': 'timeout', + 'type': 'NUMBER', + 'description': 'Upper bound of waiting time for Python script execution.'}], + 'required': ['command'], + 'parameter_description': '如果调用该工具,你必须使用Json格式 {key: value} 传参,其中key为参数名称'}] +``` + +通过动作执行器来触发一个工具 + +```python +ret = executor('IPythonInterpreter', '{"command": "import math;math.sqrt(100)"}') +ret.result +``` + +```python +[{'type': 'text', 'content': '10.0'}] +``` diff --git a/lagent/actions/base_action.py b/lagent/actions/base_action.py index 9c7e0dca..c7a93ccf 100644 --- a/lagent/actions/base_action.py +++ b/lagent/actions/base_action.py @@ -77,12 +77,12 @@ def foo(a, b): ``(int): Description``. When false, parentheses are optional but the items cannot be named: ``int: Description``. Defaults to ``False``. + Returns: + Callable: wrapped function or partial decorator + Important: ``return_data`` field will be added to ``api_description`` only when ``explode_return`` or ``returns_named_value`` is enabled. - - Returns: - Callable: wrapped function or partial decorator """ def _detect_type(string): From b5533b036dd90135d70de1b3e8c22d3575db579d Mon Sep 17 00:00:00 2001 From: BraisedPork <46232992+braisedpork1964@users.noreply.github.com> Date: Mon, 29 Jan 2024 20:16:35 +0800 Subject: [PATCH 19/20] Fix returns of OpenAI interface (#108) fix `BaseAPIModel` chat returns Co-authored-by: wangzy --- lagent/llms/openai.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/lagent/llms/openai.py b/lagent/llms/openai.py index 344c79b3..90bd6d8f 100644 --- a/lagent/llms/openai.py +++ b/lagent/llms/openai.py @@ -99,7 +99,7 @@ def chat( self, inputs: Union[List[dict], List[List[dict]]], **gen_params, - ) -> List[str]: + ) -> Union[str, List[str]]: """Generate responses given the contexts. Args: @@ -108,7 +108,7 @@ def chat( gen_params: additional generation configuration Returns: - List[str]: A list of generated strings. + Union[str, List[str]]: generated string(s) """ assert isinstance(inputs, list) if isinstance(inputs[0], dict): @@ -120,7 +120,8 @@ def chat( for messages in inputs ] wait(tasks) - return [task.result() for task in tasks] + ret = [task.result() for task in tasks] + return ret[0] if isinstance(inputs[0], dict) else ret def _chat(self, messages: List[dict], **gen_params) -> str: """Generate completion from a list of templates. From 1eb10a34678f45e72c085bebd7db88de5d109b2f Mon Sep 17 00:00:00 2001 From: liujiangning30 <147385819+liujiangning30@users.noreply.github.com> Date: Tue, 30 Jan 2024 11:25:37 +0800 Subject: [PATCH 20/20] Feat: add warn for func 'generate_from_template' (#109) * add warn for func 'generate_from_template' * clearer alerts for deprecation * clearer alerts for deprecation --- lagent/llms/base_llm.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/lagent/llms/base_llm.py b/lagent/llms/base_llm.py index e350027f..0dd8cb8a 100644 --- a/lagent/llms/base_llm.py +++ b/lagent/llms/base_llm.py @@ -1,6 +1,7 @@ from abc import abstractclassmethod from copy import copy from typing import Dict, List, Optional, Tuple, Union +from warnings import warn class LMTemplateParser: @@ -188,6 +189,17 @@ def chat(self, inputs: Union[List[dict], List[List[dict]]], **gen_params): inputs = self.template_parser(inputs) return self.generate(inputs, **gen_params) + def generate_from_template( + self, + inputs: Union[List[dict], List[List[dict]]], + **gen_params + ): + warn( + "This function will be deprecated after three months and will be replaced." + "Please use `.chat()`", + DeprecationWarning, 2) + return self.chat(inputs, **gen_params) + def stream_chat(self, inputs: List[dict], **gen_params): """Generate results as streaming given a list of templates.