Skip to content

Commit

Permalink
feat(#359): handle fallbacks (#370)
Browse files Browse the repository at this point in the history
- [ ] I have assigned ranges (e.g. `>=0.1, <0.2`) to all new dependencies (allows dependabot to keep dependency ranges wide for better compatibility)

Fixes #[issue_nr_here].

## Notes for reviewers
Reviewers can skip X, but should pay attention to Y.
  • Loading branch information
MartinBernstorff authored Feb 12, 2024
2 parents eb99ff2 + e72a90b commit c8b2d5e
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 29 deletions.
18 changes: 18 additions & 0 deletions src/timeseriesflattenerv2/aggregators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from dataclasses import dataclass

import polars as pl

from .feature_specs import AggregatedValueFrame, Aggregator, SlicedFrame


@dataclass
class MeanAggregator(Aggregator):
name: str = "mean"

def apply(self, sliced_frame: SlicedFrame, column_name: str) -> AggregatedValueFrame:
df = sliced_frame.df.group_by(
sliced_frame.pred_time_uuid_col_name, maintain_order=True
).agg(pl.col(column_name).mean())
# TODO: Figure out how to standardise the output column names

return AggregatedValueFrame(df=df)
19 changes: 15 additions & 4 deletions src/timeseriesflattenerv2/feature_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import polars as pl

Fallback = Union[int, float, str]
ValueType = Union[int, float, str]
LookDistance = dt.timedelta

# TODO: Add validation that all entity_id and timestamp columns are the same
Expand All @@ -25,7 +25,9 @@ def __post_init__(self):
self.df = self.df.with_columns(
pl.concat_str(
pl.col(self.entity_id_col_name), pl.lit("-"), pl.col(self.timestamp_col_name)
).alias(self.pred_time_uuid_col_name)
)
.str.strip_chars()
.alias(self.pred_time_uuid_col_name)
)

def to_lazyframe_with_uuid(self) -> pl.LazyFrame:
Expand Down Expand Up @@ -57,6 +59,15 @@ class AggregatedValueFrame:
pred_time_uuid_col_name: str = default_pred_time_uuid_col_name
value_col_name: str = "value"

def fill_nulls(self, fallback: ValueType) -> "SlicedFrame":
filled = self.df.with_columns(pl.col(self.value_col_name).fill_null(fallback))

return SlicedFrame(
df=filled,
pred_time_uuid_col_name=self.pred_time_uuid_col_name,
value_col_name=self.value_col_name,
)


class Aggregator(Protocol):
name: str
Expand All @@ -70,15 +81,15 @@ class PredictorSpec:
value_frame: ValueFrame
lookbehind_distances: Sequence[LookDistance]
aggregators: Sequence[Aggregator]
fallbacks: Sequence[Fallback]
fallback: ValueType


@dataclass(frozen=True)
class OutcomeSpec:
value_frame: ValueFrame
lookahead_distances: Sequence[LookDistance]
aggregators: Sequence[Aggregator]
fallbacks: Sequence[Fallback]
fallback: ValueType


@dataclass(frozen=True)
Expand Down
31 changes: 20 additions & 11 deletions src/timeseriesflattenerv2/flattener.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,28 @@
TimedeltaFrame,
ValueFrame,
ValueSpecification,
ValueType,
)


def _aggregate_within_slice(
sliced_frame: SlicedFrame, aggregators: Sequence[Aggregator]
) -> Iter[AggregatedValueFrame]:
sliced_frame: SlicedFrame, aggregators: Sequence[Aggregator], fallback: ValueType
) -> Sequence[AggregatedValueFrame]:
aggregated_value_frames = [
aggregator.apply(SlicedFrame(sliced_frame.df), column_name=sliced_frame.value_col_name)
for aggregator in aggregators
agg.apply(SlicedFrame(sliced_frame.df), column_name=sliced_frame.value_col_name)
for agg in aggregators
]

return Iter(
with_fallback = [frame.fill_nulls(fallback=fallback) for frame in aggregated_value_frames]

return [
AggregatedValueFrame(
df=frame.df,
pred_time_uuid_col_name=sliced_frame.pred_time_uuid_col_name,
value_col_name=sliced_frame.value_col_name,
)
for frame in aggregated_value_frames
)
for frame in with_fallback
]


def _slice_frame(timedelta_frame: TimedeltaFrame, distance: LookDistance) -> SlicedFrame:
Expand All @@ -47,10 +50,13 @@ def _slice_frame(timedelta_frame: TimedeltaFrame, distance: LookDistance) -> Sli


def _slice_and_aggregate_spec(
timedelta_frame: TimedeltaFrame, distance: LookDistance, aggregators: Sequence[Aggregator]
) -> Iter[AggregatedValueFrame]:
timedelta_frame: TimedeltaFrame,
distance: LookDistance,
aggregators: Sequence[Aggregator],
fallback: ValueType,
) -> Sequence[AggregatedValueFrame]:
sliced_frame = _slice_frame(timedelta_frame, distance)
return _aggregate_within_slice(sliced_frame, aggregators)
return _aggregate_within_slice(sliced_frame, aggregators, fallback=fallback)


def _normalise_lookdistances(spec: ValueSpecification) -> Sequence[LookDistance]:
Expand Down Expand Up @@ -99,7 +105,10 @@ def _process_spec(
Iter(lookdistances)
.map(
lambda distance: _slice_and_aggregate_spec(
timedelta_frame=timedelta_frame, distance=distance, aggregators=spec.aggregators
timedelta_frame=timedelta_frame,
distance=distance,
aggregators=spec.aggregators,
fallback=spec.fallback,
)
)
.flatten()
Expand Down
61 changes: 47 additions & 14 deletions src/timeseriesflattenerv2/test_flattener.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,23 @@
import datetime as dt
from dataclasses import dataclass

import polars as pl
import polars.testing as polars_testing
from timeseriesflattener.testing.utils_for_testing import str_to_pl_df

from timeseriesflattenerv2.aggregators import MeanAggregator

from . import flattener
from .feature_specs import (
AggregatedValueFrame,
Aggregator,
PredictionTimeFrame,
PredictorSpec,
SlicedFrame,
ValueFrame,
)


@dataclass
class MeanAggregator(Aggregator):
name: str = "mean"

def apply(self, sliced_frame: SlicedFrame, column_name: str) -> AggregatedValueFrame:
df = sliced_frame.df.group_by(pl.col(sliced_frame.pred_time_uuid_col_name)).agg(
pl.col(column_name).mean().alias(column_name)
)
# TODO: Figure out how to standardise the output column names

return AggregatedValueFrame(df=df)
def assert_frame_equal(left: pl.DataFrame, right: pl.DataFrame):
polars_testing.assert_frame_equal(left, right, check_dtype=False, check_column_order=False)


def test_flattener():
Expand All @@ -49,7 +41,7 @@ def test_flattener():
value_frame=ValueFrame(df=value_frame.lazy(), value_type="test_value"),
lookbehind_distances=[dt.timedelta(days=1)],
aggregators=[MeanAggregator()],
fallbacks=["NaN"],
fallback="NaN",
)
]
)
Expand Down Expand Up @@ -78,3 +70,44 @@ def test_get_timedelta_frame():
)

assert result.get_timedeltas() == expected_timedeltas


def test_aggregate_within_slice():
sliced_frame = SlicedFrame(
df=str_to_pl_df(
"""pred_time_uuid,value
1-2021-01-03,1
1-2021-01-03,2
2-2021-01-03,2
2-2021-01-03,4"""
).lazy()
)

aggregated_values = flattener._aggregate_within_slice(
sliced_frame=sliced_frame, aggregators=[MeanAggregator()], fallback=0
)

expected = str_to_pl_df(
"""pred_time_uuid,value
1-2021-01-03,1.5
2-2021-01-03,3"""
)

assert_frame_equal(aggregated_values[0].df.collect(), expected)


def test_aggregate_over_fallback():
sliced_frame = SlicedFrame(
df=pl.LazyFrame({"pred_time_uuid": ["1-2021-01-03", "1-2021-01-03"], "value": [None, None]})
)

aggregated_values = flattener._aggregate_within_slice(
sliced_frame=sliced_frame, aggregators=[MeanAggregator()], fallback=0
)

expected = str_to_pl_df(
"""pred_time_uuid,value
1-2021-01-03,0"""
)

assert_frame_equal(aggregated_values[0].df.collect(), expected)

0 comments on commit c8b2d5e

Please sign in to comment.