Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
scosman committed Dec 14, 2024
1 parent 13ebbee commit 4bd5cb3
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 0 deletions.
1 change: 1 addition & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
omit =
**/test_*.py
libs/core/kiln_ai/adapters/ml_model_list.py
conftest.py
183 changes: 183 additions & 0 deletions libs/core/kiln_ai/adapters/test_langchain_adapter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import os
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from langchain_aws import ChatBedrockConverse
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_fireworks import ChatFireworks
from langchain_groq import ChatGroq
from langchain_ollama import ChatOllama
from langchain_openai import ChatOpenAI

from kiln_ai.adapters.langchain_adapters import (
LangchainAdapter,
get_structured_output_options,
langchain_model_from_provider,
)
from kiln_ai.adapters.ml_model_list import KilnModelProvider, ModelProviderName
from kiln_ai.adapters.prompt_builders import SimpleChainOfThoughtPromptBuilder
from kiln_ai.adapters.test_prompt_adaptors import build_test_task

Expand Down Expand Up @@ -150,3 +158,178 @@ async def test_get_structured_output_options():
):
options = await get_structured_output_options("model_name", "provider")
assert options == {}


@pytest.mark.asyncio
async def test_langchain_model_from_provider_openai():
provider = KilnModelProvider(
name=ModelProviderName.openai, provider_options={"model": "gpt-4"}
)

with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
mock_config.return_value.open_ai_api_key = "test_key"
model = await langchain_model_from_provider(provider, "gpt-4")
assert isinstance(model, ChatOpenAI)
assert model.model_name == "gpt-4"


@pytest.mark.asyncio
async def test_langchain_model_from_provider_groq():
provider = KilnModelProvider(
name=ModelProviderName.groq, provider_options={"model": "mixtral-8x7b"}
)

with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
mock_config.return_value.groq_api_key = "test_key"
model = await langchain_model_from_provider(provider, "mixtral-8x7b")
assert isinstance(model, ChatGroq)
assert model.model_name == "mixtral-8x7b"


@pytest.mark.asyncio
async def test_langchain_model_from_provider_bedrock():
provider = KilnModelProvider(
name=ModelProviderName.amazon_bedrock,
provider_options={"model": "anthropic.claude-v2", "region_name": "us-east-1"},
)

with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
mock_config.return_value.bedrock_access_key = "test_access"
mock_config.return_value.bedrock_secret_key = "test_secret"
model = await langchain_model_from_provider(provider, "anthropic.claude-v2")
assert isinstance(model, ChatBedrockConverse)
assert os.environ.get("AWS_ACCESS_KEY_ID") == "test_access"
assert os.environ.get("AWS_SECRET_ACCESS_KEY") == "test_secret"


@pytest.mark.asyncio
async def test_langchain_model_from_provider_fireworks():
provider = KilnModelProvider(
name=ModelProviderName.fireworks_ai, provider_options={"model": "mixtral-8x7b"}
)

with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
mock_config.return_value.fireworks_api_key = "test_key"
model = await langchain_model_from_provider(provider, "mixtral-8x7b")
assert isinstance(model, ChatFireworks)


@pytest.mark.asyncio
async def test_langchain_model_from_provider_ollama():
provider = KilnModelProvider(
name=ModelProviderName.ollama,
provider_options={"model": "llama2", "model_aliases": ["llama2-uncensored"]},
)

mock_connection = MagicMock()
with (
patch(
"kiln_ai.adapters.langchain_adapters.get_ollama_connection",
return_value=AsyncMock(return_value=mock_connection),
),
patch(
"kiln_ai.adapters.langchain_adapters.ollama_model_installed",
return_value=True,
),
patch(
"kiln_ai.adapters.langchain_adapters.ollama_base_url",
return_value="http://localhost:11434",
),
):
model = await langchain_model_from_provider(provider, "llama2")
assert isinstance(model, ChatOllama)
assert model.model == "llama2"


@pytest.mark.asyncio
async def test_langchain_model_from_provider_invalid():
provider = KilnModelProvider.model_construct(
name="invalid_provider", provider_options={}
)

with pytest.raises(ValueError, match="Invalid model or provider"):
await langchain_model_from_provider(provider, "test_model")


@pytest.mark.asyncio
async def test_langchain_adapter_model_caching(tmp_path):
task = build_test_task(tmp_path)
custom_model = ChatGroq(model="mixtral-8x7b", groq_api_key="test")

adapter = LangchainAdapter(kiln_task=task, custom_model=custom_model)

# First call should return the cached model
model1 = await adapter.model()
assert model1 is custom_model

# Second call should return the same cached instance
model2 = await adapter.model()
assert model2 is model1


@pytest.mark.asyncio
async def test_langchain_adapter_model_structured_output(tmp_path):
task = build_test_task(tmp_path)
task.output_json_schema = """
{
"type": "object",
"properties": {
"count": {"type": "integer"}
}
}
"""

mock_model = MagicMock()
mock_model.with_structured_output = MagicMock(return_value="structured_model")

adapter = LangchainAdapter(
kiln_task=task, model_name="test_model", provider="test_provider"
)

with (
patch(
"kiln_ai.adapters.langchain_adapters.langchain_model_from",
AsyncMock(return_value=mock_model),
),
patch(
"kiln_ai.adapters.langchain_adapters.get_structured_output_options",
AsyncMock(return_value={"option1": "value1"}),
),
):
model = await adapter.model()

# Verify the model was configured with structured output
mock_model.with_structured_output.assert_called_once_with(
{
"type": "object",
"properties": {"count": {"type": "integer"}},
"title": "task_response",
"description": "A response from the task",
},
include_raw=True,
option1="value1",
)
assert model == "structured_model"


@pytest.mark.asyncio
async def test_langchain_adapter_model_no_structured_output_support(tmp_path):
task = build_test_task(tmp_path)
task.output_json_schema = (
'{"type": "object", "properties": {"count": {"type": "integer"}}}'
)

mock_model = MagicMock()
# Remove with_structured_output method
del mock_model.with_structured_output

adapter = LangchainAdapter(
kiln_task=task, model_name="test_model", provider="test_provider"
)

with patch(
"kiln_ai.adapters.langchain_adapters.langchain_model_from",
AsyncMock(return_value=mock_model),
):
with pytest.raises(ValueError, match="does not support structured output"):
await adapter.model()

0 comments on commit 4bd5cb3

Please sign in to comment.