diff --git a/CHANGELOG.md b/CHANGELOG.md index 500c10a09f..1e2ac21663 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co **Improved** - New model: `StatsForecastAutoTBATS`. This model offers the [AutoTBATS](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#autotbats) model from Nixtla's `statsforecasts` library. [#2611](https://github.com/unit8co/darts/pull/2611) by [He Weilin](https://github.com/cnhwl). -- Added `multivariate_plot` parameter in `show_anomalies()` to separately plot each component in multivariate series. [#2544](https://github.com/unit8co/darts/pull/2544) by [He Weilin](https://github.com/cnhwl). +- Added parameter `component_wise` to `show_anomalies()` to separately plot each component in multivariate series. [#2544](https://github.com/unit8co/darts/pull/2544) by [He Weilin](https://github.com/cnhwl). **Fixed** - Fixed a bug when performing optimized historical forecasts with `stride=1` using a `RegressionModel` with `output_chunk_shift>=1` and `output_chunk_length=1`, where the forecast time index was not properly shifted. [#2634](https://github.com/unit8co/darts/pull/2634) by [Mattias De Charleroy](https://github.com/MattiasDC). diff --git a/darts/ad/anomaly_model/anomaly_model.py b/darts/ad/anomaly_model/anomaly_model.py index 1b13d0553a..63655db40c 100644 --- a/darts/ad/anomaly_model/anomaly_model.py +++ b/darts/ad/anomaly_model/anomaly_model.py @@ -285,7 +285,7 @@ def show_anomalies( score_kwargs parameters for the `score()` method. component_wise - If True, it will separately plot each component in multivariate series. + If True, will separately plot each component in case of multivariate anomaly detection. """ series = _check_input(series, name="series", num_series_expected=1)[0] predict_kwargs = predict_kwargs if predict_kwargs is not None else {} diff --git a/darts/ad/anomaly_model/forecasting_am.py b/darts/ad/anomaly_model/forecasting_am.py index 318fe3361a..8b4339cd9c 100644 --- a/darts/ad/anomaly_model/forecasting_am.py +++ b/darts/ad/anomaly_model/forecasting_am.py @@ -508,7 +508,7 @@ def show_anomalies( Receiver Operating Characteristic Curve) and "AUC_PR" (Average Precision from scores). Default: "AUC_ROC". component_wise - If True, it will separately plot each component in multivariate series. + If True, will separately plot each component in case of multivariate anomaly detection. score_kwargs parameters for the `score()` method. """ diff --git a/darts/ad/scorers/scorers.py b/darts/ad/scorers/scorers.py index f5887c8314..3fadee463a 100644 --- a/darts/ad/scorers/scorers.py +++ b/darts/ad/scorers/scorers.py @@ -210,7 +210,7 @@ def show_anomalies_from_prediction( Receiver Operating Characteristic Curve) and "AUC_PR" (Average Precision from scores). Default: "AUC_ROC". component_wise - If True, it will separately plot each component in multivariate series. + If True, will separately plot each component in case of multivariate anomaly detection. """ series = _check_input(series, name="series", num_series_expected=1)[0] pred_series = _check_input( @@ -616,7 +616,7 @@ def show_anomalies( Receiver Operating Characteristic Curve) and "AUC_PR" (Average Precision from scores). Default: "AUC_ROC". component_wise - If True, it will separately plot each component in multivariate series. + If True, will separately plot each component in case of multivariate anomaly detection. """ series = _check_input(series, name="series", num_series_expected=1)[0] pred_scores = self.score(series) diff --git a/darts/ad/utils.py b/darts/ad/utils.py index f338d54b79..4395afdfeb 100644 --- a/darts/ad/utils.py +++ b/darts/ad/utils.py @@ -354,7 +354,7 @@ def show_anomalies_from_scores( Only effective when `pred_scores` is not `None`. Default: "AUC_ROC". component_wise - If True, it will separately plot each component in multivariate series. + If True, will separately plot each component in case of multivariate anomaly detection. """ series = _check_input( series, @@ -457,15 +457,13 @@ def show_anomalies_from_scores( )[0] plots_per_ts = nbr_plots * series_width if component_wise else nbr_plots + height_ratios = ([2] + [1] * (nbr_plots - 1)) * (plots_per_ts // nbr_plots) + height_total = 2 * sum(height_ratios) fig, axs = plt.subplots( - plots_per_ts, - figsize=(8, 4 * (plots_per_ts // nbr_plots) + 2 * (nbr_plots - 1)), + nrows=plots_per_ts, + figsize=(8, height_total), sharex=True, - gridspec_kw={ - "height_ratios": ([2] + [1] * (nbr_plots - 1)) * (plots_per_ts // nbr_plots) - }, - squeeze=False, - layout="constrained", + gridspec_kw={"height_ratios": height_ratios}, ) for i in range(series_width if component_wise else 1): @@ -500,9 +498,13 @@ def show_anomalies_from_scores( metric=metric, axs=axs, index_ax=i * nbr_plots, - nbr_plots=nbr_plots, ) - fig.suptitle(title) + # make title fit nicely on plot + title_height = 0.1 + title_y = 1 - title_height / height_total + + fig.suptitle(title, y=title_y) + fig.tight_layout() def _assert_binary(series: TimeSeries, name: str): @@ -774,7 +776,6 @@ def _plot_series_and_anomalies( metric: str, axs: plt.Axes, index_ax: int, - nbr_plots: int, ): """Helper function to plot series and anomalies. @@ -798,25 +799,23 @@ def _plot_series_and_anomalies( The axes to plot on. index_ax The index of the current axis. - nbr_plots - The number of plots. """ - _plot_series(series=series, ax_id=axs[index_ax][0], linewidth=0.5, label_name="") + _plot_series(series=series, ax_id=axs[index_ax], linewidth=0.5, label_name="") if pred_series is not None: _plot_series( series=pred_series, - ax_id=axs[index_ax][0], + ax_id=axs[index_ax], linewidth=0.5, label_name="model output", ) - axs[index_ax][0].set_title("") + axs[index_ax].set_title("") if anomalies is not None or pred_scores is not None: - axs[index_ax][0].set_xlabel("") + axs[index_ax].set_xlabel("") - axs[index_ax][0].legend(loc="upper center", bbox_to_anchor=(0.5, 1.1), ncol=2) + axs[index_ax].legend(loc="upper center", bbox_to_anchor=(0.5, 1.1), ncol=2) if pred_scores is not None: dict_input = {} @@ -858,33 +857,29 @@ def _plot_series_and_anomalies( _plot_series( series=elem[1]["series_score"], - ax_id=axs[index_ax][0], + ax_id=axs[index_ax], linewidth=0.5, label_name=label, ) - axs[index_ax][0].legend( - loc="upper center", bbox_to_anchor=(0.5, 1.19), ncol=2 - ) - axs[index_ax][0].set_title(f"Window: {str(w)}", loc="left") - axs[index_ax][0].set_title("") - axs[index_ax][0].set_xlabel("") + axs[index_ax].legend(loc="upper center", bbox_to_anchor=(0.5, 1.19), ncol=2) + axs[index_ax].set_title(f"Window: {str(w)}", loc="left") + axs[index_ax].set_title("") + axs[index_ax].set_xlabel("") if anomalies is not None: _plot_series( series=anomalies, - ax_id=axs[index_ax + 1][0], + ax_id=axs[index_ax + 1], linewidth=1, label_name="anomalies", color="red", ) - axs[index_ax + 1][0].set_title("") - axs[index_ax + 1][0].set_ylim([-0.1, 1.1]) - axs[index_ax + 1][0].set_yticks([0, 1]) - axs[index_ax + 1][0].set_yticklabels(["no", "yes"]) - axs[index_ax + 1][0].legend( - loc="upper center", bbox_to_anchor=(0.5, 1.2), ncol=2 - ) + axs[index_ax + 1].set_title("") + axs[index_ax + 1].set_ylim([-0.1, 1.1]) + axs[index_ax + 1].set_yticks([0, 1]) + axs[index_ax + 1].set_yticklabels(["no", "yes"]) + axs[index_ax + 1].legend(loc="upper center", bbox_to_anchor=(0.5, 1.2), ncol=2) else: - axs[index_ax][0].set_xlabel("timestamp") + axs[index_ax].set_xlabel("timestamp")