diff --git a/src/timeseriesflattener/aggregators.py b/src/timeseriesflattener/aggregators.py index 659a4724..041af936 100644 --- a/src/timeseriesflattener/aggregators.py +++ b/src/timeseriesflattener/aggregators.py @@ -2,9 +2,11 @@ import datetime as dt from abc import ABC, abstractmethod +from typing import Literal, Sequence import polars as pl from attr import dataclass +from timeseriesflattener.specs import timestamp def validate_compatible_fallback_type_for_aggregator( @@ -18,6 +20,46 @@ def validate_compatible_fallback_type_for_aggregator( ) +AggregatorName = Literal[ + "max", + "min", + "mean", + "sum", + "count", + "variance", + "bool", + "change_per_day", + "slope", + "has_values", +] + + +def strings_to_aggregators( + aggregator_names: Sequence[AggregatorName], timestamp_col_name: str +) -> Sequence[Aggregator]: + return [ + string_to_aggregator(name, timestamp_col_name=timestamp_col_name) + for name in aggregator_names + ] + + +def string_to_aggregator(aggregator_name: AggregatorName, timestamp_col_name: str) -> Aggregator: + str2aggr: dict[AggregatorName, Aggregator] = { + "max": MaxAggregator(), + "min": MinAggregator(), + "mean": MeanAggregator(), + "sum": SumAggregator(), + "count": CountAggregator(), + "variance": VarianceAggregator(), + "bool": HasValuesAggregator(), + "change_per_day": SlopeAggregator(timestamp_col_name=timestamp_col_name), + "slope": SlopeAggregator(timestamp_col_name=timestamp_col_name), + "has_values": HasValuesAggregator(), + } + + return str2aggr[aggregator_name] + + class Aggregator(ABC): name: str output_type: type[float | int | bool] @@ -125,7 +167,7 @@ def __call__(self, column_name: str) -> pl.Expr: class HasValuesAggregator(Aggregator): - """Examines whether any values exist in the column. If so, returns True, else False.""" + """Examines whether any values exist in the look window. If so, returns True, else False.""" name: str = "bool" output_type = bool diff --git a/src/timeseriesflattener/specs/outcome.py b/src/timeseriesflattener/specs/outcome.py index 18622081..d33c20fe 100644 --- a/src/timeseriesflattener/specs/outcome.py +++ b/src/timeseriesflattener/specs/outcome.py @@ -8,7 +8,11 @@ from ..validators import validate_col_name_columns_exist from .value import ValueFrame, lookdistance_to_normalised_lookperiod -from ..aggregators import validate_compatible_fallback_type_for_aggregator +from ..aggregators import ( + AggregatorName, + strings_to_aggregators, + validate_compatible_fallback_type_for_aggregator, +) if TYPE_CHECKING: from collections.abc import Sequence @@ -17,6 +21,14 @@ from .timestamp import TimestampValueFrame +def _lookdistance_to_timedelta( + lookdistance: float | tuple[float, float], +) -> tuple[dt.timedelta, dt.timedelta]: + if isinstance(lookdistance, tuple): + return (dt.timedelta(days=lookdistance[0]), dt.timedelta(days=lookdistance[1])) + return (dt.timedelta(days=0), dt.timedelta(days=lookdistance)) + + @dataclass class OutcomeSpec: """Specification for an outcome. If your outcome is binary/boolean, you can use BooleanOutcomeSpec instead.""" @@ -44,6 +56,30 @@ def __post_init__( def df(self) -> pl.DataFrame: return self.value_frame.df + @staticmethod + def from_primitives( + df: pl.DataFrame, + entity_id_col_name: str, + lookahead_days: Sequence[float | tuple[float, float]], + aggregators: Sequence[AggregatorName], + value_timestamp_col_name: str = "timestamp", + column_prefix: str = "outc", + ) -> OutcomeSpec: + """Create an OutcomeSpec from primitives.""" + lookahead_distances = [_lookdistance_to_timedelta(d) for d in lookahead_days] + + return OutcomeSpec( + value_frame=ValueFrame( + init_df=df, + entity_id_col_name=entity_id_col_name, + value_timestamp_col_name=value_timestamp_col_name, + ), + lookahead_distances=lookahead_distances, + aggregators=strings_to_aggregators(aggregators, value_timestamp_col_name), + fallback=0, + column_prefix=column_prefix, + ) + @dataclass class BooleanOutcomeSpec: @@ -81,3 +117,27 @@ def __post_init__(self, init_frame: TimestampValueFrame): @property def df(self) -> pl.DataFrame: return self.value_frame.df + + @staticmethod + def from_primitives( + df: pl.DataFrame, + entity_id_col_name: str, + lookahead_days: Sequence[float | tuple[float, float]], + aggregators: Sequence[AggregatorName], + value_timestamp_col_name: str = "timestamp", + column_prefix: str = "outc", + ) -> BooleanOutcomeSpec: + """Create an OutcomeSpec from primitives.""" + lookahead_distances = [_lookdistance_to_timedelta(d) for d in lookahead_days] + + return BooleanOutcomeSpec( + init_frame=TimestampValueFrame( + init_df=df, + value_timestamp_col_name=value_timestamp_col_name, + entity_id_col_name=entity_id_col_name, + ), + lookahead_distances=lookahead_distances, + aggregators=strings_to_aggregators(aggregators, value_timestamp_col_name), + output_name=column_prefix, + column_prefix=column_prefix, + ) diff --git a/src/timeseriesflattener/specs/static.py b/src/timeseriesflattener/specs/static.py index 63f53c6a..15cab5d0 100644 --- a/src/timeseriesflattener/specs/static.py +++ b/src/timeseriesflattener/specs/static.py @@ -38,3 +38,17 @@ class StaticSpec: value_frame: StaticFrame column_prefix: str fallback: int | float | str | None + + @staticmethod + def from_primitives( + df: pl.DataFrame, + entity_id_col_name: str, + column_prefix: str, + fallback: int | float | str | None, + ) -> StaticSpec: + """Create a StaticSpec from primitives.""" + return StaticSpec( + value_frame=StaticFrame(init_df=df, entity_id_col_name=entity_id_col_name), + column_prefix=column_prefix, + fallback=fallback, + ) diff --git a/src/timeseriesflattener/specs/temporal.py b/src/timeseriesflattener/specs/temporal.py index 829aeaeb..7e9ca930 100644 --- a/src/timeseriesflattener/specs/temporal.py +++ b/src/timeseriesflattener/specs/temporal.py @@ -4,9 +4,15 @@ from dataclasses import InitVar, dataclass from typing import TYPE_CHECKING +from timeseriesflattener.specs.outcome import _lookdistance_to_timedelta + from ..validators import validate_col_name_columns_exist from .value import ValueFrame, lookdistance_to_normalised_lookperiod -from ..aggregators import validate_compatible_fallback_type_for_aggregator +from ..aggregators import ( + AggregatorName, + strings_to_aggregators, + validate_compatible_fallback_type_for_aggregator, +) if TYPE_CHECKING: from collections.abc import Sequence @@ -48,3 +54,27 @@ def __post_init__( @property def df(self) -> pl.DataFrame: return self.value_frame.df + + @staticmethod + def from_primitives( + df: pl.DataFrame, + lookbehind_days: Sequence[float | tuple[float, float]], + aggregators: Sequence[AggregatorName], + value_timestamp_col_name: str = "timestamp", + column_prefix: str = "pred", + fallback: int | float | str | None = 0, + ) -> PredictorSpec: + """Create a PredictorSpec from primitives.""" + lookbehind_distances = [_lookdistance_to_timedelta(d) for d in lookbehind_days] + + return PredictorSpec( + value_frame=ValueFrame( + init_df=df, + entity_id_col_name=df.get_column(df.columns[0]).name, + value_timestamp_col_name=value_timestamp_col_name, + ), + lookbehind_distances=lookbehind_distances, + aggregators=strings_to_aggregators(aggregators, value_timestamp_col_name), + fallback=fallback, + column_prefix=column_prefix, + ) diff --git a/src/timeseriesflattener/specs/timedelta.py b/src/timeseriesflattener/specs/timedelta.py index 21459f54..427ddd55 100644 --- a/src/timeseriesflattener/specs/timedelta.py +++ b/src/timeseriesflattener/specs/timedelta.py @@ -52,3 +52,24 @@ def __post_init__(self): @property def df(self) -> pl.DataFrame: return self.value_frame.df + + @staticmethod + def from_primitives( + df: pl.DataFrame, + entity_id_col_name: str, + output_name: str, + value_timestamp_col_name: str = "timestamp", + column_prefix: str = "pred", + fallback: int | float | str | None = 0, + ) -> TimeDeltaSpec: + """Create a TimeDeltaSpec from primitives.""" + return TimeDeltaSpec( + init_frame=TimestampValueFrame( + init_df=df, + value_timestamp_col_name=value_timestamp_col_name, + entity_id_col_name=entity_id_col_name, + ), + fallback=fallback, + output_name=output_name, + column_prefix=column_prefix, + )