Skip to content

Commit

Permalink
refactor: simplify config, use dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
maciejmajek committed Sep 4, 2024
1 parent 1ef1e91 commit 940aa25
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 39 deletions.
31 changes: 18 additions & 13 deletions config.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
[llm]
# Choose the LLM provider: openai, aws, or ollama
vendor = "openai"
simple_model = "gpt-4o-mini" # Faster, less expensive model for straightforward tasks
complex_model = "gpt-4o" # More powerful model for complex reasoning or generation
region_name = "us-east-1" # AWS region (only applicable when vendor is "aws")
base_url = "http://localhost:11434" # Ollama API endpoint (only applicable when vendor is "ollama")
[vendor]
name = "openai" # openai, aws, ollama

[embeddings]
# Choose the embeddings provider: openai, aws, or ollama
vendor = "openai"
model = "text-embedding-ada-002" # Embedding model to use
region_name = "us-east-1" # AWS region (only applicable when vendor is "aws")
base_url = "http://localhost:11434" # Ollama API endpoint (only applicable when vendor is "ollama")
[aws]
simple_model = "anthropic.claude-3-haiku-20240307-v1:0"
complex_model = "anthropic.claude-3-5-sonnet-20240620-v1:0"
embeddings_model = "amazon.titan-embed-text-v1"
region_name = "us-east-1"

[openai]
simple_model = "gpt-4o-mini"
complex_model = "gpt-4o"
embeddings_model = "text-embedding-ada-002"

[ollama]
simple_model = "llama3.1"
complex_model = "llama3.1:70b"
embeddings_model = "llama3.1"
base_url = "http://localhost:11434"
89 changes: 63 additions & 26 deletions src/rai/rai/utils/model_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,62 +12,99 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Literal, TypedDict
from dataclasses import dataclass
from typing import Literal

import tomli


class LLMConfig(TypedDict):
vendor: str
@dataclass
class VendorConfig:
name: str


@dataclass
class ModelConfig:
simple_model: str
complex_model: str
region_name: str
base_url: str
embeddings_model: str


class EmbeddingsConfig(TypedDict):
vendor: str
model: str
@dataclass
class AWSConfig(ModelConfig):
region_name: str


@dataclass
class OllamaConfig(ModelConfig):
base_url: str


@dataclass
class RAIConfig:
vendor: VendorConfig
aws: AWSConfig
openai: ModelConfig
ollama: OllamaConfig


def load_config() -> RAIConfig:
with open("config.toml", "rb") as f:
config_dict = tomli.load(f)
return RAIConfig(
vendor=VendorConfig(**config_dict["vendor"]),
aws=AWSConfig(**config_dict["aws"]),
openai=ModelConfig(**config_dict["openai"]),
ollama=OllamaConfig(**config_dict["ollama"]),
)


def get_llm_model(model_type: Literal["simple_model", "complex_model"]):
rai_config = tomli.load(open("config.toml", "rb"))
llm_config = LLMConfig(**rai_config["llm"])
if llm_config["vendor"] == "openai":
config = load_config()
vendor = config.vendor.name
model_config = getattr(config, vendor)

if vendor == "openai":
from langchain_openai import ChatOpenAI

return ChatOpenAI(model=llm_config[model_type])
elif llm_config["vendor"] == "aws":
return ChatOpenAI(model=getattr(model_config, model_type))
elif vendor == "aws":
from langchain_aws import ChatBedrock

return ChatBedrock(model_id=llm_config[model_type], region_name=llm_config["region_name"]) # type: ignore
elif llm_config["vendor"] == "ollama":
return ChatBedrock(
model_id=getattr(model_config, model_type),
region_name=model_config.region_name,
)
elif vendor == "ollama":
from langchain_ollama import ChatOllama

return ChatOllama(model=llm_config[model_type], base_url=llm_config["base_url"])
return ChatOllama(
model=getattr(model_config, model_type), base_url=model_config.base_url
)
else:
raise ValueError(f"Unknown LLM vendor: {llm_config['vendor']}")
raise ValueError(f"Unknown LLM vendor: {vendor}")


def get_embeddings_model():
rai_config = tomli.load(open("config.toml", "rb"))
embeddings_config = EmbeddingsConfig(**rai_config["embeddings"])
config = load_config()
vendor = config.vendor.name
model_config = getattr(config, vendor)

if embeddings_config["vendor"] == "openai":
if vendor == "openai":
from langchain_openai import OpenAIEmbeddings

return OpenAIEmbeddings()
elif embeddings_config["vendor"] == "aws":
return OpenAIEmbeddings(model=model_config.embeddings_model)
elif vendor == "aws":
from langchain_aws import BedrockEmbeddings

return BedrockEmbeddings(model_id=embeddings_config["model"], region_name=embeddings_config["region_name"]) # type: ignore
elif embeddings_config["vendor"] == "ollama":
return BedrockEmbeddings(
model_id=model_config.embeddings_model, region_name=model_config.region_name
)
elif vendor == "ollama":
from langchain_ollama import OllamaEmbeddings

return OllamaEmbeddings(
model=embeddings_config["model"], base_url=embeddings_config["base_url"]
model=model_config.embeddings_model, base_url=model_config.base_url
)
else:
raise ValueError(f"Unknown embeddings vendor: {embeddings_config['vendor']}")
raise ValueError(f"Unknown embeddings vendor: {vendor}")

0 comments on commit 940aa25

Please sign in to comment.