Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updates for snowflake and unity catalog #20

Merged
merged 4 commits into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions graphreduce/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class StorageFormatEnum(enum.Enum):
parquet = 'parquet'
tsv = 'tsv'
delta = 'delta'
iceberg = 'iceberg'

class ProviderEnum(enum.Enum):
local = 'local'
Expand Down
58 changes: 44 additions & 14 deletions graphreduce/graph_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def __init__(
label_period_unit : typing.Optional[PeriodUnit] = None,
spark_sqlctx : pyspark.sql.SQLContext = None,
storage_client: typing.Optional[StorageClient] = None,

catalog_client: typing.Any = None,
sql_client: typing.Any = None,
# Only for SQL engines.
lazy_execution: bool = False,
# Debug
Expand Down Expand Up @@ -98,6 +99,8 @@ def __init__(
label_node: optionl GraphReduceNode for the label
label_operation: optional str or callable operation to call to compute the label
label_field: optional str field to compute the label
storage_client: optional `graphreduce.storage.StorageClient` instance to checkpoint compute graphs
catalog_client: optional Unity or Polaris catalog client instance
debug: bool whether to run debug logging
"""
super(GraphReduce, self).__init__(*args, **kwargs)
Expand Down Expand Up @@ -133,6 +136,11 @@ def __init__(
# if using Spark
self.spark_sqlctx = spark_sqlctx
self._storage_client = storage_client

# Catalogs.
self._catalog_client = catalog_client
# SQL engine client.
self._sql_client = sql_client

self.debug = debug

Expand Down Expand Up @@ -203,6 +211,8 @@ def hydrate_graph_attrs (
'spark_sqlctx',
'_storage_client',
'_lazy_execution',
'_catalog_client',
'_sql_client'
]
):
"""
Expand Down Expand Up @@ -440,8 +450,7 @@ def join (
else:
raise Exception(f"Cannot use spark on dataframe of type: {type(parent_node.df)}")
else:
logger.error('no valid compute layer')

logger.error('no valid compute layer')
return None


Expand All @@ -468,9 +477,7 @@ def join_sql (

if not meta:
meta = self.get_edge_data(relation_node, parent_node)

raise Exception(f"no edge metadata for {parent_node} and {relation_node}")

raise Exception(f"no edge metadata for {parent_node} and {relation_node}")
if meta.get('keys'):
meta = meta['keys']

Expand All @@ -483,12 +490,34 @@ def join_sql (

parent_table = parent_node._cur_data_ref if parent_node._cur_data_ref else parent_node.fpath
relation_table = relation_node._cur_data_ref if relation_node._cur_data_ref else relation_node.fpath
JOIN_SQL = f"""
SELECT parent.*, relation.*
FROM {parent_table} parent
LEFT JOIN {relation_table} relation
ON parent.{parent_node.prefix}_{parent_pk} = relation.{relation_node.prefix}_{relation_fk}
"""
# Check if the relation foreign key is already
# in the parent and, if so, rename it.
parent_samp = parent_node.get_sample()
relation_samp = relation_node.get_sample()
logger.info(f"parent columns: {parent_samp.columns}")
logger.info(f"relation columns: {relation_samp.columns}")
relation_fk = f"{relation_node.prefix}_{relation_fk}"
if relation_fk in parent_samp.columns or relation_fk.lower() in [_x.lower() for _x in relation_samp.columns]:
logger.info(f"removing duplicate column {relation_fk} on join")
relation_cols = [
f"relation.{c}"
for c in relation_samp.columns
if c.lower() != relation_fk.lower()
]
sel = ",".join(relation_cols)
JOIN_SQL = f"""
SELECT parent.*, {sel}
FROM {parent_table} parent
LEFT JOIN {relation_table} relation
ON parent.{parent_node.prefix}_{parent_pk} = relation.{relation_fk}
"""
else:
JOIN_SQL = f"""
SELECT parent.*, relation.*
FROM {parent_table} parent
LEFT JOIN {relation_table} relation
ON parent.{parent_node.prefix}_{parent_pk} = relation.{relation_node.prefix}_{relation_fk}
"""
# Always overwrite the join reference.
parent_node.create_ref(JOIN_SQL, 'join', overwrite=True)
self._mark_merged(parent_node, relation_node)
Expand Down Expand Up @@ -629,6 +658,8 @@ def do_transformations_sql(self):
ops = node.do_data()
if not ops:
raise Exception(f"{node.__class__.__name__}.do_data must be implemented")

logger.debug(f"do data: {node.build_query(ops)}")
node.create_ref(node.build_query(ops), node.do_data)

node.create_ref(node.build_query(node.do_annotate()), node.do_annotate)
Expand Down Expand Up @@ -755,7 +786,6 @@ def do_transformations_sql(self):
pjf_ref = parent_node.create_ref(pjf_sql, parent_node.do_post_join_filters)



def do_transformations(self):
"""
Perform all graph transformations
Expand Down Expand Up @@ -823,7 +853,6 @@ def do_transformations(self):
type_func_map=self.feature_stype_map,
compute_layer=self.compute_layer
)


# NOTE: this is pandas specific and will break
# on other compute layers for now
Expand Down Expand Up @@ -905,3 +934,4 @@ def do_transformations(self):
# post-join aggregation
if edge_data['reduce_after_join']:
parent_node.do_post_join_reduce(edge_data['relation_key'], type_func_map=self.feature_stype_map)

Loading
Loading