diff --git a/soda/core/soda/cli/cli.py b/soda/core/soda/cli/cli.py index f83d11c94..a42accd80 100644 --- a/soda/core/soda/cli/cli.py +++ b/soda/core/soda/cli/cli.py @@ -352,7 +352,7 @@ def update_dro( if data_source_scan: if distribution_type == "categorical": - query = f"SELECT {column_name}, COUNT(*) FROM {dataset_name} {filter_clause} GROUP BY {column_name} ORDER BY 2 DESC" + query = f"SELECT {column_name}, {data_source_scan.data_source.expr_count_all()} FROM {dataset_name} {filter_clause} GROUP BY {column_name} ORDER BY 2 DESC" else: query = f"SELECT {column_name} FROM {dataset_name} {filter_clause}" logging.info(f"Querying column values to build distribution reference:\n{query}") diff --git a/soda/core/soda/execution/data_source.py b/soda/core/soda/execution/data_source.py index a9778accd..d9924f7d7 100644 --- a/soda/core/soda/execution/data_source.py +++ b/soda/core/soda/execution/data_source.py @@ -717,11 +717,11 @@ def sql_get_duplicates_count( sql = dedent( f""" WITH frequencies AS ( - SELECT COUNT(*) AS frequency + SELECT {self.expr_count_all()} AS frequency FROM {table_name} WHERE {filter} GROUP BY {column_names}) - SELECT count(*) + SELECT {self.expr_count_all()} FROM frequencies WHERE frequency > 1""" ) @@ -741,7 +741,7 @@ def sql_get_duplicates_aggregated( sql = dedent( f""" WITH frequencies AS ( - SELECT {column_names}, COUNT(*) AS frequency + SELECT {column_names}, {self.expr_count_all()} AS frequency FROM {table_name} WHERE {filter} GROUP BY {column_names}) @@ -778,7 +778,7 @@ def sql_get_duplicates( FROM {table_name} WHERE {filter} GROUP BY {column_names} - HAVING count(*) {'<=' if invert_condition else '>'} 1) + HAVING {self.expr_count_all()} {'<=' if invert_condition else '>'} 1) SELECT {main_query_columns} FROM {table_name} main JOIN frequencies ON {join} @@ -894,7 +894,7 @@ def profiling_sql_value_frequencies_cte(self, table_name: str, column_name: str) quoted_column_name = self.quote_column(column_name) qualified_table_name = self.qualified_table_name(table_name) return f"""value_frequencies AS ( - SELECT {quoted_column_name} AS value_, count(*) AS frequency_ + SELECT {quoted_column_name} AS value_, {self.expr_count_all()} AS frequency_ FROM {qualified_table_name} WHERE {quoted_column_name} IS NOT NULL GROUP BY {quoted_column_name} @@ -910,7 +910,7 @@ def profiling_sql_aggregates_numeric(self, table_name: str, column_name: str) -> , sum({column_name}) as sum , var_samp({column_name}) as variance , stddev_samp({column_name}) as standard_deviation - , count(distinct({column_name})) as distinct_values + , {self.expr_count(f'distinct({column_name})')} as distinct_values , sum(case when {column_name} is null then 1 else 0 end) as missing_values FROM {qualified_table_name} """ @@ -922,7 +922,7 @@ def profiling_sql_aggregates_text(self, table_name: str, column_name: str) -> st return dedent( f""" SELECT - count(distinct({column_name})) as distinct_values + {self.expr_count(f'distinct({column_name})')} as distinct_values , sum(case when {column_name} is null then 1 else 0 end) as missing_values , avg(length({column_name})) as avg_length , min(length({column_name})) as min_length @@ -1178,10 +1178,10 @@ def literal_boolean(self, boolean: bool): return "TRUE" if boolean is True else "FALSE" def expr_count_all(self) -> str: - return "COUNT(*)" + return self.expr_count("*") def expr_count_conditional(self, condition: str): - return f"COUNT(CASE WHEN {condition} THEN 1 END)" + return self.expr_count(self.expr_conditional(condition, "1")) def expr_conditional(self, condition: str, expr: str): return f"CASE WHEN {condition} THEN {expr} END" @@ -1380,7 +1380,7 @@ def sql_groupby_count_categorical_column( ) SELECT {column_name} - , COUNT(*) AS frequency + , {self.expr_count_all()} AS frequency FROM processed_table GROUP BY {column_name} """ diff --git a/soda/core/soda/execution/query/reference_query.py b/soda/core/soda/execution/query/reference_query.py index ced5b1839..2a4f19590 100644 --- a/soda/core/soda/execution/query/reference_query.py +++ b/soda/core/soda/execution/query/reference_query.py @@ -84,7 +84,7 @@ def __init__( self.sql = jinja_resolve( data_source.sql_reference_query( - "count(*)", source_table_name, target_table_name, join_condition, where_condition + data_source.expr_count_all(), source_table_name, target_table_name, join_condition, where_condition ) ) diff --git a/soda/core/tests/data_source/test_duplicates.py b/soda/core/tests/data_source/test_duplicates.py index c2b1ee062..dedc279c3 100644 --- a/soda/core/tests/data_source/test_duplicates.py +++ b/soda/core/tests/data_source/test_duplicates.py @@ -17,7 +17,7 @@ def test_duplicates_single_column(data_source_fixture: DataSourceFixture): scan.assert_all_checks_pass() # This is a simple use case, verify that * is used in the main query. - scan.assert_log("count(*)") + scan.assert_log(data_source_fixture.data_source.expr_count_all()) def test_duplicates_multiple_columns(data_source_fixture: DataSourceFixture): diff --git a/soda/sqlserver/soda/data_sources/sqlserver_data_source.py b/soda/sqlserver/soda/data_sources/sqlserver_data_source.py index 9ef64d373..a1168d8cd 100644 --- a/soda/sqlserver/soda/data_sources/sqlserver_data_source.py +++ b/soda/sqlserver/soda/data_sources/sqlserver_data_source.py @@ -177,7 +177,7 @@ def profiling_sql_aggregates_numeric(self, table_name: str, column_name: str) -> , sum({column_name}) as sum , var({column_name}) as variance , stdev({column_name}) as standard_deviation - , count(distinct({column_name})) as distinct_values + , {self.expr_count(f'distinct({column_name})')} as distinct_values , sum(case when {column_name} is null then 1 else 0 end) as missing_values FROM {qualified_table_name} """ @@ -260,7 +260,7 @@ def profiling_sql_aggregates_text(self, table_name: str, column_name: str) -> st return dedent( f""" SELECT - count(distinct({column_name})) as distinct_values + {self.expr_count(f'distinct({column_name})')} as distinct_values , sum(case when {column_name} is null then 1 else 0 end) as missing_values , avg(len({column_name})) as avg_length , min(len({column_name})) as min_length @@ -328,7 +328,7 @@ def sql_groupby_count_categorical_column( ) SELECT {top_limit} {column_name} - , COUNT(*) AS frequency + , {self.expr_count_all()} AS frequency FROM processed_table GROUP BY {column_name} """ @@ -356,7 +356,7 @@ def sql_get_duplicates_aggregated( sql = dedent( f""" WITH frequencies AS ( - SELECT {column_names}, COUNT(*) AS frequency + SELECT {column_names}, {self.expr_count_all()} AS frequency FROM {table_name} WHERE {filter} GROUP BY {column_names}) @@ -394,7 +394,7 @@ def sql_get_duplicates( FROM {table_name} WHERE {filter} GROUP BY {column_names} - HAVING count(*) {'<=' if invert_condition else '>'} 1) + HAVING {self.expr_count_all()} {'<=' if invert_condition else '>'} 1) SELECT {limit_sql} {main_query_columns} FROM {table_name} main JOIN frequencies ON {join} diff --git a/soda/teradata/soda/data_sources/teradata_data_source.py b/soda/teradata/soda/data_sources/teradata_data_source.py index ba6ff3860..419b164ba 100644 --- a/soda/teradata/soda/data_sources/teradata_data_source.py +++ b/soda/teradata/soda/data_sources/teradata_data_source.py @@ -260,7 +260,7 @@ def sql_get_duplicates_aggregated( sql = dedent( f""" WITH frequencies AS ( - SELECT {column_names}, COUNT(*) AS frequency + SELECT {column_names}, {self.expr_count_all()} AS frequency FROM {table_name} WHERE {filter} GROUP BY {column_names}) @@ -297,7 +297,7 @@ def sql_get_duplicates( FROM {table_name} WHERE {filter} GROUP BY {column_names} - HAVING count(*) {'<=' if invert_condition else '>'} 1) + HAVING {self.expr_count_all()} {'<=' if invert_condition else '>'} 1) SELECT {limit_sql} {main_query_columns} FROM {table_name} main JOIN frequencies ON {join} @@ -528,7 +528,7 @@ def profiling_sql_aggregates_numeric(self, table_name: str, column_name: str) -> , sum({column_name}) as "sum" , var_samp({column_name}) as variance , stddev_samp({column_name}) as standard_deviation - , count(distinct({column_name})) as distinct_values + , {self.expr_count(f'distinct({column_name})')} as distinct_values , sum(case when {column_name} is null then 1 else 0 end) as missing_values FROM {qualified_table_name} """