Skip to content

Commit

Permalink
make title fit better
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisbader committed Jan 8, 2025
1 parent 78baeec commit 95dd225
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 39 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
**Improved**

- New model: `StatsForecastAutoTBATS`. This model offers the [AutoTBATS](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#autotbats) model from Nixtla's `statsforecasts` library. [#2611](https://github.com/unit8co/darts/pull/2611) by [He Weilin](https://github.com/cnhwl).
- Added `multivariate_plot` parameter in `show_anomalies()` to separately plot each component in multivariate series. [#2544](https://github.com/unit8co/darts/pull/2544) by [He Weilin](https://github.com/cnhwl).
- Added parameter `component_wise` to `show_anomalies()` to separately plot each component in multivariate series. [#2544](https://github.com/unit8co/darts/pull/2544) by [He Weilin](https://github.com/cnhwl).

**Fixed**
- Fixed a bug when performing optimized historical forecasts with `stride=1` using a `RegressionModel` with `output_chunk_shift>=1` and `output_chunk_length=1`, where the forecast time index was not properly shifted. [#2634](https://github.com/unit8co/darts/pull/2634) by [Mattias De Charleroy](https://github.com/MattiasDC).
Expand Down
2 changes: 1 addition & 1 deletion darts/ad/anomaly_model/anomaly_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def show_anomalies(
score_kwargs
parameters for the `score()` method.
component_wise
If True, it will separately plot each component in multivariate series.
If True, will separately plot each component in case of multivariate anomaly detection.
"""
series = _check_input(series, name="series", num_series_expected=1)[0]
predict_kwargs = predict_kwargs if predict_kwargs is not None else {}
Expand Down
2 changes: 1 addition & 1 deletion darts/ad/anomaly_model/forecasting_am.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def show_anomalies(
Receiver Operating Characteristic Curve) and "AUC_PR" (Average Precision from scores).
Default: "AUC_ROC".
component_wise
If True, it will separately plot each component in multivariate series.
If True, will separately plot each component in case of multivariate anomaly detection.
score_kwargs
parameters for the `score()` method.
"""
Expand Down
4 changes: 2 additions & 2 deletions darts/ad/scorers/scorers.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def show_anomalies_from_prediction(
Receiver Operating Characteristic Curve) and "AUC_PR" (Average Precision from scores).
Default: "AUC_ROC".
component_wise
If True, it will separately plot each component in multivariate series.
If True, will separately plot each component in case of multivariate anomaly detection.
"""
series = _check_input(series, name="series", num_series_expected=1)[0]
pred_series = _check_input(
Expand Down Expand Up @@ -616,7 +616,7 @@ def show_anomalies(
Receiver Operating Characteristic Curve) and "AUC_PR" (Average Precision from scores).
Default: "AUC_ROC".
component_wise
If True, it will separately plot each component in multivariate series.
If True, will separately plot each component in case of multivariate anomaly detection.
"""
series = _check_input(series, name="series", num_series_expected=1)[0]
pred_scores = self.score(series)
Expand Down
63 changes: 29 additions & 34 deletions darts/ad/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def show_anomalies_from_scores(
Only effective when `pred_scores` is not `None`.
Default: "AUC_ROC".
component_wise
If True, it will separately plot each component in multivariate series.
If True, will separately plot each component in case of multivariate anomaly detection.
"""
series = _check_input(
series,
Expand Down Expand Up @@ -457,15 +457,13 @@ def show_anomalies_from_scores(
)[0]

plots_per_ts = nbr_plots * series_width if component_wise else nbr_plots
height_ratios = ([2] + [1] * (nbr_plots - 1)) * (plots_per_ts // nbr_plots)
height_total = 2 * sum(height_ratios)
fig, axs = plt.subplots(
plots_per_ts,
figsize=(8, 4 * (plots_per_ts // nbr_plots) + 2 * (nbr_plots - 1)),
nrows=plots_per_ts,
figsize=(8, height_total),
sharex=True,
gridspec_kw={
"height_ratios": ([2] + [1] * (nbr_plots - 1)) * (plots_per_ts // nbr_plots)
},
squeeze=False,
layout="constrained",
gridspec_kw={"height_ratios": height_ratios},
)

for i in range(series_width if component_wise else 1):
Expand Down Expand Up @@ -500,9 +498,13 @@ def show_anomalies_from_scores(
metric=metric,
axs=axs,
index_ax=i * nbr_plots,
nbr_plots=nbr_plots,
)
fig.suptitle(title)
# make title fit nicely on plot
title_height = 0.1
title_y = 1 - title_height / height_total

fig.suptitle(title, y=title_y)
fig.tight_layout()


def _assert_binary(series: TimeSeries, name: str):
Expand Down Expand Up @@ -774,7 +776,6 @@ def _plot_series_and_anomalies(
metric: str,
axs: plt.Axes,
index_ax: int,
nbr_plots: int,
):
"""Helper function to plot series and anomalies.
Expand All @@ -798,25 +799,23 @@ def _plot_series_and_anomalies(
The axes to plot on.
index_ax
The index of the current axis.
nbr_plots
The number of plots.
"""
_plot_series(series=series, ax_id=axs[index_ax][0], linewidth=0.5, label_name="")
_plot_series(series=series, ax_id=axs[index_ax], linewidth=0.5, label_name="")

if pred_series is not None:
_plot_series(
series=pred_series,
ax_id=axs[index_ax][0],
ax_id=axs[index_ax],
linewidth=0.5,
label_name="model output",
)

axs[index_ax][0].set_title("")
axs[index_ax].set_title("")

if anomalies is not None or pred_scores is not None:
axs[index_ax][0].set_xlabel("")
axs[index_ax].set_xlabel("")

axs[index_ax][0].legend(loc="upper center", bbox_to_anchor=(0.5, 1.1), ncol=2)
axs[index_ax].legend(loc="upper center", bbox_to_anchor=(0.5, 1.1), ncol=2)

if pred_scores is not None:
dict_input = {}
Expand Down Expand Up @@ -858,33 +857,29 @@ def _plot_series_and_anomalies(

_plot_series(
series=elem[1]["series_score"],
ax_id=axs[index_ax][0],
ax_id=axs[index_ax],
linewidth=0.5,
label_name=label,
)

axs[index_ax][0].legend(
loc="upper center", bbox_to_anchor=(0.5, 1.19), ncol=2
)
axs[index_ax][0].set_title(f"Window: {str(w)}", loc="left")
axs[index_ax][0].set_title("")
axs[index_ax][0].set_xlabel("")
axs[index_ax].legend(loc="upper center", bbox_to_anchor=(0.5, 1.19), ncol=2)
axs[index_ax].set_title(f"Window: {str(w)}", loc="left")
axs[index_ax].set_title("")
axs[index_ax].set_xlabel("")

if anomalies is not None:
_plot_series(
series=anomalies,
ax_id=axs[index_ax + 1][0],
ax_id=axs[index_ax + 1],
linewidth=1,
label_name="anomalies",
color="red",
)

axs[index_ax + 1][0].set_title("")
axs[index_ax + 1][0].set_ylim([-0.1, 1.1])
axs[index_ax + 1][0].set_yticks([0, 1])
axs[index_ax + 1][0].set_yticklabels(["no", "yes"])
axs[index_ax + 1][0].legend(
loc="upper center", bbox_to_anchor=(0.5, 1.2), ncol=2
)
axs[index_ax + 1].set_title("")
axs[index_ax + 1].set_ylim([-0.1, 1.1])
axs[index_ax + 1].set_yticks([0, 1])
axs[index_ax + 1].set_yticklabels(["no", "yes"])
axs[index_ax + 1].legend(loc="upper center", bbox_to_anchor=(0.5, 1.2), ncol=2)
else:
axs[index_ax][0].set_xlabel("timestamp")
axs[index_ax].set_xlabel("timestamp")

0 comments on commit 95dd225

Please sign in to comment.