From 713e6347007d7957e8746e3df836570aa974911d Mon Sep 17 00:00:00 2001 From: tuantran0910 Date: Wed, 8 Jan 2025 01:29:40 +0700 Subject: [PATCH] refactor(backend): refactor provider model and add temperature validation --- chatbot-core/backend/app/models/provider.py | 22 +++++++++++++++++++ chatbot-core/backend/app/services/provider.py | 13 +++++++++++ .../backend/app/utils/api/error_handler.py | 1 + 3 files changed, 36 insertions(+) diff --git a/chatbot-core/backend/app/models/provider.py b/chatbot-core/backend/app/models/provider.py index 837cdb3..21ea66c 100644 --- a/chatbot-core/backend/app/models/provider.py +++ b/chatbot-core/backend/app/models/provider.py @@ -11,6 +11,7 @@ from pydantic import BaseModel from pydantic import Field +from pydantic import field_validator from sqlalchemy import Boolean from sqlalchemy import DateTime from sqlalchemy import Enum as SQLAlchemyEnum @@ -135,6 +136,7 @@ class BaseProviderRequest(BaseModel): Defines the structure of provider data received from the client. """ + name: ProviderType = Field(ProviderType.OPENAI, description="The name of the provider") api_key: Optional[str] = Field(None, description="API key for the provider") models: Optional[List[str]] = Field( default_factory=list, description="The models of the provider." @@ -206,6 +208,26 @@ class LLMProviderRequest(BaseProviderRequest): 0.7, description="The temperature of the LLM provider.", ge=0.0, le=2.0 ) + @field_validator("temperature") + def validate_temperature(cls, value: float) -> float: + """ + Validate that the temperature is between 0.0 and 1.0. + + Args: + value (float): The value of the attribute. + + Returns: + float: The validated temperature + """ + if value < 0.0: + raise ValueError("Temperature must be greater than or equal to 0.0.") + if (cls.name == ProviderType.COHERE or cls.name == ProviderType.OPENAI) and value > 1.0: + raise ValueError("Temperature must be less than or equal to 1.0.") + if cls.name == ProviderType.GEMINI and value > 2.0: + raise ValueError("Temperature must be less than or equal to 2.0.") + + return value + class LLMProviderResponse(BaseProviderResponse): """ diff --git a/chatbot-core/backend/app/services/provider.py b/chatbot-core/backend/app/services/provider.py index 007feb3..cc6a6ae 100644 --- a/chatbot-core/backend/app/services/provider.py +++ b/chatbot-core/backend/app/services/provider.py @@ -14,6 +14,7 @@ from app.services.base import BaseService from app.settings import Secrets from app.utils.api.api_response import APIError +from app.utils.api.error_handler import ErrorCodesMappingNumber from app.utils.api.helpers import get_logger from app.utils.llm.helpers import handle_current_embedding_model from app.utils.llm.helpers import handle_current_llm_model @@ -135,6 +136,12 @@ def update_embedding_provider( with self._transaction(): # Define to-be-updated embedding provider embedding_provider = embedding_provider_request.model_dump(exclude_unset=True) + + # Check whether the provider is changed (we do not allow changing the provider type) + if embedding_provider.get("name") != existing_embedding_provider.name: + return APIError(kind=ErrorCodesMappingNumber.PROVIDER_TYPE_CHANGE_NOT_ALLOWED.value) + + # Parse the models field if embedding_provider.get("models"): embedding_provider["models"] = json.dumps(embedding_provider["models"]) @@ -194,6 +201,12 @@ def update_llm_provider( with self._transaction(): # Define to-be-updated LLM provider llm_provider = llm_provider_request.model_dump(exclude_unset=True) + + # Check whether the provider is changed (we do not allow changing the provider type) + if llm_provider.get("name") != existing_llm_provider.name: + return APIError(kind=ErrorCodesMappingNumber.PROVIDER_TYPE_CHANGE_NOT_ALLOWED.value) + + # Parse the models field if llm_provider.get("models"): llm_provider["models"] = json.dumps(llm_provider["models"]) diff --git a/chatbot-core/backend/app/utils/api/error_handler.py b/chatbot-core/backend/app/utils/api/error_handler.py index 2b5b3c2..a3622a6 100644 --- a/chatbot-core/backend/app/utils/api/error_handler.py +++ b/chatbot-core/backend/app/utils/api/error_handler.py @@ -22,6 +22,7 @@ class ErrorCodesMappingNumber(Enum): LLM_PROVIDER_NOT_FOUND = (422, "LLM provider not found") USER_SETTING_NOT_FOUND = (422, "User setting not found") EMBEDDING_PROVIDER_NOT_FOUND = (422, "Embedding provider not found") + PROVIDER_TYPE_CHANGE_NOT_ALLOWED = (422, "Provider type change not allowed") class BaseException(Exception):