Skip to content

Commit

Permalink
Infer count() aggregation is not null
Browse files Browse the repository at this point in the history
`count([DISTINCT [expr]])` aggregate function never returns null. Infer
non-nullness of such aggregate expression. This allows elimination of
the HAVING filter for a query such as

    SELECT ... count(*) AS c
    FROM ...
    GROUP BY ...
    HAVING c IS NOT NULL
  • Loading branch information
findepi committed Jul 4, 2024
1 parent 12b1db6 commit 24f4ff1
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 27 deletions.
8 changes: 7 additions & 1 deletion datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,10 +322,16 @@ impl ExprSchemable for Expr {
}
}
Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema),
Expr::AggregateFunction(AggregateFunction { func_def, .. }) => {
Expr::AggregateFunction(AggregateFunction { func_def, args, .. }) => {
match func_def {
// TODO: min/max over non-nullable input should be recognized as non-nullable
AggregateFunctionDefinition::BuiltIn(fun) => fun.nullable(),
// TODO: UDF should be able to customize nullability
AggregateFunctionDefinition::UDF(udf)
if udf.name() == "count" && args.len() <= 1 =>
{
Ok(false)
}
AggregateFunctionDefinition::UDF(_) => Ok(true),
}
}
Expand Down
52 changes: 26 additions & 26 deletions datafusion/optimizer/src/single_distinct_to_groupby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,8 @@ mod tests {
.build()?;

// Should work
let expected = "Projection: count(alias1) AS count(DISTINCT test.b) [count(DISTINCT test.b):Int64;N]\
\n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64;N]\
let expected = "Projection: count(alias1) AS count(DISTINCT test.b) [count(DISTINCT test.b):Int64]\
\n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]\
\n Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

Expand All @@ -419,7 +419,7 @@ mod tests {
.build()?;

// Should not be optimized
let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64;N]\
let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand All @@ -437,7 +437,7 @@ mod tests {
.build()?;

// Should not be optimized
let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64;N]\
let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand All @@ -456,7 +456,7 @@ mod tests {
.build()?;

// Should not be optimized
let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64;N]\
let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand All @@ -470,8 +470,8 @@ mod tests {
.aggregate(Vec::<Expr>::new(), vec![count_distinct(lit(2) * col("b"))])?
.build()?;

let expected = "Projection: count(alias1) AS count(DISTINCT Int32(2) * test.b) [count(DISTINCT Int32(2) * test.b):Int64;N]\
\n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64;N]\
let expected = "Projection: count(alias1) AS count(DISTINCT Int32(2) * test.b) [count(DISTINCT Int32(2) * test.b):Int64]\
\n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]\
\n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

Expand All @@ -487,8 +487,8 @@ mod tests {
.build()?;

// Should work
let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64;N]\
\n Aggregate: groupBy=[[test.a]], aggr=[[count(alias1)]] [a:UInt32, count(alias1):Int64;N]\
let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64]\
\n Aggregate: groupBy=[[test.a]], aggr=[[count(alias1)]] [a:UInt32, count(alias1):Int64]\
\n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

Expand All @@ -507,7 +507,7 @@ mod tests {
.build()?;

// Do nothing
let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(DISTINCT test.c)]] [a:UInt32, count(DISTINCT test.b):Int64;N, count(DISTINCT test.c):Int64;N]\
let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(DISTINCT test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(DISTINCT test.c):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand All @@ -534,8 +534,8 @@ mod tests {
)?
.build()?;
// Should work
let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\
\n Aggregate: groupBy=[[test.a]], aggr=[[count(alias1), MAX(alias1)]] [a:UInt32, count(alias1):Int64;N, MAX(alias1):UInt32;N]\
let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64, MAX(DISTINCT test.b):UInt32;N]\
\n Aggregate: groupBy=[[test.a]], aggr=[[count(alias1), MAX(alias1)]] [a:UInt32, count(alias1):Int64, MAX(alias1):UInt32;N]\
\n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

Expand All @@ -554,7 +554,7 @@ mod tests {
.build()?;

// Do nothing
let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(test.c)]] [a:UInt32, count(DISTINCT test.b):Int64;N, count(test.c):Int64;N]\
let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(test.c):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand All @@ -569,8 +569,8 @@ mod tests {
.build()?;

// Should work
let expected = "Projection: group_alias_0 AS test.a + Int32(1), count(alias1) AS count(DISTINCT test.c) [test.a + Int32(1):Int32, count(DISTINCT test.c):Int64;N]\
\n Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int32, count(alias1):Int64;N]\
let expected = "Projection: group_alias_0 AS test.a + Int32(1), count(alias1) AS count(DISTINCT test.c) [test.a + Int32(1):Int32, count(DISTINCT test.c):Int64]\
\n Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int32, count(alias1):Int64]\
\n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

Expand Down Expand Up @@ -599,8 +599,8 @@ mod tests {
)?
.build()?;
// Should work
let expected = "Projection: test.a, sum(alias2) AS sum(test.c), count(alias1) AS count(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, count(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\
\n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), count(alias1), MAX(alias1)]] [a:UInt32, sum(alias2):UInt64;N, count(alias1):Int64;N, MAX(alias1):UInt32;N]\
let expected = "Projection: test.a, sum(alias2) AS sum(test.c), count(alias1) AS count(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, count(DISTINCT test.b):Int64, MAX(DISTINCT test.b):UInt32;N]\
\n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), count(alias1), MAX(alias1)]] [a:UInt32, sum(alias2):UInt64;N, count(alias1):Int64, MAX(alias1):UInt32;N]\
\n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

Expand All @@ -618,8 +618,8 @@ mod tests {
)?
.build()?;
// Should work
let expected = "Projection: test.a, sum(alias2) AS sum(test.c), MAX(alias3) AS MAX(test.c), count(alias1) AS count(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, MAX(test.c):UInt32;N, count(DISTINCT test.b):Int64;N]\
\n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), MAX(alias3), count(alias1)]] [a:UInt32, sum(alias2):UInt64;N, MAX(alias3):UInt32;N, count(alias1):Int64;N]\
let expected = "Projection: test.a, sum(alias2) AS sum(test.c), MAX(alias3) AS MAX(test.c), count(alias1) AS count(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, MAX(test.c):UInt32;N, count(DISTINCT test.b):Int64]\
\n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), MAX(alias3), count(alias1)]] [a:UInt32, sum(alias2):UInt64;N, MAX(alias3):UInt32;N, count(alias1):Int64]\
\n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2, MAX(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

Expand All @@ -637,8 +637,8 @@ mod tests {
)?
.build()?;
// Should work
let expected = "Projection: test.c, MIN(alias2) AS MIN(test.a), count(alias1) AS count(DISTINCT test.b) [c:UInt32, MIN(test.a):UInt32;N, count(DISTINCT test.b):Int64;N]\
\n Aggregate: groupBy=[[test.c]], aggr=[[MIN(alias2), count(alias1)]] [c:UInt32, MIN(alias2):UInt32;N, count(alias1):Int64;N]\
let expected = "Projection: test.c, MIN(alias2) AS MIN(test.a), count(alias1) AS count(DISTINCT test.b) [c:UInt32, MIN(test.a):UInt32;N, count(DISTINCT test.b):Int64]\
\n Aggregate: groupBy=[[test.c]], aggr=[[MIN(alias2), count(alias1)]] [c:UInt32, MIN(alias2):UInt32;N, count(alias1):Int64]\
\n Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[MIN(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

Expand All @@ -662,7 +662,7 @@ mod tests {
.aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])?
.build()?;
// Do nothing
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) FILTER (WHERE test.a > Int32(5)), count(DISTINCT test.b)]] [c:UInt32, sum(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, count(DISTINCT test.b):Int64;N]\
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) FILTER (WHERE test.a > Int32(5)), count(DISTINCT test.b)]] [c:UInt32, sum(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, count(DISTINCT test.b):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand All @@ -682,7 +682,7 @@ mod tests {
.aggregate(vec![col("c")], vec![sum(col("a")), expr])?
.build()?;
// Do nothing
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64;N]\
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand All @@ -705,7 +705,7 @@ mod tests {
.aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])?
.build()?;
// Do nothing
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) ORDER BY [test.a], count(DISTINCT test.b)]] [c:UInt32, sum(test.a) ORDER BY [test.a]:UInt64;N, count(DISTINCT test.b):Int64;N]\
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) ORDER BY [test.a], count(DISTINCT test.b)]] [c:UInt32, sum(test.a) ORDER BY [test.a]:UInt64;N, count(DISTINCT test.b):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand All @@ -725,7 +725,7 @@ mod tests {
.aggregate(vec![col("c")], vec![sum(col("a")), expr])?
.build()?;
// Do nothing
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]:Int64;N]\
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]:Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand All @@ -746,7 +746,7 @@ mod tests {
.aggregate(vec![col("c")], vec![sum(col("a")), expr])?
.build()?;
// Do nothing
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]:Int64;N]\
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]:Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand Down
15 changes: 15 additions & 0 deletions datafusion/optimizer/tests/optimizer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,21 @@ fn eliminate_nested_filters() {
assert_eq!(expected, format!("{plan:?}"));
}

#[test]
fn eliminate_redundant_null_check_on_count() {
let sql = "\
SELECT col_int32, count(*) c
FROM test
GROUP BY col_int32
HAVING c IS NOT NULL";
let plan = test_sql(sql).unwrap();
let expected = "\
Projection: test.col_int32, count(*) AS c\
\n Aggregate: groupBy=[[test.col_int32]], aggr=[[count(Int64(1)) AS count(*)]]\
\n TableScan: test projection=[col_int32]";
assert_eq!(expected, format!("{plan:?}"));
}

#[test]
fn test_propagate_empty_relation_inner_join_and_unions() {
let sql = "\
Expand Down

0 comments on commit 24f4ff1

Please sign in to comment.