Skip to content

Commit

Permalink
refactor(backend): refactor provider model and add temperature valida…
Browse files Browse the repository at this point in the history
…tion
  • Loading branch information
tuantran0910 committed Jan 7, 2025
1 parent 2a5be47 commit 713e634
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 0 deletions.
22 changes: 22 additions & 0 deletions chatbot-core/backend/app/models/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from pydantic import BaseModel
from pydantic import Field
from pydantic import field_validator
from sqlalchemy import Boolean
from sqlalchemy import DateTime
from sqlalchemy import Enum as SQLAlchemyEnum
Expand Down Expand Up @@ -135,6 +136,7 @@ class BaseProviderRequest(BaseModel):
Defines the structure of provider data received from the client.
"""

name: ProviderType = Field(ProviderType.OPENAI, description="The name of the provider")
api_key: Optional[str] = Field(None, description="API key for the provider")
models: Optional[List[str]] = Field(
default_factory=list, description="The models of the provider."
Expand Down Expand Up @@ -206,6 +208,26 @@ class LLMProviderRequest(BaseProviderRequest):
0.7, description="The temperature of the LLM provider.", ge=0.0, le=2.0
)

@field_validator("temperature")
def validate_temperature(cls, value: float) -> float:
"""
Validate that the temperature is between 0.0 and 1.0.
Args:
value (float): The value of the attribute.
Returns:
float: The validated temperature
"""
if value < 0.0:
raise ValueError("Temperature must be greater than or equal to 0.0.")
if (cls.name == ProviderType.COHERE or cls.name == ProviderType.OPENAI) and value > 1.0:
raise ValueError("Temperature must be less than or equal to 1.0.")
if cls.name == ProviderType.GEMINI and value > 2.0:
raise ValueError("Temperature must be less than or equal to 2.0.")

return value


class LLMProviderResponse(BaseProviderResponse):
"""
Expand Down
13 changes: 13 additions & 0 deletions chatbot-core/backend/app/services/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from app.services.base import BaseService
from app.settings import Secrets
from app.utils.api.api_response import APIError
from app.utils.api.error_handler import ErrorCodesMappingNumber
from app.utils.api.helpers import get_logger
from app.utils.llm.helpers import handle_current_embedding_model
from app.utils.llm.helpers import handle_current_llm_model
Expand Down Expand Up @@ -135,6 +136,12 @@ def update_embedding_provider(
with self._transaction():
# Define to-be-updated embedding provider
embedding_provider = embedding_provider_request.model_dump(exclude_unset=True)

# Check whether the provider is changed (we do not allow changing the provider type)
if embedding_provider.get("name") != existing_embedding_provider.name:
return APIError(kind=ErrorCodesMappingNumber.PROVIDER_TYPE_CHANGE_NOT_ALLOWED.value)

# Parse the models field
if embedding_provider.get("models"):
embedding_provider["models"] = json.dumps(embedding_provider["models"])

Expand Down Expand Up @@ -194,6 +201,12 @@ def update_llm_provider(
with self._transaction():
# Define to-be-updated LLM provider
llm_provider = llm_provider_request.model_dump(exclude_unset=True)

# Check whether the provider is changed (we do not allow changing the provider type)
if llm_provider.get("name") != existing_llm_provider.name:
return APIError(kind=ErrorCodesMappingNumber.PROVIDER_TYPE_CHANGE_NOT_ALLOWED.value)

# Parse the models field
if llm_provider.get("models"):
llm_provider["models"] = json.dumps(llm_provider["models"])

Expand Down
1 change: 1 addition & 0 deletions chatbot-core/backend/app/utils/api/error_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class ErrorCodesMappingNumber(Enum):
LLM_PROVIDER_NOT_FOUND = (422, "LLM provider not found")
USER_SETTING_NOT_FOUND = (422, "User setting not found")
EMBEDDING_PROVIDER_NOT_FOUND = (422, "Embedding provider not found")
PROVIDER_TYPE_CHANGE_NOT_ALLOWED = (422, "Provider type change not allowed")


class BaseException(Exception):
Expand Down

0 comments on commit 713e634

Please sign in to comment.