Skip to content

Commit

Permalink
Merge pull request #14 from wesmadrigal/semantic-types
Browse files Browse the repository at this point in the history
added feature engineering enhancements for spark SQL with semantic ty…
  • Loading branch information
wesmadrigal authored Nov 11, 2024
2 parents 4d79790 + d308690 commit ad4dd10
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 30 deletions.
21 changes: 15 additions & 6 deletions graphreduce/graph_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from structlog import get_logger
import pyspark
import pyvis
from pyspark.sql import functions as F

# internal
from graphreduce.node import GraphReduceNode, DynamicNode, SQLNode
Expand Down Expand Up @@ -129,12 +130,12 @@ def __init__(
self._lazy_execution = lazy_execution

# if using Spark
self._sqlctx = spark_sqlctx
self.spark_sqlctx = spark_sqlctx
self._storage_client = storage_client

self.debug = debug

if self.compute_layer == ComputeLayerEnum.spark and self._sqlctx is None:
if self.compute_layer == ComputeLayerEnum.spark and self.spark_sqlctx is None:
raise Exception(f"Must provide a `spark_sqlctx` kwarg if using {self.compute_layer.value} as compute layer")

if self.label_node and (self.label_period_val is None or self.label_period_unit is None):
Expand Down Expand Up @@ -382,19 +383,27 @@ def join (

elif self.compute_layer == ComputeLayerEnum.spark:
if isinstance(relation_df, pyspark.sql.dataframe.DataFrame) and isinstance(parent_node.df, pyspark.sql.dataframe.DataFrame):
original = f"{relation_node.prefix}_{relation_fk}"
new = f"{original}_dupe"
relation_df = relation_df.withColumnRenamed(original, new)
joined = parent_node.df.join(
relation_df,
on=parent_node.df[f"{parent_node.prefix}_{parent_pk}"] == relation_df[f"{relation_node.prefix}_{relation_fk}"],
on=parent_node.df[f"{parent_node.prefix}_{parent_pk}"] == relation_df[new],
how="left"
)
).drop(F.col(new))

self._mark_merged(parent_node, relation_node)
return joined
elif isinstance(parent_node.df, pyspark.sql.dataframe.DataFrame) and isinstance(relation_node.df, pyspark.sql.dataframe.DataFrame):
original = f"{relation_node.prefix}_{relation_fk}"
new = f"{original}_dupe"
relation_df = relation_df.withColumnRenamed(original, new)
joined = parent_node.df.join(
relation_node.df,
on=parent_node.df[f"{parent_node.prefix}_{parent_pk}"] == relation_node.df[f"{relation_node.prefix}_{relation_fk}"],
on=parent_node.df[f"{parent_node.prefix}_{parent_pk}"] == relation_node.df[new],
how="left"
)
).drop(F.col(new))

self._mark_merged(parent_node, relation_node)
return joined
else:
Expand Down
117 changes: 93 additions & 24 deletions graphreduce/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import pandas as pd
from dask import dataframe as dd
import pyspark
from pyspark.sql import functions as F, types as T
from structlog import get_logger
from dateutil.parser import parse as date_parse
from torch_frame import stype
from torch_frame.utils import infer_df_stype

# internal
Expand Down Expand Up @@ -112,7 +112,6 @@ def __init__ (
self.fpath = fpath
self.fmt = fmt
self.compute_layer = compute_layer
self.dialect = dialect
self.cut_date = cut_date
self.compute_period_val = compute_period_val
self.compute_period_unit = compute_period_unit
Expand Down Expand Up @@ -197,6 +196,9 @@ def is_ts_data (
n = self.df.count()
if float(grouped) / float(n) < 0.9:
return True
#TODO(wes): define the SQL logic.
elif self.compute_layer in [ComputeLayerEnum.sqlite]:
pass
return False


Expand Down Expand Up @@ -257,9 +259,9 @@ def do_data (
self._stypes = infer_df_stype(self.df.head())
elif self.compute_layer.value == 'spark':
if not hasattr(self, 'df') or (hasattr(self, 'df') and not isinstance(self.df, pyspark.sql.DataFrame)):
if self.dialect == 'python':
if self.fmt != 'sql':
self.df = getattr(self.spark_sqlctx.read, f"{self.fmt}")(self.fpath)
elif self.dialect == 'sql':
elif self.fmt == 'sql':
self.df = self.spark_sqlctx.sql(f"select * from {self.fpath}")

if self.columns:
Expand All @@ -268,7 +270,7 @@ def do_data (
self.df = self.df.withColumnRenamed(c, f"{self.prefix}_{c}")

# Infer the semantic type with `torch_frame`.
self._stypes = infer_df_stype(self.df.head(100).toPandas())
self._stypes = infer_df_stype(self.df.sample(0.5).limit(10).toPandas())
# at this point of connectors we may want to try integrating
# with something like fugue: https://github.com/fugue-project/fugue
elif self.compute_layer.value == 'ray':
Expand Down Expand Up @@ -482,8 +484,8 @@ def dask_auto_features (
definitions.
"""
agg_funcs = {}
for col, _type in dict(self.df.dtypes).items():
_type = str(_type)
for col, stype in self._stypes.items():
_type = str(stype)
if type_func_map.get(_type):
for func in type_func_map[_type]:
col_new = f"{col}_{func}"
Expand All @@ -506,17 +508,67 @@ def spark_auto_features (
definitions.
"""
agg_funcs = []
for field in self.df.schema.fields:
field_meta = json.loads(field.json())
col = field_meta['name']
_type = field_meta['type']
ts_data = self.is_ts_data(reduce_key)
if ts_data:
logger.info(f"{self} is time-series data")
for col, stype in self._stypes.items():
_type = str(stype)
#for field in self.df.schema.fields:
# field_meta = json.loads(field.json())
# col = field_meta['name']
# _type = field_meta['type']
if type_func_map.get(_type):
for func in type_func_map[_type]:
if func == 'nunique':
func = 'count_distinct'
col_new = f"{col}_{func}"
agg_funcs.append(getattr(F, func)(F.col(col)).alias(col_new))
return self.prep_for_features().groupby(self.colabbr(reduce_key)).agg(
grouped = self.prep_for_features().groupby(self.colabbr(reduce_key)).agg(
*agg_funcs
)
# If we have time-series data take the time
# since the last event and the cut date.
if ts_data:
logger.info(f'computed post-aggregation features for {self}')
spark_datetime = self.spark_sqlctx.sql(f"SELECT TO_DATE('{self.cut_date.strftime('%Y-%m-%d')}') as cut_date")
if 'cut_date' not in grouped.columns:
grouped = grouped.crossJoin(spark_datetime)

grouped = grouped.withColumn(
self.colabbr('time_since_last_event'),
F.unix_timestamp(F.col('cut_date')) - F.unix_timestamp(F.col(f'{self.colabbr(self.date_key)}_max'))
).drop(F.col('cut_date'))
if 'cut_date' not in self.df.columns:
self.df = self.df.crossJoin(spark_datetime)

# Number of events in last strata of time
days = [30, 60, 90, 365, 730]
for d in days:
if d > self.compute_period_val:
continue
feat_prepped = self.prep_for_features()
feat_prepped = feat_prepped.withColumn(
self.colabbr('time_since_cut'),
F.unix_timestamp(F.col('cut_date')) - F.unix_timestamp(self.colabbr(self.date_key))
).drop(F.col('cut_date'))
sub = feat_prepped.filter(
(feat_prepped[self.colabbr('time_since_cut')] >= 0)
&
(feat_prepped[self.colabbr('time_since_cut')] <= d)
)
days_group = sub.groupBy(self.colabbr(reduce_key)).agg(
F.count(self.colabbr(self.pk)).alias(self.colabbr(f'{d}d_num_events'))
)
# join this back to the main dataset.
grouped = grouped.join(
days_group,
on=self.colabbr(reduce_key),
how='left'
)
logger.info(f'merged all ts groupings to {self}')
if 'cut_date' in grouped.columns:
grouped = grouped.drop(F.col('cut_date'))
return grouped


def sql_auto_features (
Expand Down Expand Up @@ -593,9 +645,11 @@ def sql_auto_labels (
provided columns.
"""
agg_funcs = {}
for col, _type in dict(table_df_sample.dtypes).items():
if not self._stypes:
self._stypes = infer_df_stype(table_df_samp)
for col, stype in self._stypes.items():
if col.endswith('_label'):
_type = str(_type)
_type = str(stype)
if type_func_map.get(_type):
for func in type_func_map[_type]:
col_new = f"{col}_{func}_label"
Expand All @@ -621,9 +675,10 @@ def pandas_auto_labels (
provided columns.
"""
agg_funcs = {}
for col, _type in dict(self.df.dtypes).items():

for col, stype in self._stypes.items():
_type = str(stype)
if col.endswith('_label') or col == self.label_field or col == f'{self.colabbr(self.label_field)}':
_type = str(_type)
if type_func_map.get(_type):
for func in type_func_map[_type]:
col_new = f"{col}_{func}_label"
Expand All @@ -643,9 +698,9 @@ def dask_auto_labels (
provided columns.
"""
agg_funcs = {}
for col, _type in dict(self.df.dtypes).items():
for col, stype in self._stypes.items():
if col.endswith('_label'):
_type = str(_type)
_type = str(stype)
if type_func_map.get(_type):
for func in type_func_map[_type]:
col_new = f"{col}_{func}_label"
Expand All @@ -665,13 +720,17 @@ def spark_auto_labels (
provided columns.
"""
agg_funcs = []
for field in self.df.schema.fields:
field_meta = json.loads(field.json())
col = field_meta['name']
_type = field_meta['type']
#for field in self.df.schema.fields:
# field_meta = json.loads(field.json())
# col = field_meta['name']
# _type = field_meta['type']
for col, stype in self._stypes.items():
_type = str(stype)
if col.endswith('_label'):
if type_func_map.get(_type):
for func in type_func_map[_type]:
if func == 'nunique':
func = 'count_distinct'
col_new = f"{col}_{func}_label"
agg_funcs.append(getattr(F, func)(F.col(col)).alias(col_new))
return self.prep_for_labels().groupby(self.colabbr(reduce_key)).agg(
Expand Down Expand Up @@ -904,7 +963,12 @@ def default_label (
return label_df[[self.colabbr(self.pk), self.colabbr(field)+'_label']]

elif self.compute_layer == ComputeLayerEnum.spark:
pass
if self.reduce:
return self.prep_for_labels().groupBy(self.colabbr(reduce_key)).agg(
getattr(F, op)(F.col(self.colabbr(field))).alias(f'{self.colabbr(field)}_label')
)
else:
pass
elif self.compute_layer in [ComputeLayerEnum.snowflake, ComputeLayerEnum.sqlite, ComputeLayerEnum.mysql, ComputeLayerEnum.postgres, ComputeLayerEnum.redshift, ComputeLayerEnum.athena, ComputeLayerEnum.databricks]:
if self.reduce:
return self.prep_for_labels() + [
Expand Down Expand Up @@ -1136,6 +1200,8 @@ def create_temp_view (
return None


#TODO(wes): optimize by storing previously
# fetch samples.
def get_sample (
self,
n: int = 100,
Expand All @@ -1160,7 +1226,10 @@ def get_sample (
table=table,
n=n
)
return self.execute_query(qry)
samp = self.execute_query(qry)
if not self._stypes:
self._stypes = infer_df_stype(samp)
return samp


def build_query (
Expand Down

0 comments on commit ad4dd10

Please sign in to comment.