diff --git a/vizro-ai/tests/unit/vizro-ai/dashboard/_response_models/conftest.py b/vizro-ai/tests/unit/vizro-ai/dashboard/_response_models/conftest.py index 7533749c7..2ac65ee70 100644 --- a/vizro-ai/tests/unit/vizro-ai/dashboard/_response_models/conftest.py +++ b/vizro-ai/tests/unit/vizro-ai/dashboard/_response_models/conftest.py @@ -9,7 +9,7 @@ from vizro_ai.dashboard.utils import AllDfMetadata, DfMetadata -class FakeListLLM(FakeListLLM): +class MockStructuredOutputLLM(FakeListLLM): def bind_tools(self, tools: List[Any]): return super().bind(tools=tools) @@ -22,19 +22,25 @@ def with_structured_output(self, schema): @pytest.fixture def fake_llm_card(): response = ['{"text":"this is a card","href":""}'] - return FakeListLLM(responses=response) + return MockStructuredOutputLLM(responses=response) @pytest.fixture def fake_llm_layout(): response = ['{"grid":[[0,1]]}'] - return FakeListLLM(responses=response) + return MockStructuredOutputLLM(responses=response) @pytest.fixture def fake_llm_filter(): response = ['{"column": "a", "targets": ["gdp_chart"]}'] - return FakeListLLM(responses=response) + return MockStructuredOutputLLM(responses=response) + + +@pytest.fixture +def fake_llm_filter_1(): + response = ['{"column": "country", "targets": ["gdp_chart"]}'] + return MockStructuredOutputLLM(responses=response) @pytest.fixture @@ -42,6 +48,11 @@ def df_cols(): return ["continent", "country", "population", "gdp"] +@pytest.fixture +def df_schema_1(): + return {"continent": "object", "country": "object", "population": "int64", "gdp": "int64"} + + @pytest.fixture def controllable_components(): return ["gdp_chart"] @@ -59,7 +70,8 @@ def df(): @pytest.fixture def df_sample(): - return pd.DataFrame({"a": [1, 2, 3, 4, 5], "b": [4, 5, 6, 7, 8]}) + df = pd.DataFrame({"a": [1, 2, 3, 4, 5], "b": [4, 5, 6, 7, 8]}) + return df.sample(5, replace=True, random_state=19) @pytest.fixture @@ -101,3 +113,11 @@ def component_card_2(): @pytest.fixture def page_plan(component_card): return PagePlan(title="Test Page", components_plan=[component_card], controls_plan=[], layout_plan=None) + + +@pytest.fixture +def filter_prompt(): + return """ + Create a filter from the following instructions: Filter the gdp chart by country. + Do not make up things that are optional and DO NOT configure actions, action triggers or action chains. + If no options are specified, leave them out.""" diff --git a/vizro-ai/tests/unit/vizro-ai/dashboard/_response_models/test_controls.py b/vizro-ai/tests/unit/vizro-ai/dashboard/_response_models/test_controls.py index fe08e75ed..fc8a9e008 100644 --- a/vizro-ai/tests/unit/vizro-ai/dashboard/_response_models/test_controls.py +++ b/vizro-ai/tests/unit/vizro-ai/dashboard/_response_models/test_controls.py @@ -1,14 +1,20 @@ import pytest -from vizro_ai.dashboard._response_models.controls import ControlPlan, _create_filter_proxy +import vizro.models as vm +from vizro.managers import model_manager +from vizro.models import VizroBaseModel +from vizro_ai.dashboard._response_models.controls import ControlPlan, _create_filter, _create_filter_proxy try: from pydantic.v1 import ValidationError except ImportError: # pragma: no cov from pydantic import ValidationError +# Needed for testing control creation. +model_manager.__setitem__("gdp_chart", VizroBaseModel) -class TestControlCreate: - """Tests control creation.""" + +class TestFilterProxyCreate: + """Tests filter proxy creation.""" def test_create_filter_proxy_validate_targets(self, df_cols, df_schema, controllable_components): actual = _create_filter_proxy(df_cols, df_schema, controllable_components) @@ -25,6 +31,14 @@ def test_create_filter_proxy_validate_columns(self, df_cols, df_schema, controll with pytest.raises(ValidationError, match="column must be one of"): actual(targets=["gdp_chart"], column="x") + def test_create_filter_proxy(self, df_cols, df_schema, controllable_components): + actual = _create_filter_proxy(df_cols, df_schema, controllable_components) + actual_filter = actual(targets=["gdp_chart"], column="gdp") + + assert actual_filter.dict(exclude={"id": True}) == vm.Filter(targets=["gdp_chart"], column="gdp").dict( + exclude={"id": True} + ) + class TestControlPlan: """Test control plan.""" @@ -47,3 +61,17 @@ def test_control_plan_invalid_type(self, fake_llm_filter, df_metadata): control_description="Create a parameter that targets the data based on the column 'a'.", df_name="gdp_chart", ) + + +def test_create_filter(filter_prompt, fake_llm_filter_1, df_cols, df_schema_1, controllable_components): + + actual = _create_filter( + filter_prompt=filter_prompt, + model=fake_llm_filter_1, + df_cols=df_cols, + df_schema=df_schema_1, + controllable_components=controllable_components, + ) + assert actual.dict(exclude={"id": True}) == vm.Filter(targets=["gdp_chart"], column="country").dict( + exclude={"id": True} + ) diff --git a/vizro-ai/tests/unit/vizro-ai/dashboard/_response_models/test_df_info.py b/vizro-ai/tests/unit/vizro-ai/dashboard/_response_models/test_df_info.py index e6eb82a4f..1483a270c 100644 --- a/vizro-ai/tests/unit/vizro-ai/dashboard/_response_models/test_df_info.py +++ b/vizro-ai/tests/unit/vizro-ai/dashboard/_response_models/test_df_info.py @@ -1,7 +1,9 @@ +from pandas.testing import assert_frame_equal from vizro_ai.dashboard._response_models.df_info import _get_df_info -def test_get_df_info(df, df_schema): - actual_df_schema, _ = _get_df_info(df=df) +def test_get_df_info(df, df_schema, df_sample): + actual_df_schema, actual_df_sample = _get_df_info(df=df) assert actual_df_schema == df_schema + assert_frame_equal(actual_df_sample, df_sample) diff --git a/vizro-ai/tests/unit/vizro-ai/dashboard/_response_models/test_layout.py b/vizro-ai/tests/unit/vizro-ai/dashboard/_response_models/test_layout.py index 823887451..4156ebc06 100644 --- a/vizro-ai/tests/unit/vizro-ai/dashboard/_response_models/test_layout.py +++ b/vizro-ai/tests/unit/vizro-ai/dashboard/_response_models/test_layout.py @@ -35,6 +35,11 @@ def test_layout_plan(self): ["card_1", "scatter_plot"], [], ), + ( + ["card_1 scatter_plot scatter_plot", ". scatter_plot scatter_plot"], + ["card_1", "scatter_plot"], + [[0, 1, 1], [-1, 1, 1]], + ), ], ) def test_convert_to_grid(layout_grid_template_areas, component_ids, grid): diff --git a/vizro-ai/tests/unit/vizro-ai/dashboard/conftest.py b/vizro-ai/tests/unit/vizro-ai/dashboard/conftest.py index 1ca1a2e1c..49344675d 100644 --- a/vizro-ai/tests/unit/vizro-ai/dashboard/conftest.py +++ b/vizro-ai/tests/unit/vizro-ai/dashboard/conftest.py @@ -6,7 +6,7 @@ from langchain_core.messages import HumanMessage -class FakeListLLM(FakeListLLM): +class MockStructuredOutputLLM(FakeListLLM): def bind_tools(self, tools: List[Any]): return super().bind(tools=tools) @@ -19,7 +19,7 @@ def with_structured_output(self, schema): @pytest.fixture def fake_llm(): response = ['{"text":"this is a card","href":""}'] - return FakeListLLM(responses=response) + return MockStructuredOutputLLM(responses=response) @pytest.fixture diff --git a/vizro-ai/tests/unit/vizro-ai/dashboard/test_pydantic_output.py b/vizro-ai/tests/unit/vizro-ai/dashboard/test_pydantic_output.py index ee34210fb..008036ded 100644 --- a/vizro-ai/tests/unit/vizro-ai/dashboard/test_pydantic_output.py +++ b/vizro-ai/tests/unit/vizro-ai/dashboard/test_pydantic_output.py @@ -2,7 +2,7 @@ from vizro_ai.dashboard._pydantic_output import _create_message_content, _create_prompt_template, _get_pydantic_model -def test_get_pydantic_output(component_description, fake_llm): +def test_get_pydantic_model(component_description, fake_llm): pydantic_output = _get_pydantic_model( query=component_description, llm_model=fake_llm, response_model=vm.Card, df_info=None )