Skip to content

Commit

Permalink
Always use datasource specifis COUNT expression (#2003)
Browse files Browse the repository at this point in the history
  • Loading branch information
m1n0 authored Jan 29, 2024
1 parent 594d026 commit 339309f
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 21 deletions.
2 changes: 1 addition & 1 deletion soda/core/soda/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def update_dro(

if data_source_scan:
if distribution_type == "categorical":
query = f"SELECT {column_name}, COUNT(*) FROM {dataset_name} {filter_clause} GROUP BY {column_name} ORDER BY 2 DESC"
query = f"SELECT {column_name}, {data_source_scan.data_source.expr_count_all()} FROM {dataset_name} {filter_clause} GROUP BY {column_name} ORDER BY 2 DESC"
else:
query = f"SELECT {column_name} FROM {dataset_name} {filter_clause}"
logging.info(f"Querying column values to build distribution reference:\n{query}")
Expand Down
20 changes: 10 additions & 10 deletions soda/core/soda/execution/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,11 +717,11 @@ def sql_get_duplicates_count(
sql = dedent(
f"""
WITH frequencies AS (
SELECT COUNT(*) AS frequency
SELECT {self.expr_count_all()} AS frequency
FROM {table_name}
WHERE {filter}
GROUP BY {column_names})
SELECT count(*)
SELECT {self.expr_count_all()}
FROM frequencies
WHERE frequency > 1"""
)
Expand All @@ -741,7 +741,7 @@ def sql_get_duplicates_aggregated(
sql = dedent(
f"""
WITH frequencies AS (
SELECT {column_names}, COUNT(*) AS frequency
SELECT {column_names}, {self.expr_count_all()} AS frequency
FROM {table_name}
WHERE {filter}
GROUP BY {column_names})
Expand Down Expand Up @@ -778,7 +778,7 @@ def sql_get_duplicates(
FROM {table_name}
WHERE {filter}
GROUP BY {column_names}
HAVING count(*) {'<=' if invert_condition else '>'} 1)
HAVING {self.expr_count_all()} {'<=' if invert_condition else '>'} 1)
SELECT {main_query_columns}
FROM {table_name} main
JOIN frequencies ON {join}
Expand Down Expand Up @@ -894,7 +894,7 @@ def profiling_sql_value_frequencies_cte(self, table_name: str, column_name: str)
quoted_column_name = self.quote_column(column_name)
qualified_table_name = self.qualified_table_name(table_name)
return f"""value_frequencies AS (
SELECT {quoted_column_name} AS value_, count(*) AS frequency_
SELECT {quoted_column_name} AS value_, {self.expr_count_all()} AS frequency_
FROM {qualified_table_name}
WHERE {quoted_column_name} IS NOT NULL
GROUP BY {quoted_column_name}
Expand All @@ -910,7 +910,7 @@ def profiling_sql_aggregates_numeric(self, table_name: str, column_name: str) ->
, sum({column_name}) as sum
, var_samp({column_name}) as variance
, stddev_samp({column_name}) as standard_deviation
, count(distinct({column_name})) as distinct_values
, {self.expr_count(f'distinct({column_name})')} as distinct_values
, sum(case when {column_name} is null then 1 else 0 end) as missing_values
FROM {qualified_table_name}
"""
Expand All @@ -922,7 +922,7 @@ def profiling_sql_aggregates_text(self, table_name: str, column_name: str) -> st
return dedent(
f"""
SELECT
count(distinct({column_name})) as distinct_values
{self.expr_count(f'distinct({column_name})')} as distinct_values
, sum(case when {column_name} is null then 1 else 0 end) as missing_values
, avg(length({column_name})) as avg_length
, min(length({column_name})) as min_length
Expand Down Expand Up @@ -1178,10 +1178,10 @@ def literal_boolean(self, boolean: bool):
return "TRUE" if boolean is True else "FALSE"

def expr_count_all(self) -> str:
return "COUNT(*)"
return self.expr_count("*")

def expr_count_conditional(self, condition: str):
return f"COUNT(CASE WHEN {condition} THEN 1 END)"
return self.expr_count(self.expr_conditional(condition, "1"))

def expr_conditional(self, condition: str, expr: str):
return f"CASE WHEN {condition} THEN {expr} END"
Expand Down Expand Up @@ -1380,7 +1380,7 @@ def sql_groupby_count_categorical_column(
)
SELECT
{column_name}
, COUNT(*) AS frequency
, {self.expr_count_all()} AS frequency
FROM processed_table
GROUP BY {column_name}
"""
Expand Down
2 changes: 1 addition & 1 deletion soda/core/soda/execution/query/reference_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(

self.sql = jinja_resolve(
data_source.sql_reference_query(
"count(*)", source_table_name, target_table_name, join_condition, where_condition
data_source.expr_count_all(), source_table_name, target_table_name, join_condition, where_condition
)
)

Expand Down
2 changes: 1 addition & 1 deletion soda/core/tests/data_source/test_duplicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_duplicates_single_column(data_source_fixture: DataSourceFixture):
scan.assert_all_checks_pass()

# This is a simple use case, verify that * is used in the main query.
scan.assert_log("count(*)")
scan.assert_log(data_source_fixture.data_source.expr_count_all())


def test_duplicates_multiple_columns(data_source_fixture: DataSourceFixture):
Expand Down
10 changes: 5 additions & 5 deletions soda/sqlserver/soda/data_sources/sqlserver_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def profiling_sql_aggregates_numeric(self, table_name: str, column_name: str) ->
, sum({column_name}) as sum
, var({column_name}) as variance
, stdev({column_name}) as standard_deviation
, count(distinct({column_name})) as distinct_values
, {self.expr_count(f'distinct({column_name})')} as distinct_values
, sum(case when {column_name} is null then 1 else 0 end) as missing_values
FROM {qualified_table_name}
"""
Expand Down Expand Up @@ -260,7 +260,7 @@ def profiling_sql_aggregates_text(self, table_name: str, column_name: str) -> st
return dedent(
f"""
SELECT
count(distinct({column_name})) as distinct_values
{self.expr_count(f'distinct({column_name})')} as distinct_values
, sum(case when {column_name} is null then 1 else 0 end) as missing_values
, avg(len({column_name})) as avg_length
, min(len({column_name})) as min_length
Expand Down Expand Up @@ -328,7 +328,7 @@ def sql_groupby_count_categorical_column(
)
SELECT {top_limit}
{column_name}
, COUNT(*) AS frequency
, {self.expr_count_all()} AS frequency
FROM processed_table
GROUP BY {column_name}
"""
Expand Down Expand Up @@ -356,7 +356,7 @@ def sql_get_duplicates_aggregated(
sql = dedent(
f"""
WITH frequencies AS (
SELECT {column_names}, COUNT(*) AS frequency
SELECT {column_names}, {self.expr_count_all()} AS frequency
FROM {table_name}
WHERE {filter}
GROUP BY {column_names})
Expand Down Expand Up @@ -394,7 +394,7 @@ def sql_get_duplicates(
FROM {table_name}
WHERE {filter}
GROUP BY {column_names}
HAVING count(*) {'<=' if invert_condition else '>'} 1)
HAVING {self.expr_count_all()} {'<=' if invert_condition else '>'} 1)
SELECT {limit_sql} {main_query_columns}
FROM {table_name} main
JOIN frequencies ON {join}
Expand Down
6 changes: 3 additions & 3 deletions soda/teradata/soda/data_sources/teradata_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def sql_get_duplicates_aggregated(
sql = dedent(
f"""
WITH frequencies AS (
SELECT {column_names}, COUNT(*) AS frequency
SELECT {column_names}, {self.expr_count_all()} AS frequency
FROM {table_name}
WHERE {filter}
GROUP BY {column_names})
Expand Down Expand Up @@ -297,7 +297,7 @@ def sql_get_duplicates(
FROM {table_name}
WHERE {filter}
GROUP BY {column_names}
HAVING count(*) {'<=' if invert_condition else '>'} 1)
HAVING {self.expr_count_all()} {'<=' if invert_condition else '>'} 1)
SELECT {limit_sql} {main_query_columns}
FROM {table_name} main
JOIN frequencies ON {join}
Expand Down Expand Up @@ -528,7 +528,7 @@ def profiling_sql_aggregates_numeric(self, table_name: str, column_name: str) ->
, sum({column_name}) as "sum"
, var_samp({column_name}) as variance
, stddev_samp({column_name}) as standard_deviation
, count(distinct({column_name})) as distinct_values
, {self.expr_count(f'distinct({column_name})')} as distinct_values
, sum(case when {column_name} is null then 1 else 0 end) as missing_values
FROM {qualified_table_name}
"""
Expand Down

0 comments on commit 339309f

Please sign in to comment.