Skip to content

Commit

Permalink
Add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
maxschulz-COL committed Oct 23, 2023
1 parent fa055a5 commit 5548fa2
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 15 deletions.
2 changes: 1 addition & 1 deletion vizro-core/examples/default/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
import vizro.plotly.express as px
from vizro import Vizro
from vizro.actions import filter_interaction
from vizro.tables import dash_data_table
from vizro.managers import data_manager
from vizro.models.types import capture
from vizro.tables import dash_data_table


def retrieve_table_data():
Expand Down
4 changes: 1 addition & 3 deletions vizro-core/src/vizro/actions/_actions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,7 @@ def _get_parametrized_config(
for target in targets:
# TODO - avoid calling _captured_callable. Once we have done this we can remove _arguments from
# CapturedCallable entirely.
graph_config = deepcopy(
model_manager[target].figure._arguments # type: ignore[attr-defined]
)
graph_config = deepcopy(model_manager[target].figure._arguments) # type: ignore[attr-defined]
if "data_frame" in graph_config:
graph_config.pop("data_frame")

Expand Down
10 changes: 4 additions & 6 deletions vizro-core/src/vizro/models/_components/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from dash import dash_table, html
from pydantic import Field, PrivateAttr, validator

import vizro.tables as vt
from vizro.managers import data_manager
from vizro.models import Action, VizroBaseModel
from vizro.models._action._actions_chain import _action_validator_factory
from vizro.models._components._components_utils import _process_callable_data_frame
from vizro.models._models_utils import _log_call
from vizro.models.types import CapturedCallable
import vizro.tables as vt

logger = logging.getLogger(__name__)

Expand All @@ -27,7 +27,7 @@ class Table(VizroBaseModel):
"""

type: Literal["table"] = "table"
figure: CapturedCallable = Field(..., import_path = vt, description="Table to be visualized on dashboard")
figure: CapturedCallable = Field(..., import_path=vt, description="Table to be visualized on dashboard")
actions: List[Action] = []

# Component properties for actions and interactions
Expand All @@ -48,10 +48,8 @@ def __getitem__(self, arg_name: str):
# explicitly redirect it to the correct attribute.
if arg_name == "type":
return self.type
return self.table[arg_name]
return self.figure[arg_name]

@_log_call
def build(self):
return html.Div(
dash_table.DataTable(pd.DataFrame().to_dict("records"), []), id=self.id
)
return html.Div(dash_table.DataTable(pd.DataFrame().to_dict("records"), []), id=self.id)
6 changes: 6 additions & 0 deletions vizro-core/tests/unit/vizro/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Fixtures to be shared across several tests."""

import dash
import plotly.graph_objects as go
import pytest

import vizro.models as vm
Expand All @@ -25,6 +26,11 @@ def standard_px_chart(gapminder):
)


@pytest.fixture
def standard_go_chart(gapminder):
return go.Figure(data=go.Scatter(x=gapminder["gdpPercap"], y=gapminder["lifeExp"], mode="markers"))


@pytest.fixture()
def page1():
return vm.Page(title="Page 1", components=[vm.Button(), vm.Button()])
Expand Down
5 changes: 0 additions & 5 deletions vizro-core/tests/unit/vizro/models/_components/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,6 @@
from vizro.models._components.graph import create_empty_fig


@pytest.fixture
def standard_go_chart(gapminder):
return go.Figure(data=go.Scatter(x=gapminder["gdpPercap"], y=gapminder["lifeExp"], mode="markers"))


@pytest.fixture
def standard_px_chart_with_str_dataframe():
return px.scatter(
Expand Down
131 changes: 131 additions & 0 deletions vizro-core/tests/unit/vizro/models/_components/test_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""Unit tests for vizro.models.Table."""
import json

import pytest
from dash import dcc
from pydantic import ValidationError

import vizro.models as vm
import vizro.plotly.express as px
from vizro.managers import data_manager
from vizro.models._action._action import Action
from vizro.tables import dash_data_table
import plotly
from dash import dash_table, html

import pandas as pd


@pytest.fixture
def standard_dash_table():
return dash_data_table(data_frame=px.data.gapminder())

@pytest.fixture
def dash_table_with_arguments():
return dash_data_table(data_frame=px.data.gapminder(), style_header={"border": "1px solid green"})

@pytest.fixture
def dash_table_with_str_dataframe():
return dash_data_table(data_frame="gapminder")

@pytest.fixture
def expected_table():
return html.Div(
dash_table.DataTable(pd.DataFrame().to_dict("records"), []), id="text_table"
)

class TestDunderMethodsTable:
def test_create_graph_mandatory_only(self, standard_dash_table):
table = vm.Table(figure=standard_dash_table)

assert hasattr(table, "id")
assert table.type == "table"
assert table.figure == standard_dash_table
assert table.actions == []

@pytest.mark.parametrize("id", ["id_1", "id_2"])
def test_create_table_mandatory_and_optional(self, standard_dash_table, id):
table = vm.Table(
figure=standard_dash_table,
id=id,
actions=[],
)

assert table.id == id
assert table.type == "table"
assert table.figure == standard_dash_table

def test_mandatory_figure_missing(self):
with pytest.raises(ValidationError, match="field required"):
vm.Table()

def test_failed_graph_with_no_captured_callable(self, standard_go_chart):
with pytest.raises(ValidationError, match="must provide a valid CapturedCallable object"):
vm.Table(
figure=standard_go_chart,
)

@pytest.mark.xfail(reason="This test is failing as we are not yet detecting different types of captured callables")
def test_failed_graph_with_no_captured_callable(self, standard_px_chart):
with pytest.raises(ValidationError, match="must provide a valid table function vm.Table"):
vm.Table(
figure=standard_px_chart,
)

def test_getitem_known_args(self, dash_table_with_arguments):
table = vm.Table(figure=dash_table_with_arguments)
assert table["style_header"] == {"border": "1px solid green"}
assert table["type"] == "table"

def test_getitem_unknown_args(self, standard_dash_table):
table = vm.Table(figure=standard_dash_table)
with pytest.raises(KeyError):
table["unknown_args"]

# @pytest.mark.parametrize("title, expected", [(None, 24), ("Test", None)])
# def test_title_margin_adjustment(self, gapminder, title, expected):
# figure = vm.Graph(figure=px.bar(data_frame=gapminder, x="year", y="pop", title=title)).__call__()

# assert figure.layout.margin.t == expected
# assert figure.layout.template.layout.margin.t == 64
# assert figure.layout.template.layout.margin.l == 80
# assert figure.layout.template.layout.margin.b == 64
# assert figure.layout.template.layout.margin.r == 12

def test_set_action_via_validator(self, standard_dash_table, test_action_function):
table = vm.Table(figure=standard_dash_table, actions=[Action(function=test_action_function)])
actions_chain = table.actions[0]
assert actions_chain.trigger.component_property == "active_cell"


class TestProcessTableDataFrame:
def test_process_figure_data_frame_str_df(self, dash_table_with_str_dataframe, gapminder):
data_manager["gapminder"] = gapminder
table_with_str_df = vm.Table(
id="table",
figure=dash_table_with_str_dataframe,
)
assert data_manager._get_component_data("table").equals(gapminder)
assert table_with_str_df["data_frame"] == "gapminder"

def test_process_figure_data_frame_df(self, standard_dash_table, gapminder):
table_with_str_df = vm.Table(
id="table",
figure=standard_dash_table,
)
assert data_manager._get_component_data("table").equals(gapminder)
with pytest.raises(KeyError, match="'data_frame'"):
table_with_str_df.figure["data_frame"]



class TestBuildTable:
def test_graph_build(self, standard_dash_table, expected_table):
table = vm.Table(
id="text_table",
figure=standard_dash_table,
)

result = json.loads(json.dumps(table.build(), cls=plotly.utils.PlotlyJSONEncoder))
expected = json.loads(json.dumps(expected_table, cls=plotly.utils.PlotlyJSONEncoder))
assert result == expected

0 comments on commit 5548fa2

Please sign in to comment.