diff --git a/opencompass/models/openai_api.py b/opencompass/models/openai_api.py index 9d601ca05..05eb4d105 100644 --- a/opencompass/models/openai_api.py +++ b/opencompass/models/openai_api.py @@ -1,6 +1,7 @@ import json import os import time +from litellm import completion from concurrent.futures import ThreadPoolExecutor from threading import Lock from typing import Dict, List, Optional, Union @@ -206,18 +207,12 @@ def _generate(self, input: str or PromptList, max_out_len: int, stop=None, temperature=temperature, ) - raw_response = requests.post(self.url, - headers=header, - data=json.dumps(data)) + response = completion(**data, api_key=key, org=self.orgs[self.org_ctr]) + except requests.ConnectionError: self.logger.error('Got connection error, retrying...') continue - try: - response = raw_response.json() - except requests.JSONDecodeError: - self.logger.error('JsonDecode error, got', - str(raw_response.content)) - continue + try: return response['choices'][0]['message']['content'].strip() except KeyError: diff --git a/requirements.txt b/requirements.txt index cbd8d22f1..3b8492cb4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,3 +24,4 @@ tokenizers>=0.13.3 torch>=1.13.1 tqdm==4.64.1 transformers>=4.29.1 +litellm \ No newline at end of file