From 1063f52a26bfe258c7d808649b0955b1688f26a5 Mon Sep 17 00:00:00 2001 From: lwang998 <729594750@qq.com> Date: Sat, 19 Oct 2024 11:04:39 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0coze=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E5=9B=BE=E7=89=87=E6=96=87=E4=BB=B6=E4=B8=8A=E4=BC=A0?= =?UTF-8?q?=E5=8A=9F=E8=83=BD=20(#117)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 增加coze支持图片文件上传功能 * 合并到bytedance文件夹 * 合并到bytedance文件夹 * 恢复nigore * 修复没有群名导致的bug * cozepy升级到0.6版本 --- .gitignore | 2 +- bot/bytedance/bytedance_coze_bot.py | 310 ++++++++++++++++++---------- bot/bytedance/coze_client.py | 56 +++++ bot/bytedance/coze_session.py | 119 +++++++++++ config.py | 3 +- requirements.txt | 1 + 6 files changed, 382 insertions(+), 109 deletions(-) create mode 100644 bot/bytedance/coze_client.py create mode 100644 bot/bytedance/coze_session.py diff --git a/.gitignore b/.gitignore index 7dcfe9d81..bb47cac3d 100644 --- a/.gitignore +++ b/.gitignore @@ -46,4 +46,4 @@ output.txt reference-context # test bot -test_dify.py +test_dify.py \ No newline at end of file diff --git a/bot/bytedance/bytedance_coze_bot.py b/bot/bytedance/bytedance_coze_bot.py index bb791b2cb..49893c7c7 100644 --- a/bot/bytedance/bytedance_coze_bot.py +++ b/bot/bytedance/bytedance_coze_bot.py @@ -1,135 +1,156 @@ # encoding:utf-8 - -import time -from typing import List, Tuple - +import io +import os +from os.path import isfile import requests -from requests import Response - +from urllib.parse import urlparse, unquote from bot.bot import Bot -from bot.chatgpt.chat_gpt_session import ChatGPTSession -from bot.session_manager import SessionManager -from bridge.context import ContextType +from bot.bytedance.coze_client import CozeClient +from bot.bytedance.coze_session import CozeSession, CozeSessionManager +from bridge.context import ContextType, Context from bridge.reply import Reply, ReplyType from common.log import logger from config import conf - +from common import memory +from common.utils import parse_markdown_text +from common.tmp_dir import TmpDir +from cozepy import MessageType,Message class ByteDanceCozeBot(Bot): def __init__(self): super().__init__() - self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "coze") + self.sessions = CozeSessionManager(CozeSession) + self.coze_api_base = conf().get("coze_api_base", "https://api.coze.cn/") + self.coze_api_key = conf().get('coze_api_key', '') + if conf().get('coze_return_show_img', False): + self.show_img_file = True + else: + self.show_img_file = False + coze_bot_id = conf().get('coze_bot_id', '') + coze_bot_id = str(coze_bot_id) + if not coze_bot_id: + logger.error("[COZE] coze_bot_id is not set") + raise Exception("coze_bot_id is not set") + self.coze_bot_id = coze_bot_id - def reply(self, query, context=None): + def reply(self, query, context: Context = None): # acquire reply content - if context.type == ContextType.TEXT: + if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: + if context.type == ContextType.IMAGE_CREATE: + query = conf().get('image_create_prefix', ['画'])[0] + query logger.info("[COZE] query={}".format(query)) - + channel_type = conf().get("channel_type", "wx") + user_id = None + if channel_type in ["wx", "wework", "gewechat"]: + user_id = context["msg"].other_user_nickname + if user_id is None or user_id == '': + user_id = context["msg"].actual_user_nickname + elif channel_type in ["wechatcom_app", "wechatmp", "wechatmp_service", "wechatcom_service"]: + user_id = context["msg"].other_user_id + if user_id is None or user_id == '': + user_id = "default" + else: + return Reply(ReplyType.ERROR, f"unsupported channel type: {channel_type}, now coze only support wx, wechatcom_app, wechatmp, wechatmp_service channel") + logger.debug(f"[COZE] user_id={user_id}") session_id = context["session_id"] - session = self.sessions.session_query(query, session_id) - logger.debug("[COZE] session query={}".format(session.messages)) - reply_content, err = self._reply_text(session_id, session) - if err is not None: - logger.error("[COZE] reply error={}".format(err)) - return Reply(ReplyType.ERROR, "我暂时遇到了一些问题,请您稍后重试~") - logger.debug( - "[COZE] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( - session.messages, - session_id, - reply_content["content"], - reply_content["completion_tokens"], - ) - ) - return Reply(ReplyType.TEXT, reply_content["content"]) + session = self.sessions.session_query(query, user_id, session_id) + logger.debug(f"[COZE] session={session} query={query}") + reply, err = self._reply(query, session, context) + if err != None: + error_msg = conf().get("error_reply", "我暂时遇到了一些问题,请您稍后重试~") + reply = Reply(ReplyType.TEXT, error_msg) + return reply else: reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type)) return reply - def _get_api_base_url(self): - return conf().get("coze_api_base", "https://api.coze.cn/open_api/v2") + def _reply(self, query, session: CozeSession, context: Context): + chat_client = CozeClient(self.coze_api_key, self.coze_api_base) + additional_messages = self._get_upload_files(session) + messages = chat_client.create_chat_message( + bot_id=self.coze_bot_id, + query=query, + additional_messages=additional_messages, + session=session + ) + if self.show_img_file: + return self.get_parsed_reply(messages, context) + else: + return self.get_text_reply(messages, session) - def _get_headers(self): - return { - 'Authorization': f"Bearer {conf().get('coze_api_key', '')}" - } + def get_parsed_reply(self, messages: list[Message], context: Context = None): + parsed_content = None + for message in messages: + if message.type == MessageType.ANSWER: + conte = parse_markdown_text(message.content) + if parsed_content is None: + parsed_content = conte + else: + parsed_content.append(conte) - def _get_payload(self, user: str, query: str, chat_history: List[dict]): - coze_bot_id = conf().get('coze_bot_id', '') - coze_bot_id = str(coze_bot_id) - if not coze_bot_id: - logger.error("[COZE] coze_bot_id is not set") - raise Exception("coze_bot_id is not set") - return { - 'bot_id': coze_bot_id, - "user": user, - "query": query, - "chat_history": chat_history, - "stream": False - } + # {"answer": "![image](/files/tools/dbf9cd7c-2110-4383-9ba8-50d9fd1a4815.png?timestamp=1713970391&nonce=0d5badf2e39466042113a4ba9fd9bf83&sign=OVmdCxCEuEYwc9add3YNFFdUpn4VdFKgl84Cg54iLnU=)"} + at_prefix = "" + channel = context.get("channel") + is_group = context.get("isgroup", False) + if is_group: + at_prefix = "@" + context["msg"].actual_user_nickname + "\n" + for item in parsed_content[:-1]: + reply = None + if item['type'] == 'text': + content = at_prefix + item['content'] + reply = Reply(ReplyType.TEXT, content) + elif item['type'] == 'image': + image_url = self._fill_file_base_url(item['content']) + image = self._download_image(image_url) + if image: + reply = Reply(ReplyType.IMAGE, image) + else: + reply = Reply(ReplyType.TEXT, f"图片链接:{image_url}") + elif item['type'] == 'file': + file_url = self._fill_file_base_url(item['content']) + if isfile(file_url): + file_path = self._download_file(file_url) + if file_path: + reply = Reply(ReplyType.FILE, file_path) + else: + reply = Reply(ReplyType.TEXT, f"链接:{file_url}") + logger.debug(f"[COZE] reply={reply}") + if reply and channel: + channel.send(reply, context) - def _reply_text(self, session_id: str, session: ChatGPTSession, retry_count=0): - try: - query, chat_history = self._convert_messages_format(session.messages) - base_url = self._get_api_base_url() - chat_url = f'{base_url}/chat' - headers = self._get_headers() - payload = self._get_payload(session.session_id, query, chat_history) - logger.debug("[COZE] headers={}, payload={}".format(headers, payload)) - response = requests.post(chat_url, headers=headers, json=payload) - if response.status_code != 200: - error_info = f"[COZE] response text={response.text} status_code={response.status_code}" - logger.warn(error_info) - return None, error_info - answer, err = self._get_completion_content(response) - if err is not None: - return None, err - completion_tokens, total_tokens = self._calc_tokens(session.messages, answer) - return { - "total_tokens": total_tokens, - "completion_tokens": completion_tokens, - "content": answer - }, None - except Exception as e: - if retry_count < 2: - time.sleep(3) - logger.warn(f"[COZE] Exception: {repr(e)} 第{retry_count + 1}次重试") - return self._reply_text(session_id, session, retry_count + 1) + final_item = parsed_content[-1] + final_reply = None + if final_item['type'] == 'text': + content = final_item['content'] + if is_group: + at_prefix = "@" + context["msg"].actual_user_nickname + "\n" + content = at_prefix + content + final_reply = Reply(ReplyType.TEXT, final_item['content']) + elif final_item['type'] == 'image': + image_url = self._fill_file_base_url(final_item['content']) + image = self._download_image(image_url) + if image: + final_reply = Reply(ReplyType.IMAGE, image) + else: + final_reply = Reply(ReplyType.TEXT, f"图片链接:{image_url}") + elif final_item['type'] == 'file': + file_url = self._fill_file_base_url(final_item['content']) + if isfile(file_url): + file_path = self._download_file(file_url) + if file_path: + final_reply = Reply(ReplyType.FILE, file_path) else: - return None, f"[COZE] Exception: {repr(e)} 超过最大重试次数" + final_reply = Reply(ReplyType.TEXT, f"链接:{file_url}") + return final_reply, None - def _convert_messages_format(self, messages) -> Tuple[str, List[dict]]: - # [ - # {"role":"user","content":"你好","content_type":"text"}, - # {"role":"assistant","type":"answer","content":"你好,请问有什么可以帮助你的吗?","content_type":"text"} - # ] - chat_history = [] - for message in messages: - role = message.get('role') - if role == 'user': - content = message.get('content') - chat_history.append({"role": "user", "content": content, "content_type": "text"}) - elif role == 'assistant': - content = message.get('content') - chat_history.append({"role": "assistant", "type": "answer", "content": content, "content_type": "text"}) - elif role == 'system': - # TODO: deal system message - pass - user_message = chat_history.pop() - if user_message.get('role') != 'user' or user_message.get('content', '') == '': - raise Exception('no user message') - query = user_message.get('content') - logger.debug("[COZE] converted coze messages: {}".format([item for item in chat_history])) - logger.debug("[COZE] user content as query: {}".format(query)) - return query, chat_history + # def _get_api_base_url(self): + # return conf().get("coze_api_base", "https://api.coze.cn/open_api/v2") - def _get_completion_content(self, response: Response): - json_response = response.json() - if json_response['msg'] != 'success': - return None, f"[COZE] Error: {json_response['msg']}" + def _get_completion_content(self, messages: list): answer = None - for message in json_response['messages']: - if message.get('type') == 'answer': - answer = message.get('content') + for message in messages: + if message.type == MessageType.ANSWER: + answer = message.content break if not answer: return None, "[COZE] Error: empty answer" @@ -142,3 +163,78 @@ def _calc_tokens(self, messages, answer): for message in messages: prompt_tokens += len(message["content"]) return completion_tokens, prompt_tokens + completion_tokens + + def _get_upload_files(self, session: CozeSession): + session_id = session.get_session_id() + img_cache = memory.USER_IMAGE_CACHE.get(session_id) + if not img_cache or not conf().get("image_recognition"): + return None + coze_client = CozeClient(self.coze_api_key, self.coze_api_base) + msg = img_cache.get("msg") + path = img_cache.get("path") + msg.prepare() + file = coze_client.file_upload(path) + # 清理图片缓存 + memory.USER_IMAGE_CACHE[session_id] = None + + additional_messages = [] + additional_messages.append(coze_client.create_message(file)) + return additional_messages + + def _fill_file_base_url(self, url: str): + if url.startswith("https://") or url.startswith("http://"): + return url + return self.coze_api_base + url + + def _download_image(self, url): + try: + pic_res = requests.get(url, stream=True) + pic_res.raise_for_status() + image_storage = io.BytesIO() + size = 0 + for block in pic_res.iter_content(1024): + size += len(block) + image_storage.write(block) + logger.debug(f"[WX] download image success, size={size}, img_url={url}") + image_storage.seek(0) + return image_storage + except Exception as e: + logger.error(f"Error downloading {url}: {e}") + return None + + def _download_file(self, url): + try: + response = requests.get(url) + response.raise_for_status() + parsed_url = urlparse(url) + logger.debug(f"Downloading file from {url}") + url_path = unquote(parsed_url.path) + # 从路径中提取文件名 + file_name = url_path.split('/')[-1] + logger.debug(f"Saving file as {file_name}") + file_path = os.path.join(TmpDir().path(), file_name) + with open(file_path, 'wb') as file: + file.write(response.content) + return file_path + except Exception as e: + logger.error(f"Error downloading {url}: {e}") + return None + + def get_text_reply(self, messages, session: CozeSession): + answer, err = self._get_completion_content(messages) + if err is not None: + return None, err + completion_tokens, total_tokens = self._calc_tokens(session.messages, answer) + Reply(ReplyType.TEXT, answer) + if err is not None: + logger.error("[COZE] reply error={}".format(err)) + return Reply(ReplyType.ERROR, "我暂时遇到了一些问题,请您稍后重试~") + logger.debug( + "[COZE] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( + session.messages, + session.get_session_id(), + answer, + completion_tokens, + ) + ) + return Reply(ReplyType.TEXT, answer), None diff --git a/bot/bytedance/coze_client.py b/bot/bytedance/coze_client.py new file mode 100644 index 000000000..9eb2fac29 --- /dev/null +++ b/bot/bytedance/coze_client.py @@ -0,0 +1,56 @@ +import logging +import os +import time +from typing import List +from bot.bytedance.coze_session import CozeSession +from pathlib import Path +from cozepy import Coze, TokenAuth, Message, File, MessageContentType, MessageRole, MessageObjectString, \ + MessageObjectStringType + + +class CozeClient(object): + def __init__(self, coze_api_key, base_url: str): + self.coze_api_key = coze_api_key + self.base_url = base_url + self.coze = Coze(base_url=base_url, + auth=TokenAuth(token=coze_api_key)) + + def file_upload(self, path: str) -> File: + return self.coze.files.upload(file=Path(path)) + + def _send_chat(self, bot_id: str, + user_id: str, additional_messages: List[Message], session: CozeSession): + conversation_id = None + if session.get_conversation_id() is not None: + conversation_id = session.get_conversation_id() + chat_poll = self.coze.chat.create_and_poll( + bot_id=bot_id, + user_id=user_id, + conversation_id=conversation_id, + additional_messages=additional_messages + ) + message_list = chat_poll.messages + for message in message_list: + logging.debug('got message:', message.content) + return message_list + + def create_chat_message(self, bot_id: str, query: str, additional_messages: List[Message], session: CozeSession): + if additional_messages is None: + additional_messages = [Message.build_user_question_text(query)] + else: + additional_messages.append(Message.build_user_question_text(query)) + return self._send_chat(bot_id, session.get_user_id(), additional_messages, session) + + def create_message(self, file: File) -> Message: + + message_object_string = None + if self.is_image(file.file_name): + message_object_string = MessageObjectString.build_image(file.id) + else: + message_object_string = MessageObjectString.build_file(file.id) + return Message.build_user_question_objects([message_object_string]) + + def is_image(self, filepath: str): + valid_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp'] + extension = os.path.splitext(filepath)[1].lower() + return extension in valid_extensions diff --git a/bot/bytedance/coze_session.py b/bot/bytedance/coze_session.py new file mode 100644 index 000000000..bb7dc8ef9 --- /dev/null +++ b/bot/bytedance/coze_session.py @@ -0,0 +1,119 @@ +from common.expired_dict import ExpiredDict +from config import conf +from common.log import logger + + +class CozeSession(object): + def __init__(self, session_id: str, user_id: str, conversation_id=None, system_prompt=None): + self.__session_id = session_id + self.messages = [] + self.__user_id = user_id + self.__user_message_counter = 0 + self.__conversation_id = conversation_id + if system_prompt is None: + self.system_prompt = conf().get("character_desc", "") + else: + self.system_prompt = system_prompt + + def add_query(self, query): + user_item = {"role": "user", "content": query} + self.messages.append(user_item) + + def add_reply(self, reply): + assistant_item = {"role": "assistant", "content": reply} + self.messages.append(assistant_item) + + def get_session_id(self): + return self.__session_id + + def get_user_id(self): + return self.__user_id + + def get_conversation_id(self): + return self.__conversation_id + + def set_conversation_id(self, conversation_id): + self.__conversation_id = conversation_id + + def get_session(self, session_id, user): + session = self._build_session(session_id, user) + return session + + def _build_session(self, session_id: str, user: str): + """ + 如果session_id不在sessions中,创建一个新的session并添加到sessions中 + """ + if session_id is None: + return self.sessioncls(session_id, user) + + if session_id not in self.sessions: + self.sessions[session_id] = self.sessioncls(session_id, user) + session = self.sessions[session_id] + return session + + def count_user_message(self): + if conf().get("coze_conversation_max_messages", 5) <= 0: + # 当设置的最大消息数小于等于0,则不限制 + return + if self.__user_message_counter >= conf().get("coze_conversation_max_messages", 5): + self.__user_message_counter = 0 + # FIXME: coze目前不支持设置历史消息长度,暂时使用超过5条清空会话的策略,缺点是没有滑动窗口,会突然丢失历史消息 + self.__conversation_id = '' + + self.__user_message_counter += 1 + + +class CozeSessionManager(object): + def __init__(self, sessioncls, **session_args): + if conf().get("expires_in_seconds"): + sessions = ExpiredDict(conf().get("expires_in_seconds")) + else: + sessions = dict() + self.sessions = sessions + self.sessioncls = sessioncls + self.session_args = session_args + + def _build_session(self, session_id: str, user_id: str, system_prompt=None): + """ + 如果session_id不在sessions中,创建一个新的session并添加到sessions中 + """ + if session_id is None: + return self.sessioncls(session_id, user_id, system_prompt, **self.session_args) + + if session_id not in self.sessions: + self.sessions[session_id] = self.sessioncls(session_id, user_id, system_prompt, **self.session_args) + session = self.sessions[session_id] + return session + + def session_query(self, query, user_id, session_id): + session = self._build_session(session_id, user_id) + session.add_query(query) + # try: + # max_tokens = conf().get("conversation_max_tokens", 1000) + # total_tokens = session.discard_exceeding(max_tokens, None) + # logger.debug("prompt tokens used={}".format(total_tokens)) + # except Exception as e: + # logger.warning("Exception when counting tokens precisely for prompt: {}".format(str(e))) + return session + + def session_reply(self, reply, user_id, session_id, total_tokens=None): + session = self._build_session(session_id, user_id) + session.add_reply(reply) + try: + max_tokens = conf().get("conversation_max_tokens", 1000) + tokens_cnt = session.discard_exceeding(max_tokens, total_tokens) + logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt)) + except Exception as e: + logger.warning("Exception when counting tokens precisely for session: {}".format(str(e))) + return session + + # def get_session(self, session_id, user_id): + # session = self._build_session(session_id, user_id) + # return session + + def clear_session(self, session_id): + if session_id in self.sessions: + del self.sessions[session_id] + + def clear_all_session(self): + self.sessions.clear() diff --git a/config.py b/config.py index 697b5b785..a920695ca 100644 --- a/config.py +++ b/config.py @@ -107,9 +107,10 @@ "dify_app_type": "chatbot", # dify助手类型 chatbot(对应聊天助手)/agent(对应Agent)/workflow(对应工作流),默认为chatbot "dify_conversation_max_messages": 5, # dify目前不支持设置历史消息长度,暂时使用超过最大消息数清空会话的策略,缺点是没有滑动窗口,会突然丢失历史消息,当设置的值小于等于0,则不限制历史消息长度 # coze配置 - "coze_api_base": "https://api.coze.cn/open_api/v2", + "coze_api_base": "https://api.coze.cn", "coze_api_key": "xxx", "coze_bot_id": "xxx", + "coze_return_show_img": "false", # wework的通用配置 "wework_smart": True, # 配置wework是否使用已登录的企业微信,False为多开 # 语音设置 diff --git a/requirements.txt b/requirements.txt index 45d09cc1c..94b97d018 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ Pillow pre-commit web.py linkai>=0.0.6.0 +cozepy==0.6