Skip to content

Commit

Permalink
automatic year slicing
Browse files Browse the repository at this point in the history
  • Loading branch information
sarakolding committed Apr 9, 2024
1 parent e9d6478 commit 61a2625
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 30 deletions.
40 changes: 32 additions & 8 deletions src/timeseriesflattener/spec_processors/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
from ..feature_specs.predictor import PredictorSpec
from ..frame_utilities._horisontally_concat import horizontally_concatenate_dfs
from ..feature_specs.prediction_times import PredictionTimeFrame
from ..feature_specs.meta import ValueFrame

if TYPE_CHECKING:
from collections.abc import Sequence

from ..aggregators import Aggregator
from ..feature_specs.meta import LookPeriod, ValueFrame, ValueType
from ..feature_specs.meta import LookPeriod, ValueType


def _get_timedelta_frame(
Expand Down Expand Up @@ -149,25 +150,48 @@ def _slice_and_aggregate_spec(


def process_temporal_spec(
spec: TemporalSpec,
predictiontime_frame: PredictionTimeFrame,
year_start: int = 2011,
year_end: int = 2022,
spec: TemporalSpec, predictiontime_frame: PredictionTimeFrame
) -> ProcessedFrame:
aggregated_value_frames = list()

year_start = (
predictiontime_frame.df.select(pl.col(predictiontime_frame.timestamp_col_name).min())
.collect()
.item()
.year
)
year_end = (
predictiontime_frame.df.select(pl.col(predictiontime_frame.timestamp_col_name).max())
.collect()
.item()
.year
)

lookperiod_years = int(
(spec.normalised_lookperiod[0].first - spec.normalised_lookperiod[0].last).days / 365
)

for year in range(year_start, year_end + 1):
year_df = predictiontime_frame.df.filter(
year_predictiontime_df = predictiontime_frame.df.filter(
pl.col(predictiontime_frame.timestamp_col_name).dt.year() == year
)
year_frame = PredictionTimeFrame(init_df=year_df)
year_predictiontime_frame = PredictionTimeFrame(init_df=year_predictiontime_df)

year_value_df = spec.value_frame.df.filter(
(pl.col(spec.value_frame.value_timestamp_col_name).dt.year() == year)
| (
pl.col(spec.value_frame.value_timestamp_col_name).dt.year()
== year + lookperiod_years
)
)
year_value_frame = ValueFrame(year_value_df)

aggregated_value_frames += (
Iter(spec.normalised_lookperiod)
.map(
lambda lookperiod: _slice_and_aggregate_spec(
timedelta_frame=_get_timedelta_frame(
predictiontime_frame=year_frame, value_frame=spec.value_frame
predictiontime_frame=year_predictiontime_frame, value_frame=year_value_frame
),
masked_aggregator=lambda sliced_frame: _aggregate_masked_frame(
aggregators=spec.aggregators,
Expand Down
33 changes: 11 additions & 22 deletions src/timeseriesflattener/spec_processors/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,15 +295,10 @@ def test_sliding_window():
pred_frame = str_to_pl_df(
"""entity_id,pred_timestamp
1,2011-01-01,
1,2013-01-01,
1,2014-01-01,
1,2015-01-01,
1,2016-01-01,
1,2017-01-01,
1,2018-01-01,
1,2019-01-01,
1,2020-01-01,
1,2021-01-01,
1,2022-01-01,""" # 2012 year without prediction times
)

Expand All @@ -315,35 +310,29 @@ def test_sliding_window():
1,2014-01-01,4
1,2015-01-01,5
1,2016-01-01,6
1,2017-01-01,7
1,2018-01-01,8
1,2019-01-01,9
1,2020-01-01,10
1,2021-01-01,11
1,2021-01-01,12""" # 2021 year with multiple values
) # 2022 year with no values

result = process_spec.process_temporal_spec(
spec=PredictorSpec(
value_frame=ValueFrame(init_df=value_frame.lazy()),
lookbehind_distances=[dt.timedelta(days=1)],
lookbehind_distances=[dt.timedelta(days=365)],
aggregators=[MeanAggregator()],
fallback=np.nan,
fallback=0,
),
predictiontime_frame=PredictionTimeFrame(init_df=pred_frame.lazy()),
)

expected = str_to_pl_df(
"""pred_time_uuid,pred_value_within_0_to_1_days_mean_fallback_nan
1,2011-01-01,1
1,2013-01-01,3
1,2014-01-01,4
1,2015-01-01,5
1,2016-01-01,6
1,2017-01-01,7
1,2018-01-01,8
1,2019-01-01,9
1,2020-01-01,10
1,2021-01-01,11.5
1,2022-01-01,nan"""
"""pred_time_uuid,pred_value_within_0_to_365_days_mean_fallback_0
1-2011-01-01 00:00:00.000000,1
1-2014-01-01 00:00:00.000000,3.5
1-2016-01-01 00:00:00.000000,5.5
1-2018-01-01 00:00:00.000000,0
1-2020-01-01 00:00:00.000000,9
1-2022-01-01 00:00:00.000000,11.5"""
)

assert_frame_equal(result.collect(), expected)

0 comments on commit 61a2625

Please sign in to comment.