From 940aa25e3ec7707aadc1b5e41b56dc84f13d1863 Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Wed, 4 Sep 2024 20:47:07 +0200 Subject: [PATCH] refactor: simplify config, use dataclasses --- config.toml | 31 ++++---- src/rai/rai/utils/model_initialization.py | 89 ++++++++++++++++------- 2 files changed, 81 insertions(+), 39 deletions(-) diff --git a/config.toml b/config.toml index cddf8b14f..579bbe5e2 100644 --- a/config.toml +++ b/config.toml @@ -1,14 +1,19 @@ -[llm] -# Choose the LLM provider: openai, aws, or ollama -vendor = "openai" -simple_model = "gpt-4o-mini" # Faster, less expensive model for straightforward tasks -complex_model = "gpt-4o" # More powerful model for complex reasoning or generation -region_name = "us-east-1" # AWS region (only applicable when vendor is "aws") -base_url = "http://localhost:11434" # Ollama API endpoint (only applicable when vendor is "ollama") +[vendor] +name = "openai" # openai, aws, ollama -[embeddings] -# Choose the embeddings provider: openai, aws, or ollama -vendor = "openai" -model = "text-embedding-ada-002" # Embedding model to use -region_name = "us-east-1" # AWS region (only applicable when vendor is "aws") -base_url = "http://localhost:11434" # Ollama API endpoint (only applicable when vendor is "ollama") +[aws] +simple_model = "anthropic.claude-3-haiku-20240307-v1:0" +complex_model = "anthropic.claude-3-5-sonnet-20240620-v1:0" +embeddings_model = "amazon.titan-embed-text-v1" +region_name = "us-east-1" + +[openai] +simple_model = "gpt-4o-mini" +complex_model = "gpt-4o" +embeddings_model = "text-embedding-ada-002" + +[ollama] +simple_model = "llama3.1" +complex_model = "llama3.1:70b" +embeddings_model = "llama3.1" +base_url = "http://localhost:11434" diff --git a/src/rai/rai/utils/model_initialization.py b/src/rai/rai/utils/model_initialization.py index 355a0e6a0..bdeede0ca 100644 --- a/src/rai/rai/utils/model_initialization.py +++ b/src/rai/rai/utils/model_initialization.py @@ -12,62 +12,99 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Literal, TypedDict +from dataclasses import dataclass +from typing import Literal import tomli -class LLMConfig(TypedDict): - vendor: str +@dataclass +class VendorConfig: + name: str + + +@dataclass +class ModelConfig: simple_model: str complex_model: str - region_name: str - base_url: str + embeddings_model: str -class EmbeddingsConfig(TypedDict): - vendor: str - model: str +@dataclass +class AWSConfig(ModelConfig): region_name: str + + +@dataclass +class OllamaConfig(ModelConfig): base_url: str +@dataclass +class RAIConfig: + vendor: VendorConfig + aws: AWSConfig + openai: ModelConfig + ollama: OllamaConfig + + +def load_config() -> RAIConfig: + with open("config.toml", "rb") as f: + config_dict = tomli.load(f) + return RAIConfig( + vendor=VendorConfig(**config_dict["vendor"]), + aws=AWSConfig(**config_dict["aws"]), + openai=ModelConfig(**config_dict["openai"]), + ollama=OllamaConfig(**config_dict["ollama"]), + ) + + def get_llm_model(model_type: Literal["simple_model", "complex_model"]): - rai_config = tomli.load(open("config.toml", "rb")) - llm_config = LLMConfig(**rai_config["llm"]) - if llm_config["vendor"] == "openai": + config = load_config() + vendor = config.vendor.name + model_config = getattr(config, vendor) + + if vendor == "openai": from langchain_openai import ChatOpenAI - return ChatOpenAI(model=llm_config[model_type]) - elif llm_config["vendor"] == "aws": + return ChatOpenAI(model=getattr(model_config, model_type)) + elif vendor == "aws": from langchain_aws import ChatBedrock - return ChatBedrock(model_id=llm_config[model_type], region_name=llm_config["region_name"]) # type: ignore - elif llm_config["vendor"] == "ollama": + return ChatBedrock( + model_id=getattr(model_config, model_type), + region_name=model_config.region_name, + ) + elif vendor == "ollama": from langchain_ollama import ChatOllama - return ChatOllama(model=llm_config[model_type], base_url=llm_config["base_url"]) + return ChatOllama( + model=getattr(model_config, model_type), base_url=model_config.base_url + ) else: - raise ValueError(f"Unknown LLM vendor: {llm_config['vendor']}") + raise ValueError(f"Unknown LLM vendor: {vendor}") def get_embeddings_model(): - rai_config = tomli.load(open("config.toml", "rb")) - embeddings_config = EmbeddingsConfig(**rai_config["embeddings"]) + config = load_config() + vendor = config.vendor.name + model_config = getattr(config, vendor) - if embeddings_config["vendor"] == "openai": + if vendor == "openai": from langchain_openai import OpenAIEmbeddings - return OpenAIEmbeddings() - elif embeddings_config["vendor"] == "aws": + return OpenAIEmbeddings(model=model_config.embeddings_model) + elif vendor == "aws": from langchain_aws import BedrockEmbeddings - return BedrockEmbeddings(model_id=embeddings_config["model"], region_name=embeddings_config["region_name"]) # type: ignore - elif embeddings_config["vendor"] == "ollama": + return BedrockEmbeddings( + model_id=model_config.embeddings_model, region_name=model_config.region_name + ) + elif vendor == "ollama": from langchain_ollama import OllamaEmbeddings return OllamaEmbeddings( - model=embeddings_config["model"], base_url=embeddings_config["base_url"] + model=model_config.embeddings_model, base_url=model_config.base_url ) else: - raise ValueError(f"Unknown embeddings vendor: {embeddings_config['vendor']}") + raise ValueError(f"Unknown embeddings vendor: {vendor}")