diff --git a/robot/AI.py b/robot/AI.py index aca9ba9d..0730d2b7 100644 --- a/robot/AI.py +++ b/robot/AI.py @@ -226,7 +226,6 @@ def chat(self, texts, parsed): logger.critical("AnyQ robot failed to response for %r", msg, exc_info=True) return "抱歉, AnyQ回答失败" - class OPENAIRobot(AbstractRobot): SLUG = "openai" @@ -235,6 +234,8 @@ def __init__( self, openai_api_key, model, + provider, + api_version, temperature, max_tokens, top_p, @@ -267,6 +268,8 @@ def __init__( logger.critical("OpenAI 初始化失败,请升级 Python 版本至 > 3.6") self.model = model self.prefix = prefix + self.provider = provider + self.api_version = api_version self.temperature = temperature self.max_tokens = max_tokens self.top_p = top_p @@ -295,12 +298,20 @@ def stream_chat(self, texts): header = { "Content-Type": "application/json", - "Authorization": "Bearer " + self.openai.api_key, + # "Authorization": "Bearer " + self.openai.api_key, } + if self.provider == 'openai': + header['Authorization'] = "Bearer " + self.openai.api_key, + elif self.provider == 'azure': + header['api-key'] = self.openai.api_key + else: + raise ValueError("Please check your config file, OpenAiRobot's provider should be openai or azure.") data = {"model": self.model, "messages": self.context, "stream": True} logger.info(f"使用模型:{self.model},开始流式请求") url = self.api_base + "/completions" + if self.provider == 'azure': + url = f"{self.api_base}/openai/deployments/{self.model}/chat/completions?api-version={self.api_version}" # 请求接收流式数据 try: response = requests.request( @@ -368,17 +379,29 @@ def chat(self, texts, parsed): try: respond = "" self.context.append({"role": "user", "content": msg}) - response = self.openai.Completion.create( - model=self.model, - messages=self.context, - temperature=self.temperature, - max_tokens=self.max_tokens, - top_p=self.top_p, - frequency_penalty=self.frequency_penalty, - presence_penalty=self.presence_penalty, - stop=self.stop_ai, - api_base=self.api_base - ) + if self.provider == "openai": + response = self.openai.Completion.create( + model=self.model, + messages=self.context, + temperature=self.temperature, + max_tokens=self.max_tokens, + top_p=self.top_p, + frequency_penalty=self.frequency_penalty, + presence_penalty=self.presence_penalty, + stop=self.stop_ai, + api_base=self.api_base + ) + else: + from openai import AzureOpenAI + client = AzureOpenAI( + azure_endpoint = self.api_base, + api_key=self.openai_api_key, + api_version=self.api_version + ) + response = client.chat.completions.create( + model=self.model, + messages=self.context + ) message = response.choices[0].message respond = message.content self.context.append(message) diff --git a/static/default.yml b/static/default.yml index 8a2f1db9..686c68b1 100755 --- a/static/default.yml +++ b/static/default.yml @@ -284,6 +284,8 @@ tuling: # 注册一个账号,获得 openai_api_key 后填到下面的配置中即可 openai: openai_api_key: 'sk-xxxxxxxxxxxxxxxxxxxxxxxxxx' + provider: 'azure' # openai的接口填写openai, azure的填写azure + api_version: '2023-05-15' # 如果是openai的,留空就行,azure的需填写对应的api_version,参考官方文档 # 参数指定将生成文本的模型类型。目前支持 gpt-3.5-turbo 和 gpt-3.5-turbo-0301 两种选择 model: 'gpt-3.5-turbo' # 在前面加的一段前缀