Skip to content

Commit

Permalink
updates for snowflake and unity catalog
Browse files Browse the repository at this point in the history
  • Loading branch information
wesmadrigal committed Jan 24, 2025
1 parent 0fa6f93 commit 3cc1bf8
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 45 deletions.
1 change: 1 addition & 0 deletions graphreduce/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class StorageFormatEnum(enum.Enum):
parquet = 'parquet'
tsv = 'tsv'
delta = 'delta'
iceberg = 'iceberg'

class ProviderEnum(enum.Enum):
local = 'local'
Expand Down
58 changes: 44 additions & 14 deletions graphreduce/graph_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def __init__(
label_period_unit : typing.Optional[PeriodUnit] = None,
spark_sqlctx : pyspark.sql.SQLContext = None,
storage_client: typing.Optional[StorageClient] = None,

catalog_client: typing.Any = None,
sql_client: typing.Any = None,
# Only for SQL engines.
lazy_execution: bool = False,
# Debug
Expand Down Expand Up @@ -98,6 +99,8 @@ def __init__(
label_node: optionl GraphReduceNode for the label
label_operation: optional str or callable operation to call to compute the label
label_field: optional str field to compute the label
storage_client: optional `graphreduce.storage.StorageClient` instance to checkpoint compute graphs
catalog_client: optional Unity or Polaris catalog client instance
debug: bool whether to run debug logging
"""
super(GraphReduce, self).__init__(*args, **kwargs)
Expand Down Expand Up @@ -133,6 +136,11 @@ def __init__(
# if using Spark
self.spark_sqlctx = spark_sqlctx
self._storage_client = storage_client

# Catalogs.
self._catalog_client = catalog_client
# SQL engine client.
self._sql_client = sql_client

self.debug = debug

Expand Down Expand Up @@ -203,6 +211,8 @@ def hydrate_graph_attrs (
'spark_sqlctx',
'_storage_client',
'_lazy_execution',
'_catalog_client',
'_sql_client'
]
):
"""
Expand Down Expand Up @@ -440,8 +450,7 @@ def join (
else:
raise Exception(f"Cannot use spark on dataframe of type: {type(parent_node.df)}")
else:
logger.error('no valid compute layer')

logger.error('no valid compute layer')
return None


Expand All @@ -468,9 +477,7 @@ def join_sql (

if not meta:
meta = self.get_edge_data(relation_node, parent_node)

raise Exception(f"no edge metadata for {parent_node} and {relation_node}")

raise Exception(f"no edge metadata for {parent_node} and {relation_node}")
if meta.get('keys'):
meta = meta['keys']

Expand All @@ -483,12 +490,34 @@ def join_sql (

parent_table = parent_node._cur_data_ref if parent_node._cur_data_ref else parent_node.fpath
relation_table = relation_node._cur_data_ref if relation_node._cur_data_ref else relation_node.fpath
JOIN_SQL = f"""
SELECT parent.*, relation.*
FROM {parent_table} parent
LEFT JOIN {relation_table} relation
ON parent.{parent_node.prefix}_{parent_pk} = relation.{relation_node.prefix}_{relation_fk}
"""
# Check if the relation foreign key is already
# in the parent and, if so, rename it.
parent_samp = parent_node.get_sample()
relation_samp = relation_node.get_sample()
logger.info(f"parent columns: {parent_samp.columns}")
logger.info(f"relation columns: {relation_samp.columns}")
relation_fk = f"{relation_node.prefix}_{relation_fk}"
if relation_fk in parent_samp.columns or relation_fk.lower() in [_x.lower() for _x in relation_samp.columns]:
logger.info(f"removing duplicate column {relation_fk} on join")
relation_cols = [
f"relation.{c}"
for c in relation_samp.columns
if c.lower() != relation_fk.lower()
]
sel = ",".join(relation_cols)
JOIN_SQL = f"""
SELECT parent.*, {sel}
FROM {parent_table} parent
LEFT JOIN {relation_table} relation
ON parent.{parent_node.prefix}_{parent_pk} = relation.{relation_fk}
"""
else:
JOIN_SQL = f"""
SELECT parent.*, relation.*
FROM {parent_table} parent
LEFT JOIN {relation_table} relation
ON parent.{parent_node.prefix}_{parent_pk} = relation.{relation_node.prefix}_{relation_fk}
"""
# Always overwrite the join reference.
parent_node.create_ref(JOIN_SQL, 'join', overwrite=True)
self._mark_merged(parent_node, relation_node)
Expand Down Expand Up @@ -629,6 +658,8 @@ def do_transformations_sql(self):
ops = node.do_data()
if not ops:
raise Exception(f"{node.__class__.__name__}.do_data must be implemented")

logger.debug(f"do data: {node.build_query(ops)}")
node.create_ref(node.build_query(ops), node.do_data)

node.create_ref(node.build_query(node.do_annotate()), node.do_annotate)
Expand Down Expand Up @@ -755,7 +786,6 @@ def do_transformations_sql(self):
pjf_ref = parent_node.create_ref(pjf_sql, parent_node.do_post_join_filters)



def do_transformations(self):
"""
Perform all graph transformations
Expand Down Expand Up @@ -823,7 +853,6 @@ def do_transformations(self):
type_func_map=self.feature_stype_map,
compute_layer=self.compute_layer
)


# NOTE: this is pandas specific and will break
# on other compute layers for now
Expand Down Expand Up @@ -905,3 +934,4 @@ def do_transformations(self):
# post-join aggregation
if edge_data['reduce_after_join']:
parent_node.do_post_join_reduce(edge_data['relation_key'], type_func_map=self.feature_stype_map)

Loading

0 comments on commit 3cc1bf8

Please sign in to comment.