Skip to content

Commit

Permalink
consolidate and remove unused fixtures
Browse files Browse the repository at this point in the history
  • Loading branch information
nadijagraca committed Aug 5, 2024
1 parent 97e7b35 commit 48789b9
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 40 deletions.
32 changes: 10 additions & 22 deletions vizro-ai/tests/unit/vizro-ai/dashboard/_response_models/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,29 +33,13 @@ def fake_llm_layout():

@pytest.fixture
def fake_llm_filter():
response = ['{"column": "a", "targets": ["gdp_chart"]}']
response = ['{"column": "a", "targets": ["bar_chart"]}']
return MockStructuredOutputLLM(responses=response)


@pytest.fixture
def fake_llm_filter_1():
response = ['{"column": "country", "targets": ["gdp_chart"]}']
return MockStructuredOutputLLM(responses=response)


@pytest.fixture
def df_cols():
return ["continent", "country", "population", "gdp"]


@pytest.fixture
def df_schema_1():
return {"continent": "object", "country": "object", "population": "int64", "gdp": "int64"}


@pytest.fixture
def controllable_components():
return ["gdp_chart"]
return ["bar_chart"]


@pytest.fixture
Expand All @@ -69,8 +53,12 @@ def df():


@pytest.fixture
def df_sample():
df = pd.DataFrame({"a": [1, 2, 3, 4, 5], "b": [4, 5, 6, 7, 8]})
def df_cols():
return ["a", "b"]


@pytest.fixture
def df_sample(df):
return df.sample(5, replace=True, random_state=19)


Expand All @@ -82,7 +70,7 @@ def df_schema():
@pytest.fixture
def df_metadata(df, df_schema, df_sample):
df_metadata = AllDfMetadata({})
df_metadata.all_df_metadata["gdp_chart"] = DfMetadata(
df_metadata.all_df_metadata["bar_chart"] = DfMetadata(
df_schema=df_schema,
df=df,
df_sample=df_sample,
Expand Down Expand Up @@ -118,6 +106,6 @@ def page_plan(component_card):
@pytest.fixture
def filter_prompt():
return """
Create a filter from the following instructions: Filter the gdp chart by country.
Create a filter from the following instructions: Filter the bar chart by column `a`.
Do not make up things that are optional and DO NOT configure actions, action triggers or action chains.
If no options are specified, leave them out."""
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,32 @@
from pydantic import ValidationError

# Needed for testing control creation.
model_manager.__setitem__("gdp_chart", VizroBaseModel)
model_manager.__setitem__("bar_chart", VizroBaseModel)


class TestFilterProxyCreate:
"""Tests filter proxy creation."""

def test_create_filter_proxy_validate_targets(self, df_cols, df_schema, controllable_components):
actual = _create_filter_proxy(df_cols, df_schema, controllable_components)
filter_proxy = _create_filter_proxy(df_cols, df_schema, controllable_components)
with pytest.raises(ValidationError, match="targets must be one of"):
actual(targets=["population_chart"], column="gdp")
filter_proxy(targets=["population_chart"], column="a")

def test_create_filter_proxy_validate_targets_not_empty(self, df_cols, df_schema, controllable_components):
actual = _create_filter_proxy(df_cols=df_cols, df_schema=df_schema, controllable_components=[])
filter_proxy = _create_filter_proxy(df_cols=df_cols, df_schema=df_schema, controllable_components=[])
with pytest.raises(ValidationError):
actual(targets=[], column="gdp")
filter_proxy(targets=[], column="a")

def test_create_filter_proxy_validate_columns(self, df_cols, df_schema, controllable_components):
actual = _create_filter_proxy(df_cols, df_schema, controllable_components)
filter_proxy = _create_filter_proxy(df_cols, df_schema, controllable_components)
with pytest.raises(ValidationError, match="column must be one of"):
actual(targets=["gdp_chart"], column="x")
filter_proxy(targets=["bar_chart"], column="x")

def test_create_filter_proxy(self, df_cols, df_schema, controllable_components):
actual = _create_filter_proxy(df_cols, df_schema, controllable_components)
actual_filter = actual(targets=["gdp_chart"], column="gdp")
filter_proxy = _create_filter_proxy(df_cols, df_schema, controllable_components)
actual_filter = filter_proxy(targets=["bar_chart"], column="a")

assert actual_filter.dict(exclude={"id": True}) == vm.Filter(targets=["gdp_chart"], column="gdp").dict(
assert actual_filter.dict(exclude={"id": True}) == vm.Filter(targets=["bar_chart"], column="a").dict(
exclude={"id": True}
)

Expand All @@ -50,7 +50,7 @@ def test_control_plan_invalid_df_name(self, fake_llm_filter, df_metadata):
df_name="population_chart",
)
default_control = control_plan.create(
model=fake_llm_filter, controllable_components=["gdp_chart"], all_df_metadata=df_metadata
model=fake_llm_filter, controllable_components=["bar_chart"], all_df_metadata=df_metadata
)
assert default_control is None

Expand All @@ -59,19 +59,18 @@ def test_control_plan_invalid_type(self, fake_llm_filter, df_metadata):
ControlPlan(
control_type="parameter",
control_description="Create a parameter that targets the data based on the column 'a'.",
df_name="gdp_chart",
df_name="bar_chart",
)


def test_create_filter(filter_prompt, fake_llm_filter_1, df_cols, df_schema_1, controllable_components):

actual = _create_filter(
def test_create_filter(filter_prompt, fake_llm_filter, df_cols, df_schema, controllable_components):
actual_filter = _create_filter(
filter_prompt=filter_prompt,
model=fake_llm_filter_1,
model=fake_llm_filter,
df_cols=df_cols,
df_schema=df_schema_1,
df_schema=df_schema,
controllable_components=controllable_components,
)
assert actual.dict(exclude={"id": True}) == vm.Filter(targets=["gdp_chart"], column="country").dict(
assert actual_filter.dict(exclude={"id": True}) == vm.Filter(targets=["bar_chart"], column="a").dict(
exclude={"id": True}
)

0 comments on commit 48789b9

Please sign in to comment.