diff --git a/vizro-ai/tests/unit/vizro-ai/dashboard/_response_models/conftest.py b/vizro-ai/tests/unit/vizro-ai/dashboard/_response_models/conftest.py index 2ac65ee70..90a1b613a 100644 --- a/vizro-ai/tests/unit/vizro-ai/dashboard/_response_models/conftest.py +++ b/vizro-ai/tests/unit/vizro-ai/dashboard/_response_models/conftest.py @@ -33,29 +33,13 @@ def fake_llm_layout(): @pytest.fixture def fake_llm_filter(): - response = ['{"column": "a", "targets": ["gdp_chart"]}'] + response = ['{"column": "a", "targets": ["bar_chart"]}'] return MockStructuredOutputLLM(responses=response) -@pytest.fixture -def fake_llm_filter_1(): - response = ['{"column": "country", "targets": ["gdp_chart"]}'] - return MockStructuredOutputLLM(responses=response) - - -@pytest.fixture -def df_cols(): - return ["continent", "country", "population", "gdp"] - - -@pytest.fixture -def df_schema_1(): - return {"continent": "object", "country": "object", "population": "int64", "gdp": "int64"} - - @pytest.fixture def controllable_components(): - return ["gdp_chart"] + return ["bar_chart"] @pytest.fixture @@ -69,8 +53,12 @@ def df(): @pytest.fixture -def df_sample(): - df = pd.DataFrame({"a": [1, 2, 3, 4, 5], "b": [4, 5, 6, 7, 8]}) +def df_cols(): + return ["a", "b"] + + +@pytest.fixture +def df_sample(df): return df.sample(5, replace=True, random_state=19) @@ -82,7 +70,7 @@ def df_schema(): @pytest.fixture def df_metadata(df, df_schema, df_sample): df_metadata = AllDfMetadata({}) - df_metadata.all_df_metadata["gdp_chart"] = DfMetadata( + df_metadata.all_df_metadata["bar_chart"] = DfMetadata( df_schema=df_schema, df=df, df_sample=df_sample, @@ -118,6 +106,6 @@ def page_plan(component_card): @pytest.fixture def filter_prompt(): return """ - Create a filter from the following instructions: Filter the gdp chart by country. + Create a filter from the following instructions: Filter the bar chart by column `a`. Do not make up things that are optional and DO NOT configure actions, action triggers or action chains. If no options are specified, leave them out.""" diff --git a/vizro-ai/tests/unit/vizro-ai/dashboard/_response_models/test_controls.py b/vizro-ai/tests/unit/vizro-ai/dashboard/_response_models/test_controls.py index fc8a9e008..60bf7c5c4 100644 --- a/vizro-ai/tests/unit/vizro-ai/dashboard/_response_models/test_controls.py +++ b/vizro-ai/tests/unit/vizro-ai/dashboard/_response_models/test_controls.py @@ -10,32 +10,32 @@ from pydantic import ValidationError # Needed for testing control creation. -model_manager.__setitem__("gdp_chart", VizroBaseModel) +model_manager.__setitem__("bar_chart", VizroBaseModel) class TestFilterProxyCreate: """Tests filter proxy creation.""" def test_create_filter_proxy_validate_targets(self, df_cols, df_schema, controllable_components): - actual = _create_filter_proxy(df_cols, df_schema, controllable_components) + filter_proxy = _create_filter_proxy(df_cols, df_schema, controllable_components) with pytest.raises(ValidationError, match="targets must be one of"): - actual(targets=["population_chart"], column="gdp") + filter_proxy(targets=["population_chart"], column="a") def test_create_filter_proxy_validate_targets_not_empty(self, df_cols, df_schema, controllable_components): - actual = _create_filter_proxy(df_cols=df_cols, df_schema=df_schema, controllable_components=[]) + filter_proxy = _create_filter_proxy(df_cols=df_cols, df_schema=df_schema, controllable_components=[]) with pytest.raises(ValidationError): - actual(targets=[], column="gdp") + filter_proxy(targets=[], column="a") def test_create_filter_proxy_validate_columns(self, df_cols, df_schema, controllable_components): - actual = _create_filter_proxy(df_cols, df_schema, controllable_components) + filter_proxy = _create_filter_proxy(df_cols, df_schema, controllable_components) with pytest.raises(ValidationError, match="column must be one of"): - actual(targets=["gdp_chart"], column="x") + filter_proxy(targets=["bar_chart"], column="x") def test_create_filter_proxy(self, df_cols, df_schema, controllable_components): - actual = _create_filter_proxy(df_cols, df_schema, controllable_components) - actual_filter = actual(targets=["gdp_chart"], column="gdp") + filter_proxy = _create_filter_proxy(df_cols, df_schema, controllable_components) + actual_filter = filter_proxy(targets=["bar_chart"], column="a") - assert actual_filter.dict(exclude={"id": True}) == vm.Filter(targets=["gdp_chart"], column="gdp").dict( + assert actual_filter.dict(exclude={"id": True}) == vm.Filter(targets=["bar_chart"], column="a").dict( exclude={"id": True} ) @@ -50,7 +50,7 @@ def test_control_plan_invalid_df_name(self, fake_llm_filter, df_metadata): df_name="population_chart", ) default_control = control_plan.create( - model=fake_llm_filter, controllable_components=["gdp_chart"], all_df_metadata=df_metadata + model=fake_llm_filter, controllable_components=["bar_chart"], all_df_metadata=df_metadata ) assert default_control is None @@ -59,19 +59,18 @@ def test_control_plan_invalid_type(self, fake_llm_filter, df_metadata): ControlPlan( control_type="parameter", control_description="Create a parameter that targets the data based on the column 'a'.", - df_name="gdp_chart", + df_name="bar_chart", ) -def test_create_filter(filter_prompt, fake_llm_filter_1, df_cols, df_schema_1, controllable_components): - - actual = _create_filter( +def test_create_filter(filter_prompt, fake_llm_filter, df_cols, df_schema, controllable_components): + actual_filter = _create_filter( filter_prompt=filter_prompt, - model=fake_llm_filter_1, + model=fake_llm_filter, df_cols=df_cols, - df_schema=df_schema_1, + df_schema=df_schema, controllable_components=controllable_components, ) - assert actual.dict(exclude={"id": True}) == vm.Filter(targets=["gdp_chart"], column="country").dict( + assert actual_filter.dict(exclude={"id": True}) == vm.Filter(targets=["bar_chart"], column="a").dict( exclude={"id": True} )