Skip to content

Commit

Permalink
update 5 files
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBernstorff committed Aug 30, 2024
1 parent 38cbedf commit b652623
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 3 deletions.
44 changes: 43 additions & 1 deletion src/timeseriesflattener/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import datetime as dt
from abc import ABC, abstractmethod
from typing import Literal, Sequence

import polars as pl
from attr import dataclass
from timeseriesflattener.specs import timestamp


def validate_compatible_fallback_type_for_aggregator(
Expand All @@ -18,6 +20,46 @@ def validate_compatible_fallback_type_for_aggregator(
)


AggregatorName = Literal[
"max",
"min",
"mean",
"sum",
"count",
"variance",
"bool",
"change_per_day",
"slope",
"has_values",
]


def strings_to_aggregators(
aggregator_names: Sequence[AggregatorName], timestamp_col_name: str
) -> Sequence[Aggregator]:
return [
string_to_aggregator(name, timestamp_col_name=timestamp_col_name)
for name in aggregator_names
]


def string_to_aggregator(aggregator_name: AggregatorName, timestamp_col_name: str) -> Aggregator:
str2aggr: dict[AggregatorName, Aggregator] = {
"max": MaxAggregator(),
"min": MinAggregator(),
"mean": MeanAggregator(),
"sum": SumAggregator(),
"count": CountAggregator(),
"variance": VarianceAggregator(),
"bool": HasValuesAggregator(),
"change_per_day": SlopeAggregator(timestamp_col_name=timestamp_col_name),
"slope": SlopeAggregator(timestamp_col_name=timestamp_col_name),
"has_values": HasValuesAggregator(),
}

return str2aggr[aggregator_name]


class Aggregator(ABC):
name: str
output_type: type[float | int | bool]
Expand Down Expand Up @@ -125,7 +167,7 @@ def __call__(self, column_name: str) -> pl.Expr:


class HasValuesAggregator(Aggregator):
"""Examines whether any values exist in the column. If so, returns True, else False."""
"""Examines whether any values exist in the look window. If so, returns True, else False."""

name: str = "bool"
output_type = bool
Expand Down
62 changes: 61 additions & 1 deletion src/timeseriesflattener/specs/outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@

from ..validators import validate_col_name_columns_exist
from .value import ValueFrame, lookdistance_to_normalised_lookperiod
from ..aggregators import validate_compatible_fallback_type_for_aggregator
from ..aggregators import (
AggregatorName,
strings_to_aggregators,
validate_compatible_fallback_type_for_aggregator,
)

if TYPE_CHECKING:
from collections.abc import Sequence
Expand All @@ -17,6 +21,14 @@
from .timestamp import TimestampValueFrame


def _lookdistance_to_timedelta(
lookdistance: float | tuple[float, float],
) -> tuple[dt.timedelta, dt.timedelta]:
if isinstance(lookdistance, tuple):
return (dt.timedelta(days=lookdistance[0]), dt.timedelta(days=lookdistance[1]))
return (dt.timedelta(days=0), dt.timedelta(days=lookdistance))


@dataclass
class OutcomeSpec:
"""Specification for an outcome. If your outcome is binary/boolean, you can use BooleanOutcomeSpec instead."""
Expand Down Expand Up @@ -44,6 +56,30 @@ def __post_init__(
def df(self) -> pl.DataFrame:
return self.value_frame.df

@staticmethod
def from_primitives(
df: pl.DataFrame,
entity_id_col_name: str,
lookahead_days: Sequence[float | tuple[float, float]],
aggregators: Sequence[AggregatorName],
value_timestamp_col_name: str = "timestamp",
column_prefix: str = "outc",
) -> OutcomeSpec:
"""Create an OutcomeSpec from primitives."""
lookahead_distances = [_lookdistance_to_timedelta(d) for d in lookahead_days]

return OutcomeSpec(
value_frame=ValueFrame(
init_df=df,
entity_id_col_name=entity_id_col_name,
value_timestamp_col_name=value_timestamp_col_name,
),
lookahead_distances=lookahead_distances,
aggregators=strings_to_aggregators(aggregators, value_timestamp_col_name),
fallback=0,
column_prefix=column_prefix,
)


@dataclass
class BooleanOutcomeSpec:
Expand Down Expand Up @@ -81,3 +117,27 @@ def __post_init__(self, init_frame: TimestampValueFrame):
@property
def df(self) -> pl.DataFrame:
return self.value_frame.df

@staticmethod
def from_primitives(
df: pl.DataFrame,
entity_id_col_name: str,
lookahead_days: Sequence[float | tuple[float, float]],
aggregators: Sequence[AggregatorName],
value_timestamp_col_name: str = "timestamp",
column_prefix: str = "outc",
) -> BooleanOutcomeSpec:
"""Create an OutcomeSpec from primitives."""
lookahead_distances = [_lookdistance_to_timedelta(d) for d in lookahead_days]

return BooleanOutcomeSpec(
init_frame=TimestampValueFrame(
init_df=df,
value_timestamp_col_name=value_timestamp_col_name,
entity_id_col_name=entity_id_col_name,
),
lookahead_distances=lookahead_distances,
aggregators=strings_to_aggregators(aggregators, value_timestamp_col_name),
output_name=column_prefix,
column_prefix=column_prefix,
)
14 changes: 14 additions & 0 deletions src/timeseriesflattener/specs/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,17 @@ class StaticSpec:
value_frame: StaticFrame
column_prefix: str
fallback: int | float | str | None

@staticmethod
def from_primitives(
df: pl.DataFrame,
entity_id_col_name: str,
column_prefix: str,
fallback: int | float | str | None,
) -> StaticSpec:
"""Create a StaticSpec from primitives."""
return StaticSpec(
value_frame=StaticFrame(init_df=df, entity_id_col_name=entity_id_col_name),
column_prefix=column_prefix,
fallback=fallback,
)
32 changes: 31 additions & 1 deletion src/timeseriesflattener/specs/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,15 @@
from dataclasses import InitVar, dataclass
from typing import TYPE_CHECKING

from timeseriesflattener.specs.outcome import _lookdistance_to_timedelta

from ..validators import validate_col_name_columns_exist
from .value import ValueFrame, lookdistance_to_normalised_lookperiod
from ..aggregators import validate_compatible_fallback_type_for_aggregator
from ..aggregators import (
AggregatorName,
strings_to_aggregators,
validate_compatible_fallback_type_for_aggregator,
)

if TYPE_CHECKING:
from collections.abc import Sequence
Expand Down Expand Up @@ -48,3 +54,27 @@ def __post_init__(
@property
def df(self) -> pl.DataFrame:
return self.value_frame.df

@staticmethod
def from_primitives(
df: pl.DataFrame,
lookbehind_days: Sequence[float | tuple[float, float]],
aggregators: Sequence[AggregatorName],
value_timestamp_col_name: str = "timestamp",
column_prefix: str = "pred",
fallback: int | float | str | None = 0,
) -> PredictorSpec:
"""Create a PredictorSpec from primitives."""
lookbehind_distances = [_lookdistance_to_timedelta(d) for d in lookbehind_days]

return PredictorSpec(
value_frame=ValueFrame(
init_df=df,
entity_id_col_name=df.get_column(df.columns[0]).name,
value_timestamp_col_name=value_timestamp_col_name,
),
lookbehind_distances=lookbehind_distances,
aggregators=strings_to_aggregators(aggregators, value_timestamp_col_name),
fallback=fallback,
column_prefix=column_prefix,
)
21 changes: 21 additions & 0 deletions src/timeseriesflattener/specs/timedelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,24 @@ def __post_init__(self):
@property
def df(self) -> pl.DataFrame:
return self.value_frame.df

@staticmethod
def from_primitives(
df: pl.DataFrame,
entity_id_col_name: str,
output_name: str,
value_timestamp_col_name: str = "timestamp",
column_prefix: str = "pred",
fallback: int | float | str | None = 0,
) -> TimeDeltaSpec:
"""Create a TimeDeltaSpec from primitives."""
return TimeDeltaSpec(
init_frame=TimestampValueFrame(
init_df=df,
value_timestamp_col_name=value_timestamp_col_name,
entity_id_col_name=entity_id_col_name,
),
fallback=fallback,
output_name=output_name,
column_prefix=column_prefix,
)

0 comments on commit b652623

Please sign in to comment.