From 3cc1bf8f21f06945e7bda38d767a84349b93d396 Mon Sep 17 00:00:00 2001 From: Wes Date: Fri, 24 Jan 2025 17:13:54 -0500 Subject: [PATCH] updates for snowflake and unity catalog --- graphreduce/enum.py | 1 + graphreduce/graph_reduce.py | 58 +++++++++---- graphreduce/node.py | 158 +++++++++++++++++++++++++++++------- graphreduce/storage.py | 1 + requirements.txt | 5 +- setup.py | 4 +- 6 files changed, 182 insertions(+), 45 deletions(-) diff --git a/graphreduce/enum.py b/graphreduce/enum.py index a8bced1..07e1901 100644 --- a/graphreduce/enum.py +++ b/graphreduce/enum.py @@ -34,6 +34,7 @@ class StorageFormatEnum(enum.Enum): parquet = 'parquet' tsv = 'tsv' delta = 'delta' + iceberg = 'iceberg' class ProviderEnum(enum.Enum): local = 'local' diff --git a/graphreduce/graph_reduce.py b/graphreduce/graph_reduce.py index 823cc6c..f3cd850 100644 --- a/graphreduce/graph_reduce.py +++ b/graphreduce/graph_reduce.py @@ -69,7 +69,8 @@ def __init__( label_period_unit : typing.Optional[PeriodUnit] = None, spark_sqlctx : pyspark.sql.SQLContext = None, storage_client: typing.Optional[StorageClient] = None, - + catalog_client: typing.Any = None, + sql_client: typing.Any = None, # Only for SQL engines. lazy_execution: bool = False, # Debug @@ -98,6 +99,8 @@ def __init__( label_node: optionl GraphReduceNode for the label label_operation: optional str or callable operation to call to compute the label label_field: optional str field to compute the label + storage_client: optional `graphreduce.storage.StorageClient` instance to checkpoint compute graphs + catalog_client: optional Unity or Polaris catalog client instance debug: bool whether to run debug logging """ super(GraphReduce, self).__init__(*args, **kwargs) @@ -133,6 +136,11 @@ def __init__( # if using Spark self.spark_sqlctx = spark_sqlctx self._storage_client = storage_client + + # Catalogs. + self._catalog_client = catalog_client + # SQL engine client. + self._sql_client = sql_client self.debug = debug @@ -203,6 +211,8 @@ def hydrate_graph_attrs ( 'spark_sqlctx', '_storage_client', '_lazy_execution', + '_catalog_client', + '_sql_client' ] ): """ @@ -440,8 +450,7 @@ def join ( else: raise Exception(f"Cannot use spark on dataframe of type: {type(parent_node.df)}") else: - logger.error('no valid compute layer') - + logger.error('no valid compute layer') return None @@ -468,9 +477,7 @@ def join_sql ( if not meta: meta = self.get_edge_data(relation_node, parent_node) - - raise Exception(f"no edge metadata for {parent_node} and {relation_node}") - + raise Exception(f"no edge metadata for {parent_node} and {relation_node}") if meta.get('keys'): meta = meta['keys'] @@ -483,12 +490,34 @@ def join_sql ( parent_table = parent_node._cur_data_ref if parent_node._cur_data_ref else parent_node.fpath relation_table = relation_node._cur_data_ref if relation_node._cur_data_ref else relation_node.fpath - JOIN_SQL = f""" - SELECT parent.*, relation.* - FROM {parent_table} parent - LEFT JOIN {relation_table} relation - ON parent.{parent_node.prefix}_{parent_pk} = relation.{relation_node.prefix}_{relation_fk} - """ + # Check if the relation foreign key is already + # in the parent and, if so, rename it. + parent_samp = parent_node.get_sample() + relation_samp = relation_node.get_sample() + logger.info(f"parent columns: {parent_samp.columns}") + logger.info(f"relation columns: {relation_samp.columns}") + relation_fk = f"{relation_node.prefix}_{relation_fk}" + if relation_fk in parent_samp.columns or relation_fk.lower() in [_x.lower() for _x in relation_samp.columns]: + logger.info(f"removing duplicate column {relation_fk} on join") + relation_cols = [ + f"relation.{c}" + for c in relation_samp.columns + if c.lower() != relation_fk.lower() + ] + sel = ",".join(relation_cols) + JOIN_SQL = f""" + SELECT parent.*, {sel} + FROM {parent_table} parent + LEFT JOIN {relation_table} relation + ON parent.{parent_node.prefix}_{parent_pk} = relation.{relation_fk} + """ + else: + JOIN_SQL = f""" + SELECT parent.*, relation.* + FROM {parent_table} parent + LEFT JOIN {relation_table} relation + ON parent.{parent_node.prefix}_{parent_pk} = relation.{relation_node.prefix}_{relation_fk} + """ # Always overwrite the join reference. parent_node.create_ref(JOIN_SQL, 'join', overwrite=True) self._mark_merged(parent_node, relation_node) @@ -629,6 +658,8 @@ def do_transformations_sql(self): ops = node.do_data() if not ops: raise Exception(f"{node.__class__.__name__}.do_data must be implemented") + + logger.debug(f"do data: {node.build_query(ops)}") node.create_ref(node.build_query(ops), node.do_data) node.create_ref(node.build_query(node.do_annotate()), node.do_annotate) @@ -755,7 +786,6 @@ def do_transformations_sql(self): pjf_ref = parent_node.create_ref(pjf_sql, parent_node.do_post_join_filters) - def do_transformations(self): """ Perform all graph transformations @@ -823,7 +853,6 @@ def do_transformations(self): type_func_map=self.feature_stype_map, compute_layer=self.compute_layer ) - # NOTE: this is pandas specific and will break # on other compute layers for now @@ -905,3 +934,4 @@ def do_transformations(self): # post-join aggregation if edge_data['reduce_after_join']: parent_node.do_post_join_reduce(edge_data['relation_key'], type_func_map=self.feature_stype_map) + diff --git a/graphreduce/node.py b/graphreduce/node.py index 0858bf3..402bb1d 100644 --- a/graphreduce/node.py +++ b/graphreduce/node.py @@ -17,6 +17,9 @@ from dateutil.parser import parse as date_parse from torch_frame.utils import infer_df_stype import daft +from daft.unity_catalog import UnityCatalog +from pyiceberg.catalog.rest import RestCatalog + # internal from graphreduce.enum import ComputeLayerEnum, PeriodUnit, SQLOpType @@ -43,7 +46,7 @@ class GraphReduceNode(metaclass=abc.ABCMeta): be necessary to implement an engine-specific methods (e.g., `do_data` to get data from Snowflake) -The classes `do_annotate`, `do_filters`, +The methods `do_annotate`, `do_filters`, `do_normalize`, `do_reduce`, `do_labels`, `do_post_join_annotate`, and `do_post_join_filters` are abstractmethods which must be defined. @@ -97,6 +100,7 @@ def __init__ ( delimiter: str = None, encoding: str = None, ts_data: bool = False, + catalog_client: typing.Any = None ): """ Constructor @@ -122,7 +126,7 @@ def __init__ ( self.label_field = label_field self.label_operation = label_operation self.spark_sqlctx = spark_sqlctx - self.columns = columns + self.columns = columns # Read options self.delimiter = delimiter if delimiter else ',' @@ -143,6 +147,8 @@ def __init__ ( if not self.date_key: logger.warning(f"no `date_key` set for {self}") + self._catalog_client = catalog_client + def __repr__ ( self @@ -178,6 +184,8 @@ def _is_identifier ( return True elif col.lower() == 'identifier': return True + elif col.lower().endswith('key'): + return True def is_ts_data ( @@ -198,8 +206,27 @@ def is_ts_data ( if float(grouped) / float(n) < 0.9: return True #TODO(wes): define the SQL logic. - elif self.compute_layer in [ComputeLayerEnum.sqlite]: - pass + elif self.compute_layer in [ComputeLayerEnum.sqlite, ComputeLayerEnum.snowflake, ComputeLayerEnum.databricks, ComputeLayerEnum.athena]: + # run a group by and get the value. + grp_qry = f""" + select count(*) as grouped_rows + from ( + select {reduce_key}, count({self.pk}) + FROM {self.fpath} + group by {reduce_key} + ) t; + """ + row_qry = f""" + select count(*) as row_count from {self.fpath} + """ + grp_df = self.execute_query(grp_qry) + grp_df.columns = [c.lower() for c in grp_df.columns] + row_df = self.execute_query(row_qry) + row_df.columns = [c.lower() for c in row_df.columns] + grp_count = grp_df['grouped_rows'].values[0] + row_count = row_df['row_count'].values[0] + if float(grp_count)/float(row_count) < 0.9: + return True elif self.compute_layer == ComputeLayerEnum.daft: grouped = self.df.groupby(self.colabbr(reduce_key)).agg(self.df[self.colabbr(self.pk)].count()).count_rows() n = self.df.count_rows() @@ -279,7 +306,17 @@ def do_data ( self._stypes = infer_df_stype(self.df.sample(0.5).limit(10).toPandas()) elif self.compute_layer.value == 'daft': if not hasattr(self, 'df') or (hasattr(self, 'df') and not isinstance(self.df, daft.dataframe.dataframe.DataFrame)): - self.df = getattr(daft, f"read_{self.fmt}")(self.fpath) + # Iceberg. + if self._catalog_client: + if isinstance(self._catalog_client, RestCatalog): + tbl = self._catalog_client.load_table(self.fpath) + self.df = daft.read_iceberg(tbl) + elif isinstance(self._catalog_client, UnityCatalog): + tbl = self._catalog_client.load_table(self.fpath) + #TODO(wes): support more than just deltalake. + self.df = daft.read_deltalake(tbl) + else: + self.df = getattr(daft, f"read_{self.fmt}")(self.fpath) self.columns = [c.name() for c in self.df.columns] for col in self.df.columns: self.df = self.df.with_column(f"{self.prefix}_{col.name()}", col) @@ -674,9 +711,9 @@ def sql_auto_features ( logger.info(f'skipped aggregation on {col} because semantic numerical but physical object') continue if func in self.FUNCTION_MAPPING: - func = self.FUNCTION_MAPPING.get(func) + func = self.FUNCTION_MAPPING.get(func) - if not func: + if not func or func == 'nunique': continue col_new = f"{col}_{func}" agg_funcs.append( @@ -926,9 +963,9 @@ def prep_for_features ( elif isinstance(self.df, daft.dataframe.dataframe.DataFrame): return self.df.filter( ( - (self.df[self.colabbr(self.date_key)] < self.cut_date) + (self.df[self.colabbr(self.date_key)] < str(self.cut_date)) & - (self.df[self.colabbr(self.date_key)] > (self.cut_date - datetime.timedelta(minutes=self.compute_period_minutes()))) + (self.df[self.colabbr(self.date_key)] > str((self.cut_date - datetime.timedelta(minutes=self.compute_period_minutes())))) ) | (self.df[self.colabbr(self.date_key)].is_null()) @@ -1001,9 +1038,9 @@ def prep_for_labels ( ) elif isinstance(self.df, daft.dataframe.dataframe.DataFrame): return self.df.filter( - (self.df[self.colabbr(self.date_key)] > self.cut_date) + (self.df[self.colabbr(self.date_key)] > str(self.cut_date)) & - (self.df[self.colabbr(self.date_key)] < (self.cut_date + datetime.timedelta(minutes=self.label_period_minutes()))) + (self.df[self.colabbr(self.date_key)] < str(self.cut_date + datetime.timedelta(minutes=self.label_period_minutes()))) ) else: # Using a SQL engine so need to return `sqlop` instances. @@ -1049,8 +1086,8 @@ def default_label ( field: str label field to call operation on reduce: bool whether or not to reduce """ - if hasattr(self, 'df') and self.colabbr(field) in self.df.columns: - if self.compute_layer in [ComputeLayerEnum.pandas, ComputeLayerEnum.dask]: + if hasattr(self, 'df'): + if self.compute_layer in [ComputeLayerEnum.pandas, ComputeLayerEnum.dask] and self.colabbr(field) in self.df.columns: if self.reduce: if callable(op): return self.prep_for_labels().groupby(self.colabbr(reduce_key)).agg(**{ @@ -1068,14 +1105,14 @@ def default_label ( label_df[self.colabbr(field)+'_label'] = label_df[self.colabbr(field)].apply(lambda x: getattr(x, op)()) return label_df[[self.colabbr(self.pk), self.colabbr(field)+'_label']] - elif self.compute_layer == ComputeLayerEnum.spark: + elif self.compute_layer == ComputeLayerEnum.spark and self.colabbr(field) in self.df.columns: if self.reduce: return self.prep_for_labels().groupBy(self.colabbr(reduce_key)).agg( getattr(F, op)(F.col(self.colabbr(field))).alias(f'{self.colabbr(field)}_label') ) else: pass - elif self.compute_layer == ComputeLayerEnum.daft: + elif self.compute_layer == ComputeLayerEnum.daft and self.colabbr(field) in self.df.column_names: if self.reduce: aggcol = daft.col(self.colabbr(field)) return self.prep_for_labels().groupby(self.colabbr(reduce_key)).agg( @@ -1085,14 +1122,18 @@ def default_label ( pass elif self.compute_layer in [ComputeLayerEnum.snowflake, ComputeLayerEnum.sqlite, ComputeLayerEnum.mysql, ComputeLayerEnum.postgres, ComputeLayerEnum.redshift, ComputeLayerEnum.athena, ComputeLayerEnum.databricks]: if self.reduce: - return self.prep_for_labels() + [ + label_query = self.prep_for_labels() + [ sqlop(optype=SQLOpType.agg, opval=f"{self.colabbr(reduce_key)}"), sqlop(optype=SQLOpType.aggfunc, opval=f"{op}"+ f"({self.colabbr(field)}) as {self.colabbr(field)}_label") ] + logger.info(self.build_query(label_query)) + return label_query else: - return self.prep_for_labels() + [ + label_query = self.prep_for_labels() + [ sqlop(optype=SQLOpType.select, opval=f"{op}" + f"({self.colabbr(field)}) as {self.colabbr(field)}_label") ] + logger.info(self.build_query(label_query)) + return label_query else: pass @@ -1188,7 +1229,7 @@ def __init__ ( """ Constructor. """ - self.client = client + self._sql_client = client self.lazy_execution = lazy_execution # The current data ref. @@ -1206,10 +1247,13 @@ def _clean_refs(self): """ for k, v in self._temp_refs.items(): if v not in self._removed_refs: - sql = f"DROP VIEW {v}" - self.execute_query(sql) - self._removed_refs.append(v) - logger.info(f"dropped {v}") + try: + sql = f"DROP VIEW {v}" + self.execute_query(sql) + self._removed_refs.append(v) + logger.info(f"dropped {v}") + except Exception as e: + continue def get_ref_name ( @@ -1324,12 +1368,13 @@ def get_sample ( """ Gets a sample of rows for the current table or a parameterized table. - """ + """ samp_query = """ - SELECT * - FROM {table} - LIMIT {n} - """ + SELECT * + FROM {table} + LIMIT {n} + """ + if not table: qry = samp_query.format( table=self._cur_data_ref if self._cur_data_ref else self.fpath, @@ -1442,7 +1487,7 @@ def build_query ( def get_client(self) -> typing.Any: - return self.client + return self._sql_client def execute_query ( @@ -1708,3 +1753,60 @@ def create_temp_view ( except Exception as e: logger.error(e) return None + + +class SnowflakeNode(SQLNode): + def __init__( + self, + *args, + **kwargs): + """ +Constructor. + """ + super().__init__(*args, **kwargs) + # Use an available database. + + + def _clean_refs(self): + # Get all views and find the ones + # in temp refs that are still active. + views = self.execute_query("show views") + active_views = [row['name'] for ix,row in views.iterrows()] + for k, v in self._temp_refs.items(): + if v not in self._removed_refs and v in active_views: + sql = f"DROP VIEW {v}" + self.execute_query(sql, ret_df=False) + self._removed_refs.append(v) + logger.info(f"dropped {v}") + + def use_db ( + self, + db: str, + ) -> bool: + try: + res = self.execute_query(f"use database {db}") + return True + except Exception as e: + return False + + + def create_temp_view ( + self, + qry: str, + view_name: str, + ) -> str: + """ +Create a view with the results +of the query. + """ + try: + sql = f""" + CREATE VIEW {view_name} + AS {qry} + """ + self.execute_query(sql, ret_df=False) + self._cur_data_ref = view_name + return view_name + except Exception as e: + logger.error(e) + return None diff --git a/graphreduce/storage.py b/graphreduce/storage.py index 70d06bf..5785e3b 100644 --- a/graphreduce/storage.py +++ b/graphreduce/storage.py @@ -14,6 +14,7 @@ import dask.dataframe as dd import pandas as pd import pyspark +import daft class StorageClient(object): diff --git a/requirements.txt b/requirements.txt index b10f22a..e1fe4f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,9 @@ abstract.jwrotator>=0.3 dask dask[dataframe] -deltalake==0.23.1 -getdaft==0.4.1 +deltalake==0.20.1 +getdaft[unity]==0.4.1 +httpx==0.27.0 icecream==2.1.3 networkx>=2.6.3 numpy>=1.15,<2 diff --git a/setup.py b/setup.py index 7e56aae..3111976 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ setuptools.setup( name="graphreduce", - version = "1.7.5", + version = "1.7.6", url="https://github.com/wesmadrigal/graphreduce", packages = setuptools.find_packages(exclude=[ "docs", "examples" ]), install_requires = [ @@ -23,6 +23,8 @@ "getdaft[unity]==0.4.1", "dask", "dask[dataframe]", + "deltalake==0.20.1", + "httpx==0.27.0", "icecream", "networkx>=2.6.3", "numpy>=1.15,<2",