diff --git a/src/timeseriesflattenerv2/feature_specs.py b/src/timeseriesflattenerv2/feature_specs.py index dad1737f..0192f0c6 100644 --- a/src/timeseriesflattenerv2/feature_specs.py +++ b/src/timeseriesflattenerv2/feature_specs.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from typing import Protocol, Sequence, Union +import pandas as pd import polars as pl ValueType = Union[int, float, str] @@ -34,15 +35,23 @@ def to_lazyframe_with_uuid(self) -> pl.LazyFrame: return self.df -@dataclass(frozen=True) +@dataclass class ValueFrame: """A frame that contains the values of a time series.""" - df: pl.LazyFrame + df: pl.LazyFrame | pd.DataFrame value_type: str entity_id_col_name: str = default_entity_id_col_name value_timestamp_col_name: str = "value_timestamp" + @property + def lazyframe(self) -> pl.LazyFrame: + return self.df if isinstance(self.df, pl.LazyFrame) else pl.from_pandas(self.df).lazy() + + @property + def eagerframe(self) -> pl.DataFrame: + return self.df.collect() if isinstance(self.df, pl.LazyFrame) else pl.from_pandas(self.df) + @dataclass(frozen=True) class SlicedFrame: diff --git a/src/timeseriesflattenerv2/flattener.py b/src/timeseriesflattenerv2/flattener.py index fa8f7e5a..48961eac 100644 --- a/src/timeseriesflattenerv2/flattener.py +++ b/src/timeseriesflattenerv2/flattener.py @@ -79,7 +79,7 @@ def _get_timedelta_frame( ) -> TimedeltaFrame: # Join the prediction time dataframe joined_frame = predictiontime_frame.to_lazyframe_with_uuid().join( - value_frame.df, on=predictiontime_frame.entity_id_col_name + value_frame.lazyframe, on=predictiontime_frame.entity_id_col_name ) # Get timedelta @@ -134,7 +134,7 @@ def aggregate_timeseries(self, specs: Sequence[ValueSpecification]) -> Aggregate predictiontime_frame=self.predictiontime_frame, spec=spec ) ) - .map(lambda x: x.df) + .map(lambda x: x.lazyframe) .to_list() ) return AggregatedValueFrame(df=_horizontally_concatenate_dfs(dfs)) diff --git a/src/timeseriesflattenerv2/test_flattener.py b/src/timeseriesflattenerv2/test_flattener.py index 2057694b..8e752543 100644 --- a/src/timeseriesflattenerv2/test_flattener.py +++ b/src/timeseriesflattenerv2/test_flattener.py @@ -17,7 +17,9 @@ def assert_frame_equal(left: pl.DataFrame, right: pl.DataFrame): - polars_testing.assert_frame_equal(left, right, check_dtype=False, check_column_order=False) + polars_testing.assert_frame_equal( + left, right, check_dtype=False, check_column_order=False, check_row_order=False + ) def test_flattener():