Skip to content

Commit

Permalink
Merge pull request #564 from Aarhus-Psychiatry-Research/fix/514/simpl…
Browse files Browse the repository at this point in the history
…ify_public-facing_types_by_inlining_type-aliases

fix(#514): simplify public-facing types by inlining type-aliases
  • Loading branch information
HLasse authored May 22, 2024
2 parents 60d1987 + 22c3ce6 commit 0106a84
Show file tree
Hide file tree
Showing 19 changed files with 80 additions and 108 deletions.
4 changes: 2 additions & 2 deletions src/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pytest
from iterpy.iter import Iter
from timeseriesflattener.aggregators import Aggregator, MaxAggregator, MeanAggregator
from timeseriesflattener.feature_specs.meta import LookDistance, ValueFrame
from timeseriesflattener.feature_specs.meta import ValueFrame
from timeseriesflattener.feature_specs.prediction_times import PredictionTimeFrame
from timeseriesflattener.feature_specs.predictor import PredictorSpec
from timeseriesflattener.flattener import Flattener
Expand Down Expand Up @@ -50,7 +50,7 @@ def _generate_benchmark_dataset(
n_features: int,
n_observations_per_pred_time: int,
aggregations: Sequence[Literal["max", "mean"]],
lookbehinds: Sequence[LookDistance | tuple[LookDistance, LookDistance]],
lookbehinds: Sequence[dt.timedelta | tuple[dt.timedelta, dt.timedelta]],
) -> BenchmarkDataset:
pred_time_df = PredictionTimeFrame(
init_df=pl.LazyFrame(
Expand Down
16 changes: 5 additions & 11 deletions src/timeseriesflattener/_intermediary_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,11 @@
import polars as pl

from ._frame_validator import _validate_col_name_columns_exist
from .feature_specs.default_column_names import (
default_prediction_time_uuid_col_name,
default_timestamp_col_name,
)
from .frame_utilities.anyframe_to_lazyframe import _anyframe_to_lazyframe

if TYPE_CHECKING:
from collections.abc import Sequence

from .feature_specs.meta import ValueType

if TYPE_CHECKING:
import datetime as dt

Expand All @@ -27,8 +21,8 @@ class TimeMaskedFrame:

init_df: pl.LazyFrame
value_col_names: Sequence[str]
timestamp_col_name: str = default_timestamp_col_name
prediction_time_uuid_col_name: str = default_prediction_time_uuid_col_name
timestamp_col_name: str = "timestamp"
prediction_time_uuid_col_name: str = "prediction_time_uuid"
validate_cols_exist: bool = True

def __post_init__(self):
Expand All @@ -47,12 +41,12 @@ def collect(self) -> pl.DataFrame:
class AggregatedValueFrame:
df: pl.LazyFrame
value_col_name: str
prediction_time_uuid_col_name: str = default_prediction_time_uuid_col_name
prediction_time_uuid_col_name: str = "prediction_time_uuid"

def __post_init__(self):
_validate_col_name_columns_exist(obj=self)

def fill_nulls(self, fallback: ValueType) -> AggregatedValueFrame:
def fill_nulls(self, fallback: int | float | str | None) -> AggregatedValueFrame:
filled = self.df.with_columns(
pl.col(self.value_col_name)
.fill_null(fallback)
Expand All @@ -76,7 +70,7 @@ class TimeDeltaFrame:
df: pl.LazyFrame
value_col_names: Sequence[str]
value_timestamp_col_name: str
prediction_time_uuid_col_name: str = default_prediction_time_uuid_col_name
prediction_time_uuid_col_name: str = "prediction_time_uuid"
timedelta_col_name: str = "time_from_prediction_to_value"

def __post_init__(self):
Expand Down
6 changes: 0 additions & 6 deletions src/timeseriesflattener/feature_specs/default_column_names.py

This file was deleted.

34 changes: 10 additions & 24 deletions src/timeseriesflattener/feature_specs/meta.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,15 @@
from __future__ import annotations

import datetime as dt
from collections.abc import Sequence
from dataclasses import InitVar, dataclass
from typing import TYPE_CHECKING, Literal, Union
from typing import Literal

import pandas as pd
import polars as pl

from timeseriesflattener.feature_specs.default_column_names import default_entity_id_col_name

from .._frame_validator import _validate_col_name_columns_exist
from ..frame_utilities.anyframe_to_lazyframe import _anyframe_to_lazyframe

if TYPE_CHECKING:
from typing_extensions import TypeAlias


ValueType = Union[int, float, str, None]
InitDF_T = Union[pl.LazyFrame, pl.DataFrame, pd.DataFrame]


LookDistance = dt.timedelta


LookDistances: TypeAlias = Sequence[Union[LookDistance, tuple[LookDistance, LookDistance]]]


@dataclass
class ValueFrame:
Expand All @@ -37,12 +21,14 @@ class ValueFrame:
Additional columns containing the values of the time series. The name of the columns will be used for feature naming.
"""

init_df: InitVar[InitDF_T]
entity_id_col_name: str = default_entity_id_col_name
init_df: InitVar[pl.LazyFrame | pl.DataFrame | pd.DataFrame]
entity_id_col_name: str = "entity_id"
value_timestamp_col_name: str = "timestamp"
coerce_to_lazy: InitVar[bool] = True

def __post_init__(self, init_df: InitDF_T, coerce_to_lazy: bool):
def __post_init__(
self, init_df: pl.LazyFrame | pl.DataFrame | pd.DataFrame, coerce_to_lazy: bool
):
if coerce_to_lazy:
self.df = _anyframe_to_lazyframe(init_df)
else:
Expand All @@ -63,8 +49,8 @@ def collect(self) -> pl.DataFrame:

@dataclass(frozen=True)
class LookPeriod:
first: LookDistance
last: LookDistance
first: dt.timedelta
last: dt.timedelta

def __post_init__(self):
if self.first >= self.last:
Expand All @@ -74,11 +60,11 @@ def __post_init__(self):


def _lookdistance_to_normalised_lookperiod(
lookdistance: LookDistance | tuple[LookDistance, LookDistance],
lookdistance: dt.timedelta | tuple[dt.timedelta, dt.timedelta],
direction: Literal["ahead", "behind"],
) -> LookPeriod:
is_ahead = direction == "ahead"
if isinstance(lookdistance, LookDistance):
if isinstance(lookdistance, dt.timedelta):
return LookPeriod(
first=dt.timedelta(days=0) if is_ahead else -lookdistance,
last=lookdistance if is_ahead else dt.timedelta(0),
Expand Down
13 changes: 8 additions & 5 deletions src/timeseriesflattener/feature_specs/outcome.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

import datetime as dt
from dataclasses import InitVar, dataclass
from typing import TYPE_CHECKING

import polars as pl

from .._frame_validator import _validate_col_name_columns_exist
from .meta import LookDistances, ValueFrame, ValueType, _lookdistance_to_normalised_lookperiod
from .meta import ValueFrame, _lookdistance_to_normalised_lookperiod

if TYPE_CHECKING:
from collections.abc import Sequence
Expand All @@ -20,12 +21,14 @@ class OutcomeSpec:
"""Specification for an outcome. If your outcome is binary/boolean, you can use BooleanOutcomeSpec instead."""

value_frame: ValueFrame
lookahead_distances: InitVar[LookDistances]
lookahead_distances: InitVar[Sequence[dt.timedelta | tuple[dt.timedelta, dt.timedelta]]]
aggregators: Sequence[Aggregator]
fallback: ValueType
fallback: int | float | str | None
column_prefix: str = "outc"

def __post_init__(self, lookahead_distances: LookDistances):
def __post_init__(
self, lookahead_distances: Sequence[dt.timedelta | tuple[dt.timedelta, dt.timedelta]]
):
self.normalised_lookperiod = [
_lookdistance_to_normalised_lookperiod(lookdistance=lookdistance, direction="ahead")
for lookdistance in lookahead_distances
Expand All @@ -47,7 +50,7 @@ class BooleanOutcomeSpec:
"""

init_frame: InitVar[TimestampValueFrame]
lookahead_distances: LookDistances
lookahead_distances: Sequence[dt.timedelta | tuple[dt.timedelta, dt.timedelta]]
aggregators: Sequence[Aggregator]
output_name: str
column_prefix: str = "outc"
Expand Down
20 changes: 8 additions & 12 deletions src/timeseriesflattener/feature_specs/prediction_times.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,15 @@
from dataclasses import InitVar, dataclass
from typing import TYPE_CHECKING

import pandas as pd
import polars as pl

from .._frame_validator import _validate_col_name_columns_exist
from ..frame_utilities.anyframe_to_lazyframe import _anyframe_to_lazyframe
from .default_column_names import (
default_entity_id_col_name,
default_pred_time_col_name,
default_prediction_time_uuid_col_name,
)

if TYPE_CHECKING:
from collections.abc import Sequence

from .meta import InitDF_T


@dataclass
class PredictionTimeFrame:
Expand All @@ -28,13 +22,15 @@ class PredictionTimeFrame:
timestamp_col_name: The name of the column containing the timestamps for when to make a prediction.
"""

init_df: InitVar[InitDF_T]
entity_id_col_name: str = default_entity_id_col_name
timestamp_col_name: str = default_pred_time_col_name
prediction_time_uuid_col_name: str = default_prediction_time_uuid_col_name
init_df: InitVar[pl.LazyFrame | pl.DataFrame | pd.DataFrame]
entity_id_col_name: str = "entity_id"
timestamp_col_name: str = "pred_timestamp"
prediction_time_uuid_col_name: str = "prediction_time_uuid"
coerce_to_lazy: InitVar[bool] = True

def __post_init__(self, init_df: InitDF_T, coerce_to_lazy: bool):
def __post_init__(
self, init_df: pl.LazyFrame | pl.DataFrame | pd.DataFrame, coerce_to_lazy: bool
):
if coerce_to_lazy:
self.df = _anyframe_to_lazyframe(init_df)
else:
Expand Down
11 changes: 7 additions & 4 deletions src/timeseriesflattener/feature_specs/predictor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

import datetime as dt
from dataclasses import InitVar, dataclass
from typing import TYPE_CHECKING

from .._frame_validator import _validate_col_name_columns_exist
from .meta import LookDistances, ValueFrame, ValueType, _lookdistance_to_normalised_lookperiod
from .meta import ValueFrame, _lookdistance_to_normalised_lookperiod

if TYPE_CHECKING:
from collections.abc import Sequence
Expand All @@ -25,12 +26,14 @@ class PredictorSpec:
"""

value_frame: ValueFrame
lookbehind_distances: InitVar[LookDistances]
lookbehind_distances: InitVar[Sequence[dt.timedelta | tuple[dt.timedelta, dt.timedelta]]]
aggregators: Sequence[Aggregator]
fallback: ValueType
fallback: int | float | str | None
column_prefix: str = "pred"

def __post_init__(self, lookbehind_distances: LookDistances):
def __post_init__(
self, lookbehind_distances: Sequence[dt.timedelta | tuple[dt.timedelta, dt.timedelta]]
):
self.normalised_lookperiod = [
_lookdistance_to_normalised_lookperiod(lookdistance=lookdistance, direction="behind")
for lookdistance in lookbehind_distances
Expand Down
14 changes: 5 additions & 9 deletions src/timeseriesflattener/feature_specs/static.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,21 @@
from __future__ import annotations

from dataclasses import InitVar, dataclass
from typing import TYPE_CHECKING

import pandas as pd
import polars as pl

from .._frame_validator import _validate_col_name_columns_exist
from ..frame_utilities.anyframe_to_lazyframe import _anyframe_to_lazyframe
from .default_column_names import default_entity_id_col_name

if TYPE_CHECKING:
from .meta import InitDF_T, ValueType


@dataclass
class StaticFrame:
init_df: InitVar[InitDF_T]
init_df: InitVar[pl.LazyFrame | pl.DataFrame | pd.DataFrame]

entity_id_col_name: str = default_entity_id_col_name
entity_id_col_name: str = "entity_id"

def __post_init__(self, init_df: InitDF_T):
def __post_init__(self, init_df: pl.LazyFrame | pl.DataFrame | pd.DataFrame):
self.df = _anyframe_to_lazyframe(init_df)
_validate_col_name_columns_exist(obj=self)
self.value_col_names = [col for col in self.df.columns if col != self.entity_id_col_name]
Expand All @@ -41,4 +37,4 @@ class StaticSpec:

value_frame: StaticFrame
column_prefix: str
fallback: ValueType
fallback: int | float | str | None
4 changes: 2 additions & 2 deletions src/timeseriesflattener/feature_specs/timedelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import TYPE_CHECKING, Literal

from .._frame_validator import _validate_col_name_columns_exist
from .meta import ValueFrame, ValueType
from .meta import ValueFrame

if TYPE_CHECKING:
import polars as pl
Expand All @@ -15,7 +15,7 @@
@dataclass
class TimeDeltaSpec:
init_frame: TimestampValueFrame
fallback: ValueType
fallback: int | float | str | None
output_name: str
column_prefix: str = "pred"
time_format: Literal["seconds", "minutes", "hours", "days", "years"] = "days"
Expand Down
12 changes: 4 additions & 8 deletions src/timeseriesflattener/feature_specs/timestamp_frame.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
from __future__ import annotations

from dataclasses import InitVar, dataclass
from typing import TYPE_CHECKING

import pandas as pd
import polars as pl

from .._frame_validator import _validate_col_name_columns_exist
from ..frame_utilities.anyframe_to_lazyframe import _anyframe_to_lazyframe
from .default_column_names import default_entity_id_col_name

if TYPE_CHECKING:
from .meta import InitDF_T


@dataclass
Expand All @@ -22,11 +18,11 @@ class TimestampValueFrame:
value_timestamp_col_name: The name of the column containing the timestamps. Must be a string, and the column's values must be datetimes.
"""

init_df: InitVar[InitDF_T]
entity_id_col_name: str = default_entity_id_col_name
init_df: InitVar[pl.LazyFrame | pl.DataFrame | pd.DataFrame]
entity_id_col_name: str = "entity_id"
value_timestamp_col_name: str = "timestamp"

def __post_init__(self, init_df: InitDF_T):
def __post_init__(self, init_df: pl.LazyFrame | pl.DataFrame | pd.DataFrame):
self.df = _anyframe_to_lazyframe(init_df)
_validate_col_name_columns_exist(obj=self)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import pandas as pd
import polars as pl

if TYPE_CHECKING:
from ..feature_specs.meta import InitDF_T


def _anyframe_to_lazyframe(init_df: InitDF_T) -> pl.LazyFrame:
def _anyframe_to_lazyframe(init_df: pl.LazyFrame | pl.DataFrame | pd.DataFrame) -> pl.LazyFrame:
if isinstance(init_df, pl.LazyFrame):
return init_df
if isinstance(init_df, pl.DataFrame):
Expand All @@ -19,5 +15,5 @@ def _anyframe_to_lazyframe(init_df: InitDF_T) -> pl.LazyFrame:
raise ValueError(f"Unsupported type: {type(init_df)}.")


def _anyframe_to_eagerframe(init_df: InitDF_T) -> pl.DataFrame:
def _anyframe_to_eagerframe(init_df: pl.LazyFrame | pl.DataFrame | pd.DataFrame) -> pl.DataFrame:
return _anyframe_to_lazyframe(init_df).collect()
2 changes: 1 addition & 1 deletion src/timeseriesflattener/process_spec.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING
import datetime as dt

from .feature_specs.static import StaticSpec
Expand Down
Loading

0 comments on commit 0106a84

Please sign in to comment.