From e72a90bfab75600ea50901f1812ac9aac2c03b93 Mon Sep 17 00:00:00 2001 From: Martin Bernstorff Date: Fri, 9 Feb 2024 09:45:36 +0000 Subject: [PATCH] feat(#359): handle fallbacks Fixes #359 --- src/timeseriesflattenerv2/aggregators.py | 18 ++++++ src/timeseriesflattenerv2/feature_specs.py | 19 +++++-- src/timeseriesflattenerv2/flattener.py | 31 +++++++---- src/timeseriesflattenerv2/test_flattener.py | 61 ++++++++++++++++----- 4 files changed, 100 insertions(+), 29 deletions(-) create mode 100644 src/timeseriesflattenerv2/aggregators.py diff --git a/src/timeseriesflattenerv2/aggregators.py b/src/timeseriesflattenerv2/aggregators.py new file mode 100644 index 00000000..03038e5f --- /dev/null +++ b/src/timeseriesflattenerv2/aggregators.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass + +import polars as pl + +from .feature_specs import AggregatedValueFrame, Aggregator, SlicedFrame + + +@dataclass +class MeanAggregator(Aggregator): + name: str = "mean" + + def apply(self, sliced_frame: SlicedFrame, column_name: str) -> AggregatedValueFrame: + df = sliced_frame.df.group_by( + sliced_frame.pred_time_uuid_col_name, maintain_order=True + ).agg(pl.col(column_name).mean()) + # TODO: Figure out how to standardise the output column names + + return AggregatedValueFrame(df=df) diff --git a/src/timeseriesflattenerv2/feature_specs.py b/src/timeseriesflattenerv2/feature_specs.py index abb4ec24..dad1737f 100644 --- a/src/timeseriesflattenerv2/feature_specs.py +++ b/src/timeseriesflattenerv2/feature_specs.py @@ -4,7 +4,7 @@ import polars as pl -Fallback = Union[int, float, str] +ValueType = Union[int, float, str] LookDistance = dt.timedelta # TODO: Add validation that all entity_id and timestamp columns are the same @@ -25,7 +25,9 @@ def __post_init__(self): self.df = self.df.with_columns( pl.concat_str( pl.col(self.entity_id_col_name), pl.lit("-"), pl.col(self.timestamp_col_name) - ).alias(self.pred_time_uuid_col_name) + ) + .str.strip_chars() + .alias(self.pred_time_uuid_col_name) ) def to_lazyframe_with_uuid(self) -> pl.LazyFrame: @@ -57,6 +59,15 @@ class AggregatedValueFrame: pred_time_uuid_col_name: str = default_pred_time_uuid_col_name value_col_name: str = "value" + def fill_nulls(self, fallback: ValueType) -> "SlicedFrame": + filled = self.df.with_columns(pl.col(self.value_col_name).fill_null(fallback)) + + return SlicedFrame( + df=filled, + pred_time_uuid_col_name=self.pred_time_uuid_col_name, + value_col_name=self.value_col_name, + ) + class Aggregator(Protocol): name: str @@ -70,7 +81,7 @@ class PredictorSpec: value_frame: ValueFrame lookbehind_distances: Sequence[LookDistance] aggregators: Sequence[Aggregator] - fallbacks: Sequence[Fallback] + fallback: ValueType @dataclass(frozen=True) @@ -78,7 +89,7 @@ class OutcomeSpec: value_frame: ValueFrame lookahead_distances: Sequence[LookDistance] aggregators: Sequence[Aggregator] - fallbacks: Sequence[Fallback] + fallback: ValueType @dataclass(frozen=True) diff --git a/src/timeseriesflattenerv2/flattener.py b/src/timeseriesflattenerv2/flattener.py index 63124dbf..fa8f7e5a 100644 --- a/src/timeseriesflattenerv2/flattener.py +++ b/src/timeseriesflattenerv2/flattener.py @@ -15,25 +15,28 @@ TimedeltaFrame, ValueFrame, ValueSpecification, + ValueType, ) def _aggregate_within_slice( - sliced_frame: SlicedFrame, aggregators: Sequence[Aggregator] -) -> Iter[AggregatedValueFrame]: + sliced_frame: SlicedFrame, aggregators: Sequence[Aggregator], fallback: ValueType +) -> Sequence[AggregatedValueFrame]: aggregated_value_frames = [ - aggregator.apply(SlicedFrame(sliced_frame.df), column_name=sliced_frame.value_col_name) - for aggregator in aggregators + agg.apply(SlicedFrame(sliced_frame.df), column_name=sliced_frame.value_col_name) + for agg in aggregators ] - return Iter( + with_fallback = [frame.fill_nulls(fallback=fallback) for frame in aggregated_value_frames] + + return [ AggregatedValueFrame( df=frame.df, pred_time_uuid_col_name=sliced_frame.pred_time_uuid_col_name, value_col_name=sliced_frame.value_col_name, ) - for frame in aggregated_value_frames - ) + for frame in with_fallback + ] def _slice_frame(timedelta_frame: TimedeltaFrame, distance: LookDistance) -> SlicedFrame: @@ -47,10 +50,13 @@ def _slice_frame(timedelta_frame: TimedeltaFrame, distance: LookDistance) -> Sli def _slice_and_aggregate_spec( - timedelta_frame: TimedeltaFrame, distance: LookDistance, aggregators: Sequence[Aggregator] -) -> Iter[AggregatedValueFrame]: + timedelta_frame: TimedeltaFrame, + distance: LookDistance, + aggregators: Sequence[Aggregator], + fallback: ValueType, +) -> Sequence[AggregatedValueFrame]: sliced_frame = _slice_frame(timedelta_frame, distance) - return _aggregate_within_slice(sliced_frame, aggregators) + return _aggregate_within_slice(sliced_frame, aggregators, fallback=fallback) def _normalise_lookdistances(spec: ValueSpecification) -> Sequence[LookDistance]: @@ -99,7 +105,10 @@ def _process_spec( Iter(lookdistances) .map( lambda distance: _slice_and_aggregate_spec( - timedelta_frame=timedelta_frame, distance=distance, aggregators=spec.aggregators + timedelta_frame=timedelta_frame, + distance=distance, + aggregators=spec.aggregators, + fallback=spec.fallback, ) ) .flatten() diff --git a/src/timeseriesflattenerv2/test_flattener.py b/src/timeseriesflattenerv2/test_flattener.py index 9689c66c..2057694b 100644 --- a/src/timeseriesflattenerv2/test_flattener.py +++ b/src/timeseriesflattenerv2/test_flattener.py @@ -1,13 +1,14 @@ import datetime as dt -from dataclasses import dataclass import polars as pl +import polars.testing as polars_testing from timeseriesflattener.testing.utils_for_testing import str_to_pl_df +from timeseriesflattenerv2.aggregators import MeanAggregator + from . import flattener from .feature_specs import ( AggregatedValueFrame, - Aggregator, PredictionTimeFrame, PredictorSpec, SlicedFrame, @@ -15,17 +16,8 @@ ) -@dataclass -class MeanAggregator(Aggregator): - name: str = "mean" - - def apply(self, sliced_frame: SlicedFrame, column_name: str) -> AggregatedValueFrame: - df = sliced_frame.df.group_by(pl.col(sliced_frame.pred_time_uuid_col_name)).agg( - pl.col(column_name).mean().alias(column_name) - ) - # TODO: Figure out how to standardise the output column names - - return AggregatedValueFrame(df=df) +def assert_frame_equal(left: pl.DataFrame, right: pl.DataFrame): + polars_testing.assert_frame_equal(left, right, check_dtype=False, check_column_order=False) def test_flattener(): @@ -49,7 +41,7 @@ def test_flattener(): value_frame=ValueFrame(df=value_frame.lazy(), value_type="test_value"), lookbehind_distances=[dt.timedelta(days=1)], aggregators=[MeanAggregator()], - fallbacks=["NaN"], + fallback="NaN", ) ] ) @@ -78,3 +70,44 @@ def test_get_timedelta_frame(): ) assert result.get_timedeltas() == expected_timedeltas + + +def test_aggregate_within_slice(): + sliced_frame = SlicedFrame( + df=str_to_pl_df( + """pred_time_uuid,value +1-2021-01-03,1 +1-2021-01-03,2 +2-2021-01-03,2 +2-2021-01-03,4""" + ).lazy() + ) + + aggregated_values = flattener._aggregate_within_slice( + sliced_frame=sliced_frame, aggregators=[MeanAggregator()], fallback=0 + ) + + expected = str_to_pl_df( + """pred_time_uuid,value +1-2021-01-03,1.5 +2-2021-01-03,3""" + ) + + assert_frame_equal(aggregated_values[0].df.collect(), expected) + + +def test_aggregate_over_fallback(): + sliced_frame = SlicedFrame( + df=pl.LazyFrame({"pred_time_uuid": ["1-2021-01-03", "1-2021-01-03"], "value": [None, None]}) + ) + + aggregated_values = flattener._aggregate_within_slice( + sliced_frame=sliced_frame, aggregators=[MeanAggregator()], fallback=0 + ) + + expected = str_to_pl_df( + """pred_time_uuid,value +1-2021-01-03,0""" + ) + + assert_frame_equal(aggregated_values[0].df.collect(), expected)