Skip to content

Commit

Permalink
check expected time index for historical forecasts
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisbader committed Dec 31, 2024
1 parent eedc2e2 commit d37c0d5
Showing 1 changed file with 37 additions and 30 deletions.
67 changes: 37 additions & 30 deletions darts/tests/utils/historical_forecasts/test_historical_forecasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,7 @@ def test_historical_forecasts_negative_rangeindex(self):

@pytest.mark.parametrize("config", models_reg_no_cov_cls_kwargs)
def test_historical_forecasts(self, config):
"""Tests historical forecasts with retraining for expected forecast lengths and times"""
forecast_horizon = 8
# if no fit and retrain=false, should fit at fist iteration
model_cls, kwargs, model_kwarg, bounds = config
Expand Down Expand Up @@ -708,19 +709,20 @@ def test_historical_forecasts(self, config):
)

assert len(forecasts_no_train_length) == len(forecasts)

theorical_forecast_length = (
self.ts_val_length
- train_length # because we train
- forecast_horizon # because we have overlap_end = False
+ 1 # because we include the first element
)

assert len(forecasts) == theorical_forecast_length, (
f"Model {model_cls.__name__} does not return the right number of historical forecasts in the case "
f"of retrain=True and overlap_end=False, and a time index of type DateTimeIndex. "
f"Expected {theorical_forecast_length}, got {len(forecasts)}"
)
assert forecasts.time_index.equals(
self.ts_pass_val.time_index[-theorical_forecast_length:]
)

# range index
forecasts = model.historical_forecasts(
Expand All @@ -737,6 +739,10 @@ def test_historical_forecasts(self, config):
f"of retrain=True, overlap_end=False, and a time index of type RangeIndex."
f"Expected {theorical_forecast_length}, got {len(forecasts)}"
)
assert forecasts.time_index.equals(
self.ts_pass_val_range.time_index[-theorical_forecast_length:]
)
start_idx = self.ts_pass_val_range.get_index_at_point(forecasts.start_time())

# stride 2
forecasts = model.historical_forecasts(
Expand All @@ -748,30 +754,30 @@ def test_historical_forecasts(self, config):
overlap_end=False,
)

theorical_forecast_length = np.floor(
(
theorical_forecast_length = int(
np.floor(
(
self.ts_val_length
- max([
(
bounds[0] + bounds[1] + 1
), # +1 as sklearn models require min 2 train samples
train_length,
]) # because we train
- forecast_horizon # because we have overlap_end = False
+ 1 # because we include the first element
(
self.ts_val_length
- train_length # because we train
- forecast_horizon # because we have overlap_end = False
+ 1 # because we include the first element
)
- 1
)
- 1
)
/ 2
+ 1 # because of stride
) # if odd number of elements, we keep the floor
/ 2
+ 1 # because of stride
) # if odd number of elements, we keep the floor
)

assert len(forecasts) == theorical_forecast_length, (
f"Model {model_cls.__name__} does not return the right number of historical forecasts in the case "
f"of retrain=True and overlap_end=False and stride=2. "
f"Expected {theorical_forecast_length}, got {len(forecasts)}"
)
assert forecasts.time_index.equals(
self.ts_pass_val_range.time_index[start_idx::2]
)

# stride 3
forecasts = model.historical_forecasts(
Expand All @@ -787,12 +793,7 @@ def test_historical_forecasts(self, config):
(
(
self.ts_val_length
- max([
(
bounds[0] + bounds[1] + 1
), # +1 as sklearn models require min 2 train samples
train_length,
]) # because we train
- train_length # because we train
- forecast_horizon # because we have overlap_end = False
+ 1 # because we include the first element
)
Expand All @@ -808,6 +809,9 @@ def test_historical_forecasts(self, config):
f"of retrain=True and overlap_end=False and stride=3. "
f"Expected {theorical_forecast_length}, got {len(forecasts)}"
)
assert forecasts.time_index.equals(
self.ts_pass_val_range.time_index[start_idx::3]
)

# last points only False
forecasts = model.historical_forecasts(
Expand All @@ -822,12 +826,7 @@ def test_historical_forecasts(self, config):

theorical_forecast_length = (
self.ts_val_length
- max([
(
bounds[0] + bounds[1] + 1
), # +1 as sklearn models require min 2 train samples
train_length,
]) # because we train
- train_length # because we train
- forecast_horizon # because we have overlap_end = False
+ 1 # because we include the first element
)
Expand All @@ -842,6 +841,11 @@ def test_historical_forecasts(self, config):
f"Model {model_cls} does not return forecast_horizon points per historical forecast in the case of "
f"retrain=True and overlap_end=False, and last_points_only=False"
)
last_points_times = np.array([fc.end_time() for fc in forecasts])
np.testing.assert_equal(
last_points_times,
self.ts_pass_val_range.time_index[-theorical_forecast_length:].values,
)

if not model.supports_past_covariates:
with pytest.raises(ValueError) as msg:
Expand Down Expand Up @@ -1233,6 +1237,9 @@ def test_regression_auto_start_multiple_no_cov(self, config):
f"of retrain=True and overlap_end=False, and a time index of type DateTimeIndex. "
f"Expected {theorical_forecast_length}, got {len(forecasts[0])} and {len(forecasts[1])}"
)
assert forecasts[0].time_index.equals(forecasts[1].time_index) and forecasts[
0
].time_index.equals(self.ts_pass_val.time_index[-theorical_forecast_length:])

@pytest.mark.slow
@pytest.mark.parametrize(
Expand Down

0 comments on commit d37c0d5

Please sign in to comment.