From 947852234a19c85a365e54be42ebe60590593460 Mon Sep 17 00:00:00 2001 From: Cassius0924 <2670226747@qq.com> Date: Sun, 18 Feb 2024 19:05:50 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E4=BD=BF=E7=94=A8sqlalchemy?= =?UTF-8?q?=E9=87=8D=E6=9E=84=E6=95=B0=E6=8D=AE=E5=BA=93=E6=A8=A1=E5=9D=97?= =?UTF-8?q?=EF=BC=8C=E4=BC=98=E5=8C=96=E8=A1=A8=E7=B1=BB=E5=92=8C=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E7=B1=BB=E4=B9=8B=E9=97=B4=E7=9A=84=E5=85=B3=E7=B3=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- wechatter/app/routers/wechat.py | 25 +- wechatter/commands/_commands/copilot_gpt4.py | 460 +++++++++--------- wechatter/database/tables/gpt_chat_info.py | 61 ++- wechatter/database/tables/gpt_chat_message.py | 34 +- wechatter/database/tables/group.py | 20 +- wechatter/database/tables/message.py | 33 +- wechatter/database/tables/person.py | 37 +- wechatter/message/message_parser.py | 2 +- .../message_forwarder/message_forwarder.py | 15 +- wechatter/models/__init__.py | 6 + wechatter/models/github/pr_webhook.py | 4 +- wechatter/models/gpt/__init__.py | 4 + wechatter/models/gpt/gpt_chat_info.py | 36 +- wechatter/models/gpt/gpt_chat_message.py | 24 +- wechatter/models/wechat/__init__.py | 7 +- wechatter/models/wechat/group.py | 5 +- wechatter/models/wechat/message.py | 105 ++-- wechatter/models/wechat/send_to.py | 16 +- 18 files changed, 480 insertions(+), 414 deletions(-) diff --git a/wechatter/app/routers/wechat.py b/wechatter/app/routers/wechat.py index 3b9e44f..d526eee 100644 --- a/wechatter/app/routers/wechat.py +++ b/wechatter/app/routers/wechat.py @@ -51,24 +51,22 @@ async def recv_wechat_msg( # 解析命令 # 构造消息对象 - message = Message( + message = Message.from_api_msg( type=type, - content_=content, - source_=source, - is_mentioned_=is_mentioned, + content=content, + source=source, + is_mentioned=is_mentioned, ) # 向群组表中添加该群组 - add_group(message.source.g_info) + add_group(message.group) # 向用户表中添加该用户 - add_person(message.source.p_info) + add_person(message.person) # 向消息表中添加该消息 - add_message(message) + message.id = add_message(message) # TODO: 添加自己发送的消息,等待 wechatbot-webhook 支持 # DEBUG - print("==" * 20) print(str(message)) - print("==" * 20) if config.message_forwarding_enabled: MessageForwarder(config.message_forwarding_rule_list).forward_message(message) @@ -109,7 +107,7 @@ def add_group(group: Group) -> None: with make_db_session() as session: _group = session.query(DbGroup).filter(DbGroup.id == group.id).first() if _group is None: - _group = DbGroup.from_group_model(group) + _group = DbGroup.from_model(group) session.add(_group) # 逐个添加群组成员,若存在则更新 for member in group.member_list: @@ -142,7 +140,7 @@ def add_person(person: Person) -> None: with make_db_session() as session: _person = session.query(DbPerson).filter(DbPerson.id == person.id).first() if _person is None: - _person = DbPerson.from_person_model(person) + _person = DbPerson.from_model(person) session.add(_person) session.commit() logger.info(f"用户 {person.name} 已添加到数据库") @@ -152,12 +150,13 @@ def add_person(person: Person) -> None: session.commit() -def add_message(message: Message) -> None: +def add_message(message: Message) -> int: """ 添加消息到消息表 """ with make_db_session() as session: - _message = DbMessage.from_message_model(message) + _message = DbMessage.from_model(message) session.add(_message) session.commit() logger.info(f"消息 {_message.id} 已添加到数据库") + return _message.id diff --git a/wechatter/commands/_commands/copilot_gpt4.py b/wechatter/commands/_commands/copilot_gpt4.py index 43f0f54..2f722e4 100644 --- a/wechatter/commands/_commands/copilot_gpt4.py +++ b/wechatter/commands/_commands/copilot_gpt4.py @@ -7,19 +7,19 @@ import wechatter.utils.path_manager as pm from wechatter.commands.handlers import command from wechatter.database import ( - GptChatInfo, - GptChatMessage, - Message as DbMessage, + GptChatInfo as DbGptChatInfo, + GptChatMessage as DbGptChatMessage, make_db_session, ) -from wechatter.models.wechat import SendTo +from wechatter.models.gpt import GptChatInfo +from wechatter.models.wechat import Person, SendTo from wechatter.sender import sender from wechatter.utils import post_request_json DEFAULT_TOPIC = "(对话进行中*)" # DEFAULT_MODEL = "gpt-4" # TODO: 初始化对话,Prompt选择 -DEFAULT_CONVERSATIONS = [{"role": "system", "content": "你是一位乐于助人的助手"}] +DEFAULT_CONVERSATION = [{"role": "system", "content": "你是一位乐于助人的助手"}] @command( @@ -27,8 +27,8 @@ keys=["gpt"], desc="使用GPT3.5进行对话。", ) -def gpt35_command_handler(to: SendTo, message: str = "") -> None: - _gptx("gpt-3.5-turbo", to, message) +def gpt35_command_handler(to: SendTo, message: str = "", message_obj=None) -> None: + _gptx("gpt-3.5-turbo", to, message, message_obj) @command( @@ -36,8 +36,10 @@ def gpt35_command_handler(to: SendTo, message: str = "") -> None: keys=["gpt-chats", "gpt对话记录"], desc="列出GPT3.5对话记录。", ) -def gpt35_chats_command_handler(to: SendTo, message: str = "") -> None: - _gptx_chats("gpt-3.5-turbo", to, message) +def gpt35_chats_command_handler( + to: SendTo, message: str = "", message_obj=None +) -> None: + _gptx_chats("gpt-3.5-turbo", to, message, message_obj) @command( @@ -45,7 +47,9 @@ def gpt35_chats_command_handler(to: SendTo, message: str = "") -> None: keys=["gpt-record", "gpt记录"], desc="获取GPT3.5对话记录。", ) -def gpt35_record_command_handler(to: SendTo, message: str = "") -> None: +def gpt35_record_command_handler( + to: SendTo, message: str = "", message_obj=None +) -> None: _gptx_record("gpt-3.5-turbo", to, message) @@ -54,7 +58,9 @@ def gpt35_record_command_handler(to: SendTo, message: str = "") -> None: keys=["gpt-continue", "gpt继续"], desc="继续GPT3.5对话。", ) -def gpt35_continue_command_handler(to: SendTo, message: str = "") -> None: +def gpt35_continue_command_handler( + to: SendTo, message: str = "", message_obj=None +) -> None: _gptx_continue("gpt-3.5-turbo", to, message) @@ -63,8 +69,8 @@ def gpt35_continue_command_handler(to: SendTo, message: str = "") -> None: keys=["gpt4"], desc="使用GPT4进行对话。", ) -def gpt4_command_handler(to: SendTo, message: str = "") -> None: - _gptx("gpt-4", to, message) +def gpt4_command_handler(to: SendTo, message: str = "", message_obj=None) -> None: + _gptx("gpt-4", to, message, message_obj) @command( @@ -72,8 +78,8 @@ def gpt4_command_handler(to: SendTo, message: str = "") -> None: keys=["gpt4-chats", "gpt4对话记录"], desc="列出GPT4对话记录。", ) -def gpt4_chats_command_handler(to: SendTo, message: str = "") -> None: - _gptx_chats("gpt-4", to, message) +def gpt4_chats_command_handler(to: SendTo, message: str = "", message_obj=None) -> None: + _gptx_chats("gpt-4", to, message, message_obj) @command( @@ -81,7 +87,9 @@ def gpt4_chats_command_handler(to: SendTo, message: str = "") -> None: keys=["gpt4-record", "gpt4记录"], desc="获取GPT4对话记录。", ) -def gpt4_record_command_handler(to: SendTo, message: str = "") -> None: +def gpt4_record_command_handler( + to: SendTo, message: str = "", message_obj=None +) -> None: _gptx_record("gpt-4", to, message) @@ -90,25 +98,29 @@ def gpt4_record_command_handler(to: SendTo, message: str = "") -> None: keys=["gpt4-continue", "gpt4继续"], desc="继续GPT4对话。", ) -def gpt4_continue_command_handler(to: SendTo, message: str = "") -> None: +def gpt4_continue_command_handler( + to: SendTo, message: str = "", message_obj=None +) -> None: _gptx_continue("gpt-4", to, message) # TODO: # 命令:/gpt4-remove -def gpt4_remove_command_handler(to: SendTo, message: str = "") -> None: +def gpt4_remove_command_handler( + to: SendTo, message: str = "", message_obj=None +) -> None: pass -def _gptx(model: str, to: SendTo, message: str = "") -> None: - wx_id = to.p_id +def _gptx(model: str, to: SendTo, message: str = "", message_obj=None) -> None: + person = to.person # 获取文件夹下最新的对话记录 - id = CopilotGPT4.get_chatting_chat_info(wx_id, model) + chat_info = CopilotGPT4.get_chatting_chat_info(person, model) if message == "": # /gpt4 # 判断对话是否有效 sender.send_msg(to, "正在创建新对话...") - if id is None or CopilotGPT4.is_chat_valid(id): - CopilotGPT4.create_chat(wx_id=wx_id, model=model) + if chat_info is None or CopilotGPT4._is_chat_valid(chat_info): + CopilotGPT4.create_chat(person, model) logger.info("创建新对话成功") sender.send_msg(to, "创建新对话成功") return @@ -117,12 +129,14 @@ def _gptx(model: str, to: SendTo, message: str = "") -> None: else: # /gpt4 # 如果没有对话记录,则创建新对话 sender.send_msg(to, f"正在调用 {model} 进行对话...") - if id is None: - id = CopilotGPT4.create_chat(wx_id=wx_id, model=model) + if chat_info is None: + chat_info = CopilotGPT4.create_chat(person, model) logger.info("无历史对话记录,创建新对话成功") sender.send_msg(to, "无历史对话记录,创建新对话成功") try: - response = CopilotGPT4.chat(id, message) + response = CopilotGPT4.chat( + chat_info, message=message, message_obj=message_obj + ) logger.info(response) sender.send_msg(to, response) except Exception as e: @@ -131,31 +145,30 @@ def _gptx(model: str, to: SendTo, message: str = "") -> None: sender.send_msg(to, error_message) -def _gptx_chats(model: str, to: SendTo, message: str = "") -> None: - response = CopilotGPT4.get_chat_list_str(to.p_id, model) +def _gptx_chats(model: str, to: SendTo, message: str = "", message_obj=None) -> None: + response = CopilotGPT4.get_chat_list_str(to.person, model) sender.send_msg(to, response) -def _gptx_record(model: str, to: SendTo, message: str = "") -> None: - wx_id = to.p_id - id = None +def _gptx_record(model: str, to: SendTo, message: str = ""): + person = to.person if message == "": # 获取当前对话的对话记录 - id = CopilotGPT4.get_chatting_chat_info(wx_id, model) + chat_info = CopilotGPT4.get_chatting_chat_info(person, model) else: # 获取指定对话的对话记录 - id = CopilotGPT4.get_chat_info(wx_id, model, int(message)) - if id is None: + chat_info = CopilotGPT4.get_chat_info(person, model, int(message)) + if chat_info is None: logger.warning("对话不存在") sender.send_msg(to, "对话不存在") return - response = CopilotGPT4.get_brief_conversation_str(id) + response = CopilotGPT4.get_brief_conversation_str(chat_info) logger.info(response) sender.send_msg(to, response) def _gptx_continue(model: str, to: SendTo, message: str = "") -> None: - wx_id = to.p_id + person = to.person # 判断message是否为数字 if not message.isdigit(): logger.info("请输入对话记录编号") @@ -163,7 +176,7 @@ def _gptx_continue(model: str, to: SendTo, message: str = "") -> None: return sender.send_msg(to, f"正在切换到对话记录 {message}...") chat_info = CopilotGPT4.continue_chat( - wx_id=wx_id, model=model, chat_index=int(message) + person=person, model=model, chat_index=int(message) ) if chat_info is None: warning_message = "选择历史对话失败,对话不存在" @@ -183,163 +196,163 @@ class CopilotGPT4: save_path = pm.get_abs_path("data/copilot_gpt4/chats/") @staticmethod - def create_chat(wx_id: str, model: str) -> int: + def create_chat(person: Person, model: str) -> GptChatInfo: """ 创建一个新的对话 + :param person: 用户 + :param model: 模型 + :return: 新的对话信息 """ # 生成上一次对话的主题 - CopilotGPT4._save_chatting_chat_topic(wx_id, model) - CopilotGPT4._set_all_chats_unchatting(wx_id, model) + CopilotGPT4._save_chatting_chat_topic(person, model) + CopilotGPT4._set_all_chats_not_chatting(person, model) gpt_chat_info = GptChatInfo( - person_id=wx_id, + person=person, model=model, - talk_time=datetime.now(), topic=DEFAULT_TOPIC, is_chatting=True, - gpt_chat_messages=[], ) with make_db_session() as session: - session.add(gpt_chat_info) + _gpt_chat_info = DbGptChatInfo.from_model(gpt_chat_info) + session.add(_gpt_chat_info) session.commit() # 获取 SQLite 自动生成的 chat_id - session.refresh(gpt_chat_info) - return gpt_chat_info.id + session.refresh(_gpt_chat_info) + gpt_chat_info = _gpt_chat_info.to_model() + return gpt_chat_info @staticmethod - def continue_chat(wx_id: str, model: str, chat_index: int) -> Union[int, None]: + def continue_chat( + person: Person, model: str, chat_index: int + ) -> Union[GptChatInfo, None]: """ - 继续对话,从对话记录文件中读取对话记录 - :param wx_id: 微信用户ID + 继续对话,选择历史对话 + :param person: 用户 + :param model: 模型 :param chat_index: 对话记录索引(从1开始) - :return: 简略的对话记录 + :return: 对话信息 """ # 读取对话记录文件 - id = CopilotGPT4.get_chat_info(wx_id, model, chat_index) - if id is None: + chat_info = CopilotGPT4.get_chat_info(person, model, chat_index) + if chat_info is None: return None - chatting_chat_info = CopilotGPT4.get_chatting_chat_info(wx_id, model) - if not CopilotGPT4.is_chat_valid(chatting_chat_info): - # 如果对话无效,则删除该对话记录后再继续对话 - CopilotGPT4._delete_chat(wx_id, chatting_chat_info.id) - else: - # 生成上一次对话的主题 - CopilotGPT4._save_chatting_chat_topic(wx_id, model) - CopilotGPT4._set_chatting_chat(wx_id, model, id) - return id + chatting_chat_info = CopilotGPT4.get_chatting_chat_info(person, model) + if chatting_chat_info: + if not CopilotGPT4._is_chat_valid(chatting_chat_info): + # 如果对话无效,则删除该对话记录后再继续对话 + CopilotGPT4._delete_chat(chatting_chat_info) + else: + # 生成上一次对话的主题 + CopilotGPT4._save_chatting_chat_topic(person, model) + CopilotGPT4._set_chatting_chat(person, model, chat_info) + return chat_info @staticmethod - def _set_chatting_chat(wx_id: str, model: str, chat_id: int) -> None: + def _set_chatting_chat(person: Person, model: str, chat_info: GptChatInfo) -> None: """ 设置正在进行中的对话记录 """ - # 先将所有对话记录的 is_chating 字段设置为 False - CopilotGPT4._set_all_chats_unchatting(wx_id, model) + # 先将所有对话记录的 is_chatting 字段设置为 False + CopilotGPT4._set_all_chats_not_chatting(person, model) with make_db_session() as session: - chat_info = session.query(GptChatInfo).filter_by(id=chat_id).first() + chat_info = session.query(DbGptChatInfo).filter_by(id=chat_info.id).first() + if chat_info is None: + logger.error("对话记录不存在") + raise ValueError("对话记录不存在") chat_info.is_chatting = True session.commit() @staticmethod - def _delete_chat(wx_id: str, chat_id: int) -> None: + def _delete_chat(chat_info: GptChatInfo) -> None: """ 删除对话记录 """ with make_db_session() as session: - session.query(GptChatMessage).filter_by(gpt_chat_id=chat_id).delete() - session.query(GptChatInfo).filter_by(id=chat_id).delete() + session.query(DbGptChatMessage).filter_by(gpt_chat_id=chat_info.id).delete() + session.query(DbGptChatInfo).filter_by(id=chat_info.id).delete() session.commit() @staticmethod - def get_brief_conversation_str(chat_info_id: int) -> str: + def get_brief_conversation_str(chat_info: GptChatInfo) -> str: """ 获取对话记录的字符串 + :param chat_info: 对话记录 + :return: 对话记录字符串 """ with make_db_session() as session: - chat_info = session.query(GptChatInfo).filter_by(id=chat_info_id).first() - conversation_str = f"✨==={chat_info.topic}===✨\n" + chat_info = session.query(DbGptChatInfo).filter_by(id=chat_info.id).first() if chat_info is None: - conversation_str += "无对话记录" + logger.error("对话记录不存在") + raise ValueError("对话记录不存在") + conversation_str = f"✨==={chat_info.topic}===✨\n" + if not chat_info.gpt_chat_messages: + conversation_str += " 无对话记录" return conversation_str for msg in chat_info.gpt_chat_messages: - content = msg.message.content[:30] + content: str = msg.message.content + # 合并成一行,提升观感 + content = content.replace("\n", "") + # 去掉命令前缀和命令关键词 + content = content[content.find(" ") + 1 :][:30] + response = msg.gpt_response[:30] + response = response.replace("\n", "") if len(msg.message.content) > 30: content += "..." - if msg.role.value == "system": - conversation_str += f"⭐️:{content}\n" - elif msg.role.value == "assistant": - conversation_str += f"🤖:{content}\n" - elif msg.role.value == "user": - conversation_str += f"💬:{content}\n" + if len(msg.gpt_response) > 30: + response += "..." + conversation_str += f"💬:{content}\n" + conversation_str += f"🤖:{response}\n" return conversation_str - # TODO: 删掉 - @staticmethod - def _get_brief_conversation_content(conversation: List) -> List: - """ - 获取简略的对话记录的内容 - """ - content_list = [] - for conv in conversation[1:]: - if len(conv["content"]) > 20: - conv["content"] = conv["content"][:20] + "..." - content_list.append(conv["content"]) - return content_list - @staticmethod - def _set_all_chats_unchatting(wx_id: str, model: str) -> None: + def _set_all_chats_not_chatting(person: Person, model: str) -> None: """ 将所有对话记录的 is_chatting 字段设置为 False """ with make_db_session() as session: - session.query(GptChatInfo).filter_by(person_id=wx_id, model=model).update( - {"is_chatting": False} - ) + session.query(DbGptChatInfo).filter_by( + person_id=person.id, model=model + ).update({"is_chatting": False}) session.commit() @staticmethod - def is_chat_valid(chat_info_id: int) -> bool: - """ - 判断对话是否有效 - """ - # 通过 conversation 长度判断对话是否有效 - with make_db_session() as session: - chat_info = session.query(GptChatInfo).filter_by(id=chat_info_id).first() - if len(chat_info.gpt_chat_messages) <= 1: - return False - return True - - @staticmethod - def _list_chat_info(wx_id: str, model: str) -> List: + def _list_chat_info(person: Person, model: str) -> List: """ 列出用户的所有对话记录 """ - # 取出id,按照 chat_talk_time 字段倒序排序,取前20个 + # 按照 chat_talk_time 字段倒序排序,取前20个 with make_db_session() as session: chat_info_list = ( - session.query(GptChatInfo.id) - .filter_by(person_id=wx_id, model=model) + session.query(DbGptChatInfo) + .filter_by(person_id=person.id, model=model) .order_by( - GptChatInfo.is_chatting.desc(), - GptChatInfo.talk_time.desc(), + DbGptChatInfo.is_chatting.desc(), + DbGptChatInfo.talk_time.desc(), ) .limit(20) .all() ) - return [chat_info[0] for chat_info in chat_info_list] + _chat_info_list = [] + for chat_info in chat_info_list: + _chat_info_list.append(chat_info.to_model()) + return _chat_info_list @staticmethod - def get_chat_list_str(wx_id: str, model: str) -> str: + def get_chat_list_str(person: Person, model: str) -> str: """ 获取用户的所有对话记录 + :param person: 用户 + :param model: 模型 + :return: 对话记录 """ - chat_info_list = CopilotGPT4._list_chat_info(wx_id, model) - chat_info_list_str = "✨===GPT4对话记录===✨\n" - if chat_info_list == []: + chat_info_list = CopilotGPT4._list_chat_info(person, model) + chat_info_list_str = f"✨==={model}对话记录===✨\n" + if not chat_info_list: chat_info_list_str += " 📭 无对话记录" return chat_info_list_str with make_db_session() as session: - for i, id in enumerate(chat_info_list): - chat = session.query(GptChatInfo).filter_by(id=id).first() + for i, chat_info in enumerate(chat_info_list): + chat = session.query(DbGptChatInfo).filter_by(id=chat_info.id).first() if chat.is_chatting: chat_info_list_str += f"{i + 1}. 💬{chat.topic}\n" else: @@ -347,176 +360,153 @@ def get_chat_list_str(wx_id: str, model: str) -> str: return chat_info_list_str @staticmethod - def _update_chat(chat_info: GptChatInfo, newconv: List = []) -> None: - """保存对话记录 - :param chat_info: 对话记录数据 - :param newconv: 新增对话记录 - """ - # 对话记录格式 - with make_db_session() as session: - chat_info.talk_time = datetime.now() - # session.commit() - for conv in newconv: - wx_message = DbMessage( - person_id=chat_info.person_id, - type="text", - content=conv["content"], - ) - chat_message = GptChatMessage( - gpt_chat_id=chat_info.id, - role=conv["role"], - message=wx_message, - ) - session.add(chat_message) - session.commit() - - @staticmethod - def get_chat_info(wx_id: str, model: str, chat_index: int) -> Union[int, None]: + def get_chat_info( + person: Person, model: str, chat_index: int + ) -> Union[GptChatInfo, None]: """ 获取用户的对话信息 + :param person: 用户 + :param model: 模型 + :param chat_index: 对话记录索引(从1开始) + :return: 对话信息 """ - chat_info_id_list = CopilotGPT4._list_chat_info(wx_id, model) - if chat_info_id_list == []: + chat_info_id_list = CopilotGPT4._list_chat_info(person, model) + if not chat_info_id_list: return None if chat_index <= 0 or chat_index > len(chat_info_id_list): return None return chat_info_id_list[chat_index - 1] @staticmethod - def _get_chat_conversations(chat_id: int) -> List[GptChatMessage]: - """ - 获取对话记录 - """ - with make_db_session() as session: - chat_info = session.query(GptChatInfo).filter_by(id=chat_id).first() - return chat_info.gpt_chat_messages - - @staticmethod - def get_chatting_chat_info(wx_id: str, model: str) -> Union[int, None]: + def get_chatting_chat_info(person: Person, model: str) -> Union[GptChatInfo, None]: """ 获取正在进行中的对话信息 + :param person: 用户 + :param model: 模型 + :return: 对话信息 """ - # 获取对话元信息 with make_db_session() as session: - chat_info_id = ( - session.query(GptChatInfo.id) - .filter_by(person_id=wx_id, model=model, is_chatting=True) + chat_info = ( + session.query(DbGptChatInfo) + .filter_by(person_id=person.id, model=model, is_chatting=True) .first() ) - if chat_info_id is None: + if not chat_info: return None - return chat_info_id[0] + return chat_info.to_model() @staticmethod - def chat(chat_info_id: int, message: str) -> str: + def chat(chat_info: GptChatInfo, message: str, message_obj) -> str: """ 持续对话 + :param chat_info: 对话信息 + :param message: 用户消息 + :param message_obj: 消息对象 + :return: GPT 回复 """ # 对外暴露的对话方法,必须保存对话记录 return CopilotGPT4._chat( - chat_info_id=chat_info_id, message=message, is_save=True + chat_info=chat_info, message=message, message_obj=message_obj, is_save=True ) @staticmethod - def _chat(chat_info_id: int, message: str, is_save: bool) -> str: - """持续对话 + def _chat(chat_info: GptChatInfo, message: str, message_obj, is_save: bool) -> str: + """ + 持续对话 + :param chat_info: 对话信息 :param message: 用户消息 + :param message_obj: 消息对象 :param is_save: 是否保存此轮对话记录 - """ - with make_db_session() as session: - chat_info = session.query(GptChatInfo).filter_by(id=chat_info_id).first() - newconv = [] - newconv.append({"role": "user", "content": message}) - - # 发送请求 - headers = { - "Authorization": CopilotGPT4.bearer_token, - "Content-Type": "application/json", - } - json = { - "model": chat_info.model, - "messages": DEFAULT_CONVERSATIONS - + chat_info.get_conversations() - + newconv, - } - r_json = post_request_json( - url=CopilotGPT4.api, headers=headers, json=json, timeout=60 - ) - - # 判断是否有 error 或 code 字段 - if "error" in r_json or "code" in r_json: - raise ValueError("Copilot-GPT4-Server返回值错误") - - msg = r_json["choices"][0]["message"] - msg_content = msg.get("content", "调用Copilot-GPT4-Server失败") - # 将返回的 assistant 回复添加到对话记录中 - if is_save is True: - newconv.append({"role": "assistant", "content": msg_content}) - chat_info.extend_conversations(newconv) - # CopilotGPT4._update_chat(chat_info, newconv) - chat_info.talk_time = datetime.now() - with make_db_session() as session: - # TODO: ^^^^ - for conv in newconv: - wx_message = DbMessage( - person_id=chat_info.person_id, - type="text", - content=conv["content"], - ) - chat_message = GptChatMessage( - gpt_chat_id=chat_info.id, - role=conv["role"], - message=wx_message, - ) - session.add(chat_message) - session.commit() - return msg_content + :return: GPT 回复 + """ + newconv = [{"role": "user", "content": message}] + # 发送请求 + headers = { + "Authorization": CopilotGPT4.bearer_token, + "Content-Type": "application/json", + } + json = { + "model": chat_info.model, + "messages": DEFAULT_CONVERSATION + chat_info.get_conversation() + newconv, + } + r_json = post_request_json( + url=CopilotGPT4.api, headers=headers, json=json, timeout=60 + ) - @staticmethod - def _has_topic(chat_info_id: int) -> bool: - """ - 判断对话是否有主题 - """ - with make_db_session() as session: - chat_info = session.query(GptChatInfo).filter_by(id=chat_info_id).first() - return chat_info.topic != DEFAULT_TOPIC + # 判断是否有 error 或 code 字段 + if "error" in r_json or "code" in r_json: + raise ValueError("Copilot-GPT4-Server返回值错误") + + msg = r_json["choices"][0]["message"] + msg_content = msg.get("content", "调用Copilot-GPT4-Server失败") + # 将返回的 assistant 回复添加到对话记录中 + if is_save is True: + newconv.append({"role": "assistant", "content": msg_content}) + chat_info.extend_conversation(newconv) + with make_db_session() as session: + _chat_info = ( + session.query(DbGptChatInfo).filter_by(id=chat_info.id).first() + ) + _chat_info.talk_time = datetime.now() + for chat_message in chat_info.gpt_chat_messages[-len(newconv) // 2 :]: + _chat_message = DbGptChatMessage.from_model(chat_message) + _chat_message.message_id = message_obj.id + _chat_info.gpt_chat_messages.append(_chat_message) + session.commit() + return msg_content @staticmethod - def _save_chatting_chat_topic(wx_id: str, model: str) -> None: + def _save_chatting_chat_topic(person: Person, model: str) -> None: """ 生成正在进行的对话的主题 """ - id = CopilotGPT4.get_chatting_chat_info(wx_id, model) - if id is None or CopilotGPT4._has_topic(id): + chat_info = CopilotGPT4.get_chatting_chat_info(person, model) + if chat_info is None or CopilotGPT4._has_topic(chat_info): return # 生成对话主题 - if not CopilotGPT4.is_chat_valid(id): + if not CopilotGPT4._is_chat_valid(chat_info): logger.error("对话记录长度小于1") return - topic = CopilotGPT4._generate_chat_topic(id) - if topic == "": + topic = CopilotGPT4._generate_chat_topic(chat_info) + if not topic: logger.error("生成对话主题失败") raise ValueError("生成对话主题失败") # 更新对话主题 with make_db_session() as session: - chat_info = session.query(GptChatInfo).filter_by(id=id).first() + chat_info = session.query(DbGptChatInfo).filter_by(id=chat_info.id).first() chat_info.topic = topic session.commit() @staticmethod - def _generate_chat_topic(chat_info_id: int) -> str: + def _generate_chat_topic(chat_info: GptChatInfo) -> str: """ 生成对话主题,用于保存对话记录 """ - assert CopilotGPT4.is_chat_valid(chat_info_id) + assert CopilotGPT4._is_chat_valid(chat_info) # 通过一次对话生成对话主题,但这次对话不保存到对话记录中 prompt = "请用10个字以内总结一下这次对话的主题,不带任何标点符号" topic = CopilotGPT4._chat( - chat_info_id=chat_info_id, message=prompt, is_save=False + chat_info=chat_info, message=prompt, message_obj=None, is_save=False ) # 限制主题长度 if len(topic) > 21: topic = topic[:21] + "..." logger.info(f"生成对话主题:{topic}") return topic + + @staticmethod + def _has_topic(chat_info: GptChatInfo) -> bool: + """ + 判断对话是否有主题 + """ + return chat_info.topic != DEFAULT_TOPIC + + @staticmethod + def _is_chat_valid(chat_info: GptChatInfo) -> bool: + """ + 判断对话是否有效 + """ + if chat_info.gpt_chat_messages: + return True + return False diff --git a/wechatter/database/tables/gpt_chat_info.py b/wechatter/database/tables/gpt_chat_info.py index afced82..60a4c9d 100644 --- a/wechatter/database/tables/gpt_chat_info.py +++ b/wechatter/database/tables/gpt_chat_info.py @@ -6,7 +6,10 @@ from wechatter.database.tables import Base from wechatter.database.tables.gpt_chat_message import GptChatMessage -from wechatter.database.tables.message import Message +from wechatter.models.gpt import ( + GptChatInfo as GptChatInfoModel, + GptChatMessage as GptChatMessageModel, +) if TYPE_CHECKING: from wechatter.database.tables.person import Person @@ -22,14 +25,13 @@ class GptChatInfo(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) person_id: Mapped[str] = mapped_column(String, ForeignKey("person.id")) topic: Mapped[str] + model: Mapped[str] created_time: Mapped[datetime] = mapped_column( DateTime(timezone=True), default=datetime.now() ) - # 改名为 updated_time talk_time: Mapped[datetime] = mapped_column( DateTime(timezone=True), onupdate=datetime.now() ) - model: Mapped[str] is_chatting: Mapped[bool] = mapped_column(Boolean, default=True) person: Mapped["Person"] = relationship("Person", back_populates="gpt_chat_infos") @@ -37,18 +39,45 @@ class GptChatInfo(Base): "GptChatMessage", back_populates="gpt_chat_info" ) - def get_conversations(self): - return [message.to_conversation() for message in self.gpt_chat_messages] + @classmethod + def from_model(cls, gpt_chat_info_model: GptChatInfoModel): + gpt_chat_messages = [] + for message in gpt_chat_info_model.gpt_chat_messages: + gpt_chat_messages.append(GptChatMessage.from_model(message)) + + return cls( + id=gpt_chat_info_model.id, + person_id=gpt_chat_info_model.person.id, + topic=gpt_chat_info_model.topic, + model=gpt_chat_info_model.model, + created_time=gpt_chat_info_model.created_time, + talk_time=gpt_chat_info_model.talk_time, + is_chatting=gpt_chat_info_model.is_chatting, + gpt_chat_messages=gpt_chat_messages, + ) - def extend_conversations(self, conversations: List): - conv = [ - GptChatMessage( - gpt_chat_id=self.id, - role=conversation["role"], - message=Message(type="text", content=conversation["content"]), - gpt_chat_info=self, + def to_model(self) -> GptChatInfoModel: + gpt_chat_info = GptChatInfoModel( + id=self.id, + person=self.person.to_model(), + topic=self.topic, + model=self.model, + created_time=self.created_time, + talk_time=self.talk_time, + is_chatting=self.is_chatting, + ) + + gpt_chat_messages = [] + for message in self.gpt_chat_messages: + gpt_chat_messages.append( + GptChatMessageModel( + id=message.id, + message=message.message.to_model(), + gpt_chat_info=gpt_chat_info, + gpt_response=message.gpt_response, + # role=message.role.value, + ) ) - for conversation in conversations - ] - self.gpt_chat_messages.extend(conv) - return self + gpt_chat_info.gpt_chat_messages = gpt_chat_messages + + return gpt_chat_info diff --git a/wechatter/database/tables/gpt_chat_message.py b/wechatter/database/tables/gpt_chat_message.py index 33f49e1..5062a90 100644 --- a/wechatter/database/tables/gpt_chat_message.py +++ b/wechatter/database/tables/gpt_chat_message.py @@ -1,26 +1,16 @@ -import enum from typing import TYPE_CHECKING from sqlalchemy import ForeignKey, Integer from sqlalchemy.orm import Mapped, mapped_column, relationship from wechatter.database.tables import Base +from wechatter.models.gpt import GptChatMessage as GptChatMessageModel if TYPE_CHECKING: from wechatter.database.tables.gpt_chat_info import GptChatInfo from wechatter.database.tables.message import Message -class GptChatRole(enum.Enum): - """ - GPT聊天角色 - """ - - system = "system" - user = "user" - assistant = "assistant" - - class GptChatMessage(Base): """ GPT对话消息表 @@ -33,7 +23,7 @@ class GptChatMessage(Base): Integer, ForeignKey("message.id"), unique=True ) gpt_chat_id: Mapped[int] = mapped_column(Integer, ForeignKey("gpt_chat_info.id")) - role: Mapped[GptChatRole] + gpt_response: Mapped[str] message: Mapped["Message"] = relationship( "Message", back_populates="gpt_chat_message" @@ -42,8 +32,18 @@ class GptChatMessage(Base): "GptChatInfo", back_populates="gpt_chat_messages" ) - def to_conversation(self): - return { - "role": self.role.value, - "content": self.message.content, - } + @classmethod + def from_model(cls, gpt_chat_message_model: GptChatMessageModel): + return cls( + message_id=gpt_chat_message_model.message.id, + gpt_chat_id=gpt_chat_message_model.gpt_chat_info.id, + gpt_response=gpt_chat_message_model.gpt_response, + ) + + def to_model(self) -> GptChatMessageModel: + return GptChatMessageModel( + id=self.id, + message=self.message.to_model(), + gpt_chat_info=self.gpt_chat_info.to_model(), + gp_response=self.gpt_response, + ) diff --git a/wechatter/database/tables/group.py b/wechatter/database/tables/group.py index a4d8b4f..f7b2c14 100644 --- a/wechatter/database/tables/group.py +++ b/wechatter/database/tables/group.py @@ -4,11 +4,11 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship from wechatter.database.tables import Base +from wechatter.models.wechat import Group as GroupModel if TYPE_CHECKING: from wechatter.database.tables.message import Message from wechatter.database.tables.person import Person - from wechatter.models.wechat import Group as GroupModel class Group(Base): @@ -30,11 +30,25 @@ class Group(Base): messages: Mapped[List["Message"]] = relationship("Message", back_populates="group") @classmethod - def from_model(cls, group_model: "GroupModel"): + def from_model(cls, group_model: GroupModel): return cls( id=group_model.id, name=group_model.name, ) - def update(self, group_model: "GroupModel"): + def to_model(self) -> GroupModel: + member_list = [] + for member in self.members: + member_list.append(member.to_model()) + return GroupModel( + id=self.id, + name=self.name, + member_list=member_list, + ) + + def update(self, group_model: GroupModel): self.name = group_model.name + member_list = [] + for member in self.members: + member_list.append(Person.from_member_model(member)) + self.members = member_list diff --git a/wechatter/database/tables/message.py b/wechatter/database/tables/message.py index 80da11c..367fc8d 100644 --- a/wechatter/database/tables/message.py +++ b/wechatter/database/tables/message.py @@ -1,4 +1,3 @@ -import enum from datetime import datetime from typing import TYPE_CHECKING, Union @@ -7,22 +6,11 @@ from wechatter.database.tables import Base from wechatter.database.tables.gpt_chat_message import GptChatMessage +from wechatter.models.wechat import Message as MessageModel, MessageType if TYPE_CHECKING: from wechatter.database.tables.group import Group from wechatter.database.tables.person import Person - from wechatter.models.wechat import Message as MessageModel - - -class MessageType(enum.Enum): - """ - 消息类型 - """ - - text = "text" - file = "file" - urlLink = "urlLink" - friendship = "friendship" class Message(Base): @@ -54,13 +42,14 @@ class Message(Base): ) @classmethod - def from_model(cls, message_model: "MessageModel"): + def from_model(cls, message_model: MessageModel): group_id = None if message_model.is_group: - group_id = message_model.source.g_info.id + group_id = message_model.group.id return cls( - person_id=message_model.source.p_info.id, + id=message_model.id, + person_id=message_model.person.id, group_id=group_id, type=message_model.type.value, content=message_model.content, @@ -68,6 +57,12 @@ def from_model(cls, message_model: "MessageModel"): is_quoted=message_model.is_quoted, ) - # @classmethod - # def create_gpt_chat_message( - # cls, + def to_model(self) -> MessageModel: + return MessageModel( + id=self.id, + type=self.type, + person=self.person.to_model(), + group=self.group.to_model() if self.group else None, + content=self.content, + is_mentioned=self.is_mentioned, + ) diff --git a/wechatter/database/tables/person.py b/wechatter/database/tables/person.py index 2613f02..71d9b8d 100644 --- a/wechatter/database/tables/person.py +++ b/wechatter/database/tables/person.py @@ -1,27 +1,15 @@ -import enum from typing import TYPE_CHECKING, List, Union from sqlalchemy import Boolean, String from sqlalchemy.orm import Mapped, mapped_column, relationship from wechatter.database.tables import Base +from wechatter.models.wechat import Gender, GroupMember, Person as PersonModel if TYPE_CHECKING: from wechatter.database.tables.gpt_chat_info import GptChatInfo from wechatter.database.tables.group import Group from wechatter.database.tables.message import Message - from wechatter.models.wechat import GroupMember, Person as PersonModel - - -class Gender(enum.Enum): - """ - 性别表 - """ - - # 用命名小写,否则sqlalchemy会报错 - male = "male" - female = "female" - unknown = "unknown" class Person(Base): @@ -34,9 +22,9 @@ class Person(Base): id: Mapped[str] = mapped_column(String(100), primary_key=True) name: Mapped[str] alias: Mapped[Union[str, None]] = mapped_column(String, nullable=True) - gender: Mapped[Union[Gender, None]] - province: Mapped[Union[str, None]] - city: Mapped[Union[str, None]] + gender: Mapped[Union[Gender, None]] = mapped_column(String, nullable=True) + province: Mapped[Union[str, None]] = mapped_column(String, nullable=True) + city: Mapped[Union[str, None]] = mapped_column(String, nullable=True) # phone: Mapped[Union[str, None]] = mapped_column(String, nullable=True) is_star: Mapped[bool] = mapped_column(Boolean, default=False) is_friend: Mapped[bool] = mapped_column(Boolean, default=False) @@ -52,7 +40,7 @@ class Person(Base): ) @classmethod - def from_model(cls, person_model: "PersonModel"): + def from_model(cls, person_model: PersonModel): return cls( id=person_model.id, name=person_model.name, @@ -72,7 +60,20 @@ def from_member_model(cls, member_model: "GroupMember"): alias=member_model.alias, ) - def update(self, person_model: "PersonModel"): + def to_model(self) -> PersonModel: + return PersonModel( + id=self.id, + name=self.name, + alias=self.alias, + gender=self.gender, + signature="", + province=self.province, + city=self.city, + is_star=self.is_star, + is_friend=self.is_friend, + ) + + def update(self, person_model: PersonModel): self.name = person_model.name self.alias = person_model.alias self.gender = person_model.gender.value diff --git a/wechatter/message/message_parser.py b/wechatter/message/message_parser.py index cbb201b..963ccf6 100644 --- a/wechatter/message/message_parser.py +++ b/wechatter/message/message_parser.py @@ -41,7 +41,7 @@ def handle_message(self, message: Message) -> None: logger.debug("该消息为群消息,但未@机器人,不处理") return - to = SendTo.from_message_source(message.source) + to = SendTo(person=message.person, group=message.group) # 是命令消息 # 开始处理命令 diff --git a/wechatter/message_forwarder/message_forwarder.py b/wechatter/message_forwarder/message_forwarder.py index 52dcd61..f4e686f 100644 --- a/wechatter/message_forwarder/message_forwarder.py +++ b/wechatter/message_forwarder/message_forwarder.py @@ -15,15 +15,8 @@ def __init__(self, rule_list: List): def forward_message(self, message: Message): """消息转发""" - # 判断消息来源 - from_name = "" - if message.is_group: - from_name = message.source.g_info.name - else: - from_name = message.source.p_info.name - # TODO: 转发文件 - + from_name = message.sender_name # 判断消息是否符合转发规则 for rule in self.rule_list: # 判断消息来源是否符合转发规则 @@ -31,7 +24,7 @@ def forward_message(self, message: Message): # 构造转发消息 msg = self.__construct_forwarding_message(message) logger.info( - f"转发消息:{from_name} -> {rule['to_persons']}\n" + f"转发消息:{from_name} -> {rule['to_persons']};" f"转发消息:{from_name} -> {rule['to_groups']}" ) sender.mass_send_msg(rule["to_persons"], msg) @@ -42,13 +35,13 @@ def __construct_forwarding_message(self, message: Message) -> str: content = message.content if message.is_group: content = ( - f"⤴️ {message.source.p_info.name}在{message.source.g_info.name}中说:\n" + f"⤴️ {message.person.name}在{message.group.name}中说:\n" f"-------------------------\n" f"{content}" ) else: content = ( - f"⤴️ {message.source.p_info.name}说:\n" + f"⤴️ {message.person.name}说:\n" f"-------------------------\n" f"{content}" ) diff --git a/wechatter/models/__init__.py b/wechatter/models/__init__.py index e69de29..4d8af86 100644 --- a/wechatter/models/__init__.py +++ b/wechatter/models/__init__.py @@ -0,0 +1,6 @@ +from .gpt import GptChatInfo, GptChatMessage # noqa: F401 +from .wechat import Message, Person # noqa: F401 + +GptChatInfo.model_rebuild() + +__all__ = [] diff --git a/wechatter/models/github/pr_webhook.py b/wechatter/models/github/pr_webhook.py index 390d4b5..d83cb17 100644 --- a/wechatter/models/github/pr_webhook.py +++ b/wechatter/models/github/pr_webhook.py @@ -34,11 +34,11 @@ class PullRequest(BaseModel): state: str title: str user: User - body: Optional[str] + body: Optional[str] = None base: PrBranch head: PrBranch merged: bool - merged_by: Optional[User] + merged_by: Optional[User] = None class GithubPrWebhook(BaseModel): diff --git a/wechatter/models/gpt/__init__.py b/wechatter/models/gpt/__init__.py index e69de29..2c96029 100644 --- a/wechatter/models/gpt/__init__.py +++ b/wechatter/models/gpt/__init__.py @@ -0,0 +1,4 @@ +from .gpt_chat_info import GptChatInfo +from .gpt_chat_message import GptChatMessage + +__all__ = ["GptChatMessage", "GptChatInfo"] diff --git a/wechatter/models/gpt/gpt_chat_info.py b/wechatter/models/gpt/gpt_chat_info.py index 8aacdb4..b661b94 100644 --- a/wechatter/models/gpt/gpt_chat_info.py +++ b/wechatter/models/gpt/gpt_chat_info.py @@ -1,16 +1,44 @@ from datetime import datetime -from typing import List +from typing import TYPE_CHECKING, List, Optional from pydantic import BaseModel from wechatter.models.gpt.gpt_chat_message import GptChatMessage +from wechatter.models.wechat import Message + +if TYPE_CHECKING: + from wechatter.models.wechat import Person class GptChatInfo(BaseModel): + id: Optional[int] = None + person: "Person" topic: str + model: str created_time: datetime = datetime.now() talk_time: datetime = datetime.now() - model: str is_chatting: bool = True - # user: User - gpt_chat_messages: List[GptChatMessage] + gpt_chat_messages: List[GptChatMessage] = [] + + def get_conversation(self) -> List: + conversation = [] + for message in self.gpt_chat_messages: + conversation.extend(message.to_turn()) + return conversation + + def extend_conversation(self, conversation: List): + conv = [] + for i in range(0, len(conversation) - 1, 2): + conv.append( + GptChatMessage( + message=Message( + type="text", + person=self.person, + content=conversation[i]["content"], + ), + gpt_chat_info=self, + gpt_response=conversation[i + 1]["content"], + ) + ) + self.gpt_chat_messages.extend(conv) + return self diff --git a/wechatter/models/gpt/gpt_chat_message.py b/wechatter/models/gpt/gpt_chat_message.py index 929e81e..b393b7b 100644 --- a/wechatter/models/gpt/gpt_chat_message.py +++ b/wechatter/models/gpt/gpt_chat_message.py @@ -1,8 +1,11 @@ import enum +from typing import TYPE_CHECKING, Optional from pydantic import BaseModel -from wechatter.models.wechat.message import Message +if TYPE_CHECKING: + from wechatter.models.gpt.gpt_chat_info import GptChatInfo + from wechatter.models.wechat import Message class GptChatRole(enum.Enum): @@ -12,6 +15,19 @@ class GptChatRole(enum.Enum): class GptChatMessage(BaseModel): - gpt_chat_id: int - role: GptChatRole - message: Message + id: Optional[int] = None + message: "Message" + gpt_chat_info: "GptChatInfo" + gpt_response: str + + def to_turn(self): + return [ + { + "role": GptChatRole.user.value, + "content": self.message.content, + }, + { + "role": GptChatRole.assistant.value, + "content": self.gpt_response, + }, + ] diff --git a/wechatter/models/wechat/__init__.py b/wechatter/models/wechat/__init__.py index 985c636..571e673 100644 --- a/wechatter/models/wechat/__init__.py +++ b/wechatter/models/wechat/__init__.py @@ -1,13 +1,14 @@ from .group import Group, GroupMember -from .message import Message, MessageSource -from .person import Person +from .message import Message, MessageType +from .person import Gender, Person from .send_to import SendTo __all__ = [ "Message", - "MessageSource", + "MessageType", "SendTo", "Group", "GroupMember", "Person", + "Gender", ] diff --git a/wechatter/models/wechat/group.py b/wechatter/models/wechat/group.py index 33f0134..fd1b8a1 100644 --- a/wechatter/models/wechat/group.py +++ b/wechatter/models/wechat/group.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from pydantic import BaseModel @@ -20,5 +20,6 @@ class Group(BaseModel): id: str name: str - admin_id_list: List[str] + # alias: str 目前上游不支持 + admin_id_list: Optional[List[str]] = None member_list: List[GroupMember] diff --git a/wechatter/models/wechat/message.py b/wechatter/models/wechat/message.py index 6594053..d9e771c 100644 --- a/wechatter/models/wechat/message.py +++ b/wechatter/models/wechat/message.py @@ -3,7 +3,7 @@ import json import re from functools import cached_property -from typing import Union +from typing import Optional from loguru import logger from pydantic import BaseModel, computed_field @@ -35,48 +35,31 @@ class MessageSenderType(enum.Enum): # ARTICLE = 2 -class MessageSource(BaseModel): - """ - 消息来源类 - """ - - p_info: Person - g_info: Union[Group, None] = None - - def __str__(self) -> str: - result = "" - if self.g_info is not None: - result += str(self.g_info) - result += str(self.p_info) - return result - - class Message(BaseModel): """ 微信消息类(消息接收) - :property content: 消息内容 - :property source: 消息来源 - :property is_mentioned: 是否@机器人 - :property is_quoted: 是否引用机器人消息 - :property is_group: 是否是群消息 """ type: MessageType - content_: str - source_: str - is_mentioned_: str - - @computed_field - @cached_property - def content(self) -> str: - # 对于 iPad、手机端的微信,@名称后面会跟着一个未解码的空格的Unicode编码:"@Cassius\u2005/help" - return self.content_.replace("\u2005", " ", 1) - - @computed_field - @cached_property - def source(self) -> MessageSource: + person: Person + group: Optional[Group] = None + content: str + is_mentioned: bool = False + id: Optional[int] = None + + @classmethod + def from_api_msg( + cls, + type: MessageType, + content: str, + source: str, + is_mentioned: str, + ): + """ + 从API接口创建消息对象 + """ try: - source_json = json.loads(self.source_) + source_json = json.loads(source) except json.JSONDecodeError as e: logger.error("消息来源解析失败") raise e @@ -89,7 +72,7 @@ def source(self) -> MessageSource: g = "male" elif gender == 0: g = "female" - p_info = Person( + _person = Person( id=payload.get("id", ""), name=payload.get("name", ""), alias=payload.get("alias", ""), @@ -101,37 +84,37 @@ def source(self) -> MessageSource: is_star=payload.get("star", ""), is_friend=payload.get("friend", ""), ) - message_source = MessageSource(p_info=p_info) + _group = None # room为群信息,只有群消息才有room if source_json["room"] != "": g_data = source_json["room"] payload = g_data.get("payload", {}) - message_source.g_info = Group( + _group = Group( id=g_data.get("id", ""), name=payload.get("topic", ""), admin_id_list=payload.get("adminIdList", []), member_list=payload.get("memberList", []), ) - return message_source - - @computed_field - @cached_property - def is_mentioned(self) -> bool: - """ - 是否@机器人 - """ - if self.is_mentioned_ == "1": - return True - return False + _content = content.replace("\u2005", " ", 1) + _is_mentioned = False + if is_mentioned == "1": + _is_mentioned = True + return cls( + type=type, + person=_person, + group=_group, + content=_content, + is_mentioned=is_mentioned, + ) @computed_field - @cached_property + @property def is_group(self) -> bool: """ 是否是群消息 """ - return bool(self.source.g_info) + return self.group is not None @computed_field @cached_property @@ -143,14 +126,28 @@ def is_quoted(self) -> bool: quote_pattern = r"(?s)「(.*?)」\n- - - - - - - - - - - - - - -" match_result = re.match(quote_pattern, self.content) # 判断是否为引用机器人消息 - if bool(match_result) and self.content.startswith(f"「{config.bot_name}"): + if match_result and self.content.startswith(f"「{config.bot_name}"): return True return False + # TODO: 判断所有的引用消息,不仅仅是机器人消息 + # 待解决:在群中如果有人设置了自己的群中名称,那么引用内容的名字会变化,导致无法匹配到用户 + + @computed_field + @property + def sender_name(self) -> str: + """ + 返回消息发送对象名,如果是群则返回群名,如果不是则返回人名 + """ + return self.group.name if self.is_group else self.person.name + def __str__(self) -> str: + source = self.person + if self.is_group: + source = self.group return ( f"消息内容:{self.content}\n" - f"消息来源:\n{self.source}\n" + f"消息来源:{source}\n" f"是否@:{self.is_mentioned}\n" f"是否引用:{self.is_quoted}" ) diff --git a/wechatter/models/wechat/send_to.py b/wechatter/models/wechat/send_to.py index 8f902d3..f5cc151 100644 --- a/wechatter/models/wechat/send_to.py +++ b/wechatter/models/wechat/send_to.py @@ -1,10 +1,9 @@ -from typing import Union +from typing import Optional from loguru import logger from pydantic import BaseModel, computed_field from wechatter.models.wechat.group import Group -from wechatter.models.wechat.message import MessageSource from wechatter.models.wechat.person import Person @@ -14,7 +13,7 @@ class SendTo(BaseModel): """ person: Person - group: Union[Group, None] + group: Optional[Group] = None @computed_field @property @@ -33,7 +32,7 @@ def p_alias(self) -> str: @computed_field @property - def g_id(self) -> Union[str, None]: + def g_id(self) -> Optional[str]: try: return self.group.id except AttributeError: @@ -42,16 +41,9 @@ def g_id(self) -> Union[str, None]: @computed_field @property - def g_name(self) -> Union[str, None]: + def g_name(self) -> Optional[str]: try: return self.group.name except AttributeError: logger.warning("此发送对象不是群聊") return None - - @classmethod - def from_message_source(cls, source: MessageSource): - return cls( - person=source.p_info, - group=source.g_info, - )