From c4523f02b74a02f61abc59179d7f9627f60871b3 Mon Sep 17 00:00:00 2001 From: liukuikun <24622904+Harold-lkk@users.noreply.github.com> Date: Fri, 20 Dec 2024 18:13:29 +0800 Subject: [PATCH] [Model] support Claude (#283) * [Model] support Claude * fix openai --- .pre-commit-config.yaml | 14 +- lagent/actions/__init__.py | 40 +++- lagent/hooks/logger.py | 15 +- lagent/llms/__init__.py | 14 +- lagent/llms/anthropic_llm.py | 392 +++++++++++++++++++++++++++++++++++ lagent/llms/openai.py | 294 ++++++++++---------------- lagent/utils/util.py | 20 +- requirements/runtime.txt | 3 +- 8 files changed, 562 insertions(+), 230 deletions(-) create mode 100644 lagent/llms/anthropic_llm.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5066bc49..1f51fcc6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,20 +1,22 @@ exclude: ^(tests/data|scripts|ftdp/protocols|ftdp/template_configs|ftdp/tool_dicts)/ repos: - repo: https://github.com/PyCQA/flake8 - rev: 7.0.0 + rev: 7.1.1 hooks: - id: flake8 - repo: https://github.com/PyCQA/isort rev: 5.13.2 hooks: - id: isort + args: ["--profile", "black", "--filter-files", "--line-width", "119"] - repo: https://github.com/psf/black - rev: 22.8.0 + rev: 24.10.0 hooks: - id: black args: ["--line-length", "119", "--skip-string-normalization"] + - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v5.0.0 hooks: - id: trailing-whitespace - id: check-yaml @@ -27,7 +29,7 @@ repos: - id: mixed-line-ending args: ["--fix=lf"] - repo: https://github.com/executablebooks/mdformat - rev: 0.7.17 + rev: 0.7.21 hooks: - id: mdformat args: ["--number"] @@ -36,11 +38,11 @@ repos: - mdformat_frontmatter - linkify-it-py - repo: https://github.com/codespell-project/codespell - rev: v2.2.6 + rev: v2.3.0 hooks: - id: codespell - repo: https://github.com/asottile/pyupgrade - rev: v3.15.0 + rev: v3.19.1 hooks: - id: pyupgrade args: ["--py36-plus"] diff --git a/lagent/actions/__init__.py b/lagent/actions/__init__.py index 4f73baa4..b75a2262 100644 --- a/lagent/actions/__init__.py +++ b/lagent/actions/__init__.py @@ -1,6 +1,6 @@ from .action_executor import ActionExecutor, AsyncActionExecutor from .arxiv_search import ArxivSearch, AsyncArxivSearch -from .base_action import BaseAction, tool_api +from .base_action import AsyncActionMixin, BaseAction, tool_api from .bing_map import AsyncBINGMap, BINGMap from .builtin_actions import FinishAction, InvalidAction, NoAction from .google_scholar_search import AsyncGoogleScholar, GoogleScholar @@ -14,12 +14,34 @@ from .web_browser import AsyncWebBrowser, WebBrowser __all__ = [ - 'BaseAction', 'ActionExecutor', 'AsyncActionExecutor', 'InvalidAction', - 'FinishAction', 'NoAction', 'BINGMap', 'AsyncBINGMap', 'ArxivSearch', - 'AsyncArxivSearch', 'GoogleSearch', 'AsyncGoogleSearch', 'GoogleScholar', - 'AsyncGoogleScholar', 'IPythonInterpreter', 'AsyncIPythonInterpreter', - 'IPythonInteractive', 'AsyncIPythonInteractive', - 'IPythonInteractiveManager', 'PythonInterpreter', 'AsyncPythonInterpreter', - 'PPT', 'AsyncPPT', 'WebBrowser', 'AsyncWebBrowser', 'BaseParser', - 'JsonParser', 'TupleParser', 'tool_api' + 'BaseAction', + 'ActionExecutor', + 'AsyncActionExecutor', + 'InvalidAction', + 'FinishAction', + 'NoAction', + 'BINGMap', + 'AsyncBINGMap', + 'ArxivSearch', + 'AsyncArxivSearch', + 'GoogleSearch', + 'AsyncGoogleSearch', + 'GoogleScholar', + 'AsyncGoogleScholar', + 'IPythonInterpreter', + 'AsyncIPythonInterpreter', + 'IPythonInteractive', + 'AsyncIPythonInteractive', + 'IPythonInteractiveManager', + 'PythonInterpreter', + 'AsyncPythonInterpreter', + 'PPT', + 'AsyncPPT', + 'WebBrowser', + 'AsyncWebBrowser', + 'BaseParser', + 'JsonParser', + 'TupleParser', + 'tool_api', + 'AsyncActionMixin', ] diff --git a/lagent/hooks/logger.py b/lagent/hooks/logger.py index 50224e43..ccdb8012 100644 --- a/lagent/hooks/logger.py +++ b/lagent/hooks/logger.py @@ -1,5 +1,4 @@ import random -from typing import Optional from termcolor import COLORS, colored @@ -8,10 +7,10 @@ class MessageLogger(Hook): - - def __init__(self, name: str = 'lagent'): + def __init__(self, name: str = 'lagent', add_file_handler: bool = False): self.logger = get_logger( - name, 'info', '%(asctime)s %(levelname)8s %(name)8s - %(message)s') + name, 'info', '%(asctime)s %(levelname)8s %(name)8s - %(message)s', add_file_handler=add_file_handler + ) self.sender2color = {} def before_agent(self, agent, messages, session_id): @@ -29,9 +28,5 @@ def after_action(self, executor, message, session_id): def _process_message(self, message, session_id): sender = message.sender - color = self.sender2color.setdefault(sender, - random.choice(list(COLORS))) - self.logger.info( - colored( - f'session id: {session_id}, message sender: {sender}\n' - f'{message.content}', color)) + color = self.sender2color.setdefault(sender, random.choice(list(COLORS))) + self.logger.info(colored(f'session id: {session_id}, message sender: {sender}\n' f'{message.content}', color)) diff --git a/lagent/llms/__init__.py b/lagent/llms/__init__.py index fcbbd07d..95679b15 100644 --- a/lagent/llms/__init__.py +++ b/lagent/llms/__init__.py @@ -1,9 +1,15 @@ +from .anthropic_llm import AsyncClaudeAPI, ClaudeAPI from .base_api import AsyncBaseAPILLM, BaseAPILLM from .base_llm import AsyncBaseLLM, BaseLLM from .huggingface import HFTransformer, HFTransformerCasualLM, HFTransformerChat -from .lmdeploy_wrapper import (AsyncLMDeployClient, AsyncLMDeployPipeline, - AsyncLMDeployServer, LMDeployClient, - LMDeployPipeline, LMDeployServer) +from .lmdeploy_wrapper import ( + AsyncLMDeployClient, + AsyncLMDeployPipeline, + AsyncLMDeployServer, + LMDeployClient, + LMDeployPipeline, + LMDeployServer, +) from .meta_template import INTERNLM2_META from .openai import GPTAPI, AsyncGPTAPI from .sensenova import SensenovaAPI @@ -29,4 +35,6 @@ 'VllmModel', 'AsyncVllmModel', 'SensenovaAPI', + 'AsyncClaudeAPI', + 'ClaudeAPI', ] diff --git a/lagent/llms/anthropic_llm.py b/lagent/llms/anthropic_llm.py new file mode 100644 index 00000000..8cd7802d --- /dev/null +++ b/lagent/llms/anthropic_llm.py @@ -0,0 +1,392 @@ +import asyncio +import json +import os +from typing import Dict, List, Optional, Union + +import anthropic +import httpcore +import httpx +from anthropic import NOT_GIVEN +from requests.exceptions import ProxyError + +from lagent.llms import AsyncBaseAPILLM, BaseAPILLM + + +class ClaudeAPI(BaseAPILLM): + + is_api: bool = True + + def __init__( + self, + model_type: str = 'claude-3-5-sonnet-20241022', + retry: int = 5, + key: Union[str, List[str]] = 'ENV', + proxies: Optional[Dict] = None, + meta_template: Optional[Dict] = [ + dict(role='system', api_role='system'), + dict(role='user', api_role='user'), + dict(role='assistant', api_role='assistant'), + dict(role='environment', api_role='user'), + ], + temperature: float = NOT_GIVEN, + max_new_tokens: int = 512, + top_p: float = NOT_GIVEN, + top_k: int = NOT_GIVEN, + repetition_penalty: float = 0.0, + stop_words: Union[List[str], str] = None, + ): + + super().__init__( + meta_template=meta_template, + model_type=model_type, + retry=retry, + temperature=temperature, + max_new_tokens=max_new_tokens, + top_p=top_p, + top_k=top_k, + repetition_penalty=repetition_penalty, + stop_words=stop_words, + ) + + key = os.getenv('Claude_API_KEY') if key == 'ENV' else key + + if isinstance(key, str): + self.keys = [key] + else: + self.keys = list(set(key)) + self.clients = {key: anthropic.AsyncAnthropic(proxies=proxies, api_key=key) for key in self.keys} + + # record invalid keys and skip them when requesting API + # - keys have insufficient_quota + self.invalid_keys = set() + + self.key_ctr = 0 + self.model_type = model_type + self.proxies = proxies + + def chat( + self, + inputs: Union[List[dict], List[List[dict]]], + session_ids: Union[int, List[int]] = None, + **gen_params, + ) -> Union[str, List[str]]: + """Generate responses given the contexts. + + Args: + inputs (Union[List[dict], List[List[dict]]]): a list of messages + or list of lists of messages + gen_params: additional generation configuration + + Returns: + Union[str, List[str]]: generated string(s) + """ + assert isinstance(inputs, list) + gen_params = {**self.gen_params, **gen_params} + import nest_asyncio + + nest_asyncio.apply() + + async def run_async_tasks(): + tasks = [ + self._chat(self.template_parser(messages), **gen_params) + for messages in ([inputs] if isinstance(inputs[0], dict) else inputs) + ] + return await asyncio.gather(*tasks) + + try: + loop = asyncio.get_running_loop() + # If the event loop is already running, schedule the task + future = asyncio.ensure_future(run_async_tasks()) + ret = loop.run_until_complete(future) + except RuntimeError: + # If no running event loop, start a new one + ret = asyncio.run(run_async_tasks()) + return ret[0] if isinstance(inputs[0], dict) else ret + + def generate_request_data(self, model_type, messages, gen_params): + """ + Generates the request data for different model types. + + Args: + model_type (str): The type of the model (e.g., 'gpt', 'internlm', 'qwen'). + messages (list): The list of messages to be sent to the model. + gen_params (dict): The generation parameters. + json_mode (bool): Flag to determine if the response format should be JSON. + + Returns: + tuple: A tuple containing the header and the request data. + """ + # Copy generation parameters to avoid modifying the original dictionary + gen_params = gen_params.copy() + + # Hold out 100 tokens due to potential errors in token calculation + max_tokens = min(gen_params.pop('max_new_tokens'), 4096) + if max_tokens <= 0: + return '', '' + gen_params.pop('repetition_penalty') + if 'stop_words' in gen_params: + gen_params['stop_sequences'] = gen_params.pop('stop_words') + # Common parameters processing + gen_params['max_tokens'] = max_tokens + gen_params.pop('skip_special_tokens', None) + gen_params.pop('session_id', None) + + system = None + if messages[0]['role'] == 'system': + system = messages.pop(0) + system = system['content'] + for message in messages: + message.pop('name', None) + data = {'model': model_type, 'messages': messages, **gen_params} + if system: + data['system'] = system + return data + + async def _chat(self, messages: List[dict], **gen_params) -> str: + """Generate completion from a list of templates. + + Args: + messages (List[dict]): a list of prompt dictionaries + gen_params: additional generation configuration + + Returns: + str: The generated string. + """ + assert isinstance(messages, list) + + data = self.generate_request_data(model_type=self.model_type, messages=messages, gen_params=gen_params) + max_num_retries = 0 + + while max_num_retries < self.retry: + if len(self.invalid_keys) == len(self.keys): + raise RuntimeError('All keys have insufficient quota.') + # find the next valid key + while True: + self.key_ctr += 1 + if self.key_ctr == len(self.keys): + self.key_ctr = 0 + + if self.keys[self.key_ctr] not in self.invalid_keys: + break + + key = self.keys[self.key_ctr] + client = self.clients[key] + + try: + response = await client.messages.create(**data) + response = json.loads(response.json()) + return response['content'][0]['text'].strip() + except (anthropic.RateLimitError, anthropic.APIConnectionError) as e: + print(f'API请求错误: {e}') + await asyncio.sleep(5) + + except (httpcore.ProxyError, ProxyError) as e: + + print(f'代理服务器错误: {e}') + await asyncio.sleep(5) + except httpx.TimeoutException as e: + print(f'请求超时: {e}') + await asyncio.sleep(5) + + except KeyboardInterrupt: + raise + + except Exception as error: + if error.body['error']['message'] == 'invalid x-api-key': + self.invalid_keys.add(key) + self.logger.warn(f'invalid key: {key}') + elif error.body['error']['type'] == 'overloaded_error': + await asyncio.sleep(5) + elif error.body['error']['message'] == 'Internal server error': + await asyncio.sleep(5) + elif error.body['error']['message'] == ( + 'Your credit balance is too low to access the Anthropic API. Please go to Plans & Billing to ' + 'upgrade or purchase credits.' + ): + self.invalid_keys.add(key) + print(f'API has no quota: {key}, Valid keys: {len(self.keys) - len(self.invalid_keys)}') + max_num_retries += 1 + + raise RuntimeError( + 'Calling Claude failed after retrying for ' f'{max_num_retries} times. Check the logs for ' 'details.' + ) + + +class AsyncClaudeAPI(AsyncBaseAPILLM): + + is_api: bool = True + + def __init__( + self, + model_type: str = 'claude-3-5-sonnet-20241022', + retry: int = 5, + key: Union[str, List[str]] = 'ENV', + proxies: Optional[Dict] = None, + meta_template: Optional[Dict] = None, + temperature: float = NOT_GIVEN, + max_new_tokens: int = 512, + top_p: float = NOT_GIVEN, + top_k: int = NOT_GIVEN, + repetition_penalty: float = 0.0, + stop_words: Union[List[str], str] = None, + ): + + super().__init__( + model_type=model_type, + retry=retry, + meta_template=meta_template, + temperature=temperature, + max_new_tokens=max_new_tokens, + top_p=top_p, + top_k=top_k, + repetition_penalty=repetition_penalty, + stop_words=stop_words, + ) + + key = os.getenv('Claude_API_KEY') if key == 'ENV' else key + + if isinstance(key, str): + self.keys = [key] + else: + self.keys = list(set(key)) + self.clients = {key: anthropic.AsyncAnthropic(proxies=proxies, api_key=key) for key in self.keys} + + # record invalid keys and skip them when requesting API + # - keys have insufficient_quota + self.invalid_keys = set() + + self.key_ctr = 0 + self.model_type = model_type + self.proxies = proxies + + async def chat( + self, + inputs: Union[List[dict], List[List[dict]]], + session_ids: Union[int, List[int]] = None, + **gen_params, + ) -> Union[str, List[str]]: + """Generate responses given the contexts. + + Args: + inputs (Union[List[dict], List[List[dict]]]): a list of messages + or list of lists of messages + gen_params: additional generation configuration + + Returns: + Union[str, List[str]]: generated string(s) + """ + assert isinstance(inputs, list) + gen_params = {**self.gen_params, **gen_params} + tasks = [ + self._chat(self.template_parser(messages), **gen_params) + for messages in ([inputs] if isinstance(inputs[0], dict) else inputs) + ] + ret = await asyncio.gather(*tasks) + return ret[0] if isinstance(inputs[0], dict) else ret + + def generate_request_data(self, model_type, messages, gen_params): + """ + Generates the request data for different model types. + + Args: + model_type (str): The type of the model (e.g., 'gpt', 'internlm', 'qwen'). + messages (list): The list of messages to be sent to the model. + gen_params (dict): The generation parameters. + json_mode (bool): Flag to determine if the response format should be JSON. + + Returns: + tuple: A tuple containing the header and the request data. + """ + # Copy generation parameters to avoid modifying the original dictionary + gen_params = gen_params.copy() + + # Hold out 100 tokens due to potential errors in token calculation + max_tokens = min(gen_params.pop('max_new_tokens'), 4096) + if max_tokens <= 0: + return '', '' + gen_params.pop('repetition_penalty') + if 'stop_words' in gen_params: + gen_params['stop_sequences'] = gen_params.pop('stop_words') + # Common parameters processing + gen_params['max_tokens'] = max_tokens + gen_params.pop('skip_special_tokens', None) + gen_params.pop('session_id', None) + + system = None + if messages[0]['role'] == 'system': + system = messages.pop(0) + system = system['content'] + + data = {'model': model_type, 'messages': messages, **gen_params} + if system: + data['system'] = system + return data + + async def _chat(self, messages: List[dict], **gen_params) -> str: + """Generate completion from a list of templates. + + Args: + messages (List[dict]): a list of prompt dictionaries + gen_params: additional generation configuration + + Returns: + str: The generated string. + """ + assert isinstance(messages, list) + + data = self.generate_request_data(model_type=self.model_type, messages=messages, gen_params=gen_params) + max_num_retries = 0 + + while max_num_retries < self.retry: + if len(self.invalid_keys) == len(self.keys): + raise RuntimeError('All keys have insufficient quota.') + # find the next valid key + while True: + self.key_ctr += 1 + if self.key_ctr == len(self.keys): + self.key_ctr = 0 + + if self.keys[self.key_ctr] not in self.invalid_keys: + break + + key = self.keys[self.key_ctr] + client = self.clients[key] + + try: + response = await client.messages.create(**data) + response = json.loads(response.json()) + return response['content'][0]['text'].strip() + except (anthropic.RateLimitError, anthropic.APIConnectionError) as e: + print(f'API请求错误: {e}') + await asyncio.sleep(5) + + except (httpcore.ProxyError, ProxyError) as e: + + print(f'代理服务器错误: {e}') + await asyncio.sleep(5) + except httpx.TimeoutException as e: + print(f'请求超时: {e}') + await asyncio.sleep(5) + + except KeyboardInterrupt: + raise + + except Exception as error: + if error.body['error']['message'] == 'invalid x-api-key': + self.invalid_keys.add(key) + self.logger.warn(f'invalid key: {key}') + elif error.body['error']['type'] == 'overloaded_error': + await asyncio.sleep(5) + elif error.body['error']['message'] == 'Internal server error': + await asyncio.sleep(5) + elif error.body['error']['message'] == ( + 'Your credit balance is too low to access the Anthropic API. Please go to Plans & Billing to' + ' upgrade or purchase credits.' + ): + self.invalid_keys.add(key) + print(f'API has no quota: {key}, Valid keys: {len(self.keys) - len(self.invalid_keys)}') + max_num_retries += 1 + + raise RuntimeError( + 'Calling Claude failed after retrying for ' f'{max_num_retries} times. Check the logs for ' 'details.' + ) diff --git a/lagent/llms/openai.py b/lagent/llms/openai.py index ffbd1b3d..a1ac34ba 100644 --- a/lagent/llms/openai.py +++ b/lagent/llms/openai.py @@ -47,30 +47,27 @@ class GPTAPI(BaseAPILLM): is_api: bool = True - def __init__(self, - model_type: str = 'gpt-3.5-turbo', - retry: int = 2, - json_mode: bool = False, - key: Union[str, List[str]] = 'ENV', - org: Optional[Union[str, List[str]]] = None, - meta_template: Optional[Dict] = [ - dict(role='system', api_role='system'), - dict(role='user', api_role='user'), - dict(role='assistant', api_role='assistant'), - dict(role='environment', api_role='system') - ], - api_base: str = OPENAI_API_BASE, - proxies: Optional[Dict] = None, - **gen_params): + def __init__( + self, + model_type: str = 'gpt-3.5-turbo', + retry: int = 2, + json_mode: bool = False, + key: Union[str, List[str]] = 'ENV', + org: Optional[Union[str, List[str]]] = None, + meta_template: Optional[Dict] = [ + dict(role='system', api_role='system'), + dict(role='user', api_role='user'), + dict(role='assistant', api_role='assistant'), + dict(role='environment', api_role='system'), + ], + api_base: str = OPENAI_API_BASE, + proxies: Optional[Dict] = None, + **gen_params, + ): if 'top_k' in gen_params: - warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', - DeprecationWarning) + warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', DeprecationWarning) gen_params.pop('top_k') - super().__init__( - model_type=model_type, - meta_template=meta_template, - retry=retry, - **gen_params) + super().__init__(model_type=model_type, meta_template=meta_template, retry=retry, **gen_params) self.gen_params.pop('top_k') self.logger = getLogger(__name__) @@ -115,11 +112,8 @@ def chat( gen_params = {**self.gen_params, **gen_params} with ThreadPoolExecutor(max_workers=20) as executor: tasks = [ - executor.submit(self._chat, - self.template_parser._prompt2api(messages), - **gen_params) - for messages in ( - [inputs] if isinstance(inputs[0], dict) else inputs) + executor.submit(self._chat, messages, **gen_params) + for messages in ([inputs] if isinstance(inputs[0], dict) else inputs) ] ret = [task.result() for task in tasks] return ret[0] if isinstance(inputs[0], dict) else ret @@ -150,7 +144,7 @@ def stream_chat( if stop_words is None: stop_words = [] # mapping to role that openai supports - messages = self.template_parser._prompt2api(inputs) + messages = self.template_parser(inputs) for text in self._stream_chat(messages, **gen_params): if self.model_type.lower().startswith('qwen'): resp = text @@ -180,12 +174,10 @@ def _chat(self, messages: List[dict], **gen_params) -> str: str: The generated string. """ assert isinstance(messages, list) - + messages = self.template_parser(messages) header, data = self.generate_request_data( - model_type=self.model_type, - messages=messages, - gen_params=gen_params, - json_mode=self.json_mode) + model_type=self.model_type, messages=messages, gen_params=gen_params, json_mode=self.json_mode + ) max_num_retries, errmsg = 0, '' while max_num_retries < self.retry: @@ -214,11 +206,7 @@ def _chat(self, messages: List[dict], **gen_params) -> str: response = dict() try: - raw_response = requests.post( - self.url, - headers=header, - data=json.dumps(data), - proxies=self.proxies) + raw_response = requests.post(self.url, headers=header, data=json.dumps(data), proxies=self.proxies) response = raw_response.json() return response['choices'][0]['message']['content'].strip() except requests.ConnectionError: @@ -239,17 +227,18 @@ def _chat(self, messages: List[dict], **gen_params) -> str: self.logger.warn(f'insufficient_quota key: {key}') continue - errmsg = 'Find error message in response: ' + str( - response['error']) + errmsg = 'Find error message in response: ' + str(response['error']) self.logger.error(errmsg) except Exception as error: errmsg = str(error) + '\n' + str(traceback.format_exc()) self.logger.error(errmsg) max_num_retries += 1 - raise RuntimeError('Calling OpenAI failed after retrying for ' - f'{max_num_retries} times. Check the logs for ' - f'details. errmsg: {errmsg}') + raise RuntimeError( + 'Calling OpenAI failed after retrying for ' + f'{max_num_retries} times. Check the logs for ' + f'details. errmsg: {errmsg}' + ) def _stream_chat(self, messages: List[dict], **gen_params) -> str: """Generate completion from a list of templates. @@ -263,8 +252,7 @@ def _stream_chat(self, messages: List[dict], **gen_params) -> str: """ def streaming(raw_response): - for chunk in raw_response.iter_lines( - chunk_size=8192, decode_unicode=False, delimiter=b'\n'): + for chunk in raw_response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b'\n'): if chunk: decoded = chunk.decode('utf-8') if decoded.startswith('data: [DONE]'): @@ -300,10 +288,8 @@ def streaming(raw_response): assert isinstance(messages, list) header, data = self.generate_request_data( - model_type=self.model_type, - messages=messages, - gen_params=gen_params, - json_mode=self.json_mode) + model_type=self.model_type, messages=messages, gen_params=gen_params, json_mode=self.json_mode + ) max_num_retries, errmsg = 0, '' while max_num_retries < self.retry: @@ -330,11 +316,7 @@ def streaming(raw_response): response = dict() try: - raw_response = requests.post( - self.url, - headers=header, - data=json.dumps(data), - proxies=self.proxies) + raw_response = requests.post(self.url, headers=header, data=json.dumps(data), proxies=self.proxies) return streaming(raw_response) except requests.ConnectionError: errmsg = 'Got connection error ' + str(traceback.format_exc()) @@ -354,23 +336,20 @@ def streaming(raw_response): self.logger.warn(f'insufficient_quota key: {key}') continue - errmsg = 'Find error message in response: ' + str( - response['error']) + errmsg = 'Find error message in response: ' + str(response['error']) self.logger.error(errmsg) except Exception as error: errmsg = str(error) + '\n' + str(traceback.format_exc()) self.logger.error(errmsg) max_num_retries += 1 - raise RuntimeError('Calling OpenAI failed after retrying for ' - f'{max_num_retries} times. Check the logs for ' - f'details. errmsg: {errmsg}') + raise RuntimeError( + 'Calling OpenAI failed after retrying for ' + f'{max_num_retries} times. Check the logs for ' + f'details. errmsg: {errmsg}' + ) - def generate_request_data(self, - model_type, - messages, - gen_params, - json_mode=False): + def generate_request_data(self, model_type, messages, gen_params, json_mode=False): """ Generates the request data for different model types. @@ -401,34 +380,21 @@ def generate_request_data(self, if 'stop_words' in gen_params: gen_params['stop'] = gen_params.pop('stop_words') if 'repetition_penalty' in gen_params: - gen_params['frequency_penalty'] = gen_params.pop( - 'repetition_penalty') + gen_params['frequency_penalty'] = gen_params.pop('repetition_penalty') # Model-specific processing data = {} if model_type.lower().startswith('gpt'): if 'top_k' in gen_params: - warnings.warn( - '`top_k` parameter is deprecated in OpenAI APIs.', - DeprecationWarning) + warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', DeprecationWarning) gen_params.pop('top_k') gen_params.pop('skip_special_tokens', None) gen_params.pop('session_id', None) - data = { - 'model': model_type, - 'messages': messages, - 'n': 1, - **gen_params - } + data = {'model': model_type, 'messages': messages, 'n': 1, **gen_params} if json_mode: data['response_format'] = {'type': 'json_object'} elif model_type.lower().startswith('internlm'): - data = { - 'model': model_type, - 'messages': messages, - 'n': 1, - **gen_params - } + data = {'model': model_type, 'messages': messages, 'n': 1, **gen_params} if json_mode: data['response_format'] = {'type': 'json_object'} elif model_type.lower().startswith('qwen'): @@ -436,21 +402,11 @@ def generate_request_data(self, gen_params.pop('skip_special_tokens', None) gen_params.pop('session_id', None) if 'frequency_penalty' in gen_params: - gen_params['repetition_penalty'] = gen_params.pop( - 'frequency_penalty') + gen_params['repetition_penalty'] = gen_params.pop('frequency_penalty') gen_params['result_format'] = 'message' - data = { - 'model': model_type, - 'input': { - 'messages': messages - }, - 'parameters': { - **gen_params - } - } + data = {'model': model_type, 'input': {'messages': messages}, 'parameters': {**gen_params}} else: - raise NotImplementedError( - f'Model type {model_type} is not supported') + raise NotImplementedError(f'Model type {model_type} is not supported') return header, data @@ -464,6 +420,7 @@ def tokenize(self, prompt: str) -> list: list: token ids """ import tiktoken + self.tiktoken = tiktoken enc = self.tiktoken.encoding_for_model(self.model_type) return enc.encode(prompt) @@ -495,29 +452,27 @@ class AsyncGPTAPI(AsyncBaseAPILLM): is_api: bool = True - def __init__(self, - model_type: str = 'gpt-3.5-turbo', - retry: int = 2, - json_mode: bool = False, - key: Union[str, List[str]] = 'ENV', - org: Optional[Union[str, List[str]]] = None, - meta_template: Optional[Dict] = [ - dict(role='system', api_role='system'), - dict(role='user', api_role='user'), - dict(role='assistant', api_role='assistant') - ], - api_base: str = OPENAI_API_BASE, - proxies: Optional[Dict] = None, - **gen_params): + def __init__( + self, + model_type: str = 'gpt-3.5-turbo', + retry: int = 2, + json_mode: bool = False, + key: Union[str, List[str]] = 'ENV', + org: Optional[Union[str, List[str]]] = None, + meta_template: Optional[Dict] = [ + dict(role='system', api_role='system'), + dict(role='user', api_role='user'), + dict(role='assistant', api_role='assistant'), + dict(role='environment', api_role='system'), + ], + api_base: str = OPENAI_API_BASE, + proxies: Optional[Dict] = None, + **gen_params, + ): if 'top_k' in gen_params: - warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', - DeprecationWarning) + warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', DeprecationWarning) gen_params.pop('top_k') - super().__init__( - model_type=model_type, - meta_template=meta_template, - retry=retry, - **gen_params) + super().__init__(model_type=model_type, meta_template=meta_template, retry=retry, **gen_params) self.gen_params.pop('top_k') self.logger = getLogger(__name__) @@ -562,8 +517,7 @@ async def chat( raise NotImplementedError('unsupported parameter: max_tokens') gen_params = {**self.gen_params, **gen_params} tasks = [ - self._chat(messages, **gen_params) for messages in ( - [inputs] if isinstance(inputs[0], dict) else inputs) + self._chat(messages, **gen_params) for messages in ([inputs] if isinstance(inputs[0], dict) else inputs) ] ret = await asyncio.gather(*tasks) return ret[0] if isinstance(inputs[0], dict) else ret @@ -594,7 +548,7 @@ async def stream_chat( if stop_words is None: stop_words = [] # mapping to role that openai supports - messages = self.template_parser._prompt2api(inputs) + messages = self.template_parser(inputs) async for text in self._stream_chat(messages, **gen_params): if self.model_type.lower().startswith('qwen'): resp = text @@ -624,12 +578,10 @@ async def _chat(self, messages: List[dict], **gen_params) -> str: str: The generated string. """ assert isinstance(messages, list) - + messages = self.template_parser(messages) header, data = self.generate_request_data( - model_type=self.model_type, - messages=messages, - gen_params=gen_params, - json_mode=self.json_mode) + model_type=self.model_type, messages=messages, gen_params=gen_params, json_mode=self.json_mode + ) max_num_retries, errmsg = 0, '' while max_num_retries < self.retry: @@ -658,14 +610,10 @@ async def _chat(self, messages: List[dict], **gen_params) -> str: try: async with aiohttp.ClientSession() as session: async with session.post( - self.url, - headers=header, - json=data, - proxy=self.proxies.get( - 'https', self.proxies.get('http'))) as resp: + self.url, headers=header, json=data, proxy=self.proxies.get('https', self.proxies.get('http')) + ) as resp: response = await resp.json() - return response['choices'][0]['message'][ - 'content'].strip() + return response['choices'][0]['message']['content'].strip() except aiohttp.ClientConnectionError: errmsg = 'Got connection error ' + str(traceback.format_exc()) self.logger.error(errmsg) @@ -675,8 +623,7 @@ async def _chat(self, messages: List[dict], **gen_params) -> str: self.logger.error(errmsg) continue except json.JSONDecodeError: - errmsg = 'JsonDecode error, got ' + (await resp.text( - errors='replace')) + errmsg = 'JsonDecode error, got ' + (await resp.text(errors='replace')) self.logger.error(errmsg) continue except KeyError: @@ -689,20 +636,20 @@ async def _chat(self, messages: List[dict], **gen_params) -> str: self.logger.warn(f'insufficient_quota key: {key}') continue - errmsg = 'Find error message in response: ' + str( - response['error']) + errmsg = 'Find error message in response: ' + str(response['error']) self.logger.error(errmsg) except Exception as error: errmsg = str(error) + '\n' + str(traceback.format_exc()) self.logger.error(errmsg) max_num_retries += 1 - raise RuntimeError('Calling OpenAI failed after retrying for ' - f'{max_num_retries} times. Check the logs for ' - f'details. errmsg: {errmsg}') + raise RuntimeError( + 'Calling OpenAI failed after retrying for ' + f'{max_num_retries} times. Check the logs for ' + f'details. errmsg: {errmsg}' + ) - async def _stream_chat(self, messages: List[dict], - **gen_params) -> AsyncGenerator[str, None]: + async def _stream_chat(self, messages: List[dict], **gen_params) -> AsyncGenerator[str, None]: """Generate completion from a list of templates. Args: @@ -750,10 +697,8 @@ async def streaming(raw_response): assert isinstance(messages, list) header, data = self.generate_request_data( - model_type=self.model_type, - messages=messages, - gen_params=gen_params, - json_mode=self.json_mode) + model_type=self.model_type, messages=messages, gen_params=gen_params, json_mode=self.json_mode + ) max_num_retries, errmsg = 0, '' while max_num_retries < self.retry: @@ -782,12 +727,8 @@ async def streaming(raw_response): try: async with aiohttp.ClientSession() as session: async with session.post( - self.url, - headers=header, - json=data, - proxy=self.proxies.get( - 'https', - self.proxies.get('http'))) as raw_response: + self.url, headers=header, json=data, proxy=self.proxies.get('https', self.proxies.get('http')) + ) as raw_response: async for msg in streaming(raw_response): yield msg return @@ -809,23 +750,20 @@ async def streaming(raw_response): self.logger.warn(f'insufficient_quota key: {key}') continue - errmsg = 'Find error message in response: ' + str( - response['error']) + errmsg = 'Find error message in response: ' + str(response['error']) self.logger.error(errmsg) except Exception as error: errmsg = str(error) + '\n' + str(traceback.format_exc()) self.logger.error(errmsg) max_num_retries += 1 - raise RuntimeError('Calling OpenAI failed after retrying for ' - f'{max_num_retries} times. Check the logs for ' - f'details. errmsg: {errmsg}') + raise RuntimeError( + 'Calling OpenAI failed after retrying for ' + f'{max_num_retries} times. Check the logs for ' + f'details. errmsg: {errmsg}' + ) - def generate_request_data(self, - model_type, - messages, - gen_params, - json_mode=False): + def generate_request_data(self, model_type, messages, gen_params, json_mode=False): """ Generates the request data for different model types. @@ -856,34 +794,21 @@ def generate_request_data(self, if 'stop_words' in gen_params: gen_params['stop'] = gen_params.pop('stop_words') if 'repetition_penalty' in gen_params: - gen_params['frequency_penalty'] = gen_params.pop( - 'repetition_penalty') + gen_params['frequency_penalty'] = gen_params.pop('repetition_penalty') # Model-specific processing data = {} if model_type.lower().startswith('gpt'): if 'top_k' in gen_params: - warnings.warn( - '`top_k` parameter is deprecated in OpenAI APIs.', - DeprecationWarning) + warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', DeprecationWarning) gen_params.pop('top_k') gen_params.pop('skip_special_tokens', None) gen_params.pop('session_id', None) - data = { - 'model': model_type, - 'messages': messages, - 'n': 1, - **gen_params - } + data = {'model': model_type, 'messages': messages, 'n': 1, **gen_params} if json_mode: data['response_format'] = {'type': 'json_object'} elif model_type.lower().startswith('internlm'): - data = { - 'model': model_type, - 'messages': messages, - 'n': 1, - **gen_params - } + data = {'model': model_type, 'messages': messages, 'n': 1, **gen_params} if json_mode: data['response_format'] = {'type': 'json_object'} elif model_type.lower().startswith('qwen'): @@ -891,21 +816,11 @@ def generate_request_data(self, gen_params.pop('skip_special_tokens', None) gen_params.pop('session_id', None) if 'frequency_penalty' in gen_params: - gen_params['repetition_penalty'] = gen_params.pop( - 'frequency_penalty') + gen_params['repetition_penalty'] = gen_params.pop('frequency_penalty') gen_params['result_format'] = 'message' - data = { - 'model': model_type, - 'input': { - 'messages': messages - }, - 'parameters': { - **gen_params - } - } + data = {'model': model_type, 'input': {'messages': messages}, 'parameters': {**gen_params}} else: - raise NotImplementedError( - f'Model type {model_type} is not supported') + raise NotImplementedError(f'Model type {model_type} is not supported') return header, data @@ -919,6 +834,7 @@ def tokenize(self, prompt: str) -> list: list: token ids """ import tiktoken + self.tiktoken = tiktoken enc = self.tiktoken.encoding_for_model(self.model_type) return enc.encode(prompt) diff --git a/lagent/utils/util.py b/lagent/utils/util.py index a40482b5..609382ec 100644 --- a/lagent/utils/util.py +++ b/lagent/utils/util.py @@ -29,8 +29,8 @@ def load_class_from_string(class_path: str, path=None): def create_object(config: Union[Dict, Any] = None): - """Create an instance based on the configuration where 'type' is a - preserved key to indicate the class (path). When accepting non-dictionary + """Create an instance based on the configuration where 'type' is a + preserved key to indicate the class (path). When accepting non-dictionary input, the function degenerates to an identity. """ if config is None or not isinstance(config, dict): @@ -62,8 +62,7 @@ async def async_as_completed(futures: Iterable[asyncio.Future]): yield await next_completed -def filter_suffix(response: Union[str, List[str]], - suffixes: Optional[List[str]] = None) -> str: +def filter_suffix(response: Union[str, List[str]], suffixes: Optional[List[str]] = None) -> str: """Filter response with suffixes. Args: @@ -95,12 +94,11 @@ def filter_suffix(response: Union[str, List[str]], def get_logger( name: str = 'lagent', level: str = 'debug', - fmt: - str = '%(asctime)s %(levelname)8s %(filename)20s %(lineno)4s - %(message)s', + fmt: str = '%(asctime)s %(levelname)8s %(filename)20s %(lineno)4s - %(message)s', add_file_handler: bool = False, log_dir: str = 'log', - log_file: str = time.strftime('%Y-%m-%d.log', time.localtime()), - max_bytes: int = 5 * 1024 * 1024, + log_file: str = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) + '.log', + max_bytes: int = 50 * 1024 * 1024, backup_count: int = 3, ): logger = logging.getLogger(name) @@ -117,10 +115,8 @@ def get_logger( os.makedirs(log_dir) log_file_path = osp.join(log_dir, log_file) file_handler = RotatingFileHandler( - log_file_path, - maxBytes=max_bytes, - backupCount=backup_count, - encoding='utf-8') + log_file_path, maxBytes=max_bytes, backupCount=backup_count, encoding='utf-8' + ) file_handler.setFormatter(formatter) logger.addHandler(file_handler) diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 68446dbd..bb28b273 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,4 +1,5 @@ aiohttp +anthropic arxiv asyncache asyncer @@ -14,8 +15,8 @@ jupyter_client==8.6.2 jupyter_core==5.7.2 pydantic==2.6.4 requests +tenacity termcolor tiktoken timeout-decorator typing-extensions -tenacity