Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AGGrid implementation #260

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions vizro-core/examples/default/app_dev.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""Example to show dashboard configuration."""
import pandas as pd

import vizro.models as vm
import vizro.plotly.express as px
from vizro import Vizro
from vizro.actions import export_data, filter_interaction
from vizro.tables import dash_ag_grid, dash_data_table

df = px.data.gapminder()
df_mean = (
df.groupby(by=["continent", "year"]).agg({"lifeExp": "mean", "pop": "mean", "gdpPercap": "mean"}).reset_index()
)

df_transformed = df.copy()
df_transformed["lifeExp"] = df.groupby(by=["continent", "year"])["lifeExp"].transform("mean")
df_transformed["gdpPercap"] = df.groupby(by=["continent", "year"])["gdpPercap"].transform("mean")
df_transformed["pop"] = df.groupby(by=["continent", "year"])["pop"].transform("sum")
df_concat = pd.concat([df_transformed.assign(color="Continent Avg."), df.assign(color="Country")], ignore_index=True)


def create_benchmark_analysis():
"""Function returns a page to perform analysis on country level."""
# Apply formatting to table columns
columns = [
{"id": "country", "name": "country"},
{"id": "continent", "name": "continent"},
{"id": "year", "name": "year"},
{"id": "lifeExp", "name": "lifeExp", "type": "numeric", "format": {"specifier": ",.1f"}},
{"id": "gdpPercap", "name": "gdpPercap", "type": "numeric", "format": {"specifier": "$,.2f"}},
{"id": "pop", "name": "pop", "type": "numeric", "format": {"specifier": ",d"}},
]

page_country = vm.Page(
title="Benchmark Analysis",
# description="Discovering how the metrics differ for each country and export data for further investigation",
# layout=vm.Layout(grid=[[0, 1]] * 5 + [[2, -1]], col_gap="32px", row_gap="60px"),
components=[
vm.Table(
id="table_country_new",
title="Click on a cell in country column:",
figure=dash_ag_grid(
data_frame=df,
),
actions=[vm.Action(function=filter_interaction(targets=["line_country"]))],
),
vm.Table(
id="table_country",
title="Click on a cell in country column:",
figure=dash_data_table(
id="dash_data_table_country",
data_frame=df,
columns=columns,
style_data_conditional=[
{
"if": {"filter_query": "{gdpPercap} < 1045", "column_id": "gdpPercap"},
"backgroundColor": "#ff9222",
},
{
"if": {
"filter_query": "{gdpPercap} >= 1045 && {gdpPercap} <= 4095",
"column_id": "gdpPercap",
},
"backgroundColor": "#de9e75",
},
{
"if": {
"filter_query": "{gdpPercap} > 4095 && {gdpPercap} <= 12695",
"column_id": "gdpPercap",
},
"backgroundColor": "#aaa9ba",
},
{
"if": {"filter_query": "{gdpPercap} > 12695", "column_id": "gdpPercap"},
"backgroundColor": "#00b4ff",
},
],
sort_action="native",
style_cell={"textAlign": "left"},
),
actions=[vm.Action(function=filter_interaction(targets=["line_country"]))],
),
vm.Graph(
id="line_country",
figure=px.line(
df_concat,
title="Country vs. Continent",
x="year",
y="gdpPercap",
color="color",
labels={"year": "Year", "data": "Data", "gdpPercap": "GDP per capita"},
color_discrete_map={"Country": "#afe7f9", "Continent": "#003875"},
markers=True,
hover_name="country",
),
),
vm.Button(
text="Export data",
actions=[
vm.Action(
function=export_data(
targets=["line_country"],
)
),
],
),
],
controls=[
vm.Filter(column="continent", selector=vm.Dropdown(value="Europe", multi=False, title="Select continent")),
vm.Filter(column="year", selector=vm.RangeSlider(title="Select timeframe", step=1, marks=None)),
vm.Parameter(
targets=["line_country.y"],
selector=vm.Dropdown(
options=["lifeExp", "gdpPercap", "pop"], multi=False, value="gdpPercap", title="Choose y-axis"
),
),
],
)
return page_country


dashboard = vm.Dashboard(
pages=[
create_benchmark_analysis(),
],
)

if __name__ == "__main__":
Vizro(assets_folder="../assets").build(dashboard).run()
3 changes: 2 additions & 1 deletion vizro-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ dependencies = [
"tornado>=6.3.2", # not directly required, pinned by Snyk to avoid a vulnerability: https://security.snyk.io/vuln/SNYK-PYTHON-TORNADO-5537286
"setuptools>=65.5.1", # not directly required, pinned by Snyk to avoid a vulnerability: https://security.snyk.io/vuln/SNYK-PYTHON-SETUPTOOLS-3180412
"werkzeug>=3.0.1", # not directly required, pinned by Snyk to avoid a vulnerability: https://security.snyk.io/vuln/SNYK-PYTHON-WERKZEUG-6035177
"MarkupSafe" # required to sanitize user input
"MarkupSafe", # required to sanitize user input,
"dash-ag-grid"
]
description = "Vizro is a package to facilitate visual analytics."
dynamic = ["version"]
Expand Down
31 changes: 29 additions & 2 deletions vizro-core/src/vizro/actions/_actions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _get_parent_vizro_model(_underlying_callable_object_id: str) -> VizroBaseMod
)


def _apply_table_filter_interaction(
def _apply_dashtable_filter_interaction(
data_frame: pd.DataFrame, target: str, ctd_filter_interaction: Dict[str, CallbackTriggerDict]
) -> pd.DataFrame:
ctd_active_cell = ctd_filter_interaction["active_cell"]
Expand All @@ -133,6 +133,26 @@ def _apply_table_filter_interaction(
return data_frame


def _apply_aggrid_filter_interaction(
data_frame: pd.DataFrame, target: str, ctd_filter_interaction: Dict[str, CallbackTriggerDict]
) -> pd.DataFrame:
ctd_cellClicked = ctd_filter_interaction["cellClicked"]
if not ctd_cellClicked["value"]:
return data_frame

# ctd_active_cell["id"] represents the underlying table id, so we need to fetch its parent Vizro Table actions.
source_table_actions = _get_component_actions(_get_parent_vizro_model(ctd_cellClicked["id"]))

for action in source_table_actions:
if action.function._function.__name__ != "filter_interaction" or target not in action.function["targets"]:
continue
column = ctd_cellClicked["value"]["colId"]
clicked_data = ctd_cellClicked["value"]["value"]
data_frame = data_frame[data_frame[column].isin([clicked_data])]

return data_frame


def _apply_filter_interaction(
data_frame: pd.DataFrame,
ctds_filter_interaction: List[Dict[str, CallbackTriggerDict]],
Expand All @@ -147,7 +167,14 @@ def _apply_filter_interaction(
)

if "active_cell" in ctd_filter_interaction and "derived_viewport_data" in ctd_filter_interaction:
data_frame = _apply_table_filter_interaction(
data_frame = _apply_dashtable_filter_interaction(
data_frame=data_frame,
target=target,
ctd_filter_interaction=ctd_filter_interaction,
)

if "cellClicked" in ctd_filter_interaction:
data_frame = _apply_aggrid_filter_interaction(
data_frame=data_frame,
target=target,
ctd_filter_interaction=ctd_filter_interaction,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from vizro.actions import _on_page_load, _parameter, export_data, filter_interaction
from vizro.managers import data_manager, model_manager
from vizro.managers._model_manager import ModelID
from vizro.models import Action, Page, Table, VizroBaseModel
from vizro.models import Action, Page, VizroBaseModel
from vizro.models._action._actions_chain import ActionsChain
from vizro.models._controls import Filter, Parameter
from vizro.models.types import ControlType
Expand Down Expand Up @@ -108,18 +108,27 @@ def _get_inputs_of_figure_interactions(
for action in figure_interactions_on_page:
# TODO: Consider do we want to move the following logic into Model implementation
triggered_model = _get_triggered_model(action_id=ModelID(str(action.id)))
if isinstance(triggered_model, Table):
inputs.append(
{
"active_cell": State(
component_id=triggered_model._callable_object_id, component_property="active_cell"
),
"derived_viewport_data": State(
component_id=triggered_model._callable_object_id,
component_property="derived_viewport_data",
),
}
)
if hasattr(triggered_model, "_table_type"): # not check this, put this configuration inside the models
if triggered_model._table_type == "DataTable":
inputs.append(
{
"active_cell": State(
component_id=triggered_model._callable_object_id, component_property="active_cell"
),
"derived_viewport_data": State(
component_id=triggered_model._callable_object_id,
component_property="derived_viewport_data",
),
}
)
elif triggered_model._table_type == "AgGrid":
inputs.append(
{
"cellClicked": State(
component_id=triggered_model._callable_object_id, component_property="cellClicked"
),
}
)
else:
inputs.append(
{
Expand Down
40 changes: 31 additions & 9 deletions vizro-core/src/vizro/models/_components/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,26 @@
import vizro.tables as vt
from vizro.managers import data_manager
from vizro.models import Action, VizroBaseModel
from vizro.models._action._actions_chain import _action_validator_factory
from vizro.models._action._actions_chain import _set_actions
from vizro.models._components._components_utils import _process_callable_data_frame
from vizro.models._models_utils import _log_call
from vizro.models.types import CapturedCallable

logger = logging.getLogger(__name__)


def _get_table_type(figure): # this function can be applied also in pre-build
kwargs = figure._arguments.copy()

# This workaround is needed because the underlying table object requires a data_frame
kwargs["data_frame"] = DataFrame()

# The underlying table object is pre-built, so we can fetch its ID.
underlying_table_object = figure._function(**kwargs)
Comment on lines +24 to +30
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
kwargs = figure._arguments.copy()
# This workaround is needed because the underlying table object requires a data_frame
kwargs["data_frame"] = DataFrame()
# The underlying table object is pre-built, so we can fetch its ID.
underlying_table_object = figure._function(**kwargs)
underlying_table_object = figure(data_frame=pd.DataFrame())

This should work by itself I think?

Probably we should do some try/except here to give a clear error message in the case that the function call fails for some reason.

I'd also like to understand why this evaluation of figure to extract id is needed (not changed by you here).

Copy link
Contributor Author

@maxschulz-COL maxschulz-COL Jan 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Posting the answer @petar-qb gave to me:

Graphs and Tables in the Dash (so in the Vizro too) are handled differently.

dcc.Graph has: id, figure attributes where figure attribute is any plotly chart.
If this graph has to be changed as Output of the callback we target it as - dash.Output(dcc_graph_id, "figure")
If this graph has to be an Input (let's say clickData property) of the callback we defined it as - dash.Input(dcc_graph_id, "clickData")
It means that we access to graph properties by accessing the outer wrapper dcc.Graph component.

The problem is that there's no outer dcc component wrapper for tables 😕. So, there is nothing like dcc.Table Dash inbuilt component that has id and figure attributes. Table ID is written directly inside its "figure" callable (e.g. inside the dash_table.DataTable()). This is different than graphs because plotly graph doesn't contain the ID (you cannot put the ID inside px.box(...)), but its outer component dcc.Graph does.

Yes, Vizro has created some kind of wrapper vm.Table that has id and figure where the figure is callable that can return dash_table.DataTable or AgGrid. Still, we didn't solve this problem because we can't (or at least, we didn't decide like that) to propagate the vm.Table.id into underlying table component.

Now, let's give an example on how callback inputs and outputs are created in the case of Tables.
If Vizro Table has to be changed as Output of the callback we target it with - dash.Output(vm_table_id, "children") - We can re-render dash_table.DashTable only if we change "children" of the outer Div component.
If Vizro Table has to be the Input (let's say active_cell property) of the callback we defined it as - dash.Input(underlying_table_id, "active_cell") - So we need to fetch this ID in the case that filter_interaction is defined on the Table.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@petar-qb this half makes sense to me. The half that doesn't make sense is:

  1. why can't we also use self._callable_object_id in callback outputs as well? Is the same true for AG Grid or just Dash datatable?
  2. why not set self._callable_object_id = self.id - you say we decided not to? I guess the answer is that if my above question is not possible you'll get duplicate ids for the table and the containing div
  3. do we keep on setting _callable_object_id somewhere outside Table.build?

Ideally what I'd like to do is this:

class Table(VizroBaseModel):
    def __call__(self, **kwargs):
        kwargs.setdefault("data_frame", data_manager._get_component_data(self.id))
        return self.figure(id=self.id, **kwargs)

    # no need for pre_build at all

so that the id is always injected from the automatically. But I'm guessing this will not be possible.

No need resolve this conversation as part of this PR because it's outside the scope of the PR, but it would be great to have a chat about it - let me put something in the calendar 🙂

table_type = underlying_table_object.__class__.__name__
return underlying_table_object, table_type


class Table(VizroBaseModel):
"""Wrapper for table components to visualize in dashboard.

Expand All @@ -37,14 +49,26 @@ class Table(VizroBaseModel):
actions: List[Action] = []

_callable_object_id: str = PrivateAttr()
_table_type: str = (
PrivateAttr()
) # Ideally we would be able to use the populated content of this field in the `set_actions` validator.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely the current code feels a bit tangled here, but I appreciate it's hard to get these things with private properties and validators working exactly as you'd like.

Copy link
Contributor Author

@maxschulz-COL maxschulz-COL Jan 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I spent a while getting this to work, but also didn't hunt for a better solution once it was working (except removing the super().__init__) as there were other bigger questions. We should definitely revisit this once we have agreed on an implementation approach.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Indeed there's no point spending a long time perfecting this if we don't need it at all in the end.


# Component properties for actions and interactions
_output_property: str = PrivateAttr("children")

# validator
set_actions = _action_validator_factory("active_cell")
_validate_callable = validator("figure", allow_reuse=True, always=True)(_process_callable_data_frame)

@validator("actions")
def set_actions(cls, v, values):
_, table_type = _get_table_type(values["figure"])
if table_type == "DataTable":
Comment on lines +64 to +65
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do this by something other than string comparison? e.g. change _get_table_type to return the class and then use isinstance.

return _set_actions(v, values, "active_cell")
elif table_type == "AgGrid":
return _set_actions(v, values, "cellClicked")
else:
raise ValueError(f"Table type {table_type} not supported.")

# Convenience wrapper/syntactic sugar.
def __call__(self, **kwargs):
kwargs.setdefault("data_frame", data_manager._get_component_data(self.id))
Expand All @@ -61,13 +85,7 @@ def __getitem__(self, arg_name: str):
@_log_call
def pre_build(self):
if self.actions:
kwargs = self.figure._arguments.copy()

# This workaround is needed because the underlying table object requires a data_frame
kwargs["data_frame"] = DataFrame()

# The underlying table object is pre-built, so we can fetch its ID.
underlying_table_object = self.figure._function(**kwargs)
underlying_table_object, table_type = _get_table_type(self.figure)

if not hasattr(underlying_table_object, "id"):
raise ValueError(
Expand All @@ -76,6 +94,10 @@ def pre_build(self):
)

self._callable_object_id = underlying_table_object.id
self._table_type = table_type
# Idea: fetch it from the functions attributes? Or just hard-code it here?
# Can check difference between AGGrid and dashtable because we call it already
# Once we recognize, two ways to go: 1) slightly change model properties 2) inject dash dependencies,

def build(self):
return dcc.Loading(
Expand Down
3 changes: 2 additions & 1 deletion vizro-core/src/vizro/tables/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from vizro.tables.dash_aggrid import dash_ag_grid
from vizro.tables.dash_table import dash_data_table

# Please keep alphabetically ordered
__all__ = ["dash_data_table"]
__all__ = ["dash_ag_grid", "dash_data_table"]
13 changes: 13 additions & 0 deletions vizro-core/src/vizro/tables/dash_aggrid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import dash_ag_grid as dag

from vizro.models.types import capture


@capture("action")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@capture("action")
@capture("table")

def dash_ag_grid(data_frame=None):
"""Custom AgGrid."""
return dag.AgGrid(
id="get-started-example-basic",
rowData=data_frame.to_dict("records"),
columnDefs=[{"field": col} for col in data_frame.columns],
)
Loading