Skip to content

Commit

Permalink
[AIC-py][eval] small refactor to ease defining metrics, add tests
Browse files Browse the repository at this point in the history
- Add decorators for metric creation and reimplement some existing ones
- Add a unit test


Test plan:

Existing and new unit tests.
  • Loading branch information
jonathanlastmileai committed Jan 20, 2024
1 parent 204ee2f commit 4f024fa
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 64 deletions.
2 changes: 1 addition & 1 deletion python/src/aiconfig/eval/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _extract_value(input_text_datum: TextBasedInputDatum | None) -> str | None:
case str(input_text):
return input_text
case frozendict():
return json.dumps(input_text_datum.value)
return json.dumps(input_text_datum.value, sort_keys=True)

return [
SampleEvaluationResult(
Expand Down
125 changes: 84 additions & 41 deletions python/src/aiconfig/eval/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import abstractmethod
from dataclasses import dataclass
from functools import partial, total_ordering
from typing import Any, Callable, Generic, Protocol, Type
from typing import Any, Awaitable, Callable, Concatenate, Generic, ParamSpec, Protocol, Type

import lastmile_utils.lib.core.api as core_utils
import nltk
Expand All @@ -29,22 +29,65 @@ async def __call__(self, datum: common.T_Evaluable) -> common.T_MetricValue:
return await self.evaluation_fn(datum)


def _check_substring(
output_datum: str,
substring: str,
#
case_sensitive: bool,
) -> bool:
if case_sensitive:
return substring in output_datum
else:
return substring.lower() in output_datum.lower()
PS = ParamSpec("PS")


async def _calculate_brevity(datum: str) -> int:
if len(datum) == 0:
raise ValueError("Brevity is meaningless for empty string.")
return len(datum)
@core_utils.parametrized
def metric(
parametrized_evaluation_fn: Callable[Concatenate[common.T_Evaluable, PS], common.T_MetricValue],
name: str | None = None,
description: str | None = None,
best_value: common.T_MetricValue | None = None,
worst_value: common.T_MetricValue | None = None,
) -> Callable[PS, Metric[common.T_Evaluable, common.T_MetricValue]]:
name_ = name or parametrized_evaluation_fn.__name__
description_ = description or name_

def _construct(*args: PS.args, **kwargs: PS.kwargs) -> Metric[common.T_Evaluable, common.T_MetricValue]:
async def evaluation_fn(datum: common.T_Evaluable) -> common.T_MetricValue:
return parametrized_evaluation_fn(datum, *args, **kwargs)

return Metric(
evaluation_fn=evaluation_fn,
metric_metadata=common.EvaluationMetricMetadata(
name=name_,
description=description_,
best_value=best_value,
worst_value=worst_value,
extra_metadata=dict(args=args, **kwargs),
),
)

return _construct


@core_utils.parametrized
def metric_async(
parametrized_evaluation_fn: Callable[Concatenate[common.T_Evaluable, PS], Awaitable[common.T_MetricValue]],
name: str | None = None,
description: str | None = None,
best_value: common.T_MetricValue | None = None,
worst_value: common.T_MetricValue | None = None,
) -> Callable[PS, Metric[common.T_Evaluable, common.T_MetricValue]]:
name_ = name or parametrized_evaluation_fn.__name__
description_ = description or name_

def _construct(*args: PS.args, **kwargs: PS.kwargs) -> Metric[common.T_Evaluable, common.T_MetricValue]:
async def evaluation_fn(datum: common.T_Evaluable) -> common.T_MetricValue:
return await parametrized_evaluation_fn(datum, *args, **kwargs)

return Metric(
evaluation_fn=evaluation_fn,
metric_metadata=common.EvaluationMetricMetadata(
name=name_,
description=description_,
best_value=best_value,
worst_value=worst_value,
extra_metadata=dict(args=args, **kwargs),
),
)

return _construct


@dataclass(frozen=True)
Expand Down Expand Up @@ -261,37 +304,37 @@ def make_openai_structured_llm_metric(
raise ValueError(f"Error making metric: {e}")


def substring_match(substring: str, case_sensitive: bool = True) -> Metric[str, bool]:
async def _fn(datum: str) -> bool:
return _check_substring(
output_datum=datum,
substring=substring,
case_sensitive=case_sensitive,
)
# 2. literal metrics

return Metric(
evaluation_fn=_fn,
metric_metadata=common.EvaluationMetricMetadata(
name="substring_match",
description="True (pass) if contains given substring",
best_value=True,
worst_value=False,
extra_metadata=dict(substring=substring, case_sensitive=case_sensitive),
),
)

@metric(
#
description="True (pass) if contains given substring",
best_value=True,
worst_value=False,
)
def substring_match(datum: str, substring: str, case_sensitive: bool = True) -> bool:
if case_sensitive:
return substring in datum
else:
return substring.lower() in datum.lower()

# 2. literal metrics

brevity: Metric[str, int] = Metric(
evaluation_fn=_calculate_brevity,
metric_metadata=common.EvaluationMetricMetadata(
name="brevity",
description="Absolute text length",
best_value=1,
worst_value=sys.maxsize,
),
@metric(
#
description="Absolute text length",
name="brevity",
best_value=1,
worst_value=sys.maxsize,
)
def make_brevity(datum: str):
if len(datum) == 0:
raise ValueError("Brevity is meaningless for empty string.")
return len(datum)


# For backwards-compatibility
brevity = make_brevity()


gpt3_5_text_ratings = make_openai_structured_llm_metric(
Expand Down
58 changes: 37 additions & 21 deletions python/tests/mocks.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,44 @@
from typing import Any
from abc import abstractmethod
from typing import Any, Protocol

from aiconfig.Config import AIConfigRuntime
from aiconfig.model_parser import InferenceOptions


class MockAIConfigRuntime(AIConfigRuntime):
def __init__(self):
class MockRunTextToText(Protocol):
@abstractmethod
async def __call__(self, prompt_name: str, params: dict[str, str]) -> str:
pass

async def run_and_get_output_text(
self,
prompt_name: str,
params: dict[Any, Any] | None = None,
options: InferenceOptions | None = None,
**kwargs, # type: ignore
) -> str:
"""
This overrides the real method for mocking, but the output doesn't matter very much.
We're currently not really testing properties of the output.
We just have to return a string so the tests work.
Real method: https://github.com/lastmile-ai/aiconfig/blob/a4376d1f951e19776633d397a3cda7fa85506eef/python/src/aiconfig/Config.py#L277
"""
params_ = params or {}
assert params_.keys() == {"the_query"}, 'For eval, AIConfig params must have just the key "the_query".'
the_query = params_["the_query"]
return f"output_for_{prompt_name}_the_query_{the_query}"

def make_mock_aiconfig_runtime(mock_run_text_to_text: MockRunTextToText | None = None) -> AIConfigRuntime:
async def _default_mock_run_text_to_text(prompt_name: str, params: dict[str, str]) -> str:
return f"output_for_{prompt_name}_the_query_{params['the_query']}"

mock_run_text_to_text_impl = _default_mock_run_text_to_text if mock_run_text_to_text is None else mock_run_text_to_text

class _MockAIConfigRuntime(AIConfigRuntime):
def __init__(self):
pass

async def run_and_get_output_text(
self,
prompt_name: str,
params: dict[Any, Any] | None = None,
options: InferenceOptions | None = None,
**kwargs, # type: ignore
) -> str:
"""
This overrides the real method for mocking, but the output doesn't matter very much.
We're currently not really testing properties of the output.
We just have to return a string so the tests work.
Real method: https://github.com/lastmile-ai/aiconfig/blob/a4376d1f951e19776633d397a3cda7fa85506eef/python/src/aiconfig/Config.py#L277
"""
params_ = params or {}
return await mock_run_text_to_text_impl(prompt_name, params_)

return _MockAIConfigRuntime()


MockAIConfigRuntime = lambda: make_mock_aiconfig_runtime()
78 changes: 77 additions & 1 deletion python/tests/test_eval.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import itertools
import json
import logging
import os
from typing import Any
from frozendict import frozendict

import hypothesis
import hypothesis.strategies as st
Expand Down Expand Up @@ -188,6 +190,79 @@ async def test_run_test_suite_with_inputs(data: st.DataObject):
assert False, f"expected Ok, got Err({e})"


@hypothesis.given(st.data())
@pytest.mark.asyncio
async def test_run_test_suite_with_inputs_general_params(data: st.DataObject):
"""In test_run_test_suite_outputs_only, we test the user-facing function (e2e)
In this case that's harder because run_test_suite_with_inputs takes
an aiconfig path, not object.
In order to test with a mock AIConfig object, in this test we go one level down
and test run_test_suite_helper().
Also see test_run_with_inputs_sanity_check.
"""
metrics_list = [brevity, substring_match("hello")]
inputs = st.dictionaries(st.text(min_size=1), st.text(min_size=1), min_size=0, max_size=2)
test_pairs = st.tuples(inputs, st.sampled_from(metrics_list))
user_test_suite_with_inputs = data.draw(
st.lists(
test_pairs,
min_size=1,
)
)

async def mock_run_text_to_text(prompt_name: str, params: dict[str, str]) -> str:
return f"{prompt_name}_output." + ",".join(f"{key=};{value=}" for key, value in params.items())

mock_aiconfig = mocks.make_mock_aiconfig_runtime(mock_run_text_to_text)

out = await run_test_suite_helper(
TestSuiteWithInputsSpec(
test_suite=user_test_suite_with_inputs, prompt_name="prompt0", aiconfig=mock_aiconfig, general_settings=TestSuiteGeneralSettings()
)
)

df_out = out.map(text_eval_res_to_df)

match df_out:
case Ok(df):
assert isinstance(df, pd.DataFrame)
assert df.shape[0] == (len(user_test_suite_with_inputs))
assert df.columns.tolist() == [
"input",
"aiconfig_output",
"value",
"metric_id",
"metric_name",
"metric_description",
"best_possible_value",
"worst_possible_value",
]

input_pairs = {
(
#
json.dumps(frozendict(input_datum), sort_keys=True),
metric.metric_metadata.id,
)
for input_datum, metric in user_test_suite_with_inputs
}
result_pairs = set( # type: ignore[no-untyped-call]
df[["input", "metric_id"]].itertuples(index=False, name=None) # type: ignore[no-untyped-call]
)

assert input_pairs == result_pairs, f"fail: {input_pairs=}, {result_pairs=}"

df_brevity = df[df["metric_name"] == "brevity"] # type: ignore
assert (df_brevity["aiconfig_output"].apply(len) == df_brevity["value"]).all() # type: ignore

df_substring = df[df["metric_name"] == "substring_match"] # type: ignore
assert (df_substring["value"].apply(lambda x: x in {False, True})).all() # type: ignore

case Err(e):
assert False, f"expected Ok, got Err({e})"


def _make_mock_nltk_metrics() -> MetricList[str]:
def _mock_get_nltk_polarity_scores(text: str) -> dict[str, float]:
return MOCK_NLTK_SENTIMENT_SCORE_MAPPING[text]
Expand Down Expand Up @@ -246,11 +321,12 @@ async def test_exception_metric(caplog: pytest.LogCaptureFixture):
user_test_suite_outputs_only = list(
itertools.product(
["Hundred Acre Wood", ""],
[metrics.brevity],
[brevity],
)
)
with caplog.at_level(logging.ERROR):
df = await run_test_suite_outputs_only(user_test_suite_outputs_only)
print(df[["metric_name"]])
mapping: dict[str, Any] = df.query("metric_name=='brevity'").set_index("aiconfig_output").value.to_dict() # type: ignore
assert mapping["Hundred Acre Wood"] == 17.0
assert pd.isnull(mapping[""]) # type: ignore
Expand Down

0 comments on commit 4f024fa

Please sign in to comment.