diff --git a/vizro-core/changelog.d/20231115_102949_antony.milne_update_graph_theme.md b/vizro-core/changelog.d/20231115_102949_antony.milne_update_graph_theme.md new file mode 100644 index 000000000..f1f65e73c --- /dev/null +++ b/vizro-core/changelog.d/20231115_102949_antony.milne_update_graph_theme.md @@ -0,0 +1,48 @@ + + + + + + + + + diff --git a/vizro-core/src/vizro/actions/_actions_utils.py b/vizro-core/src/vizro/actions/_actions_utils.py index e82756e03..4f4363203 100644 --- a/vizro-core/src/vizro/actions/_actions_utils.py +++ b/vizro-core/src/vizro/actions/_actions_utils.py @@ -246,7 +246,6 @@ def _get_modified_page_figures( ctds_filter: List[CallbackTriggerDict], ctds_filter_interaction: List[Dict[str, CallbackTriggerDict]], ctds_parameters: List[CallbackTriggerDict], - ctd_theme: CallbackTriggerDict, targets: Optional[List[ModelID]] = None, ) -> Dict[ModelID, Any]: if not targets: @@ -267,7 +266,5 @@ def _get_modified_page_figures( outputs[target] = model_manager[target]( # type: ignore[operator] data_frame=filtered_data[target], **parameterized_config[target] ) - if hasattr(outputs[target], "update_layout"): - outputs[target].update_layout(template="vizro_dark" if ctd_theme["value"] else "vizro_light") return outputs diff --git a/vizro-core/src/vizro/actions/_callback_mapping/_callback_mapping_utils.py b/vizro-core/src/vizro/actions/_callback_mapping/_callback_mapping_utils.py index da6620214..2e55f12da 100644 --- a/vizro-core/src/vizro/actions/_callback_mapping/_callback_mapping_utils.py +++ b/vizro-core/src/vizro/actions/_callback_mapping/_callback_mapping_utils.py @@ -154,7 +154,7 @@ def _get_action_callback_inputs(action_id: ModelID) -> Dict[str, Any]: if "filter_interaction" in include_inputs else [] ), - "theme_selector": (State("theme_selector", "on") if "theme_selector" in include_inputs else []), + "theme_selector": State("theme_selector", "on") if "theme_selector" in include_inputs else [], } return action_input_mapping diff --git a/vizro-core/src/vizro/actions/_filter_action.py b/vizro-core/src/vizro/actions/_filter_action.py index f35daa7b4..bd3a2b596 100644 --- a/vizro-core/src/vizro/actions/_filter_action.py +++ b/vizro-core/src/vizro/actions/_filter_action.py @@ -36,5 +36,4 @@ def _filter( ctds_filter=ctx.args_grouping["filters"], ctds_filter_interaction=ctx.args_grouping["filter_interaction"], ctds_parameters=ctx.args_grouping["parameters"], - ctd_theme=ctx.args_grouping["theme_selector"], ) diff --git a/vizro-core/src/vizro/actions/_on_page_load_action.py b/vizro-core/src/vizro/actions/_on_page_load_action.py index 16c6279f5..ca492907a 100644 --- a/vizro-core/src/vizro/actions/_on_page_load_action.py +++ b/vizro-core/src/vizro/actions/_on_page_load_action.py @@ -31,5 +31,4 @@ def _on_page_load(page_id: ModelID, **inputs: Dict[str, Any]) -> Dict[ModelID, A ctds_filter=ctx.args_grouping["filters"], ctds_filter_interaction=ctx.args_grouping["filter_interaction"], ctds_parameters=ctx.args_grouping["parameters"], - ctd_theme=ctx.args_grouping["theme_selector"], ) diff --git a/vizro-core/src/vizro/actions/_parameter_action.py b/vizro-core/src/vizro/actions/_parameter_action.py index ba77a7762..defc5d5aa 100644 --- a/vizro-core/src/vizro/actions/_parameter_action.py +++ b/vizro-core/src/vizro/actions/_parameter_action.py @@ -30,5 +30,4 @@ def _parameter(targets: List[str], **inputs: Dict[str, Any]) -> Dict[ModelID, An ctds_filter=ctx.args_grouping["filters"], ctds_filter_interaction=ctx.args_grouping["filter_interaction"], ctds_parameters=ctx.args_grouping["parameters"], - ctd_theme=ctx.args_grouping["theme_selector"], ) diff --git a/vizro-core/src/vizro/actions/filter_interaction_action.py b/vizro-core/src/vizro/actions/filter_interaction_action.py index d3a04f3a8..fd2392617 100644 --- a/vizro-core/src/vizro/actions/filter_interaction_action.py +++ b/vizro-core/src/vizro/actions/filter_interaction_action.py @@ -37,5 +37,4 @@ def filter_interaction( ctds_filter=ctx.args_grouping["filters"], ctds_filter_interaction=ctx.args_grouping["filter_interaction"], ctds_parameters=ctx.args_grouping["parameters"], - ctd_theme=ctx.args_grouping["theme_selector"], ) diff --git a/vizro-core/src/vizro/models/_components/graph.py b/vizro-core/src/vizro/models/_components/graph.py index 220e835bf..33b5dbe09 100644 --- a/vizro-core/src/vizro/models/_components/graph.py +++ b/vizro-core/src/vizro/models/_components/graph.py @@ -1,11 +1,13 @@ import logging from typing import List, Literal -from dash import dcc +from dash import ctx, dcc +from dash.exceptions import MissingCallbackContextException from plotly import graph_objects as go from pydantic import Field, PrivateAttr, validator import vizro.plotly.express as px +from vizro import _themes as themes from vizro.managers import data_manager from vizro.models import Action, VizroBaseModel from vizro.models._action._actions_chain import _action_validator_factory @@ -44,6 +46,16 @@ def __call__(self, **kwargs): # Remove top margin if title is provided if fig.layout.title.text is None: fig.update_layout(margin_t=24) + + # Possibly we should enforce that __call__ can only be used within the context of a callback, but it's easy + # to just swallow up the error here as it doesn't cause any problems. + try: + # At the moment theme_selector is always present so this if statement is redundant, but possibly in + # future we'll have callbacks that do Graph.__call__() without theme_selector set. + if "theme_selector" in ctx.args_grouping: + fig = self._update_theme(fig, ctx.args_grouping["theme_selector"]["value"]) + except MissingCallbackContextException: + logger.info("fig.update_layout called outside of callback context.") return fig # Convenience wrapper/syntactic sugar. @@ -81,3 +93,10 @@ def build(self): color="grey", parent_className="loading-container", ) + + @staticmethod + def _update_theme(fig: go.Figure, theme_selector: bool): + # Basically the same as doing fig.update_layout(template="vizro_light/dark") but works for both the call in + # self.__call__ and in the update_graph_theme callback. + fig["layout"]["template"] = themes.dark if theme_selector else themes.light + return fig diff --git a/vizro-core/src/vizro/models/_page.py b/vizro-core/src/vizro/models/_page.py index acefed092..641ddfb18 100644 --- a/vizro-core/src/vizro/models/_page.py +++ b/vizro-core/src/vizro/models/_page.py @@ -5,11 +5,10 @@ from dash import Input, Output, Patch, callback, dcc, html from pydantic import Field, root_validator, validator -import vizro._themes as themes from vizro._constants import ON_PAGE_LOAD_ACTION_PREFIX from vizro.actions import _on_page_load from vizro.managers._model_manager import DuplicateIDError -from vizro.models import Action, Graph, Layout, VizroBaseModel +from vizro.models import Action, Layout, VizroBaseModel from vizro.models._action._actions_chain import ActionsChain, Trigger from vizro.models._models_utils import _log_call, get_unique_grid_component_ids @@ -139,24 +138,26 @@ def _update_graph_theme(self): # The obvious way to do this would be to alter pio.templates.default, but this changes global state and so is # not good. # Putting graphs as inputs here would be a nice way to trigger the theme change automatically so that we don't - # need the update_layout call in Graph.__call__, but this results in an extra callback and the graph + + # need the call to _update_theme inside Graph.__call__ also, but this results in an extra callback and the graph # flickering. - # TODO: consider making this clientside callback and then possibly we can remove the update_layout in + # The code is written to be generic and extensible so that it runs _update_theme on any component with such a + # method defined. But at the moment this just means Graphs. + # TODO: consider making this clientside callback and then possibly we can remove the call to _update_theme in # Graph.__call__ without any flickering. - # TODO: consider putting the Graph-specific logic here in the Graph model itself (whether clientside or - # serverside) to keep the code here abstract. - outputs = [ - Output(component.id, "figure", allow_duplicate=True) - for component in self.components - if isinstance(component, Graph) - ] - if outputs: + # TODO: if we do this then we should *consider* defining the callback in Graph itself rather than at Page + # level. This would mean multiple callbacks on one page but if it's clientside that probably doesn't matter. - @callback(outputs, Input("theme_selector", "on"), prevent_initial_call="initial_duplicate") - def update_graph_theme(theme_selector_on: bool): - patched_figure = Patch() - patched_figure["layout"]["template"] = themes.dark if theme_selector_on else themes.light - return [patched_figure] * len(outputs) + themed_components = [component for component in self.components if hasattr(component, "_update_theme")] + if themed_components: + + @callback( + [Output(component.id, "figure", allow_duplicate=True) for component in themed_components], + Input("theme_selector", "on"), + prevent_initial_call="initial_duplicate", + ) + def update_graph_theme(theme_selector: bool): + return [component._update_theme(Patch(), theme_selector) for component in themed_components] def _create_component_container(self, components_content): component_container = html.Div( diff --git a/vizro-core/tests/unit/vizro/actions/test_on_page_load_action.py b/vizro-core/tests/unit/vizro/actions/test_on_page_load_action.py index 6e402a095..f9e3b70d1 100644 --- a/vizro-core/tests/unit/vizro/actions/test_on_page_load_action.py +++ b/vizro-core/tests/unit/vizro/actions/test_on_page_load_action.py @@ -80,7 +80,7 @@ def callback_context_on_page_load(request): "theme_selector": CallbackTriggerDict( id="theme_selector", property="on", - value=True if template == "vizro_dark" else False, + value=template == "vizro_dark", str_id="theme_selector", triggered=False, ), diff --git a/vizro-core/tests/unit/vizro/models/_components/test_graph.py b/vizro-core/tests/unit/vizro/models/_components/test_graph.py index 9694febe9..6f3e190f9 100644 --- a/vizro-core/tests/unit/vizro/models/_components/test_graph.py +++ b/vizro-core/tests/unit/vizro/models/_components/test_graph.py @@ -5,10 +5,13 @@ import plotly.graph_objects as go import pytest from dash import dcc +from dash._callback_context import context_value +from dash._utils import AttributeDict from pydantic import ValidationError import vizro.models as vm import vizro.plotly.express as px +from vizro.actions._actions_utils import CallbackTriggerDict from vizro.managers import data_manager from vizro.models._action._action import Action @@ -26,18 +29,6 @@ def standard_px_chart_with_str_dataframe(): ) -@pytest.fixture -def expected_empty_chart(): - figure = go.Figure() - figure.add_trace(go.Scatter(x=[None], y=[None], showlegend=False, hoverinfo="none")) - figure.update_layout( - xaxis={"visible": False}, - yaxis={"visible": False}, - annotations=[{"text": "NO DATA", "showarrow": False, "font": {"size": 16}}], - ) - return figure - - @pytest.fixture def expected_graph(): return dcc.Loading( @@ -74,11 +65,7 @@ def test_create_graph_mandatory_only(self, standard_px_chart): @pytest.mark.parametrize("id", ["id_1", "id_2"]) def test_create_graph_mandatory_and_optional(self, standard_px_chart, id): - graph = vm.Graph( - figure=standard_px_chart, - id=id, - actions=[], - ) + graph = vm.Graph(figure=standard_px_chart, id=id, actions=[]) assert graph.id == id assert graph.type == "graph" @@ -90,9 +77,7 @@ def test_mandatory_figure_missing(self): def test_failed_graph_with_wrong_figure(self, standard_go_chart): with pytest.raises(ValidationError, match="must provide a valid CapturedCallable object"): - vm.Graph( - figure=standard_go_chart, - ) + vm.Graph(figure=standard_go_chart) def test_getitem_known_args(self, standard_px_chart): graph = vm.Graph(figure=standard_px_chart) @@ -107,13 +92,34 @@ def test_getitem_unknown_args(self, standard_px_chart): @pytest.mark.parametrize("title, expected", [(None, 24), ("Test", None)]) def test_title_margin_adjustment(self, gapminder, title, expected): - figure = vm.Graph(figure=px.bar(data_frame=gapminder, x="year", y="pop", title=title)).__call__() - - assert figure.layout.margin.t == expected - assert figure.layout.template.layout.margin.t == 64 - assert figure.layout.template.layout.margin.l == 80 - assert figure.layout.template.layout.margin.b == 64 - assert figure.layout.template.layout.margin.r == 12 + graph = vm.Graph(figure=px.bar(data_frame=gapminder, x="year", y="pop", title=title)).__call__() + + assert graph.layout.margin.t == expected + assert graph.layout.template.layout.margin.t == 64 + assert graph.layout.template.layout.margin.l == 80 + assert graph.layout.template.layout.margin.b == 64 + assert graph.layout.template.layout.margin.r == 12 + + def test_update_theme_outside_callback(self, standard_px_chart): + graph = vm.Graph(figure=standard_px_chart).__call__() + assert graph == standard_px_chart.update_layout(margin_t=24, template="vizro_dark") + + @pytest.mark.parametrize("template", ["vizro_dark", "vizro_light"]) + def test_update_theme_inside_callback(self, standard_px_chart, template): + mock_callback_context = { + "args_grouping": { + "theme_selector": CallbackTriggerDict( + id="theme_selector", + property="on", + value=template == "vizro_dark", + str_id="theme_selector", + triggered=False, + ) + } + } + context_value.set(AttributeDict(**mock_callback_context)) + graph = vm.Graph(figure=standard_px_chart).__call__() + assert graph == standard_px_chart.update_layout(margin_t=24, template=template) def test_set_action_via_validator(self, standard_px_chart, test_action_function): graph = vm.Graph(figure=standard_px_chart, actions=[Action(function=test_action_function)]) @@ -124,18 +130,12 @@ def test_set_action_via_validator(self, standard_px_chart, test_action_function) class TestProcessFigureDataFrame: def test_process_figure_data_frame_str_df(self, standard_px_chart_with_str_dataframe, gapminder): data_manager["gapminder"] = gapminder - graph_with_str_df = vm.Graph( - id="text_graph", - figure=standard_px_chart_with_str_dataframe, - ) + graph_with_str_df = vm.Graph(id="text_graph", figure=standard_px_chart_with_str_dataframe) assert data_manager._get_component_data("text_graph").equals(gapminder) assert graph_with_str_df["data_frame"] == "gapminder" def test_process_figure_data_frame_df(self, standard_px_chart, gapminder): - graph_with_df = vm.Graph( - id="text_graph", - figure=standard_px_chart, - ) + graph_with_df = vm.Graph(id="text_graph", figure=standard_px_chart) assert data_manager._get_component_data("text_graph").equals(gapminder) with pytest.raises(KeyError, match="'data_frame'"): graph_with_df.figure["data_frame"]