Skip to content

Commit

Permalink
addressing pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
nadijagraca committed Aug 5, 2024
1 parent aa36f85 commit 97e7b35
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from vizro_ai.dashboard.utils import AllDfMetadata, DfMetadata


class FakeListLLM(FakeListLLM):
class MockStructuredOutputLLM(FakeListLLM):
def bind_tools(self, tools: List[Any]):
return super().bind(tools=tools)

Expand All @@ -22,26 +22,37 @@ def with_structured_output(self, schema):
@pytest.fixture
def fake_llm_card():
response = ['{"text":"this is a card","href":""}']
return FakeListLLM(responses=response)
return MockStructuredOutputLLM(responses=response)


@pytest.fixture
def fake_llm_layout():
response = ['{"grid":[[0,1]]}']
return FakeListLLM(responses=response)
return MockStructuredOutputLLM(responses=response)


@pytest.fixture
def fake_llm_filter():
response = ['{"column": "a", "targets": ["gdp_chart"]}']
return FakeListLLM(responses=response)
return MockStructuredOutputLLM(responses=response)


@pytest.fixture
def fake_llm_filter_1():
response = ['{"column": "country", "targets": ["gdp_chart"]}']
return MockStructuredOutputLLM(responses=response)


@pytest.fixture
def df_cols():
return ["continent", "country", "population", "gdp"]


@pytest.fixture
def df_schema_1():
return {"continent": "object", "country": "object", "population": "int64", "gdp": "int64"}


@pytest.fixture
def controllable_components():
return ["gdp_chart"]
Expand All @@ -59,7 +70,8 @@ def df():

@pytest.fixture
def df_sample():
return pd.DataFrame({"a": [1, 2, 3, 4, 5], "b": [4, 5, 6, 7, 8]})
df = pd.DataFrame({"a": [1, 2, 3, 4, 5], "b": [4, 5, 6, 7, 8]})
return df.sample(5, replace=True, random_state=19)


@pytest.fixture
Expand Down Expand Up @@ -101,3 +113,11 @@ def component_card_2():
@pytest.fixture
def page_plan(component_card):
return PagePlan(title="Test Page", components_plan=[component_card], controls_plan=[], layout_plan=None)


@pytest.fixture
def filter_prompt():
return """
Create a filter from the following instructions: Filter the gdp chart by country.
Do not make up things that are optional and DO NOT configure actions, action triggers or action chains.
If no options are specified, leave them out."""
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import pytest
from vizro_ai.dashboard._response_models.controls import ControlPlan, _create_filter_proxy
import vizro.models as vm
from vizro.managers import model_manager
from vizro.models import VizroBaseModel
from vizro_ai.dashboard._response_models.controls import ControlPlan, _create_filter, _create_filter_proxy

try:
from pydantic.v1 import ValidationError
except ImportError: # pragma: no cov
from pydantic import ValidationError

# Needed for testing control creation.
model_manager.__setitem__("gdp_chart", VizroBaseModel)

class TestControlCreate:
"""Tests control creation."""

class TestFilterProxyCreate:
"""Tests filter proxy creation."""

def test_create_filter_proxy_validate_targets(self, df_cols, df_schema, controllable_components):
actual = _create_filter_proxy(df_cols, df_schema, controllable_components)
Expand All @@ -25,6 +31,14 @@ def test_create_filter_proxy_validate_columns(self, df_cols, df_schema, controll
with pytest.raises(ValidationError, match="column must be one of"):
actual(targets=["gdp_chart"], column="x")

def test_create_filter_proxy(self, df_cols, df_schema, controllable_components):
actual = _create_filter_proxy(df_cols, df_schema, controllable_components)
actual_filter = actual(targets=["gdp_chart"], column="gdp")

assert actual_filter.dict(exclude={"id": True}) == vm.Filter(targets=["gdp_chart"], column="gdp").dict(
exclude={"id": True}
)


class TestControlPlan:
"""Test control plan."""
Expand All @@ -47,3 +61,17 @@ def test_control_plan_invalid_type(self, fake_llm_filter, df_metadata):
control_description="Create a parameter that targets the data based on the column 'a'.",
df_name="gdp_chart",
)


def test_create_filter(filter_prompt, fake_llm_filter_1, df_cols, df_schema_1, controllable_components):

actual = _create_filter(
filter_prompt=filter_prompt,
model=fake_llm_filter_1,
df_cols=df_cols,
df_schema=df_schema_1,
controllable_components=controllable_components,
)
assert actual.dict(exclude={"id": True}) == vm.Filter(targets=["gdp_chart"], column="country").dict(
exclude={"id": True}
)
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from pandas.testing import assert_frame_equal
from vizro_ai.dashboard._response_models.df_info import _get_df_info


def test_get_df_info(df, df_schema):
actual_df_schema, _ = _get_df_info(df=df)
def test_get_df_info(df, df_schema, df_sample):
actual_df_schema, actual_df_sample = _get_df_info(df=df)

assert actual_df_schema == df_schema
assert_frame_equal(actual_df_sample, df_sample)
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ def test_layout_plan(self):
["card_1", "scatter_plot"],
[],
),
(
["card_1 scatter_plot scatter_plot", ". scatter_plot scatter_plot"],
["card_1", "scatter_plot"],
[[0, 1, 1], [-1, 1, 1]],
),
],
)
def test_convert_to_grid(layout_grid_template_areas, component_ids, grid):
Expand Down
4 changes: 2 additions & 2 deletions vizro-ai/tests/unit/vizro-ai/dashboard/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from langchain_core.messages import HumanMessage


class FakeListLLM(FakeListLLM):
class MockStructuredOutputLLM(FakeListLLM):
def bind_tools(self, tools: List[Any]):
return super().bind(tools=tools)

Expand All @@ -19,7 +19,7 @@ def with_structured_output(self, schema):
@pytest.fixture
def fake_llm():
response = ['{"text":"this is a card","href":""}']
return FakeListLLM(responses=response)
return MockStructuredOutputLLM(responses=response)


@pytest.fixture
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from vizro_ai.dashboard._pydantic_output import _create_message_content, _create_prompt_template, _get_pydantic_model


def test_get_pydantic_output(component_description, fake_llm):
def test_get_pydantic_model(component_description, fake_llm):
pydantic_output = _get_pydantic_model(
query=component_description, llm_model=fake_llm, response_model=vm.Card, df_info=None
)
Expand Down

0 comments on commit 97e7b35

Please sign in to comment.