Skip to content

Commit

Permalink
more tests for COT. This time for the structured path.
Browse files Browse the repository at this point in the history
  • Loading branch information
scosman committed Nov 7, 2024
1 parent 4a6d036 commit 66151a8
Showing 1 changed file with 30 additions and 2 deletions.
32 changes: 30 additions & 2 deletions libs/core/kiln_ai/adapters/test_structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
built_in_models,
ollama_online,
)
from kiln_ai.adapters.prompt_builders import (
BasePromptBuilder,
SimpleChainOfThoughtPromptBuilder,
)
from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
from kiln_ai.datamodel.test_json_schema import json_joke_schema, json_triangle_schema


Expand Down Expand Up @@ -190,7 +195,18 @@ def build_structured_input_test_task(tmp_path: Path):

async def run_structured_input_test(tmp_path: Path, model_name: str, provider: str):
task = build_structured_input_test_task(tmp_path)
a = LangChainPromptAdapter(task, model_name=model_name, provider=provider)
await run_structured_input_task(task, model_name, provider)


async def run_structured_input_task(
task: datamodel.Task,
model_name: str,
provider: str,
pb: BasePromptBuilder | None = None,
):
a = LangChainPromptAdapter(
task, model_name=model_name, provider=provider, prompt_builder=pb
)
with pytest.raises(ValueError):
# not structured input in dictionary
await a.invoke("a=1, b=2, c=3")
Expand All @@ -203,7 +219,10 @@ async def run_structured_input_test(tmp_path: Path, model_name: str, provider: s
assert isinstance(response, str)
assert "[[equilateral]]" in response
adapter_info = a.adapter_info()
assert adapter_info.prompt_builder_name == "SimplePromptBuilder"
expected_pb_name = "simple_prompt_builder"
if pb is not None:
expected_pb_name = pb.__class__.prompt_builder_name()
assert adapter_info.prompt_builder_name == expected_pb_name
assert adapter_info.model_name == model_name
assert adapter_info.model_provider == provider
assert adapter_info.adapter_name == "kiln_langchain_adapter"
Expand All @@ -224,3 +243,12 @@ async def test_all_built_in_models_structured_input(tmp_path):
await run_structured_input_test(tmp_path, model.name, provider.name)
except Exception as e:
raise RuntimeError(f"Error running {model.name} {provider}") from e


@pytest.mark.paid
@pytest.mark.ollama
@pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
async def test_structured_cot_prompt_builder(tmp_path, model_name, provider_name):
task = build_structured_input_test_task(tmp_path)
pb = SimpleChainOfThoughtPromptBuilder(task)
await run_structured_input_task(task, model_name, provider_name, pb)

0 comments on commit 66151a8

Please sign in to comment.