From 4bd5cb309a793e868da2f3eed31c03877f8a822f Mon Sep 17 00:00:00 2001 From: scosman Date: Sat, 14 Dec 2024 10:06:36 -0500 Subject: [PATCH] Add tests --- .coveragerc | 1 + .../adapters/test_langchain_adapter.py | 183 ++++++++++++++++++ 2 files changed, 184 insertions(+) diff --git a/.coveragerc b/.coveragerc index b6d0044a..6eea2f82 100644 --- a/.coveragerc +++ b/.coveragerc @@ -2,3 +2,4 @@ omit = **/test_*.py libs/core/kiln_ai/adapters/ml_model_list.py + conftest.py diff --git a/libs/core/kiln_ai/adapters/test_langchain_adapter.py b/libs/core/kiln_ai/adapters/test_langchain_adapter.py index 21b7dfce..40052847 100644 --- a/libs/core/kiln_ai/adapters/test_langchain_adapter.py +++ b/libs/core/kiln_ai/adapters/test_langchain_adapter.py @@ -1,12 +1,20 @@ +import os from unittest.mock import AsyncMock, MagicMock, patch +import pytest +from langchain_aws import ChatBedrockConverse from langchain_core.messages import AIMessage, HumanMessage, SystemMessage +from langchain_fireworks import ChatFireworks from langchain_groq import ChatGroq +from langchain_ollama import ChatOllama +from langchain_openai import ChatOpenAI from kiln_ai.adapters.langchain_adapters import ( LangchainAdapter, get_structured_output_options, + langchain_model_from_provider, ) +from kiln_ai.adapters.ml_model_list import KilnModelProvider, ModelProviderName from kiln_ai.adapters.prompt_builders import SimpleChainOfThoughtPromptBuilder from kiln_ai.adapters.test_prompt_adaptors import build_test_task @@ -150,3 +158,178 @@ async def test_get_structured_output_options(): ): options = await get_structured_output_options("model_name", "provider") assert options == {} + + +@pytest.mark.asyncio +async def test_langchain_model_from_provider_openai(): + provider = KilnModelProvider( + name=ModelProviderName.openai, provider_options={"model": "gpt-4"} + ) + + with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config: + mock_config.return_value.open_ai_api_key = "test_key" + model = await langchain_model_from_provider(provider, "gpt-4") + assert isinstance(model, ChatOpenAI) + assert model.model_name == "gpt-4" + + +@pytest.mark.asyncio +async def test_langchain_model_from_provider_groq(): + provider = KilnModelProvider( + name=ModelProviderName.groq, provider_options={"model": "mixtral-8x7b"} + ) + + with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config: + mock_config.return_value.groq_api_key = "test_key" + model = await langchain_model_from_provider(provider, "mixtral-8x7b") + assert isinstance(model, ChatGroq) + assert model.model_name == "mixtral-8x7b" + + +@pytest.mark.asyncio +async def test_langchain_model_from_provider_bedrock(): + provider = KilnModelProvider( + name=ModelProviderName.amazon_bedrock, + provider_options={"model": "anthropic.claude-v2", "region_name": "us-east-1"}, + ) + + with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config: + mock_config.return_value.bedrock_access_key = "test_access" + mock_config.return_value.bedrock_secret_key = "test_secret" + model = await langchain_model_from_provider(provider, "anthropic.claude-v2") + assert isinstance(model, ChatBedrockConverse) + assert os.environ.get("AWS_ACCESS_KEY_ID") == "test_access" + assert os.environ.get("AWS_SECRET_ACCESS_KEY") == "test_secret" + + +@pytest.mark.asyncio +async def test_langchain_model_from_provider_fireworks(): + provider = KilnModelProvider( + name=ModelProviderName.fireworks_ai, provider_options={"model": "mixtral-8x7b"} + ) + + with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config: + mock_config.return_value.fireworks_api_key = "test_key" + model = await langchain_model_from_provider(provider, "mixtral-8x7b") + assert isinstance(model, ChatFireworks) + + +@pytest.mark.asyncio +async def test_langchain_model_from_provider_ollama(): + provider = KilnModelProvider( + name=ModelProviderName.ollama, + provider_options={"model": "llama2", "model_aliases": ["llama2-uncensored"]}, + ) + + mock_connection = MagicMock() + with ( + patch( + "kiln_ai.adapters.langchain_adapters.get_ollama_connection", + return_value=AsyncMock(return_value=mock_connection), + ), + patch( + "kiln_ai.adapters.langchain_adapters.ollama_model_installed", + return_value=True, + ), + patch( + "kiln_ai.adapters.langchain_adapters.ollama_base_url", + return_value="http://localhost:11434", + ), + ): + model = await langchain_model_from_provider(provider, "llama2") + assert isinstance(model, ChatOllama) + assert model.model == "llama2" + + +@pytest.mark.asyncio +async def test_langchain_model_from_provider_invalid(): + provider = KilnModelProvider.model_construct( + name="invalid_provider", provider_options={} + ) + + with pytest.raises(ValueError, match="Invalid model or provider"): + await langchain_model_from_provider(provider, "test_model") + + +@pytest.mark.asyncio +async def test_langchain_adapter_model_caching(tmp_path): + task = build_test_task(tmp_path) + custom_model = ChatGroq(model="mixtral-8x7b", groq_api_key="test") + + adapter = LangchainAdapter(kiln_task=task, custom_model=custom_model) + + # First call should return the cached model + model1 = await adapter.model() + assert model1 is custom_model + + # Second call should return the same cached instance + model2 = await adapter.model() + assert model2 is model1 + + +@pytest.mark.asyncio +async def test_langchain_adapter_model_structured_output(tmp_path): + task = build_test_task(tmp_path) + task.output_json_schema = """ + { + "type": "object", + "properties": { + "count": {"type": "integer"} + } + } + """ + + mock_model = MagicMock() + mock_model.with_structured_output = MagicMock(return_value="structured_model") + + adapter = LangchainAdapter( + kiln_task=task, model_name="test_model", provider="test_provider" + ) + + with ( + patch( + "kiln_ai.adapters.langchain_adapters.langchain_model_from", + AsyncMock(return_value=mock_model), + ), + patch( + "kiln_ai.adapters.langchain_adapters.get_structured_output_options", + AsyncMock(return_value={"option1": "value1"}), + ), + ): + model = await adapter.model() + + # Verify the model was configured with structured output + mock_model.with_structured_output.assert_called_once_with( + { + "type": "object", + "properties": {"count": {"type": "integer"}}, + "title": "task_response", + "description": "A response from the task", + }, + include_raw=True, + option1="value1", + ) + assert model == "structured_model" + + +@pytest.mark.asyncio +async def test_langchain_adapter_model_no_structured_output_support(tmp_path): + task = build_test_task(tmp_path) + task.output_json_schema = ( + '{"type": "object", "properties": {"count": {"type": "integer"}}}' + ) + + mock_model = MagicMock() + # Remove with_structured_output method + del mock_model.with_structured_output + + adapter = LangchainAdapter( + kiln_task=task, model_name="test_model", provider="test_provider" + ) + + with patch( + "kiln_ai.adapters.langchain_adapters.langchain_model_from", + AsyncMock(return_value=mock_model), + ): + with pytest.raises(ValueError, match="does not support structured output"): + await adapter.model()