diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index ae940a4713453..8db27794c1fdb 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -439,11 +439,13 @@ where #[cfg(test)] mod tests { use crate::expressions::NoOp; + use crate::EmitTo; use super::*; use arrow::array::{ ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, - Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + Int64Array, Int8Array, PrimitiveArray, UInt16Array, UInt32Array, UInt64Array, + UInt8Array, }; use arrow::datatypes::DataType; use arrow::datatypes::{ @@ -806,4 +808,83 @@ mod tests { assert_eq!(result, ScalarValue::Int64(Some(2))); Ok(()) } + + macro_rules! test_count_distinct_groups_evaluate { + ($DATA_TYPE:ident, $PRIM_TYPE:ty) => { + test_count_distinct_groups_evaluate!($DATA_TYPE, $PRIM_TYPE, EmitTo::All); + test_count_distinct_groups_evaluate!($DATA_TYPE, $PRIM_TYPE, EmitTo::First(3)); + }; + + ($DATA_TYPE:ident, $PRIM_TYPE:ty, $EMIT_TO:expr) => { + let group_indices = vec![0, 1, 1, 2, 0, 3, 3, 3, 4]; + let input_values: Vec> = vec![ + Some(1), + Some(1), + Some(1), + None, + Some(2), + Some(4), + Some(5), + None, + Some(7) + ]; + let values = Arc::new(PrimitiveArray::<$DATA_TYPE>::from(input_values)); + + let mut accumulator = DistinctCountGroupsAccumulator::<$DATA_TYPE>::new(); + accumulator.update_batch(&[values], &group_indices, None, 5).unwrap(); + + let mut expected_values = vec![2, 1, 0, 2, 1]; + let expected = Int64Array::from($EMIT_TO.take_needed(&mut expected_values)); + + let evaluated = accumulator.evaluate($EMIT_TO).unwrap(); + let actual = evaluated.as_primitive::(); + assert_eq!( + expected, + *actual, + "DistinctCountGroups::evaluate() test failed for data type {} and emit_to {:?}", + stringify!($DATA_TYPE), + $EMIT_TO + ); + }; + } + + #[test] + fn count_distinct_groups_evaluate_i8() { + test_count_distinct_groups_evaluate!(Int8Type, i8); + } + + #[test] + fn count_distinct_groups_evaluate_i16() { + test_count_distinct_groups_evaluate!(Int16Type, i16); + } + + #[test] + fn count_distinct_groups_evaluate_i32() { + test_count_distinct_groups_evaluate!(Int32Type, i32); + } + + #[test] + fn count_distinct_groups_evaluate_i64() { + test_count_distinct_groups_evaluate!(Int64Type, i64); + } + + #[test] + fn count_distinct_groups_evaluate_u8() { + test_count_distinct_groups_evaluate!(UInt8Type, u8); + } + + #[test] + fn count_distinct_groups_evaluate_u16() { + test_count_distinct_groups_evaluate!(UInt16Type, u16); + } + + #[test] + fn count_distinct_groups_evaluate_u32() { + test_count_distinct_groups_evaluate!(UInt32Type, u32); + } + + #[test] + fn count_distinct_groups_evaluate_u64() { + test_count_distinct_groups_evaluate!(UInt64Type, u64); + } }