Skip to content

Commit

Permalink
compute stats for datetimes
Browse files Browse the repository at this point in the history
  • Loading branch information
polinaeterna committed Jul 31, 2024
1 parent c4bc5e7 commit 79790e0
Showing 1 changed file with 107 additions and 2 deletions.
109 changes: 107 additions & 2 deletions services/worker/src/worker/statistics_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 The HuggingFace Authors.
import datetime
import enum
import io
import logging
Expand Down Expand Up @@ -50,11 +51,12 @@ class ColumnType(str, enum.Enum):
STRING_TEXT = "string_text"
AUDIO = "audio"
IMAGE = "image"
DATETIME = "datetime"


class Histogram(TypedDict):
hist: list[int]
bin_edges: list[Union[int, float]]
bin_edges: list[Union[int, float, str]]


class NumericalStatisticsItem(TypedDict):
Expand All @@ -68,6 +70,17 @@ class NumericalStatisticsItem(TypedDict):
histogram: Optional[Histogram]


class DatetimeStatisticsItem(TypedDict):
nan_count: int
nan_proportion: float
min: Optional[str] # might be None in very rare cases when the whole column is only None values
max: Optional[str]
mean: Optional[str]
median: Optional[str]
std: Optional[str] # string representation of timedelta
histogram: Optional[Histogram]


class CategoricalStatisticsItem(TypedDict):
nan_count: int
nan_proportion: float
Expand All @@ -83,7 +96,9 @@ class BoolStatisticsItem(TypedDict):
frequencies: dict[str, int]


SupportedStatistics = Union[NumericalStatisticsItem, CategoricalStatisticsItem, BoolStatisticsItem]
SupportedStatistics = Union[
NumericalStatisticsItem, CategoricalStatisticsItem, BoolStatisticsItem, DatetimeStatisticsItem
]


class StatisticsPerColumnItem(TypedDict):
Expand Down Expand Up @@ -699,3 +714,93 @@ def get_shape(example: Optional[Union[bytes, dict[str, Any]]]) -> Union[tuple[No
@classmethod
def transform(cls, example: Optional[Union[bytes, dict[str, Any]]]) -> Optional[int]:
return cls.get_width(example)


class DatetimeColumn(Column):
transform_column = IntColumn

@classmethod
def compute_transformed_data(
cls,
data: pl.DataFrame,
column_name: str,
transformed_column_name: str,
min_date: datetime.datetime,
) -> pl.DataFrame:
return data.select((pl.col(column_name) - min_date).dt.total_seconds().alias(transformed_column_name))

@staticmethod
def shift_and_convert_to_string(min_date, seconds) -> str:
return datetime_to_string(min_date + datetime.timedelta(seconds=seconds))

@classmethod
def _compute_statistics(
cls,
data: pl.DataFrame,
column_name: str,
n_samples: int,
) -> DatetimeStatisticsItem:
nan_count, nan_proportion = nan_count_proportion(data, column_name, n_samples)
if nan_count == n_samples: # all values are None
return DatetimeStatisticsItem(
nan_count=n_samples,
nan_proportion=1.0,
min=None,
max=None,
mean=None,
median=None,
std=None,
histogram=None,
)

min_date = data[column_name].min()
timedelta_column_name = f"{column_name}_timedelta"
# compute distribution of time passed from min date in **seconds**
timedelta_df = cls.compute_transformed_data(data, column_name, timedelta_column_name, min_date)
timedelta_stats: NumericalStatisticsItem = cls.transform_column.compute_statistics(
timedelta_df,
column_name=timedelta_column_name,
n_samples=n_samples,
)
for stat in ("max", "mean", "median"):
timedelta_stats[stat] = cls.shift_and_convert_to_string(min_date, timedelta_stats[stat])

bin_edges = [
cls.shift_and_convert_to_string(min_date, seconds) for seconds in timedelta_stats["histogram"]["bin_edges"]
]

return DatetimeStatisticsItem(
nan_count=nan_count,
nan_proportion=nan_proportion,
min=datetime_to_string(min_date),
max=timedelta_stats["max"],
mean=timedelta_stats["mean"],
median=timedelta_stats["median"],
std=str(timedelta_stats["std"]),
histogram=Histogram(
hist=timedelta_stats["histogram"]["hist"],
bin_edges=bin_edges,
),
)

def compute_and_prepare_response(self, data: pl.DataFrame) -> StatisticsPerColumnItem:
stats = self.compute_statistics(data, column_name=self.name, n_samples=self.n_samples)
return StatisticsPerColumnItem(
column_name=self.name,
column_type=ColumnType.DATETIME,
column_statistics=stats,
)


def datetime_to_string(dt: datetime.datetime, format: str = "%Y-%m-%d %H:%M:%S") -> str:
"""
Convert a datetime.datetime object to a string.
Args:
dt (datetime): The datetime object to convert.
format (str, optional): The format of the output string. Defaults to "%Y-%m-%d %H:%M:%S".
Returns:
str: The datetime object as a string.
"""
return dt.strftime(format)

0 comments on commit 79790e0

Please sign in to comment.