Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
scosman committed Dec 14, 2024
1 parent 1c9d767 commit 13ebbee
Showing 1 changed file with 203 additions and 1 deletion.
204 changes: 203 additions & 1 deletion libs/core/kiln_ai/adapters/test_provider_tools.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,33 @@
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, Mock, patch

import pytest

from kiln_ai.adapters.ml_model_list import (
KilnModel,
ModelName,
ModelProviderName,
)
from kiln_ai.adapters.ollama_tools import OllamaConnection
from kiln_ai.adapters.provider_tools import (
builtin_model_from,
check_provider_warnings,
finetune_cache,
finetune_provider_model,
get_model_and_provider,
kiln_model_provider_from,
provider_enabled,
provider_name_from_id,
provider_options_for_custom_model,
provider_warnings,
)
from kiln_ai.datamodel import Finetune, Task


@pytest.fixture(autouse=True)
def clear_finetune_cache():
"""Clear the finetune provider model cache before each test"""
finetune_cache.clear()
yield


@pytest.fixture
Expand All @@ -24,6 +36,34 @@ def mock_config():
yield mock


@pytest.fixture
def mock_project():
with patch("kiln_ai.adapters.provider_tools.project_from_id") as mock:
project = Mock()
project.path = "/fake/path"
mock.return_value = project
yield mock


@pytest.fixture
def mock_task():
with patch("kiln_ai.datamodel.Task.from_id_and_parent_path") as mock:
task = Mock(spec=Task)
task.path = "/fake/path/task"
mock.return_value = task
yield mock


@pytest.fixture
def mock_finetune():
with patch("kiln_ai.datamodel.Finetune.from_id_and_parent_path") as mock:
finetune = Mock(spec=Finetune)
finetune.provider = ModelProviderName.openai
finetune.fine_tune_model_id = "ft:gpt-3.5-turbo:custom:model-123"
mock.return_value = finetune
yield mock


def test_check_provider_warnings_no_warning(mock_config):
mock_config.return_value = "some_value"

Expand Down Expand Up @@ -103,6 +143,8 @@ def test_provider_name_from_id_case_sensitivity():
(ModelProviderName.ollama, "Ollama"),
(ModelProviderName.openai, "OpenAI"),
(ModelProviderName.fireworks_ai, "Fireworks AI"),
(ModelProviderName.kiln_fine_tune, "Fine Tuned Models"),
(ModelProviderName.kiln_custom_registry, "Custom Models"),
],
)
def test_provider_name_from_id_parametrized(provider_id, expected_name):
Expand Down Expand Up @@ -327,3 +369,163 @@ async def test_kiln_model_provider_from_custom_registry(mock_config):
assert provider.supports_data_gen is False
assert provider.untested_model is True
assert provider.provider_options == {"model": "gpt-4-turbo"}


@pytest.mark.asyncio
async def test_builtin_model_from_invalid_model():
"""Test that an invalid model name returns None"""
result = await builtin_model_from("non_existent_model")
assert result is None


@pytest.mark.asyncio
async def test_builtin_model_from_valid_model_default_provider(mock_config):
"""Test getting a valid model with default provider"""
mock_config.return_value = "fake-api-key"

provider = await builtin_model_from(ModelName.phi_3_5)

assert provider is not None
assert provider.name == ModelProviderName.ollama
assert provider.provider_options["model"] == "phi3.5"


@pytest.mark.asyncio
async def test_builtin_model_from_valid_model_specific_provider(mock_config):
"""Test getting a valid model with specific provider"""
mock_config.return_value = "fake-api-key"

provider = await builtin_model_from(
ModelName.llama_3_1_70b, provider_name=ModelProviderName.groq
)

assert provider is not None
assert provider.name == ModelProviderName.groq
assert provider.provider_options["model"] == "llama-3.1-70b-versatile"


@pytest.mark.asyncio
async def test_builtin_model_from_invalid_provider(mock_config):
"""Test that requesting an invalid provider returns None"""
mock_config.return_value = "fake-api-key"

provider = await builtin_model_from(
ModelName.phi_3_5, provider_name="invalid_provider"
)

assert provider is None


@pytest.mark.asyncio
async def test_builtin_model_from_model_no_providers():
"""Test handling of a model with no providers"""
with patch("kiln_ai.adapters.provider_tools.built_in_models") as mock_models:
# Create a mock model with no providers
mock_model = KilnModel(
name=ModelName.phi_3_5,
friendly_name="Test Model",
providers=[],
family="test_family",
)
mock_models.__iter__.return_value = [mock_model]

with pytest.raises(ValueError) as exc_info:
await builtin_model_from(ModelName.phi_3_5)

assert str(exc_info.value) == f"Model {ModelName.phi_3_5} has no providers"


@pytest.mark.asyncio
async def test_builtin_model_from_provider_warning_check(mock_config):
"""Test that provider warnings are checked"""
# Make the config check fail
mock_config.return_value = None

with pytest.raises(ValueError) as exc_info:
await builtin_model_from(ModelName.llama_3_1_70b, ModelProviderName.groq)

assert provider_warnings[ModelProviderName.groq].message in str(exc_info.value)


def test_finetune_provider_model_success(mock_project, mock_task, mock_finetune):
"""Test successful creation of a fine-tuned model provider"""
model_id = "project-123::task-456::finetune-789"

provider = finetune_provider_model(model_id)

assert provider.name == ModelProviderName.openai
assert provider.provider_options == {"model": "ft:gpt-3.5-turbo:custom:model-123"}

# Test cache
cached_provider = finetune_provider_model(model_id)
assert cached_provider is provider


def test_finetune_provider_model_invalid_id():
"""Test handling of invalid model ID format"""
with pytest.raises(ValueError) as exc_info:
finetune_provider_model("invalid-id-format")
assert str(exc_info.value) == "Invalid fine tune ID: invalid-id-format"


def test_finetune_provider_model_project_not_found(mock_project):
"""Test handling of non-existent project"""
mock_project.return_value = None

with pytest.raises(ValueError) as exc_info:
finetune_provider_model("project-123::task-456::finetune-789")
assert str(exc_info.value) == "Project project-123 not found"


def test_finetune_provider_model_task_not_found(mock_project, mock_task):
"""Test handling of non-existent task"""
mock_task.return_value = None

with pytest.raises(ValueError) as exc_info:
finetune_provider_model("project-123::task-456::finetune-789")
assert str(exc_info.value) == "Task task-456 not found"


def test_finetune_provider_model_finetune_not_found(
mock_project, mock_task, mock_finetune
):
"""Test handling of non-existent fine-tune"""
mock_finetune.return_value = None

with pytest.raises(ValueError) as exc_info:
finetune_provider_model("project-123::task-456::finetune-789")
assert str(exc_info.value) == "Fine tune finetune-789 not found"


def test_finetune_provider_model_incomplete_finetune(
mock_project, mock_task, mock_finetune
):
"""Test handling of incomplete fine-tune"""
finetune = Mock(spec=Finetune)
finetune.fine_tune_model_id = None
mock_finetune.return_value = finetune

with pytest.raises(ValueError) as exc_info:
finetune_provider_model("project-123::task-456::finetune-789")
assert (
str(exc_info.value)
== "Fine tune finetune-789 not completed. Refresh it's status in the fine-tune tab."
)


def test_finetune_provider_model_fireworks_provider(
mock_project, mock_task, mock_finetune
):
"""Test creation of Fireworks AI provider with specific adapter options"""
finetune = Mock(spec=Finetune)
finetune.provider = ModelProviderName.fireworks_ai
finetune.fine_tune_model_id = "fireworks-model-123"
mock_finetune.return_value = finetune

provider = finetune_provider_model("project-123::task-456::finetune-789")

assert provider.name == ModelProviderName.fireworks_ai
assert provider.provider_options == {"model": "fireworks-model-123"}
assert provider.adapter_options == {
"langchain": {"with_structured_output_options": {"method": "json_mode"}}
}

0 comments on commit 13ebbee

Please sign in to comment.