Skip to content

Commit

Permalink
Move update of graph theme to Graph.__call__ (#174)
Browse files Browse the repository at this point in the history
  • Loading branch information
antonymilne authored Nov 15, 2023
1 parent 529c26b commit 4160de8
Show file tree
Hide file tree
Showing 11 changed files with 123 additions and 62 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<!--
A new scriv changelog fragment.
Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨
- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Removed
- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Added
- A bullet item for the Added category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Changed
- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Deprecated
- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Fixed
- A bullet item for the Fixed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Security
- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
3 changes: 0 additions & 3 deletions vizro-core/src/vizro/actions/_actions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,6 @@ def _get_modified_page_figures(
ctds_filter: List[CallbackTriggerDict],
ctds_filter_interaction: List[Dict[str, CallbackTriggerDict]],
ctds_parameters: List[CallbackTriggerDict],
ctd_theme: CallbackTriggerDict,
targets: Optional[List[ModelID]] = None,
) -> Dict[ModelID, Any]:
if not targets:
Expand All @@ -267,7 +266,5 @@ def _get_modified_page_figures(
outputs[target] = model_manager[target]( # type: ignore[operator]
data_frame=filtered_data[target], **parameterized_config[target]
)
if hasattr(outputs[target], "update_layout"):
outputs[target].update_layout(template="vizro_dark" if ctd_theme["value"] else "vizro_light")

return outputs
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _get_action_callback_inputs(action_id: ModelID) -> Dict[str, Any]:
if "filter_interaction" in include_inputs
else []
),
"theme_selector": (State("theme_selector", "on") if "theme_selector" in include_inputs else []),
"theme_selector": State("theme_selector", "on") if "theme_selector" in include_inputs else [],
}
return action_input_mapping

Expand Down
1 change: 0 additions & 1 deletion vizro-core/src/vizro/actions/_filter_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,4 @@ def _filter(
ctds_filter=ctx.args_grouping["filters"],
ctds_filter_interaction=ctx.args_grouping["filter_interaction"],
ctds_parameters=ctx.args_grouping["parameters"],
ctd_theme=ctx.args_grouping["theme_selector"],
)
1 change: 0 additions & 1 deletion vizro-core/src/vizro/actions/_on_page_load_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,4 @@ def _on_page_load(page_id: ModelID, **inputs: Dict[str, Any]) -> Dict[ModelID, A
ctds_filter=ctx.args_grouping["filters"],
ctds_filter_interaction=ctx.args_grouping["filter_interaction"],
ctds_parameters=ctx.args_grouping["parameters"],
ctd_theme=ctx.args_grouping["theme_selector"],
)
1 change: 0 additions & 1 deletion vizro-core/src/vizro/actions/_parameter_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,4 @@ def _parameter(targets: List[str], **inputs: Dict[str, Any]) -> Dict[ModelID, An
ctds_filter=ctx.args_grouping["filters"],
ctds_filter_interaction=ctx.args_grouping["filter_interaction"],
ctds_parameters=ctx.args_grouping["parameters"],
ctd_theme=ctx.args_grouping["theme_selector"],
)
1 change: 0 additions & 1 deletion vizro-core/src/vizro/actions/filter_interaction_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,4 @@ def filter_interaction(
ctds_filter=ctx.args_grouping["filters"],
ctds_filter_interaction=ctx.args_grouping["filter_interaction"],
ctds_parameters=ctx.args_grouping["parameters"],
ctd_theme=ctx.args_grouping["theme_selector"],
)
21 changes: 20 additions & 1 deletion vizro-core/src/vizro/models/_components/graph.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import logging
from typing import List, Literal

from dash import dcc
from dash import ctx, dcc
from dash.exceptions import MissingCallbackContextException
from plotly import graph_objects as go
from pydantic import Field, PrivateAttr, validator

import vizro.plotly.express as px
from vizro import _themes as themes
from vizro.managers import data_manager
from vizro.models import Action, VizroBaseModel
from vizro.models._action._actions_chain import _action_validator_factory
Expand Down Expand Up @@ -44,6 +46,16 @@ def __call__(self, **kwargs):
# Remove top margin if title is provided
if fig.layout.title.text is None:
fig.update_layout(margin_t=24)

# Possibly we should enforce that __call__ can only be used within the context of a callback, but it's easy
# to just swallow up the error here as it doesn't cause any problems.
try:
# At the moment theme_selector is always present so this if statement is redundant, but possibly in
# future we'll have callbacks that do Graph.__call__() without theme_selector set.
if "theme_selector" in ctx.args_grouping:
fig = self._update_theme(fig, ctx.args_grouping["theme_selector"]["value"])
except MissingCallbackContextException:
logger.info("fig.update_layout called outside of callback context.")
return fig

# Convenience wrapper/syntactic sugar.
Expand Down Expand Up @@ -81,3 +93,10 @@ def build(self):
color="grey",
parent_className="loading-container",
)

@staticmethod
def _update_theme(fig: go.Figure, theme_selector: bool):
# Basically the same as doing fig.update_layout(template="vizro_light/dark") but works for both the call in
# self.__call__ and in the update_graph_theme callback.
fig["layout"]["template"] = themes.dark if theme_selector else themes.light
return fig
35 changes: 18 additions & 17 deletions vizro-core/src/vizro/models/_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
from dash import Input, Output, Patch, callback, dcc, html
from pydantic import Field, root_validator, validator

import vizro._themes as themes
from vizro._constants import ON_PAGE_LOAD_ACTION_PREFIX
from vizro.actions import _on_page_load
from vizro.managers._model_manager import DuplicateIDError
from vizro.models import Action, Graph, Layout, VizroBaseModel
from vizro.models import Action, Layout, VizroBaseModel
from vizro.models._action._actions_chain import ActionsChain, Trigger
from vizro.models._models_utils import _log_call, get_unique_grid_component_ids

Expand Down Expand Up @@ -139,24 +138,26 @@ def _update_graph_theme(self):
# The obvious way to do this would be to alter pio.templates.default, but this changes global state and so is
# not good.
# Putting graphs as inputs here would be a nice way to trigger the theme change automatically so that we don't
# need the update_layout call in Graph.__call__, but this results in an extra callback and the graph

# need the call to _update_theme inside Graph.__call__ also, but this results in an extra callback and the graph
# flickering.
# TODO: consider making this clientside callback and then possibly we can remove the update_layout in
# The code is written to be generic and extensible so that it runs _update_theme on any component with such a
# method defined. But at the moment this just means Graphs.
# TODO: consider making this clientside callback and then possibly we can remove the call to _update_theme in
# Graph.__call__ without any flickering.
# TODO: consider putting the Graph-specific logic here in the Graph model itself (whether clientside or
# serverside) to keep the code here abstract.
outputs = [
Output(component.id, "figure", allow_duplicate=True)
for component in self.components
if isinstance(component, Graph)
]
if outputs:
# TODO: if we do this then we should *consider* defining the callback in Graph itself rather than at Page
# level. This would mean multiple callbacks on one page but if it's clientside that probably doesn't matter.

@callback(outputs, Input("theme_selector", "on"), prevent_initial_call="initial_duplicate")
def update_graph_theme(theme_selector_on: bool):
patched_figure = Patch()
patched_figure["layout"]["template"] = themes.dark if theme_selector_on else themes.light
return [patched_figure] * len(outputs)
themed_components = [component for component in self.components if hasattr(component, "_update_theme")]
if themed_components:

@callback(
[Output(component.id, "figure", allow_duplicate=True) for component in themed_components],
Input("theme_selector", "on"),
prevent_initial_call="initial_duplicate",
)
def update_graph_theme(theme_selector: bool):
return [component._update_theme(Patch(), theme_selector) for component in themed_components]

def _create_component_container(self, components_content):
component_container = html.Div(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def callback_context_on_page_load(request):
"theme_selector": CallbackTriggerDict(
id="theme_selector",
property="on",
value=True if template == "vizro_dark" else False,
value=template == "vizro_dark",
str_id="theme_selector",
triggered=False,
),
Expand Down
70 changes: 35 additions & 35 deletions vizro-core/tests/unit/vizro/models/_components/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
import plotly.graph_objects as go
import pytest
from dash import dcc
from dash._callback_context import context_value
from dash._utils import AttributeDict
from pydantic import ValidationError

import vizro.models as vm
import vizro.plotly.express as px
from vizro.actions._actions_utils import CallbackTriggerDict
from vizro.managers import data_manager
from vizro.models._action._action import Action

Expand All @@ -26,18 +29,6 @@ def standard_px_chart_with_str_dataframe():
)


@pytest.fixture
def expected_empty_chart():
figure = go.Figure()
figure.add_trace(go.Scatter(x=[None], y=[None], showlegend=False, hoverinfo="none"))
figure.update_layout(
xaxis={"visible": False},
yaxis={"visible": False},
annotations=[{"text": "NO DATA", "showarrow": False, "font": {"size": 16}}],
)
return figure


@pytest.fixture
def expected_graph():
return dcc.Loading(
Expand Down Expand Up @@ -74,11 +65,7 @@ def test_create_graph_mandatory_only(self, standard_px_chart):

@pytest.mark.parametrize("id", ["id_1", "id_2"])
def test_create_graph_mandatory_and_optional(self, standard_px_chart, id):
graph = vm.Graph(
figure=standard_px_chart,
id=id,
actions=[],
)
graph = vm.Graph(figure=standard_px_chart, id=id, actions=[])

assert graph.id == id
assert graph.type == "graph"
Expand All @@ -90,9 +77,7 @@ def test_mandatory_figure_missing(self):

def test_failed_graph_with_wrong_figure(self, standard_go_chart):
with pytest.raises(ValidationError, match="must provide a valid CapturedCallable object"):
vm.Graph(
figure=standard_go_chart,
)
vm.Graph(figure=standard_go_chart)

def test_getitem_known_args(self, standard_px_chart):
graph = vm.Graph(figure=standard_px_chart)
Expand All @@ -107,13 +92,34 @@ def test_getitem_unknown_args(self, standard_px_chart):

@pytest.mark.parametrize("title, expected", [(None, 24), ("Test", None)])
def test_title_margin_adjustment(self, gapminder, title, expected):
figure = vm.Graph(figure=px.bar(data_frame=gapminder, x="year", y="pop", title=title)).__call__()

assert figure.layout.margin.t == expected
assert figure.layout.template.layout.margin.t == 64
assert figure.layout.template.layout.margin.l == 80
assert figure.layout.template.layout.margin.b == 64
assert figure.layout.template.layout.margin.r == 12
graph = vm.Graph(figure=px.bar(data_frame=gapminder, x="year", y="pop", title=title)).__call__()

assert graph.layout.margin.t == expected
assert graph.layout.template.layout.margin.t == 64
assert graph.layout.template.layout.margin.l == 80
assert graph.layout.template.layout.margin.b == 64
assert graph.layout.template.layout.margin.r == 12

def test_update_theme_outside_callback(self, standard_px_chart):
graph = vm.Graph(figure=standard_px_chart).__call__()
assert graph == standard_px_chart.update_layout(margin_t=24, template="vizro_dark")

@pytest.mark.parametrize("template", ["vizro_dark", "vizro_light"])
def test_update_theme_inside_callback(self, standard_px_chart, template):
mock_callback_context = {
"args_grouping": {
"theme_selector": CallbackTriggerDict(
id="theme_selector",
property="on",
value=template == "vizro_dark",
str_id="theme_selector",
triggered=False,
)
}
}
context_value.set(AttributeDict(**mock_callback_context))
graph = vm.Graph(figure=standard_px_chart).__call__()
assert graph == standard_px_chart.update_layout(margin_t=24, template=template)

def test_set_action_via_validator(self, standard_px_chart, test_action_function):
graph = vm.Graph(figure=standard_px_chart, actions=[Action(function=test_action_function)])
Expand All @@ -124,18 +130,12 @@ def test_set_action_via_validator(self, standard_px_chart, test_action_function)
class TestProcessFigureDataFrame:
def test_process_figure_data_frame_str_df(self, standard_px_chart_with_str_dataframe, gapminder):
data_manager["gapminder"] = gapminder
graph_with_str_df = vm.Graph(
id="text_graph",
figure=standard_px_chart_with_str_dataframe,
)
graph_with_str_df = vm.Graph(id="text_graph", figure=standard_px_chart_with_str_dataframe)
assert data_manager._get_component_data("text_graph").equals(gapminder)
assert graph_with_str_df["data_frame"] == "gapminder"

def test_process_figure_data_frame_df(self, standard_px_chart, gapminder):
graph_with_df = vm.Graph(
id="text_graph",
figure=standard_px_chart,
)
graph_with_df = vm.Graph(id="text_graph", figure=standard_px_chart)
assert data_manager._get_component_data("text_graph").equals(gapminder)
with pytest.raises(KeyError, match="'data_frame'"):
graph_with_df.figure["data_frame"]
Expand Down

0 comments on commit 4160de8

Please sign in to comment.