From a58c914616dbedaf5caae912dd617aa5c508112e Mon Sep 17 00:00:00 2001 From: liukuikun <24622904+Harold-lkk@users.noreply.github.com> Date: Sat, 21 Dec 2024 13:33:04 +0800 Subject: [PATCH] support o1 and claude (#285) --- lagent/llms/anthropic_llm.py | 5 ++++- lagent/llms/openai.py | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/lagent/llms/anthropic_llm.py b/lagent/llms/anthropic_llm.py index fe50bfa..8aac28c 100644 --- a/lagent/llms/anthropic_llm.py +++ b/lagent/llms/anthropic_llm.py @@ -320,7 +320,8 @@ def generate_request_data(self, model_type, messages, gen_params): if messages[0]['role'] == 'system': system = messages.pop(0) system = system['content'] - + for message in messages: + message.pop('name', None) data = {'model': model_type, 'messages': messages, **gen_params} if system: data['system'] = system @@ -389,6 +390,8 @@ async def _chat(self, messages: List[dict], **gen_params) -> str: ): self.invalid_keys.add(key) print(f'API has no quota: {key}, Valid keys: {len(self.keys) - len(self.invalid_keys)}') + else: + raise error max_num_retries += 1 raise RuntimeError( diff --git a/lagent/llms/openai.py b/lagent/llms/openai.py index a1ac34b..0d75308 100644 --- a/lagent/llms/openai.py +++ b/lagent/llms/openai.py @@ -804,6 +804,7 @@ def generate_request_data(self, model_type, messages, gen_params, json_mode=Fals gen_params.pop('top_k') gen_params.pop('skip_special_tokens', None) gen_params.pop('session_id', None) + data = {'model': model_type, 'messages': messages, 'n': 1, **gen_params} if json_mode: data['response_format'] = {'type': 'json_object'} @@ -819,6 +820,8 @@ def generate_request_data(self, model_type, messages, gen_params, json_mode=Fals gen_params['repetition_penalty'] = gen_params.pop('frequency_penalty') gen_params['result_format'] = 'message' data = {'model': model_type, 'input': {'messages': messages}, 'parameters': {**gen_params}} + elif model_type.lower().startswith('o1'): + data = {'model': model_type, 'messages': messages, 'n': 1} else: raise NotImplementedError(f'Model type {model_type} is not supported')