Skip to content

Commit

Permalink
Improve code in utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
cnhwl committed Jan 2, 2025
1 parent fb343e9 commit da1e644
Showing 1 changed file with 33 additions and 39 deletions.
72 changes: 33 additions & 39 deletions darts/ad/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,25 +357,17 @@ def show_anomalies_from_scores(
If True, it will separately plot each component in multivariate series.
"""

series = (
_check_input(
series,
name="series",
check_multivariate=True,
)[0]
if multivariate_plot
else _check_input(
series,
name="series",
num_series_expected=1,
)[0]
)
series = _check_input(
series,
name="series",
num_series_expected=1,
check_multivariate=multivariate_plot,
)[0]

if title is None and pred_scores is not None:
title = "Anomaly results"

nbr_plots = 1

if anomalies is not None:
nbr_plots = nbr_plots + 1
elif metric is not None:
Expand Down Expand Up @@ -433,7 +425,7 @@ def show_anomalies_from_scores(
logger=logger,
)

nbr_plots = nbr_plots + len(set(window))
nbr_plots += len(set(window))
series_width = series.n_components
plots_per_ts = nbr_plots * series_width if multivariate_plot else nbr_plots
fig, axs = plt.subplots(
Expand All @@ -444,33 +436,36 @@ def show_anomalies_from_scores(
squeeze=False,
)

if multivariate_plot:
if pred_series is not None:
pred_series = _check_input(
pred_series,
name="pred_series",
width_expected=series.width,
check_multivariate=True,
)[0]
if pred_series is not None:
pred_series = _check_input(
pred_series,
name="pred_series",
width_expected=series.width,
num_series_expected=1,
check_multivariate=True,
)[0]

if anomalies is not None:
anomalies = _check_input(
anomalies,
name="anomalies",
if anomalies is not None:
anomalies = _check_input(
anomalies,
name="anomalies",
width_expected=series.width,
num_series_expected=1,
check_binary=True,
check_multivariate=True,
)[0]

if pred_scores is not None:
for pred_score in pred_scores:
pred_score = _check_input(
pred_score,
name="pred_score",
width_expected=series.width,
check_binary=True,
num_series_expected=1,
check_multivariate=True,
)[0]

if pred_scores is not None:
for pred_score in pred_scores:
pred_score = _check_input(
pred_score,
name="pred_score",
width_expected=series.width,
check_multivariate=True,
)[0]

if multivariate_plot:
for i in range(series_width):
_plot_series_and_anomalies(
series=series[series.components[i]],
Expand All @@ -489,7 +484,6 @@ def show_anomalies_from_scores(
nbr_plots=nbr_plots,
)

fig.suptitle(title)
else:
_plot_series_and_anomalies(
series=series,
Expand All @@ -504,7 +498,7 @@ def show_anomalies_from_scores(
nbr_plots=nbr_plots,
)

fig.suptitle(title)
fig.suptitle(title)


def _assert_binary(series: TimeSeries, name: str):
Expand Down

0 comments on commit da1e644

Please sign in to comment.