Skip to content

Commit

Permalink
Merge pull request #15 from wesmadrigal/semantic-types
Browse files Browse the repository at this point in the history
Semantic types
  • Loading branch information
wesmadrigal authored Nov 27, 2024
2 parents ad4dd10 + a23dce1 commit 39a5f8e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 9 deletions.
6 changes: 4 additions & 2 deletions graphreduce/graph_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def join_any (
if isinstance(to_node.df, pyspark.sql.dataframe.DataFrame) and isinstance(from_node.df, pyspark.sql.dataframe.DataFrame):
joined = to_node.df.join(
from_node.df,
on=to_node.df[f"{to_node.prefix}_{to_node_key}"] == from_node.df[f"{relation_node.prefix}_{from_node_key}"],
on=to_node.df[f"{to_node.prefix}_{to_node_key}"] == from_node.df[f"{from_node.prefix}_{from_node_key}"],
how="left"
)
self._mark_merged(to_node, from_node)
Expand Down Expand Up @@ -385,6 +385,7 @@ def join (
if isinstance(relation_df, pyspark.sql.dataframe.DataFrame) and isinstance(parent_node.df, pyspark.sql.dataframe.DataFrame):
original = f"{relation_node.prefix}_{relation_fk}"
new = f"{original}_dupe"

relation_df = relation_df.withColumnRenamed(original, new)
joined = parent_node.df.join(
relation_df,
Expand All @@ -397,7 +398,8 @@ def join (
elif isinstance(parent_node.df, pyspark.sql.dataframe.DataFrame) and isinstance(relation_node.df, pyspark.sql.dataframe.DataFrame):
original = f"{relation_node.prefix}_{relation_fk}"
new = f"{original}_dupe"
relation_df = relation_df.withColumnRenamed(original, new)
relation_node.df = relation_node.df.withColumnRenamed(original, new)

joined = parent_node.df.join(
relation_node.df,
on=parent_node.df[f"{parent_node.prefix}_{parent_pk}"] == relation_node.df[new],
Expand Down
27 changes: 20 additions & 7 deletions graphreduce/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,8 @@ def pandas_auto_features (
"""
agg_funcs = {}

self._stypes = infer_df_stype(self.get_sample().head())

ts_data = self.is_ts_data(reduce_key)
if ts_data:
# Make sure the dates are cleaned.
Expand Down Expand Up @@ -483,6 +485,7 @@ def dask_auto_features (
upward through the graph from child nodes with no feature
definitions.
"""
self._stypes = infer_df_stype(self.get_sample().head())
agg_funcs = {}
for col, stype in self._stypes.items():
_type = str(stype)
Expand All @@ -507,17 +510,22 @@ def spark_auto_features (
upward through the graph from child nodes with no feature
definitions.
"""

self._stypes = infer_df_stype(self.df.sample(0.5).limit(10).toPandas())
agg_funcs = []
ts_data = self.is_ts_data(reduce_key)
if ts_data:
logger.info(f"{self} is time-series data")
for col, stype in self._stypes.items():
_type = str(stype)
#for field in self.df.schema.fields:
# field_meta = json.loads(field.json())
# col = field_meta['name']
# _type = field_meta['type']
if type_func_map.get(_type):

if self._is_identifier(col) and col != reduce_key:
func = 'count'
col_new = f"{col}_{func}"
agg_funcs.append(F.count(F.col(col)).alias(col_new))
elif self._is_identifier(col) and col == reduce_key:
continue
elif type_func_map.get(_type):
for func in type_func_map[_type]:
if func == 'nunique':
func = 'count_distinct'
Expand All @@ -529,6 +537,11 @@ def spark_auto_features (
# If we have time-series data take the time
# since the last event and the cut date.
if ts_data:
# convert the date key to a timestamp
date_key_field = [x for x in self.df.schema.fields if x.name == self.colabbr(self.date_key)][0]
if date_key_field.dataType not in [T.TimestampType(), T.DateType()]:
logger.info(f'{self} date key was {date_key_field.dataType} - converting to Timestamp')
self.df = self.df.withColumn(self.colabbr(self.date_key), F.to_timestamp(F.col(self.colabbr(self.date_key))))
logger.info(f'computed post-aggregation features for {self}')
spark_datetime = self.spark_sqlctx.sql(f"SELECT TO_DATE('{self.cut_date.strftime('%Y-%m-%d')}') as cut_date")
if 'cut_date' not in grouped.columns:
Expand All @@ -549,12 +562,12 @@ def spark_auto_features (
feat_prepped = self.prep_for_features()
feat_prepped = feat_prepped.withColumn(
self.colabbr('time_since_cut'),
F.unix_timestamp(F.col('cut_date')) - F.unix_timestamp(self.colabbr(self.date_key))
F.unix_timestamp(F.col('cut_date')) - F.unix_timestamp(F.col(self.colabbr(self.date_key)))
).drop(F.col('cut_date'))
sub = feat_prepped.filter(
(feat_prepped[self.colabbr('time_since_cut')] >= 0)
&
(feat_prepped[self.colabbr('time_since_cut')] <= d)
(feat_prepped[self.colabbr('time_since_cut')] <= (d*86400))
)
days_group = sub.groupBy(self.colabbr(reduce_key)).agg(
F.count(self.colabbr(self.pk)).alias(self.colabbr(f'{d}d_num_events'))
Expand Down

0 comments on commit 39a5f8e

Please sign in to comment.