Skip to content

Commit

Permalink
Got bar charts working with error bars
Browse files Browse the repository at this point in the history
  • Loading branch information
EdmundGoodman committed Mar 3, 2024
1 parent 2d0f288 commit 056b82c
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 38 deletions.
90 changes: 53 additions & 37 deletions src/hpc_multibench/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-
"""A set of functions to analyse the results of a test bench run."""

from collections.abc import Iterator
from enum import Enum, auto

from hpc_multibench.roofline_model import RooflineDataModel
Expand Down Expand Up @@ -29,12 +30,21 @@ class PlotStyle(Enum):
# from labellines import labelLines

if PLOT_STYLE == PlotStyle.SEABORN:
import pandas as pd
import seaborn as sns

sns.set_theme()


def get_metrics_uncertainties_iterator(
run_metrics: list[tuple[RunConfiguration, dict[str, str | float]]],
run_uncertainties: list[tuple[RunConfiguration, dict[str, float | None]]],
) -> Iterator[tuple[RunConfiguration, dict[str, str | float], dict[str, float | None]]]:
"""Get an iterator of metrics and uncertainties in a helpful shape."""
zipped_data = zip(run_metrics, run_uncertainties, strict=True)
for (run_configuration, metrics), (_, uncertainties) in zipped_data:
yield (run_configuration, metrics, uncertainties)


def get_line_plot_data(
plot: LinePlotModel,
run_metrics: list[tuple[RunConfiguration, dict[str, str | float]]],
Expand All @@ -43,8 +53,8 @@ def get_line_plot_data(
"""Get the data needed to plot a specified line plot for a set of runs."""
# Reshape the metrics data from multiple runs into groups of points
data: dict[str, list[tuple[float, float, float | None, float | None]]] = {}
for (run_configuration, metrics), (_, uncertainties) in zip(
run_metrics, run_uncertainties, strict=True
for run_configuration, metrics, uncertainties in get_metrics_uncertainties_iterator(
run_metrics, run_uncertainties
):
split_names: list[str] = [
f"{split_metric}={metrics[split_metric]}"
Expand Down Expand Up @@ -123,12 +133,17 @@ def draw_line_plot(
def get_bar_chart_data(
plot: BarChartModel,
run_metrics: list[tuple[RunConfiguration, dict[str, str | float]]],
) -> dict[tuple[str, ...], float]:
run_uncertainties: list[tuple[RunConfiguration, dict[str, float | None]]],
) -> dict[str, tuple[float, float | None, int]]:
"""Get the data needed to plot a specified bar chart for a set of runs."""
data: dict[tuple[str, ...], float] = {} # {("a", "b"): 1.0, ("a", "c"): 2.0}
data: dict[str, tuple[float, float | None, int]] = {}

# Extract the outputs into the data format needed for the line plot
for run_configuration, metrics in run_metrics:
hue_index_lookup: dict[str, int] = {}
new_hue_index = 0
for run_configuration, metrics, uncertainties in get_metrics_uncertainties_iterator(
run_metrics, run_uncertainties
):
split_names: list[str] = [
f"{split_metric}={metrics[split_metric]}"
for split_metric in plot.split_metrics
Expand All @@ -137,57 +152,58 @@ def get_bar_chart_data(
fix_names: list[str] = [
f"{metric}={value}" for metric, value in plot.fix_metrics.items()
]
series_name = (run_configuration.name, *fix_names, *split_names)
series_name = ", ".join([run_configuration.name, *fix_names, *split_names])
if any(
metrics[metric] != str(value) for metric, value in plot.fix_metrics.items()
):
continue

data[series_name] = float(metrics[plot.y])
if run_configuration.name not in hue_index_lookup:
hue_index_lookup[run_configuration.name] = new_hue_index
new_hue_index += 1

data[series_name] = (
float(metrics[plot.y]),
uncertainties[plot.y],
hue_index_lookup[run_configuration.name],
)
return data


def draw_bar_chart(
plot: BarChartModel,
run_metrics: list[tuple[RunConfiguration, dict[str, str | float]]],
run_uncertainties: list[tuple[RunConfiguration, dict[str, float | None]]],
) -> None:
"""Draw a specified bar chart for a set of run outputs."""
data = get_bar_chart_data(plot, run_metrics)
data = get_bar_chart_data(plot, run_metrics, run_uncertainties)

if PLOT_STYLE == PlotStyle.SEABORN:
dataframe = pd.DataFrame(
{
"Run Configuration": [",\n".join(key) for key in data],
plot.y: list(data.values()),
"Run Type": [key[0] for key in data],
}
)
sns.barplot(
data=dataframe.sort_values(plot.y),
x="Run Configuration",
y=plot.y,
hue="Run Type",
)
plt.xticks(rotation=45, ha="right")
plt.gcf().subplots_adjust(bottom=0.25)
elif PLOT_STYLE == PlotStyle.PLOTEXT:
if PLOT_STYLE == PlotStyle.PLOTEXT:
plt.clear_figure()
shaped_data: list[tuple[str, float]] = sorted(
[(", ".join(name), metric) for name, metric in data.items()],
key=lambda x: x[1],
# NOTE: Plotext cannot render error bars!
plt.bar(
data.keys(),
[metric for metric, _, _ in data.values()],
orientation="horizontal",
width=3 / 5,
)
plt.bar(*zip(*shaped_data, strict=True), orientation="horizontal", width=3 / 5)
plt.ylabel(plot.y)
plt.theme(PLOTEXT_THEME)
else:
newline_shaped_data: list[tuple[str, float]] = sorted(
[(",\n".join(name), metric) for name, metric in data.items()],
key=lambda x: x[1],
palette = (
sns.color_palette()
if PLOT_STYLE == PlotStyle.SEABORN
else plt.rcParams["axes.prop_cycle"].by_key()["color"]
)
plt.bar(*zip(*newline_shaped_data, strict=True))
plt.ylabel(plot.y)
plt.xticks(rotation=45, ha="right")
plt.gcf().subplots_adjust(bottom=0.25)
plt.barh(
list(data.keys()),
[metric for metric, _, _ in data.values()],
xerr=[uncertainty for _, uncertainty, _ in data.values()],
color=[palette[hue] for _, _, hue in data.values()],
ecolor="black",
)
plt.xlabel(plot.y)
plt.gcf().subplots_adjust(left=0.25)

plt.title(plot.title)
plt.show()
Expand Down
2 changes: 1 addition & 1 deletion src/hpc_multibench/test_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def report(self) -> None: # noqa: C901, PLR0912, PLR0915
draw_line_plot(line_plot, run_metrics, run_uncertainties)

for bar_chart in self.bench_model.analysis.bar_charts:
draw_bar_chart(bar_chart, run_metrics)
draw_bar_chart(bar_chart, run_metrics, run_uncertainties)

for roofline_plot in self.bench_model.analysis.roofline_plots:
draw_roofline_plot(roofline_plot, run_metrics)

0 comments on commit 056b82c

Please sign in to comment.