Skip to content

Commit

Permalink
Rename model_*, pydantic reserves that namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
scosman committed Sep 3, 2024
1 parent 4e56611 commit 396e6bd
Showing 1 changed file with 23 additions and 25 deletions.
48 changes: 23 additions & 25 deletions libs/core/kiln_ai/adapters/ml_model_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,17 @@ class KilnModelProvider(BaseModel):


class KilnModel(BaseModel):
model_family: str
model_name: str
family: str
name: str
providers: List[KilnModelProvider]
supports_structured_output: bool = True


built_in_models: List[KilnModel] = [
# GPT 4o Mini
KilnModel(
model_family=ModelFamily.gpt,
model_name=ModelName.gpt_4o_mini,
family=ModelFamily.gpt,
name=ModelName.gpt_4o_mini,
providers=[
KilnModelProvider(
name=ModelProviderName.openai,
Expand All @@ -62,8 +62,8 @@ class KilnModel(BaseModel):
),
# GPT 4o
KilnModel(
model_family=ModelFamily.gpt,
model_name=ModelName.gpt_4o,
family=ModelFamily.gpt,
name=ModelName.gpt_4o,
providers=[
KilnModelProvider(
name=ModelProviderName.openai,
Expand All @@ -73,8 +73,8 @@ class KilnModel(BaseModel):
),
# Llama 3.1-8b
KilnModel(
model_family=ModelFamily.llama,
model_name=ModelName.llama_3_1_8b,
family=ModelFamily.llama,
name=ModelName.llama_3_1_8b,
providers=[
KilnModelProvider(
name=ModelProviderName.groq,
Expand All @@ -97,8 +97,8 @@ class KilnModel(BaseModel):
),
# Llama 3.1 70b
KilnModel(
model_family=ModelFamily.llama,
model_name=ModelName.llama_3_1_70b,
family=ModelFamily.llama,
name=ModelName.llama_3_1_70b,
providers=[
KilnModelProvider(
name=ModelProviderName.groq,
Expand All @@ -122,8 +122,8 @@ class KilnModel(BaseModel):
),
# Mistral Large
KilnModel(
model_family=ModelFamily.mistral,
model_name=ModelName.mistral_large,
family=ModelFamily.mistral,
name=ModelName.mistral_large,
providers=[
KilnModelProvider(
name=ModelProviderName.amazon_bedrock,
Expand All @@ -141,8 +141,8 @@ class KilnModel(BaseModel):
),
# Phi 3.5
KilnModel(
model_family=ModelFamily.phi,
model_name=ModelName.phi_3_5,
family=ModelFamily.phi,
name=ModelName.phi_3_5,
supports_structured_output=False,
providers=[
KilnModelProvider(
Expand All @@ -154,21 +154,19 @@ class KilnModel(BaseModel):
]


def langchain_model_from(
model_name: str, provider_name: str | None = None
) -> BaseChatModel:
if model_name not in ModelName.__members__:
raise ValueError(f"Invalid model_name: {model_name}")
def langchain_model_from(name: str, provider_name: str | None = None) -> BaseChatModel:
if name not in ModelName.__members__:
raise ValueError(f"Invalid name: {name}")

# Select the model from built_in_models using the model_name
model = next(filter(lambda m: m.model_name == model_name, built_in_models))
# Select the model from built_in_models using the name
model = next(filter(lambda m: m.name == name, built_in_models))
if model is None:
raise ValueError(f"Model {model_name} not found")
raise ValueError(f"Model {name} not found")

# If a provider is provided, select the provider from the model's provider_config
provider: KilnModelProvider | None = None
if model.providers is None or len(model.providers) == 0:
raise ValueError(f"Model {model_name} has no providers")
raise ValueError(f"Model {name} has no providers")
elif provider_name is None:
# TODO: priority order
provider = model.providers[0]
Expand All @@ -177,7 +175,7 @@ def langchain_model_from(
filter(lambda p: p.name == provider_name, model.providers), None
)
if provider is None:
raise ValueError(f"Provider {provider_name} not found for model {model_name}")
raise ValueError(f"Provider {provider_name} not found for model {name}")

if provider.name == ModelProviderName.openai:
return ChatOpenAI(**provider.provider_options)
Expand All @@ -188,7 +186,7 @@ def langchain_model_from(
elif provider.name == ModelProviderName.ollama:
return ChatOllama(**provider.provider_options, base_url=ollama_base_url())
else:
raise ValueError(f"Invalid model or provider: {model_name} - {provider_name}")
raise ValueError(f"Invalid model or provider: {name} - {provider_name}")


def ollama_base_url():
Expand Down

0 comments on commit 396e6bd

Please sign in to comment.