Skip to content

Commit

Permalink
Added support for the new dir structure in the dashboard.
Browse files Browse the repository at this point in the history
Signed-off-by: TheRootOf3 <[email protected]>
  • Loading branch information
TheRootOf3 committed Sep 19, 2024
1 parent 68929fa commit a64f2a5
Showing 1 changed file with 118 additions and 62 deletions.
180 changes: 118 additions & 62 deletions results_viz/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,29 +33,60 @@ def parse_args():
f"Experiment: {args.experiment_path}",
style={"textAlign": "center"},
),
html.P("x axis/further filter:", style={"font-weight": "bold"}),
html.P("x axis", style={"font-weight": "bold"}),
dcc.RadioItems(
id="column-radios",
id="x-axis-radios",
options=[
{"label": "sample_size/splits", "value": 1},
{"label": "splits/sample_size", "value": 2},
{"label": "sample size", "value": 0},
{"label": "splits", "value": 1},
{"label": "epoch count", "value": 2},
],
value=1,
value=0,
labelStyle={"display": "block"},
),
html.P("further filter value:", style={"font-weight": "bold"}),
html.P("sample size", style={"font-weight": "bold"}),
dcc.Checklist(
id="further-checkbox",
options=["Show baselines", "Show all"],
id="sample-size-checkbox",
options=["Show all"],
value=[],
labelStyle={"width": "50%", "display": "block"},
),
dcc.Dropdown(
id="further-value-dd",
id="sample-size-dd",
options=df.sample_size.unique(),
value=df.sample_size.unique()[0],
),
html.P("splits", style={"font-weight": "bold"}),
dcc.Checklist(
id="splits-checkbox",
options=["Show all"],
value=[],
labelStyle={"width": "50%", "display": "block"},
),
dcc.Dropdown(
id="splits-dd",
options=df.splits.unique(),
value=df.splits.unique()[0],
),
html.P("epoch count", style={"font-weight": "bold"}),
dcc.Checklist(
id="epoch-count-checkbox",
options=["Show all"],
value=[],
labelStyle={"width": "50%", "display": "block"},
),
dcc.Dropdown(
id="epoch-count-dd",
options=df.epoch_count.unique(),
value=df.epoch_count.unique()[0],
),
html.P("series to display:", style={"font-weight": "bold"}),
dcc.Checklist(
id="baselines-checkbox",
options=["Show baselines"],
value=[],
labelStyle={"width": "50%", "display": "block"},
),
html.Button("Clear All", id="clear-button"), # Clear all button
html.Button("All benchmarks", id="benchmarks-button"), # Set all benchmarks button
dcc.Checklist(
Expand All @@ -69,13 +100,34 @@ def parse_args():


@app.callback(
Output("further-value-dd", "disabled"),
Input("further-checkbox", "value"),
Output("sample-size-dd", "disabled"),
Output("splits-dd", "disabled"),
Output("epoch-count-dd", "disabled"),
Output("sample-size-checkbox", "options"),
Output("splits-checkbox", "options"),
Output("epoch-count-checkbox", "options"),
Input("x-axis-radios", "value"),
Input("sample-size-checkbox", component_property="value"),
Input("splits-checkbox", component_property="value"),
Input("epoch-count-checkbox", component_property="value"),
)
def use_further_checkbox(value):
if "Show all" in value:
return True
return False
def select_x_axis(x_axis, sample_size_cbox, splits_cbox, epoch_count_cbox):
output_dd = [False, False, False]
output_cbox = [False, False, False]
output_cbox[x_axis] = True
output_dd[x_axis] = True
if "Show all" in sample_size_cbox:
output_dd[0] = True
if "Show all" in splits_cbox:
output_dd[1] = True
if "Show all" in epoch_count_cbox:
output_dd[2] = True
return (
*output_dd,
[{"label": "Show all", "value": "Show all", "disabled": output_cbox[0]}],
[{"label": "Show all", "value": "Show all", "disabled": output_cbox[1]}],
[{"label": "Show all", "value": "Show all", "disabled": output_cbox[2]}],
)


@app.callback(
Expand All @@ -95,67 +147,71 @@ def clear_checklist(but1, but2):
return df.columns[:12]


@app.callback(
Output("further-value-dd", "options"),
Input("column-radios", "value"),
)
def clear_checklist(value):
if value == 1:
return df["splits"].unique()
return df["sample_size"].unique()


@callback(
Output("graph-content", "figure"),
[
Input("column-radios", "value"),
Input("further-value-dd", "value"),
Input("further-checkbox", "value"),
Input("x-axis-radios", "value"),
Input("sample-size-dd", "value"),
Input("splits-dd", "value"),
Input("epoch-count-dd", "value"),
Input("sample-size-checkbox", "value"),
Input("splits-checkbox", "value"),
Input("epoch-count-checkbox", "value"),
Input("baselines-checkbox", "value"),
Input("column-checkboxes", "value"),
],
)
def update_graph(
x_axis_further, further_filter_value, further_checkbox_value, series_names
x_axis, ss_dd, s_dd, ec_dd, ss_cbox, s_cbox, ec_cbox, baselines_cbox, series_names
):
x_axis, further_filter_name = "sample_size", "splits"
filters_dict = {}

if x_axis_further == 2:
x_axis, further_filter_name = "splits", "sample_size"
if x_axis == 0:
x_axis_name = "sample_size"
elif x_axis == 1:
x_axis_name = "splits"
else:
x_axis_name = "epoch_count"

further_series_values = (
df[further_filter_name].unique()
if "Show all" in further_checkbox_value
else [further_filter_value]
filters_dict["sample_size"] = (
df.sample_size.unique() if "Show all" in ss_cbox else [ss_dd]
)
filters_dict["splits"] = df.splits.unique() if "Show all" in s_cbox else [s_dd]
filters_dict["epoch_count"] = (
df.epoch_count.unique() if "Show all" in ec_cbox else [ec_dd]
)

fig = go.Figure()

for further_series_value in further_series_values:
dff = df[df[further_filter_name] == further_series_value].sort_values(by=x_axis)
del filters_dict[x_axis_name]

for series_name in series_names:
fig.add_trace(
go.Scatter(
x=dff[x_axis],
y=(
dff[series_name]
if series_name != "squadv2_f1"
else dff[series_name] / 100
),
mode="lines+markers",
marker=dict(
size=10,
symbol="circle",
),
name=(
f"{series_name}_{further_series_value}"
if len(further_series_values) > 1
else series_name
),
fig = go.Figure()
filter1_name, filter1_values = list(filters_dict.items())[0]
filter2_name, filter2_values = list(filters_dict.items())[1]

for val1 in filter1_values:
for val2 in filter2_values:
dff = df[
(df[filter1_name] == val1) & (df[filter2_name] == val2)
].sort_values(by=x_axis_name)

for series_name in series_names:
fig.add_trace(
go.Scatter(
x=dff[x_axis_name],
y=(
dff[series_name]
if series_name != "squadv2_f1"
else dff[series_name] / 100
),
mode="lines+markers",
marker=dict(
size=10,
symbol="circle",
),
name=(f"{series_name}_{val1}_{val2}"),
)
)
)

if "Show baselines" in further_checkbox_value:
if "Show baselines" in baselines_cbox:
for series_name in series_names:
fig.add_hline(
y=(
Expand All @@ -167,7 +223,7 @@ def update_graph(
line_dash="dash",
)

fig.update_layout(xaxis=dict(tickvals=dff[x_axis].unique()))
fig.update_layout(xaxis=dict(tickvals=dff[x_axis_name].unique()))
return fig


Expand Down

0 comments on commit a64f2a5

Please sign in to comment.