From af3f62ba38e904088fd29000415b322d92cb03ad Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 13 Jul 2023 10:45:11 -0400 Subject: [PATCH 1/3] Allow better vectorization in accumulate functions --- .../groups_accumulator/accumulate.rs | 43 ++++++++++++++++--- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs index bcc9d30bedd8..b8b58b72ceb5 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs @@ -139,10 +139,14 @@ impl NullState { // no nulls, no filter, (false, None) => { let iter = group_indices.iter().zip(data.iter()); + for (&group_index, &new_value) in iter { - seen_values.set_bit(group_index, true); value_fn(group_index, new_value); } + // update seen values in separate loop + for &group_index in group_indices.iter() { + seen_values.set_bit(group_index, true); + } } // nulls, no filter (true, None) => { @@ -157,6 +161,7 @@ impl NullState { let data_remainder = data_chunks.remainder(); group_indices_chunks + .clone() .zip(data_chunks) .zip(bit_chunks.iter()) .for_each(|((group_index_chunk, data_chunk), mask)| { @@ -167,7 +172,6 @@ impl NullState { // valid bit was set, real value let is_valid = (mask & index_mask) != 0; if is_valid { - seen_values.set_bit(group_index, true); value_fn(group_index, new_value); } index_mask <<= 1; @@ -175,6 +179,21 @@ impl NullState { ) }); + group_indices_chunks.zip(bit_chunks.iter()).for_each( + |(group_index_chunk, mask)| { + // index_mask has value 1 << i in the loop + let mut index_mask = 1; + group_index_chunk.iter().for_each(|&group_index| { + // valid bit was set, real value + let is_valid = (mask & index_mask) != 0; + if is_valid { + seen_values.set_bit(group_index, true); + } + index_mask <<= 1; + }) + }, + ); + // handle any remaining bits (after the initial 64) let remainder_bits = bit_chunks.remainder_bits(); group_indices_remainder @@ -184,10 +203,17 @@ impl NullState { .for_each(|(i, (&group_index, &new_value))| { let is_valid = remainder_bits & (1 << i) != 0; if is_valid { - seen_values.set_bit(group_index, true); value_fn(group_index, new_value); } }); + group_indices_remainder.iter().enumerate().for_each( + |(i, &group_index)| { + let is_valid = remainder_bits & (1 << i) != 0; + if is_valid { + seen_values.set_bit(group_index, true); + } + }, + ); } // no nulls, but a filter (false, Some(filter)) => { @@ -201,10 +227,17 @@ impl NullState { .zip(filter.iter()) .for_each(|((&group_index, &new_value), filter_value)| { if let Some(true) = filter_value { - seen_values.set_bit(group_index, true); value_fn(group_index, new_value); } - }) + }); + + group_indices.iter().zip(filter.iter()).for_each( + |(&group_index, filter_value)| { + if let Some(true) = filter_value { + seen_values.set_bit(group_index, true); + } + }, + ) } // both null values and filters (true, Some(filter)) => { From e15f5eed8b5a36bc2af7c6c0679107d5dbc63ac4 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 13 Jul 2023 15:01:52 -0400 Subject: [PATCH 2/3] try and special case nulls --- .../physical-expr/src/aggregate/average.rs | 41 +++--- .../groups_accumulator/accumulate.rs | 124 ++++++++++-------- .../aggregate/groups_accumulator/bool_op.rs | 2 +- .../aggregate/groups_accumulator/prim_op.rs | 2 +- .../physical-expr/src/aggregate/min_max.rs | 4 +- 5 files changed, 95 insertions(+), 78 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs index e95e9fcf877a..843401de2368 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/physical-expr/src/aggregate/average.rs @@ -565,30 +565,33 @@ where let sums = std::mem::take(&mut self.sums); let nulls = self.null_state.build(); - assert_eq!(nulls.len(), sums.len()); assert_eq!(counts.len(), sums.len()); // don't evaluate averages with null inputs to avoid errors on null values - let array: PrimitiveArray = if nulls.null_count() > 0 { - let mut builder = PrimitiveBuilder::::with_capacity(nulls.len()); - let iter = sums.into_iter().zip(counts.into_iter()).zip(nulls.iter()); - - for ((sum, count), is_valid) in iter { - if is_valid { - builder.append_value((self.avg_fn)(sum, count)?) - } else { - builder.append_null(); + let array: PrimitiveArray = match nulls { + Some(nulls) if nulls.null_count() > 0 => { + assert_eq!(nulls.len(), sums.len()); + let mut builder = PrimitiveBuilder::::with_capacity(nulls.len()); + let iter = sums.into_iter().zip(counts.into_iter()).zip(nulls.iter()); + + for ((sum, count), is_valid) in iter { + if is_valid { + builder.append_value((self.avg_fn)(sum, count)?) + } else { + builder.append_null(); + } } + builder.finish() + } + _ => { + let averages: Vec = sums + .into_iter() + .zip(counts.into_iter()) + .map(|(sum, count)| (self.avg_fn)(sum, count)) + .collect::>>()?; + PrimitiveArray::new(averages.into(), nulls) // no copy } - builder.finish() - } else { - let averages: Vec = sums - .into_iter() - .zip(counts.into_iter()) - .map(|(sum, count)| (self.avg_fn)(sum, count)) - .collect::>>()?; - PrimitiveArray::new(averages.into(), Some(nulls)) // no copy }; // fix up decimal precision and scale for decimals @@ -599,7 +602,7 @@ where // return arrays for sums and counts fn state(&mut self) -> Result> { - let nulls = Some(self.null_state.build()); + let nulls = self.null_state.build(); let counts = std::mem::take(&mut self.counts); let counts = UInt64Array::new(counts.into(), nulls.clone()); // zero copy diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs index b8b58b72ceb5..26ab255c76bd 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs @@ -57,20 +57,21 @@ pub struct NullState { /// /// If `seen_values[i]` is false, have not seen any values that /// pass the filter yet for group `i` - seen_values: BooleanBufferBuilder, + seen_values: Option, } impl NullState { pub fn new() -> Self { - Self { - seen_values: BooleanBufferBuilder::new(0), - } + Self { seen_values: None } } /// return the size of all buffers allocated by this null state, not including self pub fn size(&self) -> usize { // capacity is in bits, so convert to bytes - self.seen_values.capacity() / 8 + self.seen_values + .as_ref() + .map(|seen_values| seen_values.capacity() / 8) + .unwrap_or(0) } /// Invokes `value_fn(group_index, value)` for each non null, non @@ -130,8 +131,27 @@ impl NullState { let data: &[T::Native] = values.values(); assert_eq!(data.len(), group_indices.len()); - // ensure the seen_values is big enough (start everything at - // "not seen" valid) + let input_has_nulls = values.null_count() > 0; + + // avoid tracking null state if possible + if !input_has_nulls && + self.seen_values.is_none()&& // we have seen values for all previous groups + opt_filter.is_none() + // we will be looking at all values + { + // since we know all groups have at least one non + // value, there is no need to track seen_values + // individually + let iter = group_indices.iter().zip(data.iter()); + + for (&group_index, &new_value) in iter { + value_fn(group_index, new_value); + } + return; + } + + // have been tracking values previously, so we still need to + // track them here let seen_values = initialize_builder(&mut self.seen_values, total_num_groups, false); @@ -139,13 +159,9 @@ impl NullState { // no nulls, no filter, (false, None) => { let iter = group_indices.iter().zip(data.iter()); - for (&group_index, &new_value) in iter { - value_fn(group_index, new_value); - } - // update seen values in separate loop - for &group_index in group_indices.iter() { seen_values.set_bit(group_index, true); + value_fn(group_index, new_value); } } // nulls, no filter @@ -161,7 +177,6 @@ impl NullState { let data_remainder = data_chunks.remainder(); group_indices_chunks - .clone() .zip(data_chunks) .zip(bit_chunks.iter()) .for_each(|((group_index_chunk, data_chunk), mask)| { @@ -172,6 +187,7 @@ impl NullState { // valid bit was set, real value let is_valid = (mask & index_mask) != 0; if is_valid { + seen_values.set_bit(group_index, true); value_fn(group_index, new_value); } index_mask <<= 1; @@ -179,21 +195,6 @@ impl NullState { ) }); - group_indices_chunks.zip(bit_chunks.iter()).for_each( - |(group_index_chunk, mask)| { - // index_mask has value 1 << i in the loop - let mut index_mask = 1; - group_index_chunk.iter().for_each(|&group_index| { - // valid bit was set, real value - let is_valid = (mask & index_mask) != 0; - if is_valid { - seen_values.set_bit(group_index, true); - } - index_mask <<= 1; - }) - }, - ); - // handle any remaining bits (after the initial 64) let remainder_bits = bit_chunks.remainder_bits(); group_indices_remainder @@ -203,17 +204,10 @@ impl NullState { .for_each(|(i, (&group_index, &new_value))| { let is_valid = remainder_bits & (1 << i) != 0; if is_valid { + seen_values.set_bit(group_index, true); value_fn(group_index, new_value); } }); - group_indices_remainder.iter().enumerate().for_each( - |(i, &group_index)| { - let is_valid = remainder_bits & (1 << i) != 0; - if is_valid { - seen_values.set_bit(group_index, true); - } - }, - ); } // no nulls, but a filter (false, Some(filter)) => { @@ -226,18 +220,11 @@ impl NullState { .zip(data.iter()) .zip(filter.iter()) .for_each(|((&group_index, &new_value), filter_value)| { - if let Some(true) = filter_value { - value_fn(group_index, new_value); - } - }); - - group_indices.iter().zip(filter.iter()).for_each( - |(&group_index, filter_value)| { if let Some(true) = filter_value { seen_values.set_bit(group_index, true); + value_fn(group_index, new_value); } - }, - ) + }) } // both null values and filters (true, Some(filter)) => { @@ -354,9 +341,20 @@ impl NullState { /// group_indices should have null values (because they never saw /// any values) /// + /// If all groups are valid, returns None (no NullBuffer) + /// /// resets the internal state to empty - pub fn build(&mut self) -> NullBuffer { - NullBuffer::new(self.seen_values.finish()) + pub fn build(&mut self) -> Option { + let Some(seen_values) = self.seen_values.as_mut() else { + return None; + }; + + let nulls = NullBuffer::new(seen_values.finish()); + if nulls.null_count() > 0 { + Some(nulls) + } else { + None + } } } @@ -456,10 +454,15 @@ pub fn accumulate_indices( /// /// All new entries are initialized to `default_value` fn initialize_builder( - builder: &mut BooleanBufferBuilder, + builder: &mut Option, total_num_groups: usize, default_value: bool, ) -> &mut BooleanBufferBuilder { + if builder.is_none() { + *builder = Some(BooleanBufferBuilder::new(total_num_groups)); + } + let builder = builder.as_mut().unwrap(); + if builder.len() < total_num_groups { let new_groups = total_num_groups - builder.len(); builder.append_n(new_groups, default_value); @@ -716,8 +719,11 @@ mod test { assert_eq!(accumulated_values, expected_values, "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"); - let seen_values = null_state.seen_values.finish_cloned(); - mock.validate_seen_values(&seen_values); + + if let Some(seen_values) = null_state.seen_values.as_ref() { + let seen_values = seen_values.finish_cloned(); + mock.validate_seen_values(&seen_values); + } // Validate the final buffer (one value per group) let expected_null_buffer = mock.expected_null_buffer(total_num_groups); @@ -833,8 +839,10 @@ mod test { assert_eq!(accumulated_values, expected_values, "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"); - let seen_values = null_state.seen_values.finish_cloned(); - mock.validate_seen_values(&seen_values); + if let Some(seen_values) = null_state.seen_values.as_ref() { + let seen_values = seen_values.finish_cloned(); + mock.validate_seen_values(&seen_values); + } // Validate the final buffer (one value per group) let expected_null_buffer = mock.expected_null_buffer(total_num_groups); @@ -878,10 +886,16 @@ mod test { } /// Create the expected null buffer based on if the input had nulls and a filter - fn expected_null_buffer(&self, total_num_groups: usize) -> NullBuffer { - (0..total_num_groups) + fn expected_null_buffer(&self, total_num_groups: usize) -> Option { + let nulls: NullBuffer = (0..total_num_groups) .map(|group_index| self.expected_seen(group_index)) - .collect() + .collect(); + + if nulls.null_count() > 0 { + Some(nulls) + } else { + None + } } } } diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/bool_op.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/bool_op.rs index 83ffc3717b44..01efdb57efcb 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/bool_op.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/bool_op.rs @@ -101,7 +101,7 @@ where fn evaluate(&mut self) -> Result { let values = self.values.finish(); let nulls = self.null_state.build(); - let values = BooleanArray::new(values, Some(nulls)); + let values = BooleanArray::new(values, nulls); Ok(Arc::new(values)) } diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs index a49651a5e3fa..02f1ff3e3f3f 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs @@ -115,7 +115,7 @@ where fn evaluate(&mut self) -> Result { let values = std::mem::take(&mut self.values); let nulls = self.null_state.build(); - let values = PrimitiveArray::::new(values.into(), Some(nulls)); // no copy + let values = PrimitiveArray::::new(values.into(), nulls); // no copy adjust_output_array(&self.data_type, Arc::new(values)) } diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs index ebf317e6d0f3..8fecb888dcaf 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/physical-expr/src/aggregate/min_max.rs @@ -1407,7 +1407,7 @@ where let min_max = std::mem::take(&mut self.min_max); let nulls = self.null_state.build(); - let min_max = PrimitiveArray::::new(min_max.into(), Some(nulls)); // no copy + let min_max = PrimitiveArray::::new(min_max.into(), nulls); // no copy let min_max = adjust_output_array(&self.data_type, Arc::new(min_max))?; Ok(Arc::new(min_max)) @@ -1418,7 +1418,7 @@ where let nulls = self.null_state.build(); let min_max = std::mem::take(&mut self.min_max); - let min_max = PrimitiveArray::::new(min_max.into(), Some(nulls)); // zero copy + let min_max = PrimitiveArray::::new(min_max.into(), nulls); // zero copy let min_max = adjust_output_array(&self.data_type, Arc::new(min_max))?; From 67c9afec06038985427a413af475eb300d85e33c Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 13 Jul 2023 15:16:19 -0400 Subject: [PATCH 3/3] fix test --- .../src/aggregate/groups_accumulator/accumulate.rs | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs index 26ab255c76bd..711fe5f777ba 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs @@ -150,8 +150,8 @@ impl NullState { return; } - // have been tracking values previously, so we still need to - // track them here + // have been tracking seen values previously, so we still need + // to track them here let seen_values = initialize_builder(&mut self.seen_values, total_num_groups, false); @@ -726,7 +726,14 @@ mod test { } // Validate the final buffer (one value per group) - let expected_null_buffer = mock.expected_null_buffer(total_num_groups); + let expected_null_buffer = if values.null_count() > 0 || opt_filter.is_some() { + mock.expected_null_buffer(total_num_groups) + } else { + // the test data doesn't always pass all group indices + // unlike the real hash grouper, so only build a null + // buffer if it would have made one + None + }; let null_buffer = null_state.build();