Skip to content

Commit

Permalink
Minimal viable AGGrid implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
maxschulz-COL committed Jan 11, 2024
1 parent 69a6d37 commit 96b6259
Show file tree
Hide file tree
Showing 5 changed files with 248 additions and 19 deletions.
151 changes: 151 additions & 0 deletions vizro-core/examples/default/app_dev.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""Example to show dashboard configuration."""

import sys

import dash_ag_grid as dag
import pandas as pd

import vizro.models as vm
import vizro.plotly.express as px
from vizro import Vizro
from vizro.actions import export_data, filter_interaction

#######################################
from vizro.models.types import capture
from vizro.tables import dash_data_table

print("PYTHON EXECUTABLE", sys.executable)


@capture("action")
def AgGrid(data_frame=None):
"""Custom AgGrid."""
return dag.AgGrid(

This comment has been minimized.

Copy link
@antonymilne

antonymilne Jan 16, 2024

Contributor

Are you intending to move this to vizro.tables like we have dash_data_table?

This comment has been minimized.

Copy link
@maxschulz-COL

maxschulz-COL Jan 16, 2024

Author Contributor

Yes, in later iterations it already is

id="get-started-example-basic",
rowData=data_frame.to_dict("records"),
columnDefs=[{"field": col} for col in data_frame.columns],
)


#######################################

df = px.data.gapminder()
df_mean = (
df.groupby(by=["continent", "year"]).agg({"lifeExp": "mean", "pop": "mean", "gdpPercap": "mean"}).reset_index()
)

df_transformed = df.copy()
df_transformed["lifeExp"] = df.groupby(by=["continent", "year"])["lifeExp"].transform("mean")
df_transformed["gdpPercap"] = df.groupby(by=["continent", "year"])["gdpPercap"].transform("mean")
df_transformed["pop"] = df.groupby(by=["continent", "year"])["pop"].transform("sum")
df_concat = pd.concat([df_transformed.assign(color="Continent Avg."), df.assign(color="Country")], ignore_index=True)


def create_benchmark_analysis():
"""Function returns a page to perform analysis on country level."""
# Apply formatting to table columns
columns = [
{"id": "country", "name": "country"},
{"id": "continent", "name": "continent"},
{"id": "year", "name": "year"},
{"id": "lifeExp", "name": "lifeExp", "type": "numeric", "format": {"specifier": ",.1f"}},
{"id": "gdpPercap", "name": "gdpPercap", "type": "numeric", "format": {"specifier": "$,.2f"}},
{"id": "pop", "name": "pop", "type": "numeric", "format": {"specifier": ",d"}},
]

page_country = vm.Page(
title="Benchmark Analysis",
# description="Discovering how the metrics differ for each country and export data for further investigation",
# layout=vm.Layout(grid=[[0, 1]] * 5 + [[2, -1]], col_gap="32px", row_gap="60px"),
components=[
vm.Table(
id="table_country_new",
title="Click on a cell in country column:",
figure=AgGrid(
data_frame=df,
),
actions=[vm.Action(function=filter_interaction(targets=["line_country"]))],
),
vm.Table(
id="table_country",
title="Click on a cell in country column:",
figure=dash_data_table(
id="dash_data_table_country",
data_frame=df,
columns=columns,
style_data_conditional=[
{
"if": {"filter_query": "{gdpPercap} < 1045", "column_id": "gdpPercap"},
"backgroundColor": "#ff9222",
},
{
"if": {
"filter_query": "{gdpPercap} >= 1045 && {gdpPercap} <= 4095",
"column_id": "gdpPercap",
},
"backgroundColor": "#de9e75",
},
{
"if": {
"filter_query": "{gdpPercap} > 4095 && {gdpPercap} <= 12695",
"column_id": "gdpPercap",
},
"backgroundColor": "#aaa9ba",
},
{
"if": {"filter_query": "{gdpPercap} > 12695", "column_id": "gdpPercap"},
"backgroundColor": "#00b4ff",
},
],
sort_action="native",
style_cell={"textAlign": "left"},
),
actions=[vm.Action(function=filter_interaction(targets=["line_country"]))],
),
vm.Graph(
id="line_country",
figure=px.line(
df_concat,
title="Country vs. Continent",
x="year",
y="gdpPercap",
color="color",
labels={"year": "Year", "data": "Data", "gdpPercap": "GDP per capita"},
color_discrete_map={"Country": "#afe7f9", "Continent": "#003875"},
markers=True,
hover_name="country",
),
),
vm.Button(
text="Export data",
actions=[
vm.Action(
function=export_data(
targets=["line_country"],
)
),
],
),
],
controls=[
vm.Filter(column="continent", selector=vm.Dropdown(value="Europe", multi=False, title="Select continent")),
vm.Filter(column="year", selector=vm.RangeSlider(title="Select timeframe", step=1, marks=None)),
vm.Parameter(
targets=["line_country.y"],
selector=vm.Dropdown(
options=["lifeExp", "gdpPercap", "pop"], multi=False, value="gdpPercap", title="Choose y-axis"
),
),
],
)
return page_country


dashboard = vm.Dashboard(
pages=[
create_benchmark_analysis(),
],
)

if __name__ == "__main__":
Vizro(assets_folder="../assets").build(dashboard).run()
3 changes: 2 additions & 1 deletion vizro-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ dependencies = [
"tornado>=6.3.2", # not directly required, pinned by Snyk to avoid a vulnerability: https://security.snyk.io/vuln/SNYK-PYTHON-TORNADO-5537286
"setuptools>=65.5.1", # not directly required, pinned by Snyk to avoid a vulnerability: https://security.snyk.io/vuln/SNYK-PYTHON-SETUPTOOLS-3180412
"werkzeug>=3.0.1", # not directly required, pinned by Snyk to avoid a vulnerability: https://security.snyk.io/vuln/SNYK-PYTHON-WERKZEUG-6035177
"MarkupSafe" # required to sanitize user input
"MarkupSafe", # required to sanitize user input,
"dash-ag-grid"
]
description = "Vizro is a package to facilitate visual analytics."
dynamic = ["version"]
Expand Down
31 changes: 29 additions & 2 deletions vizro-core/src/vizro/actions/_actions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _get_parent_vizro_model(_underlying_callable_object_id: str) -> VizroBaseMod
)


def _apply_table_filter_interaction(
def _apply_dashtable_filter_interaction(
data_frame: pd.DataFrame, target: str, ctd_filter_interaction: Dict[str, CallbackTriggerDict]
) -> pd.DataFrame:
ctd_active_cell = ctd_filter_interaction["active_cell"]
Expand All @@ -133,6 +133,26 @@ def _apply_table_filter_interaction(
return data_frame


def _apply_aggrid_filter_interaction(
data_frame: pd.DataFrame, target: str, ctd_filter_interaction: Dict[str, CallbackTriggerDict]
) -> pd.DataFrame:
ctd_cellClicked = ctd_filter_interaction["cellClicked"]
if not ctd_cellClicked["value"]:
return data_frame

# ctd_active_cell["id"] represents the underlying table id, so we need to fetch its parent Vizro Table actions.
source_table_actions = _get_component_actions(_get_parent_vizro_model(ctd_cellClicked["id"]))

for action in source_table_actions:
if action.function._function.__name__ != "filter_interaction" or target not in action.function["targets"]:
continue
column = ctd_cellClicked["value"]["colId"]
clicked_data = ctd_cellClicked["value"]["value"]
data_frame = data_frame[data_frame[column].isin([clicked_data])]

return data_frame


def _apply_filter_interaction(
data_frame: pd.DataFrame,
ctds_filter_interaction: List[Dict[str, CallbackTriggerDict]],
Expand All @@ -147,7 +167,14 @@ def _apply_filter_interaction(
)

if "active_cell" in ctd_filter_interaction and "derived_viewport_data" in ctd_filter_interaction:
data_frame = _apply_table_filter_interaction(
data_frame = _apply_dashtable_filter_interaction(
data_frame=data_frame,
target=target,
ctd_filter_interaction=ctd_filter_interaction,
)

if "cellClicked" in ctd_filter_interaction:
data_frame = _apply_aggrid_filter_interaction(
data_frame=data_frame,
target=target,
ctd_filter_interaction=ctd_filter_interaction,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from vizro.actions import _on_page_load, _parameter, export_data, filter_interaction
from vizro.managers import data_manager, model_manager
from vizro.managers._model_manager import ModelID
from vizro.models import Action, Page, Table, VizroBaseModel
from vizro.models import Action, Page, VizroBaseModel
from vizro.models._action._actions_chain import ActionsChain
from vizro.models._controls import Filter, Parameter
from vizro.models.types import ControlType
Expand Down Expand Up @@ -108,18 +108,27 @@ def _get_inputs_of_figure_interactions(
for action in figure_interactions_on_page:
# TODO: Consider do we want to move the following logic into Model implementation
triggered_model = _get_triggered_model(action_id=ModelID(str(action.id)))
if isinstance(triggered_model, Table):
inputs.append(
{
"active_cell": State(
component_id=triggered_model._callable_object_id, component_property="active_cell"
),
"derived_viewport_data": State(
component_id=triggered_model._callable_object_id,
component_property="derived_viewport_data",
),
}
)
if hasattr(triggered_model, "table_type"): # not check this, put this configuration inside the models
if triggered_model.table_type == "DataTable":
inputs.append(
{
"active_cell": State(
component_id=triggered_model._callable_object_id, component_property="active_cell"
),
"derived_viewport_data": State(
component_id=triggered_model._callable_object_id,
component_property="derived_viewport_data",
),
}
)
elif triggered_model.table_type == "AgGrid":
inputs.append(
{
"cellClicked": State(
component_id=triggered_model._callable_object_id, component_property="cellClicked"
),
}
)
else:
inputs.append(
{
Expand Down
47 changes: 44 additions & 3 deletions vizro-core/src/vizro/models/_components/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,33 @@
from pandas import DataFrame

try:
from pydantic.v1 import Field, PrivateAttr, validator
from pydantic.v1 import Field, PrivateAttr, root_validator, validator
except ImportError: # pragma: no cov
from pydantic import Field, PrivateAttr, validator

import vizro.tables as vt
from vizro.managers import data_manager
from vizro.models import Action, VizroBaseModel
from vizro.models._action._actions_chain import _action_validator_factory
from vizro.models._action._actions_chain import _set_actions
from vizro.models._components._components_utils import _process_callable_data_frame
from vizro.models._models_utils import _log_call
from vizro.models.types import CapturedCallable

logger = logging.getLogger(__name__)


def _get_table_type(figure): # this function can be applied also in pre-build
kwargs = figure._arguments.copy()

# This workaround is needed because the underlying table object requires a data_frame
kwargs["data_frame"] = DataFrame()

# The underlying table object is pre-built, so we can fetch its ID.
underlying_table_object = figure._function(**kwargs)
table_type = underlying_table_object.__class__.__name__
return table_type


class Table(VizroBaseModel):
"""Wrapper for table components to visualize in dashboard.
Expand All @@ -34,17 +46,42 @@ class Table(VizroBaseModel):
type: Literal["table"] = "table"
figure: CapturedCallable = Field(..., import_path=vt, description="Table to be visualized on dashboard")
title: str = Field("", description="Title of the table")
# foo: str = Field(None, exclude=True)
actions: List[Action] = []

_callable_object_id: str = PrivateAttr()
_table_type: str = (
PrivateAttr()
) # Ideally we would be able to use the populated content of this field in the `set_actions` validator.

# Component properties for actions and interactions
_output_property: str = PrivateAttr("children")

# validator
set_actions = _action_validator_factory("active_cell")
_validate_callable = validator("figure", allow_reuse=True, always=True)(_process_callable_data_frame)

@validator("actions")
def set_actions(cls, v, values):
table_type = _get_table_type(values["figure"])
if table_type == "DataTable":
return _set_actions(v, values, "active_cell")
elif table_type == "AgGrid":
return _set_actions(v, values, "cellClicked")
else:
raise ValueError(f"Table type {table_type} not supported.")

# set_actions = _action_validator_factory("cellClicked") # Need to make this sit with the captured callable

# Approach similar to layout model - need to confirm if we can do without __init__ and populate at another time
def __init__(self, **data):
super().__init__(**data)
self._table_type = _get_table_type(self.figure)

@property
def table_type(self):
return self._table_type


# Convenience wrapper/syntactic sugar.
def __call__(self, **kwargs):
kwargs.setdefault("data_frame", data_manager._get_component_data(self.id))
Expand All @@ -68,6 +105,7 @@ def pre_build(self):

# The underlying table object is pre-built, so we can fetch its ID.
underlying_table_object = self.figure._function(**kwargs)
table_type = underlying_table_object.__class__.__name__

if not hasattr(underlying_table_object, "id"):
raise ValueError(
Expand All @@ -76,6 +114,9 @@ def pre_build(self):
)

self._callable_object_id = underlying_table_object.id
self._table_type = table_type
# Idea: fetch it from the functions attributes? Or just hard-code it here? Can check difference between AGGrid and dashtable because we call it already
# Once we recognise, two ways to go: 1) slightly change model properties 2) inject dash dependencies,

def build(self):
return dcc.Loading(
Expand Down

0 comments on commit 96b6259

Please sign in to comment.