Skip to content

Commit

Permalink
updated sql node and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wesmadrigal committed Nov 9, 2024
1 parent eff5697 commit f6616d6
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 10 deletions.
45 changes: 38 additions & 7 deletions graphreduce/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,20 +535,48 @@ def sql_auto_features (
data type inference libraries.
"""
agg_funcs = []
for col, _type in dict(table_df_sample.dtypes).items():
_type = str(_type)
if type_func_map.get(_type):
if not self._stypes:
self._stypes = infer_df_stype(table_df_sample)
for col, stype in self._stypes.items():
_type = str(stype)
if self._is_identifier(col) and col != reduce_key:
# We only perform counts for identifiers.
func = "count"
col_new = f"{col}_{func}"
agg_funcs.append(
sqlop(
optype=SQLOpType.aggfunc,
opval=f"{func}" + f"({col}) as {col_new}"
)
)
elif self._is_identifier(col) and col == reduce_key:
continue
elif type_func_map.get(_type):
for func in type_func_map[_type]:
# There should be a better top-level mapping
# but for now this will do. SQL engines typically
# don't have 'median' and 'mean'. 'mean' is typically
# just called 'avg'.
if (_type == 'numerical' or 'timestamp') and dict(table_df_sample)[col].__str__() == 'object' and func in ['min','max','mean', 'median']:
logger.info(f'skipped aggregation on {col} because semantic numerical but physical object')
continue
if func in self.FUNCTION_MAPPING:
func = self.FUNCTION_MAPPING.get(func)

if not func:
continue
col_new = f"{col}_{func}"
agg_funcs.append(
sqlop(
optype=SQLOpType.aggfunc,
opval=f"{func}" + f"({col}) as {col_new}"
)
)
# Need the aggregation and time-based filtering.
if not len(agg_funcs):
logger.info(f'No aggregations for {self}')
return self.df
agg = sqlop(optype=SQLOpType.agg, opval=f"{self.colabbr(reduce_key)}")

# Need the aggregation and time-based filtering.
tfilt = self.prep_for_features() if self.prep_for_features() else []

return tfilt + agg_funcs + [agg]
Expand Down Expand Up @@ -966,8 +994,12 @@ class SQLNode(GraphReduceNode):
AWS Athena, which requires additional params.
Subclasses should simply extend the `SQLNode` interface:
"""
FUNCTION_MAPPING = {
'mean': 'avg',
'median': None,
'nunique': None,
}
def __init__ (
self,
*args,
Expand Down Expand Up @@ -1091,7 +1123,6 @@ def create_temp_view (
Create a view with the results of
the query.
"""

try:
sql = f"""
CREATE VIEW {view_name} AS
Expand Down
11 changes: 8 additions & 3 deletions tests/test_graph_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,16 @@ def test_multi_node():
fpath=os.path.join(data_path, 'cust.csv'),
fmt='csv',
prefix='cust',
date_key=None
date_key=None,
pk='id',
)

order_node = DynamicNode(
fpath=os.path.join(data_path, 'orders.csv'),
fmt='csv',
prefix='ord',
date_key='ts'
date_key='ts',
pk='id',
)

gr = GraphReduce(
Expand Down Expand Up @@ -308,26 +310,30 @@ def test_sql_graph_transform():
def test_sql_graph_auto_fe():
conn = _setup_sqlite()
cust = SQLNode(fpath='cust',
pk='id',
prefix='cust',
client=conn,
compute_layer=ComputeLayerEnum.sqlite,
columns=['id','name'])

notif = SQLNode(fpath='notifications',
prefix='not',
pk='id',
client=conn,
compute_layer=ComputeLayerEnum.sqlite,
columns=['id','customer_id','ts'],
date_key='ts')

ni = SQLNode(fpath='notification_interactions',
prefix='ni',
pk='id',
client=conn,
compute_layer=ComputeLayerEnum.sqlite,
columns=['id','notification_id','interaction_type_id','ts'],
date_key='ts')

order = SQLNode(fpath='orders',
pk='id',
prefix='ord',
client=conn,
compute_layer=ComputeLayerEnum.sqlite,
Expand All @@ -351,7 +357,6 @@ def test_sql_graph_auto_fe():
compute_layer=ComputeLayerEnum.sqlite,
use_temp_tables=True,
lazy_execution=False,

# Auto feature engineering params.
auto_features=True,
auto_feature_hops_back=3,
Expand Down

0 comments on commit f6616d6

Please sign in to comment.