diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs index e95e9fcf877a..843401de2368 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/physical-expr/src/aggregate/average.rs @@ -565,30 +565,33 @@ where let sums = std::mem::take(&mut self.sums); let nulls = self.null_state.build(); - assert_eq!(nulls.len(), sums.len()); assert_eq!(counts.len(), sums.len()); // don't evaluate averages with null inputs to avoid errors on null values - let array: PrimitiveArray = if nulls.null_count() > 0 { - let mut builder = PrimitiveBuilder::::with_capacity(nulls.len()); - let iter = sums.into_iter().zip(counts.into_iter()).zip(nulls.iter()); - - for ((sum, count), is_valid) in iter { - if is_valid { - builder.append_value((self.avg_fn)(sum, count)?) - } else { - builder.append_null(); + let array: PrimitiveArray = match nulls { + Some(nulls) if nulls.null_count() > 0 => { + assert_eq!(nulls.len(), sums.len()); + let mut builder = PrimitiveBuilder::::with_capacity(nulls.len()); + let iter = sums.into_iter().zip(counts.into_iter()).zip(nulls.iter()); + + for ((sum, count), is_valid) in iter { + if is_valid { + builder.append_value((self.avg_fn)(sum, count)?) + } else { + builder.append_null(); + } } + builder.finish() + } + _ => { + let averages: Vec = sums + .into_iter() + .zip(counts.into_iter()) + .map(|(sum, count)| (self.avg_fn)(sum, count)) + .collect::>>()?; + PrimitiveArray::new(averages.into(), nulls) // no copy } - builder.finish() - } else { - let averages: Vec = sums - .into_iter() - .zip(counts.into_iter()) - .map(|(sum, count)| (self.avg_fn)(sum, count)) - .collect::>>()?; - PrimitiveArray::new(averages.into(), Some(nulls)) // no copy }; // fix up decimal precision and scale for decimals @@ -599,7 +602,7 @@ where // return arrays for sums and counts fn state(&mut self) -> Result> { - let nulls = Some(self.null_state.build()); + let nulls = self.null_state.build(); let counts = std::mem::take(&mut self.counts); let counts = UInt64Array::new(counts.into(), nulls.clone()); // zero copy diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs index bcc9d30bedd8..711fe5f777ba 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs @@ -57,20 +57,21 @@ pub struct NullState { /// /// If `seen_values[i]` is false, have not seen any values that /// pass the filter yet for group `i` - seen_values: BooleanBufferBuilder, + seen_values: Option, } impl NullState { pub fn new() -> Self { - Self { - seen_values: BooleanBufferBuilder::new(0), - } + Self { seen_values: None } } /// return the size of all buffers allocated by this null state, not including self pub fn size(&self) -> usize { // capacity is in bits, so convert to bytes - self.seen_values.capacity() / 8 + self.seen_values + .as_ref() + .map(|seen_values| seen_values.capacity() / 8) + .unwrap_or(0) } /// Invokes `value_fn(group_index, value)` for each non null, non @@ -130,8 +131,27 @@ impl NullState { let data: &[T::Native] = values.values(); assert_eq!(data.len(), group_indices.len()); - // ensure the seen_values is big enough (start everything at - // "not seen" valid) + let input_has_nulls = values.null_count() > 0; + + // avoid tracking null state if possible + if !input_has_nulls && + self.seen_values.is_none()&& // we have seen values for all previous groups + opt_filter.is_none() + // we will be looking at all values + { + // since we know all groups have at least one non + // value, there is no need to track seen_values + // individually + let iter = group_indices.iter().zip(data.iter()); + + for (&group_index, &new_value) in iter { + value_fn(group_index, new_value); + } + return; + } + + // have been tracking seen values previously, so we still need + // to track them here let seen_values = initialize_builder(&mut self.seen_values, total_num_groups, false); @@ -321,9 +341,20 @@ impl NullState { /// group_indices should have null values (because they never saw /// any values) /// + /// If all groups are valid, returns None (no NullBuffer) + /// /// resets the internal state to empty - pub fn build(&mut self) -> NullBuffer { - NullBuffer::new(self.seen_values.finish()) + pub fn build(&mut self) -> Option { + let Some(seen_values) = self.seen_values.as_mut() else { + return None; + }; + + let nulls = NullBuffer::new(seen_values.finish()); + if nulls.null_count() > 0 { + Some(nulls) + } else { + None + } } } @@ -423,10 +454,15 @@ pub fn accumulate_indices( /// /// All new entries are initialized to `default_value` fn initialize_builder( - builder: &mut BooleanBufferBuilder, + builder: &mut Option, total_num_groups: usize, default_value: bool, ) -> &mut BooleanBufferBuilder { + if builder.is_none() { + *builder = Some(BooleanBufferBuilder::new(total_num_groups)); + } + let builder = builder.as_mut().unwrap(); + if builder.len() < total_num_groups { let new_groups = total_num_groups - builder.len(); builder.append_n(new_groups, default_value); @@ -683,11 +719,21 @@ mod test { assert_eq!(accumulated_values, expected_values, "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"); - let seen_values = null_state.seen_values.finish_cloned(); - mock.validate_seen_values(&seen_values); + + if let Some(seen_values) = null_state.seen_values.as_ref() { + let seen_values = seen_values.finish_cloned(); + mock.validate_seen_values(&seen_values); + } // Validate the final buffer (one value per group) - let expected_null_buffer = mock.expected_null_buffer(total_num_groups); + let expected_null_buffer = if values.null_count() > 0 || opt_filter.is_some() { + mock.expected_null_buffer(total_num_groups) + } else { + // the test data doesn't always pass all group indices + // unlike the real hash grouper, so only build a null + // buffer if it would have made one + None + }; let null_buffer = null_state.build(); @@ -800,8 +846,10 @@ mod test { assert_eq!(accumulated_values, expected_values, "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"); - let seen_values = null_state.seen_values.finish_cloned(); - mock.validate_seen_values(&seen_values); + if let Some(seen_values) = null_state.seen_values.as_ref() { + let seen_values = seen_values.finish_cloned(); + mock.validate_seen_values(&seen_values); + } // Validate the final buffer (one value per group) let expected_null_buffer = mock.expected_null_buffer(total_num_groups); @@ -845,10 +893,16 @@ mod test { } /// Create the expected null buffer based on if the input had nulls and a filter - fn expected_null_buffer(&self, total_num_groups: usize) -> NullBuffer { - (0..total_num_groups) + fn expected_null_buffer(&self, total_num_groups: usize) -> Option { + let nulls: NullBuffer = (0..total_num_groups) .map(|group_index| self.expected_seen(group_index)) - .collect() + .collect(); + + if nulls.null_count() > 0 { + Some(nulls) + } else { + None + } } } } diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/bool_op.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/bool_op.rs index 83ffc3717b44..01efdb57efcb 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/bool_op.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/bool_op.rs @@ -101,7 +101,7 @@ where fn evaluate(&mut self) -> Result { let values = self.values.finish(); let nulls = self.null_state.build(); - let values = BooleanArray::new(values, Some(nulls)); + let values = BooleanArray::new(values, nulls); Ok(Arc::new(values)) } diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs index a49651a5e3fa..02f1ff3e3f3f 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs @@ -115,7 +115,7 @@ where fn evaluate(&mut self) -> Result { let values = std::mem::take(&mut self.values); let nulls = self.null_state.build(); - let values = PrimitiveArray::::new(values.into(), Some(nulls)); // no copy + let values = PrimitiveArray::::new(values.into(), nulls); // no copy adjust_output_array(&self.data_type, Arc::new(values)) } diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs index ebf317e6d0f3..8fecb888dcaf 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/physical-expr/src/aggregate/min_max.rs @@ -1407,7 +1407,7 @@ where let min_max = std::mem::take(&mut self.min_max); let nulls = self.null_state.build(); - let min_max = PrimitiveArray::::new(min_max.into(), Some(nulls)); // no copy + let min_max = PrimitiveArray::::new(min_max.into(), nulls); // no copy let min_max = adjust_output_array(&self.data_type, Arc::new(min_max))?; Ok(Arc::new(min_max)) @@ -1418,7 +1418,7 @@ where let nulls = self.null_state.build(); let min_max = std::mem::take(&mut self.min_max); - let min_max = PrimitiveArray::::new(min_max.into(), Some(nulls)); // zero copy + let min_max = PrimitiveArray::::new(min_max.into(), nulls); // zero copy let min_max = adjust_output_array(&self.data_type, Arc::new(min_max))?;