From 10efc14dd246b39edac114d0ca2b34df2f4743bb Mon Sep 17 00:00:00 2001 From: Hansimov <591172499@qq.com> Date: Sun, 10 Dec 2023 00:20:56 +0800 Subject: [PATCH] :gem: [Feature] ConversationStyle: Use Enum for more intuitive, and set default styles --- conversations/__init__.py | 1 + conversations/conversation_connector.py | 9 +++++++-- conversations/conversation_style.py | 10 ++++++++++ networks/chathub_request_payload_constructor.py | 3 ++- 4 files changed, 20 insertions(+), 3 deletions(-) create mode 100644 conversations/conversation_style.py diff --git a/conversations/__init__.py b/conversations/__init__.py index b2a4658..7289e9e 100644 --- a/conversations/__init__.py +++ b/conversations/__init__.py @@ -1,3 +1,4 @@ +from .conversation_style import ConversationStyle from .conversation_connector import ConversationConnector from .conversation_creator import ConversationCreator from .conversation_session import ConversationSession diff --git a/conversations/conversation_connector.py b/conversations/conversation_connector.py index 2b6accc..c78d84f 100644 --- a/conversations/conversation_connector.py +++ b/conversations/conversation_connector.py @@ -8,6 +8,7 @@ MessageParser, OpenaiStreamOutputer, ) +from conversations import ConversationStyle from utils.logger import logger from utils.enver import enver @@ -26,14 +27,18 @@ class ConversationConnector: def __init__( self, - conversation_style: str = "precise", + conversation_style: ConversationStyle = "precise", sec_access_token: str = "", client_id: str = "", conversation_id: str = "", invocation_id: int = 0, cookies={}, ): - self.conversation_style = conversation_style + if conversation_style.lower() not in ConversationStyle.__members__: + self.conversation_style = ConversationStyle.PRECISE.value + else: + self.conversation_style = conversation_style.lower() + self.sec_access_token = sec_access_token self.client_id = client_id self.conversation_id = conversation_id diff --git a/conversations/conversation_style.py b/conversations/conversation_style.py new file mode 100644 index 0000000..18f8ea1 --- /dev/null +++ b/conversations/conversation_style.py @@ -0,0 +1,10 @@ +from enum import Enum + + +class ConversationStyle(Enum): + PRECISE: str = "precise" + BALANCED: str = "balanced" + CREATIVE: str = "creative" + PRECISE_OFFLINE: str = "precise-offline" + BALANCED_OFFLINE: str = "balanced-offline" + CREATIVE_OFFLINE: str = "creative-offline" diff --git a/networks/chathub_request_payload_constructor.py b/networks/chathub_request_payload_constructor.py index aff6414..2f4a9c7 100644 --- a/networks/chathub_request_payload_constructor.py +++ b/networks/chathub_request_payload_constructor.py @@ -1,5 +1,6 @@ import random import uuid +from conversations import ConversationStyle class ChathubRequestPayloadConstructor: @@ -9,7 +10,7 @@ def __init__( client_id: str, conversation_id: str, invocation_id: int = 0, - conversation_style: str = "precise", + conversation_style: ConversationStyle = "precise", ): self.prompt = prompt self.client_id = client_id