Skip to content

Commit

Permalink
support merge batch for distinct array aggregate (apache#10526)
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 authored May 15, 2024
1 parent eddec8e commit 626c6bc
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 6 deletions.
11 changes: 5 additions & 6 deletions datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,11 @@ impl Accumulator for DistinctArrayAggAccumulator {
return Ok(());
}

let array = &states[0];

assert_eq!(array.len(), 1, "state array should only include 1 row!");
// Unwrap outer ListArray then do update batch
let inner_array = array.as_list::<i32>().value(0);
self.update_batch(&[inner_array])
states[0]
.as_list::<i32>()
.iter()
.flatten()
.try_for_each(|val| self.update_batch(&[val]))
}

fn evaluate(&mut self) -> Result<ScalarValue> {
Expand Down
67 changes: 67 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,73 @@ statement error This feature is not implemented: LIMIT not supported in ARRAY_AG
SELECT array_agg(c13 LIMIT 1) FROM aggregate_test_100


# Test distinct aggregate function with merge batch
query II
with A as (
select 1 as id, 2 as foo
UNION ALL
select 1, null
UNION ALL
select 1, null
UNION ALL
select 1, 3
UNION ALL
select 1, 2
---- The order is non-deterministic, verify with length
) select array_length(array_agg(distinct a.foo)), sum(distinct 1) from A a group by a.id;
----
3 1

# It has only AggregateExec with FinalPartitioned mode, so `merge_batch` is used
# If the plan is changed, whether the `merge_batch` is used should be verified to ensure the test coverage
query TT
explain with A as (
select 1 as id, 2 as foo
UNION ALL
select 1, null
UNION ALL
select 1, null
UNION ALL
select 1, 3
UNION ALL
select 1, 2
) select array_length(array_agg(distinct a.foo)), sum(distinct 1) from A a group by a.id;
----
logical_plan
01)Projection: array_length(ARRAY_AGG(DISTINCT a.foo)), SUM(DISTINCT Int64(1))
02)--Aggregate: groupBy=[[a.id]], aggr=[[ARRAY_AGG(DISTINCT a.foo), SUM(DISTINCT Int64(1))]]
03)----SubqueryAlias: a
04)------SubqueryAlias: a
05)--------Union
06)----------Projection: Int64(1) AS id, Int64(2) AS foo
07)------------EmptyRelation
08)----------Projection: Int64(1) AS id, Int64(NULL) AS foo
09)------------EmptyRelation
10)----------Projection: Int64(1) AS id, Int64(NULL) AS foo
11)------------EmptyRelation
12)----------Projection: Int64(1) AS id, Int64(3) AS foo
13)------------EmptyRelation
14)----------Projection: Int64(1) AS id, Int64(2) AS foo
15)------------EmptyRelation
physical_plan
01)ProjectionExec: expr=[array_length(ARRAY_AGG(DISTINCT a.foo)@1) as array_length(ARRAY_AGG(DISTINCT a.foo)), SUM(DISTINCT Int64(1))@2 as SUM(DISTINCT Int64(1))]
02)--AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[ARRAY_AGG(DISTINCT a.foo), SUM(DISTINCT Int64(1))]
03)----CoalesceBatchesExec: target_batch_size=8192
04)------RepartitionExec: partitioning=Hash([id@0], 4), input_partitions=5
05)--------AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[ARRAY_AGG(DISTINCT a.foo), SUM(DISTINCT Int64(1))]
06)----------UnionExec
07)------------ProjectionExec: expr=[1 as id, 2 as foo]
08)--------------PlaceholderRowExec
09)------------ProjectionExec: expr=[1 as id, NULL as foo]
10)--------------PlaceholderRowExec
11)------------ProjectionExec: expr=[1 as id, NULL as foo]
12)--------------PlaceholderRowExec
13)------------ProjectionExec: expr=[1 as id, 3 as foo]
14)--------------PlaceholderRowExec
15)------------ProjectionExec: expr=[1 as id, 2 as foo]
16)--------------PlaceholderRowExec


# FIX: custom absolute values
# csv_query_avg_multi_batch

Expand Down

0 comments on commit 626c6bc

Please sign in to comment.