Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: validate fallback and aggregator output is compatible #573

Merged
merged 7 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions docs/tutorials/04_from_legacy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
"\n",
"import pandas as pd\n",
"from timeseriesflattener.v1.aggregation_fns import (\n",
" boolean,\n",
" change_per_day,\n",
" count,\n",
" earliest,\n",
Expand Down Expand Up @@ -84,7 +83,7 @@
" summed,\n",
" count,\n",
" variance,\n",
" boolean,\n",
" # boolean, requires the fallback to be a bool\n",
" change_per_day,\n",
" ],\n",
" fallback=[0],\n",
Expand Down Expand Up @@ -144,7 +143,7 @@
" summed,\n",
" count,\n",
" variance,\n",
" boolean,\n",
" # boolean, requires the fallback to be a bool\n",
" change_per_day,\n",
" ],\n",
" fallback=[0],\n",
Expand Down
24 changes: 23 additions & 1 deletion src/timeseriesflattener/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,20 @@
from attr import dataclass


def _validate_compatible_fallback_type_for_aggregator(
aggregator: Aggregator, fallback: str | int | float | None
) -> None:
try:
pl.Series([aggregator.output_type()]).fill_null(fallback)
except:
raise ValueError(
f"Invalid fallback value {fallback} for aggregator {aggregator.__class__.__name__}. Fallback of type {type(fallback)} is not compatible with the aggregator's output type of {type(aggregator.output_type)}."
)


class Aggregator(ABC):
name: str
output_type: type[float | int | bool]

@abstractmethod
def __call__(self, column_name: str) -> pl.Expr:
Expand All @@ -22,6 +34,7 @@ class MinAggregator(Aggregator):
"""Returns the minimum value in the look window."""

name: str = "min"
output_type = float

def __call__(self, column_name: str) -> pl.Expr:
return pl.col(column_name).min().alias(self.new_col_name(column_name))
Expand All @@ -31,6 +44,7 @@ class MaxAggregator(Aggregator):
"""Returns the maximum value in the look window."""

name: str = "max"
output_type = float

def __call__(self, column_name: str) -> pl.Expr:
return pl.col(column_name).max().alias(self.new_col_name(column_name))
Expand All @@ -40,6 +54,7 @@ class MeanAggregator(Aggregator):
"""Returns the mean value in the look window."""

name: str = "mean"
output_type = float

def __call__(self, column_name: str) -> pl.Expr:
return pl.col(column_name).mean().alias(self.new_col_name(column_name))
Expand All @@ -49,6 +64,7 @@ class CountAggregator(Aggregator):
"""Returns the count of non-null values in the look window."""

name: str = "count"
output_type = int

def __call__(self, column_name: str) -> pl.Expr:
return pl.col(column_name).count().alias(self.new_col_name(column_name))
Expand All @@ -60,6 +76,7 @@ class EarliestAggregator(Aggregator):

timestamp_col_name: str
name: str = "earliest"
output_type = float

def __call__(self, column_name: str) -> pl.Expr:
return (
Expand All @@ -76,6 +93,7 @@ class LatestAggregator(Aggregator):

timestamp_col_name: str
name: str = "latest"
output_type = float

def __call__(self, column_name: str) -> pl.Expr:
return (
Expand All @@ -90,6 +108,7 @@ class SumAggregator(Aggregator):
"""Returns the sum of all values in the look window."""

name: str = "sum"
output_type = float

def __call__(self, column_name: str) -> pl.Expr:
return pl.col(column_name).sum().alias(self.new_col_name(column_name))
Expand All @@ -99,15 +118,17 @@ class VarianceAggregator(Aggregator):
"""Returns the variance of the values in the look window"""

name: str = "var"
output_type = float

def __call__(self, column_name: str) -> pl.Expr:
return pl.col(column_name).var().alias(self.new_col_name(column_name))


class HasValuesAggregator(Aggregator):
"""Examines whether any values exist in the column. If so, returns 1, else 0."""
"""Examines whether any values exist in the column. If so, returns True, else False."""

name: str = "bool"
output_type = bool

def __call__(self, column_name: str) -> pl.Expr:
return (
Expand All @@ -126,6 +147,7 @@ class SlopeAggregator(Aggregator):

timestamp_col_name: str
name: str = "slope"
output_type = float

def __call__(self, column_name: str) -> pl.Expr:
# Convert to days for the slope. Arbitrarily chosen to be the number of days since 1970-01-01.
Expand Down
9 changes: 9 additions & 0 deletions src/timeseriesflattener/feature_specs/outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .._frame_validator import _validate_col_name_columns_exist
from .meta import ValueFrame, _lookdistance_to_normalised_lookperiod
from ..aggregators import _validate_compatible_fallback_type_for_aggregator

if TYPE_CHECKING:
from collections.abc import Sequence
Expand All @@ -34,6 +35,10 @@ def __post_init__(
for lookdistance in lookahead_distances
]
_validate_col_name_columns_exist(obj=self)
for aggregator in self.aggregators:
_validate_compatible_fallback_type_for_aggregator(
aggregator=aggregator, fallback=self.fallback
)

@property
def df(self) -> pl.LazyFrame:
Expand Down Expand Up @@ -62,6 +67,10 @@ def __post_init__(self, init_frame: TimestampValueFrame):
]

self.fallback = 0
for aggregator in self.aggregators:
_validate_compatible_fallback_type_for_aggregator(
aggregator=aggregator, fallback=self.fallback
)

self.value_frame = ValueFrame(
init_df=init_frame.df.with_columns((pl.lit(1)).alias(self.output_name)),
Expand Down
5 changes: 5 additions & 0 deletions src/timeseriesflattener/feature_specs/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .._frame_validator import _validate_col_name_columns_exist
from .meta import ValueFrame, _lookdistance_to_normalised_lookperiod
from ..aggregators import _validate_compatible_fallback_type_for_aggregator

if TYPE_CHECKING:
from collections.abc import Sequence
Expand Down Expand Up @@ -39,6 +40,10 @@ def __post_init__(
for lookdistance in lookbehind_distances
]
_validate_col_name_columns_exist(obj=self)
for aggregator in self.aggregators:
_validate_compatible_fallback_type_for_aggregator(
aggregator=aggregator, fallback=self.fallback
)

@property
def df(self) -> pl.LazyFrame:
Expand Down
2 changes: 1 addition & 1 deletion src/timeseriesflattener/feature_specs/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class StaticSpec:

The value_frame must contain columns:
entity_id_col_name: The name of the column containing the entity ids. Must be a string, and the column's values must be strings which are unique.
additional columns containing the values of the static feature. The name of the columns will be used for feature naming.
additional columns containing the values of the static feature. The names of the columns will be used for feature naming.
"""

value_frame: StaticFrame
Expand Down
1 change: 0 additions & 1 deletion src/timeseriesflattener/feature_specs/test_from_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def test_create_predictorspec_from_legacy():
summed,
count,
variance,
boolean,
change_per_day,
],
fallback=[0],
Expand Down
6 changes: 3 additions & 3 deletions src/timeseriesflattener/flattener.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,11 @@ def aggregate_timeseries(
self, specs: Sequence[ValueSpecification], step_size: dt.timedelta | None = None
) -> AggregatedFrame:
"""Perform the aggregation/flattening.

Args:
specs: The specifications for the features to be created.
step_size: The step size for the aggregation.
If not None, will aggregate prediction times in chunks of step_size.
step_size: The step size for the aggregation.
If not None, will aggregate prediction times in chunks of step_size.
Reduce if you encounter memory issues."""
if self.compute_lazily:
print(
Expand Down
33 changes: 32 additions & 1 deletion src/timeseriesflattener/test_aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
SlopeAggregator,
SumAggregator,
VarianceAggregator,
_validate_compatible_fallback_type_for_aggregator,
)
from .spec_processors.temporal import _aggregate_masked_frame
from .test_flattener import assert_frame_equal
Expand Down Expand Up @@ -91,7 +92,10 @@ def expected_output(self) -> pl.DataFrame:
aggregator=VarianceAggregator(), input_values=[1, 2], expected_output_values=[0.5]
),
SingleVarAggregatorExample(
aggregator=HasValuesAggregator(), input_values=[1, 2], expected_output_values=[1], fallback_str="False"
aggregator=HasValuesAggregator(),
input_values=[1, 2],
expected_output_values=[1],
fallback_str="False",
),
SingleVarAggregatorExample(
aggregator=HasValuesAggregator(),
Expand Down Expand Up @@ -159,3 +163,30 @@ def test_aggregator(example: AggregatorExampleType):
)

assert_frame_equal(result.collect(), example.expected_output)


@pytest.mark.parametrize(
("aggregator", "fallback", "valid_fallback"),
[
(MeanAggregator(), 1, True),
(MeanAggregator(), np.nan, True),
(HasValuesAggregator(), np.nan, False),
(HasValuesAggregator(), False, True),
(HasValuesAggregator(), 1, False),
],
)
def test_valid_fallback_for_aggregator(
aggregator: Aggregator, fallback: float | int | bool | None, valid_fallback: bool
):
if valid_fallback:
assert (
_validate_compatible_fallback_type_for_aggregator(
aggregator=aggregator, fallback=fallback
)
is None
)
else:
with pytest.raises(ValueError):
_validate_compatible_fallback_type_for_aggregator(
aggregator=aggregator, fallback=fallback
)
Loading