From 6786203ea8c9c4cfe604f74a5cece8e829fb56c8 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Sun, 8 May 2022 00:00:31 +0800 Subject: [PATCH] Grouped Aggregate in row format (#2375) * first move: re-group aggregates functionalities in core/physical_p/aggregates * basic accumulators * main updating procedure * output as record batch * aggregate with row state * make row non-optional * address comments, add docs, part fix #2455 * Apply suggestions from code review Co-authored-by: Andrew Lamb Co-authored-by: Andrew Lamb --- datafusion/core/Cargo.toml | 7 +- .../core/benches/aggregate_query_sql.rs | 10 + datafusion/core/src/lib.rs | 1 - .../core/src/physical_plan/aggregates/hash.rs | 25 +- .../core/src/physical_plan/aggregates/mod.rs | 67 +++ .../src/physical_plan/aggregates/row_hash.rs | 484 ++++++++++++++++++ .../core/src/physical_plan/hash_utils.rs | 34 +- datafusion/core/tests/sql/aggregates.rs | 22 + datafusion/core/tests/sql/functions.rs | 4 +- datafusion/physical-expr/Cargo.toml | 1 + .../physical-expr/src/aggregate/average.rs | 113 +++- .../physical-expr/src/aggregate/count.rs | 60 +++ .../physical-expr/src/aggregate/min_max.rs | 202 ++++++++ datafusion/physical-expr/src/aggregate/mod.rs | 19 + .../src/aggregate/row_accumulator.rs | 65 +++ datafusion/physical-expr/src/aggregate/sum.rs | 177 ++++++- datafusion/row/src/accessor.rs | 296 +++++++++++ datafusion/row/src/layout.rs | 19 +- datafusion/row/src/lib.rs | 18 +- datafusion/row/src/reader.rs | 5 +- datafusion/row/src/writer.rs | 2 + 21 files changed, 1581 insertions(+), 50 deletions(-) create mode 100644 datafusion/core/src/physical_plan/aggregates/row_hash.rs create mode 100644 datafusion/physical-expr/src/aggregate/row_accumulator.rs create mode 100644 datafusion/row/src/accessor.rs diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 6dc8acb3d4d4..e11e02e95bdf 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -48,8 +48,6 @@ force_hash_collisions = [] jit = ["datafusion-jit"] pyarrow = ["pyo3", "arrow/pyarrow", "datafusion-common/pyarrow"] regex_expressions = ["datafusion-physical-expr/regex_expressions"] -# Used to enable row format experiment -row = ["datafusion-row"] # Used to enable scheduler scheduler = ["rayon"] simd = ["arrow/simd"] @@ -66,7 +64,7 @@ datafusion-data-access = { path = "../../data-access", version = "1.0.0" } datafusion-expr = { path = "../expr", version = "7.0.0" } datafusion-jit = { path = "../jit", version = "7.0.0", optional = true } datafusion-physical-expr = { path = "../physical-expr", version = "7.0.0" } -datafusion-row = { path = "../row", version = "7.0.0", optional = true } +datafusion-row = { path = "../row", version = "7.0.0" } futures = "0.3" hashbrown = { version = "0.12", features = ["raw"] } lazy_static = { version = "^1.4.0" } @@ -134,8 +132,7 @@ name = "sql_planner" [[bench]] harness = false name = "jit" -required-features = ["row", "jit"] +required-features = ["jit"] [[test]] name = "row" -required-features = ["row"] diff --git a/datafusion/core/benches/aggregate_query_sql.rs b/datafusion/core/benches/aggregate_query_sql.rs index 807e64ff5e27..8570f81700c5 100644 --- a/datafusion/core/benches/aggregate_query_sql.rs +++ b/datafusion/core/benches/aggregate_query_sql.rs @@ -133,6 +133,16 @@ fn criterion_benchmark(c: &mut Criterion) { ) }) }); + + c.bench_function("aggregate_query_group_by_u64_multiple_keys", |b| { + b.iter(|| { + query( + ctx.clone(), + "SELECT u64_wide, utf8, MIN(f64), AVG(f64), COUNT(f64) \ + FROM t GROUP BY u64_wide, utf8", + ) + }) + }); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index c598d9a33cef..b553c0ed84b5 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -233,7 +233,6 @@ pub use datafusion_data_access; pub use datafusion_expr as logical_expr; pub use datafusion_physical_expr as physical_expr; -#[cfg(feature = "row")] pub use datafusion_row as row; #[cfg(feature = "jit")] diff --git a/datafusion/core/src/physical_plan/aggregates/hash.rs b/datafusion/core/src/physical_plan/aggregates/hash.rs index d5a3253678c6..45719260ccf5 100644 --- a/datafusion/core/src/physical_plan/aggregates/hash.rs +++ b/datafusion/core/src/physical_plan/aggregates/hash.rs @@ -28,7 +28,9 @@ use futures::{ }; use crate::error::Result; -use crate::physical_plan::aggregates::{AccumulatorItem, AggregateMode}; +use crate::physical_plan::aggregates::{ + evaluate, evaluate_many, AccumulatorItem, AggregateMode, +}; use crate::physical_plan::hash_utils::create_hashes; use crate::physical_plan::metrics::{BaselineMetrics, RecordOutput}; use crate::physical_plan::{aggregates, AggregateExpr, PhysicalExpr}; @@ -380,27 +382,6 @@ impl std::fmt::Debug for Accumulators { } } -/// Evaluates expressions against a record batch. -fn evaluate( - expr: &[Arc], - batch: &RecordBatch, -) -> Result> { - expr.iter() - .map(|expr| expr.evaluate(batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) - .collect::>>() -} - -/// Evaluates expressions against a record batch. -fn evaluate_many( - expr: &[Vec>], - batch: &RecordBatch, -) -> Result>> { - expr.iter() - .map(|expr| evaluate(expr, batch)) - .collect::>>() -} - /// Create a RecordBatch with all group keys and accumulator' states or values. fn create_batch_from_map( mode: &AggregateMode, diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/core/src/physical_plan/aggregates/mod.rs index 3682ec6eb8e9..abe20cdcbc94 100644 --- a/datafusion/core/src/physical_plan/aggregates/mod.rs +++ b/datafusion/core/src/physical_plan/aggregates/mod.rs @@ -29,6 +29,7 @@ use crate::physical_plan::{ }; use arrow::array::ArrayRef; use arrow::datatypes::{Field, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; use datafusion_common::Result; use datafusion_expr::Accumulator; use datafusion_physical_expr::expressions::Column; @@ -40,9 +41,13 @@ use std::sync::Arc; mod hash; mod no_grouping; +mod row_hash; +use crate::physical_plan::aggregates::row_hash::GroupedHashAggregateStreamV2; pub use datafusion_expr::AggregateFunction; +use datafusion_physical_expr::aggregate::row_accumulator::RowAccumulator; pub use datafusion_physical_expr::expressions::create_aggregate_expr; +use datafusion_row::{row_supported, RowType}; /// Hash aggregate modes #[derive(Debug, Copy, Clone, PartialEq, Eq)] @@ -142,6 +147,12 @@ impl AggregateExec { pub fn input_schema(&self) -> SchemaRef { self.input_schema.clone() } + + fn row_aggregate_supported(&self) -> bool { + let group_schema = group_schema(&self.schema, self.group_expr.len()); + row_supported(&group_schema, RowType::Compact) + && accumulator_v2_supported(&self.aggr_expr) + } } impl ExecutionPlan for AggregateExec { @@ -212,6 +223,15 @@ impl ExecutionPlan for AggregateExec { input, baseline_metrics, )?)) + } else if self.row_aggregate_supported() { + Ok(Box::pin(GroupedHashAggregateStreamV2::new( + self.mode, + self.schema.clone(), + group_expr, + self.aggr_expr.clone(), + input, + baseline_metrics, + )?)) } else { Ok(Box::pin(GroupedHashAggregateStream::new( self.mode, @@ -315,6 +335,11 @@ fn create_schema( Ok(Schema::new(fields)) } +fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef { + let group_fields = schema.fields()[0..group_count].to_vec(); + Arc::new(Schema::new(group_fields)) +} + /// returns physical expressions to evaluate against a batch /// The expressions are different depending on `mode`: /// * Partial: AggregateExpr::expressions @@ -362,6 +387,7 @@ fn merge_expressions( } pub(crate) type AccumulatorItem = Box; +pub(crate) type AccumulatorItemV2 = Box; fn create_accumulators( aggr_expr: &[Arc], @@ -372,6 +398,26 @@ fn create_accumulators( .collect::>>() } +fn accumulator_v2_supported(aggr_expr: &[Arc]) -> bool { + aggr_expr + .iter() + .all(|expr| expr.row_accumulator_supported()) +} + +fn create_accumulators_v2( + aggr_expr: &[Arc], +) -> datafusion_common::Result> { + let mut state_index = 0; + aggr_expr + .iter() + .map(|expr| { + let result = expr.create_row_accumulator(state_index); + state_index += expr.state_fields().unwrap().len(); + result + }) + .collect::>>() +} + /// returns a vector of ArrayRefs, where each entry corresponds to either the /// final value (mode = Final) or states (mode = Partial) fn finalize_aggregation( @@ -402,6 +448,27 @@ fn finalize_aggregation( } } +/// Evaluates expressions against a record batch. +fn evaluate( + expr: &[Arc], + batch: &RecordBatch, +) -> Result> { + expr.iter() + .map(|expr| expr.evaluate(batch)) + .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .collect::>>() +} + +/// Evaluates expressions against a record batch. +fn evaluate_many( + expr: &[Vec>], + batch: &RecordBatch, +) -> Result>> { + expr.iter() + .map(|expr| evaluate(expr, batch)) + .collect::>>() +} + #[cfg(test)] mod tests { use crate::execution::context::TaskContext; diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs b/datafusion/core/src/physical_plan/aggregates/row_hash.rs new file mode 100644 index 000000000000..e364048e75fd --- /dev/null +++ b/datafusion/core/src/physical_plan/aggregates/row_hash.rs @@ -0,0 +1,484 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Hash aggregation through row format + +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::vec; + +use ahash::RandomState; +use futures::{ + ready, + stream::{Stream, StreamExt}, +}; + +use crate::error::Result; +use crate::physical_plan::aggregates::{ + evaluate, evaluate_many, group_schema, AccumulatorItemV2, AggregateMode, +}; +use crate::physical_plan::hash_utils::create_row_hashes; +use crate::physical_plan::metrics::{BaselineMetrics, RecordOutput}; +use crate::physical_plan::{aggregates, AggregateExpr, PhysicalExpr}; +use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream}; + +use arrow::compute::cast; +use arrow::datatypes::Schema; +use arrow::{array::ArrayRef, compute}; +use arrow::{ + array::{Array, UInt32Builder}, + error::{ArrowError, Result as ArrowResult}, +}; +use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; +use datafusion_common::ScalarValue; +use datafusion_row::accessor::RowAccessor; +use datafusion_row::layout::RowLayout; +use datafusion_row::reader::{read_row, RowReader}; +use datafusion_row::writer::{write_row, RowWriter}; +use datafusion_row::{MutableRecordBatch, RowType}; +use hashbrown::raw::RawTable; + +/// Grouping aggregate with row-format aggregation states inside. +/// +/// For each aggregation entry, we use: +/// - [Compact] row represents grouping keys for fast hash computation and comparison directly on raw bytes. +/// - [WordAligned] row to store aggregation state, designed to be CPU-friendly when updates over every field are often. +/// +/// The architecture is the following: +/// +/// 1. For each input RecordBatch, update aggregation states corresponding to all appeared grouping keys. +/// 2. At the end of the aggregation (e.g. end of batches in a partition), the accumulator converts its state to a RecordBatch of a single row +/// 3. The RecordBatches of all accumulators are merged (`concatenate` in `rust/arrow`) together to a single RecordBatch. +/// 4. The state's RecordBatch is `merge`d to a new state +/// 5. The state is mapped to the final value +/// +/// [Compact]: datafusion_row::layout::RowType::Compact +/// [WordAligned]: datafusion_row::layout::RowType::WordAligned +pub(crate) struct GroupedHashAggregateStreamV2 { + schema: SchemaRef, + input: SendableRecordBatchStream, + mode: AggregateMode, + aggr_state: AggregationState, + aggregate_expressions: Vec>>, + + group_expr: Vec>, + accumulators: Vec, + + group_schema: SchemaRef, + aggr_schema: SchemaRef, + aggr_layout: Arc, + + baseline_metrics: BaselineMetrics, + random_state: RandomState, + finished: bool, +} + +fn aggr_state_schema(aggr_expr: &[Arc]) -> Result { + let fields = aggr_expr + .iter() + .flat_map(|expr| expr.state_fields().unwrap().into_iter()) + .collect::>(); + Ok(Arc::new(Schema::new(fields))) +} + +impl GroupedHashAggregateStreamV2 { + /// Create a new GroupedRowHashAggregateStream + pub fn new( + mode: AggregateMode, + schema: SchemaRef, + group_expr: Vec>, + aggr_expr: Vec>, + input: SendableRecordBatchStream, + baseline_metrics: BaselineMetrics, + ) -> Result { + let timer = baseline_metrics.elapsed_compute().timer(); + + // The expressions to evaluate the batch, one vec of expressions per aggregation. + // Assume create_schema() always put group columns in front of aggr columns, we set + // col_idx_base to group expression count. + let aggregate_expressions = + aggregates::aggregate_expressions(&aggr_expr, &mode, group_expr.len())?; + + let accumulators = aggregates::create_accumulators_v2(&aggr_expr)?; + + let group_schema = group_schema(&schema, group_expr.len()); + let aggr_schema = aggr_state_schema(&aggr_expr)?; + + let aggr_layout = Arc::new(RowLayout::new(&aggr_schema, RowType::WordAligned)); + timer.done(); + + Ok(Self { + schema, + mode, + input, + group_expr, + accumulators, + group_schema, + aggr_schema, + aggr_layout, + baseline_metrics, + aggregate_expressions, + aggr_state: Default::default(), + random_state: Default::default(), + finished: false, + }) + } +} + +impl Stream for GroupedHashAggregateStreamV2 { + type Item = ArrowResult; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let this = &mut *self; + if this.finished { + return Poll::Ready(None); + } + + let elapsed_compute = this.baseline_metrics.elapsed_compute(); + + loop { + let result = match ready!(this.input.poll_next_unpin(cx)) { + Some(Ok(batch)) => { + let timer = elapsed_compute.timer(); + let result = group_aggregate_batch( + &this.mode, + &this.random_state, + &this.group_expr, + &mut this.accumulators, + &this.group_schema, + this.aggr_layout.clone(), + batch, + &mut this.aggr_state, + &this.aggregate_expressions, + ); + + timer.done(); + + match result { + Ok(_) => continue, + Err(e) => Err(ArrowError::ExternalError(Box::new(e))), + } + } + Some(Err(e)) => Err(e), + None => { + this.finished = true; + let timer = this.baseline_metrics.elapsed_compute().timer(); + let result = create_batch_from_map( + &this.mode, + &this.group_schema, + &this.aggr_schema, + &mut this.aggr_state, + &mut this.accumulators, + &this.schema, + ) + .record_output(&this.baseline_metrics); + + timer.done(); + result + } + }; + + this.finished = true; + return Poll::Ready(Some(result)); + } + } +} + +impl RecordBatchStream for GroupedHashAggregateStreamV2 { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +/// TODO: Make this a member function of [`GroupedHashAggregateStreamV2`] +#[allow(clippy::too_many_arguments)] +fn group_aggregate_batch( + mode: &AggregateMode, + random_state: &RandomState, + group_expr: &[Arc], + accumulators: &mut [AccumulatorItemV2], + group_schema: &Schema, + state_layout: Arc, + batch: RecordBatch, + aggr_state: &mut AggregationState, + aggregate_expressions: &[Vec>], +) -> Result<()> { + // evaluate the grouping expressions + let group_values = evaluate(group_expr, &batch)?; + let group_rows: Vec> = create_group_rows(group_values, group_schema); + + // evaluate the aggregation expressions. + // We could evaluate them after the `take`, but since we need to evaluate all + // of them anyways, it is more performant to do it while they are together. + let aggr_input_values = evaluate_many(aggregate_expressions, &batch)?; + + // 1.1 construct the key from the group values + // 1.2 construct the mapping key if it does not exist + // 1.3 add the row' index to `indices` + + // track which entries in `aggr_state` have rows in this batch to aggregate + let mut groups_with_rows = vec![]; + + // 1.1 Calculate the group keys for the group values + let mut batch_hashes = vec![0; batch.num_rows()]; + create_row_hashes(&group_rows, random_state, &mut batch_hashes)?; + + for (row, hash) in batch_hashes.into_iter().enumerate() { + let AggregationState { map, group_states } = aggr_state; + + let entry = map.get_mut(hash, |(_hash, group_idx)| { + // verify that a group that we are inserting with hash is + // actually the same key value as the group in + // existing_idx (aka group_values @ row) + let group_state = &group_states[*group_idx]; + group_rows[row] == group_state.group_by_values + }); + + match entry { + // Existing entry for this group value + Some((_hash, group_idx)) => { + let group_state = &mut group_states[*group_idx]; + // 1.3 + if group_state.indices.is_empty() { + groups_with_rows.push(*group_idx); + }; + group_state.indices.push(row as u32); // remember this row + } + // 1.2 Need to create new entry + None => { + // Add new entry to group_states and save newly created index + let group_state = RowGroupState { + group_by_values: group_rows[row].clone(), + aggregation_buffer: vec![0; state_layout.fixed_part_width()], + indices: vec![row as u32], // 1.3 + }; + let group_idx = group_states.len(); + group_states.push(group_state); + groups_with_rows.push(group_idx); + + // for hasher function, use precomputed hash value + map.insert(hash, (hash, group_idx), |(hash, _group_idx)| *hash); + } + }; + } + + // Collect all indices + offsets based on keys in this vec + let mut batch_indices: UInt32Builder = UInt32Builder::new(0); + let mut offsets = vec![0]; + let mut offset_so_far = 0; + for group_idx in groups_with_rows.iter() { + let indices = &aggr_state.group_states[*group_idx].indices; + batch_indices.append_slice(indices)?; + offset_so_far += indices.len(); + offsets.push(offset_so_far); + } + let batch_indices = batch_indices.finish(); + + // `Take` all values based on indices into Arrays + let values: Vec>> = aggr_input_values + .iter() + .map(|array| { + array + .iter() + .map(|array| { + compute::take( + array.as_ref(), + &batch_indices, + None, // None: no index check + ) + .unwrap() + }) + .collect() + // 2.3 + }) + .collect(); + + // 2.1 for each key in this batch + // 2.2 for each aggregation + // 2.3 `slice` from each of its arrays the keys' values + // 2.4 update / merge the accumulator with the values + // 2.5 clear indices + groups_with_rows + .iter() + .zip(offsets.windows(2)) + .try_for_each(|(group_idx, offsets)| { + let group_state = &mut aggr_state.group_states[*group_idx]; + // 2.2 + accumulators + .iter_mut() + .zip(values.iter()) + .map(|(accumulator, aggr_array)| { + ( + accumulator, + aggr_array + .iter() + .map(|array| { + // 2.3 + array.slice(offsets[0], offsets[1] - offsets[0]) + }) + .collect::>(), + ) + }) + .try_for_each(|(accumulator, values)| { + let mut state_accessor = + RowAccessor::new_from_layout(state_layout.clone()); + state_accessor + .point_to(0, group_state.aggregation_buffer.as_mut_slice()); + match mode { + AggregateMode::Partial => { + accumulator.update_batch(&values, &mut state_accessor) + } + AggregateMode::FinalPartitioned | AggregateMode::Final => { + // note: the aggregation here is over states, not values, thus the merge + accumulator.merge_batch(&values, &mut state_accessor) + } + } + }) + // 2.5 + .and({ + group_state.indices.clear(); + Ok(()) + }) + })?; + + Ok(()) +} + +/// The state that is built for each output group. +#[derive(Debug)] +struct RowGroupState { + /// The actual group by values, stored sequentially + group_by_values: Vec, + + // Accumulator state, stored sequentially + aggregation_buffer: Vec, + + /// scratch space used to collect indices for input rows in a + /// bach that have values to aggregate. Reset on each batch + indices: Vec, +} + +/// The state of all the groups +#[derive(Default)] +struct AggregationState { + /// Logically maps group values to an index in `group_states` + /// + /// Uses the raw API of hashbrown to avoid actually storing the + /// keys in the table + /// + /// keys: u64 hashes of the GroupValue + /// values: (hash, index into `group_states`) + map: RawTable<(u64, usize)>, + + /// State for each group + group_states: Vec, +} + +impl std::fmt::Debug for AggregationState { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + // hashes are not store inline, so could only get values + let map_string = "RawTable"; + f.debug_struct("AggregationState") + .field("map", &map_string) + .field("group_states", &self.group_states) + .finish() + } +} + +/// Create grouping rows +fn create_group_rows(arrays: Vec, schema: &Schema) -> Vec> { + let mut writer = RowWriter::new(schema, RowType::Compact); + let mut results = vec![]; + for cur_row in 0..arrays[0].len() { + write_row(&mut writer, cur_row, schema, &arrays); + results.push(writer.get_row().to_vec()); + writer.reset() + } + results +} + +/// Create a RecordBatch with all group keys and accumulator' states or values. +fn create_batch_from_map( + mode: &AggregateMode, + group_schema: &Schema, + aggr_schema: &Schema, + aggr_state: &mut AggregationState, + accumulators: &mut [AccumulatorItemV2], + output_schema: &Schema, +) -> ArrowResult { + if aggr_state.group_states.is_empty() { + return Ok(RecordBatch::new_empty(Arc::new(output_schema.to_owned()))); + } + + let mut state_accessor = RowAccessor::new(aggr_schema, RowType::WordAligned); + + let (group_buffers, mut state_buffers): (Vec<_>, Vec<_>) = aggr_state + .group_states + .iter() + .map(|gs| (gs.group_by_values.clone(), gs.aggregation_buffer.clone())) + .unzip(); + + let mut columns: Vec = + read_as_batch(&group_buffers, group_schema, RowType::Compact); + + match mode { + AggregateMode::Partial => columns.extend(read_as_batch( + &state_buffers, + aggr_schema, + RowType::WordAligned, + )), + AggregateMode::Final | AggregateMode::FinalPartitioned => { + let mut results: Vec> = vec![vec![]; accumulators.len()]; + for buffer in state_buffers.iter_mut() { + state_accessor.point_to(0, buffer); + for (i, acc) in accumulators.iter().enumerate() { + results[i].push(acc.evaluate(&state_accessor).unwrap()); + } + } + for scalars in results { + columns.push(ScalarValue::iter_to_array(scalars)?); + } + } + } + + // cast output if needed (e.g. for types like Dictionary where + // the intermediate GroupByScalar type was not the same as the + // output + let columns = columns + .iter() + .zip(output_schema.fields().iter()) + .map(|(col, desired_field)| cast(col, desired_field.data_type())) + .collect::>>()?; + + RecordBatch::try_new(Arc::new(output_schema.to_owned()), columns) +} + +fn read_as_batch(rows: &[Vec], schema: &Schema, row_type: RowType) -> Vec { + let row_num = rows.len(); + let mut output = MutableRecordBatch::new(row_num, Arc::new(schema.clone())); + let mut row = RowReader::new(schema, row_type); + + for data in rows { + row.point_to(0, data); + read_row(&row, &mut output, schema); + } + + output.output_as_columns() +} diff --git a/datafusion/core/src/physical_plan/hash_utils.rs b/datafusion/core/src/physical_plan/hash_utils.rs index 2ca1fa3df9d1..92598afbd9ce 100644 --- a/datafusion/core/src/physical_plan/hash_utils.rs +++ b/datafusion/core/src/physical_plan/hash_utils.rs @@ -278,7 +278,39 @@ pub fn create_hashes<'a>( for hash in hashes_buffer.iter_mut() { *hash = 0 } - return Ok(hashes_buffer); + Ok(hashes_buffer) +} + +/// Test version of `create_row_hashes` that produces the same value for +/// all hashes (to test collisions) +/// +/// See comments on `hashes_buffer` for more details +#[cfg(feature = "force_hash_collisions")] +pub fn create_row_hashes<'a>( + _rows: &[Vec], + _random_state: &RandomState, + hashes_buffer: &'a mut Vec, +) -> Result<&'a mut Vec> { + for hash in hashes_buffer.iter_mut() { + *hash = 0 + } + Ok(hashes_buffer) +} + +/// Creates hash values for every row, based on their raw bytes. +#[cfg(not(feature = "force_hash_collisions"))] +pub fn create_row_hashes<'a>( + rows: &[Vec], + random_state: &RandomState, + hashes_buffer: &'a mut Vec, +) -> Result<&'a mut Vec> { + for hash in hashes_buffer.iter_mut() { + *hash = 0 + } + for (i, hash) in hashes_buffer.iter_mut().enumerate() { + *hash = >::get_hash(&rows[i], random_state); + } + Ok(hashes_buffer) } /// Creates hash values for every row, based on the values in the diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index acead8766a4e..42999a743bbd 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -652,6 +652,28 @@ async fn csv_query_array_agg_one() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_array_agg_with_overflow() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = + "select c2, sum(c3) sum_c3, avg(c3) avg_c3, max(c3) max_c3, min(c3) min_c3, count(c3) count_c3 from aggregate_test_100 group by c2 order by c2"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----+--------+---------------------+--------+--------+----------+", + "| c2 | sum_c3 | avg_c3 | max_c3 | min_c3 | count_c3 |", + "+----+--------+---------------------+--------+--------+----------+", + "| 1 | 367 | 16.681818181818183 | 125 | -99 | 22 |", + "| 2 | 184 | 8.363636363636363 | 122 | -117 | 22 |", + "| 3 | 395 | 20.789473684210527 | 123 | -101 | 19 |", + "| 4 | 29 | 1.2608695652173914 | 123 | -117 | 23 |", + "| 5 | -194 | -13.857142857142858 | 118 | -101 | 14 |", + "+----+--------+---------------------+--------+--------+----------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + #[tokio::test] async fn csv_query_array_agg_distinct() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/core/tests/sql/functions.rs b/datafusion/core/tests/sql/functions.rs index 396bd11940c1..1710e9af49f3 100644 --- a/datafusion/core/tests/sql/functions.rs +++ b/datafusion/core/tests/sql/functions.rs @@ -17,7 +17,6 @@ use super::*; -/// sqrt(f32) is slightly different than sqrt(CAST(f32 AS double))) #[tokio::test] async fn sqrt_f32_vs_f64() -> Result<()> { let ctx = create_ctx()?; @@ -25,7 +24,8 @@ async fn sqrt_f32_vs_f64() -> Result<()> { // sqrt(f32)'s plan passes let sql = "SELECT avg(sqrt(c11)) FROM aggregate_test_100"; let actual = execute(&ctx, sql).await; - let expected = vec![vec!["0.6584407806396484"]]; + let sql = "SELECT avg(CAST(sqrt(c11) AS double)) FROM aggregate_test_100"; + let expected = execute(&ctx, sql).await; assert_eq!(actual, expected); let sql = "SELECT avg(sqrt(CAST(c11 AS double))) FROM aggregate_test_100"; diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index ef4ee8cbf492..d64ecb07b714 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -46,6 +46,7 @@ blake3 = { version = "1.0", optional = true } chrono = { version = "0.4", default-features = false } datafusion-common = { path = "../common", version = "7.0.0" } datafusion-expr = { path = "../expr", version = "7.0.0" } +datafusion-row = { path = "../row", version = "7.0.0" } hashbrown = { version = "0.12", features = ["raw"] } lazy_static = { version = "^1.4.0" } md-5 = { version = "^0.10.0", optional = true } diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs index ccb73eed5369..3eee84bb5f50 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/physical-expr/src/aggregate/average.rs @@ -21,6 +21,7 @@ use std::any::Any; use std::convert::TryFrom; use std::sync::Arc; +use crate::aggregate::row_accumulator::RowAccumulator; use crate::aggregate::sum; use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; @@ -33,6 +34,7 @@ use arrow::{ use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Accumulator; +use datafusion_row::accessor::RowAccessor; /// AVG aggregate expression #[derive(Debug)] @@ -101,6 +103,32 @@ impl AggregateExpr for Avg { fn name(&self) -> &str { &self.name } + + fn row_accumulator_supported(&self) -> bool { + matches!( + self.data_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ) + } + + fn create_row_accumulator( + &self, + start_index: usize, + ) -> Result> { + Ok(Box::new(AvgRowAccumulator::new( + start_index, + self.data_type.clone(), + ))) + } } /// An accumulator to compute the average @@ -130,7 +158,10 @@ impl Accumulator for AvgAccumulator { let values = &values[0]; self.count += (values.len() - values.data().null_count()) as u64; - self.sum = sum::sum(&self.sum, &sum::sum_batch(values)?)?; + self.sum = sum::sum( + &self.sum, + &sum::sum_batch(values, &self.sum.get_datatype())?, + )?; Ok(()) } @@ -140,7 +171,10 @@ impl Accumulator for AvgAccumulator { self.count += compute::sum(counts).unwrap_or(0); // sums are summed - self.sum = sum::sum(&self.sum, &sum::sum_batch(&states[1])?)?; + self.sum = sum::sum( + &self.sum, + &sum::sum_batch(&states[1], &self.sum.get_datatype())?, + )?; Ok(()) } @@ -167,6 +201,81 @@ impl Accumulator for AvgAccumulator { } } +#[derive(Debug)] +struct AvgRowAccumulator { + state_index: usize, + sum_datatype: DataType, +} + +impl AvgRowAccumulator { + pub fn new(start_index: usize, sum_datatype: DataType) -> Self { + Self { + state_index: start_index, + sum_datatype, + } + } +} + +impl RowAccumulator for AvgRowAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + accessor: &mut RowAccessor, + ) -> Result<()> { + let values = &values[0]; + // count + let delta = (values.len() - values.data().null_count()) as u64; + accessor.add_u64(self.state_index(), delta); + + // sum + sum::add_to_row( + &self.sum_datatype, + self.state_index() + 1, + accessor, + &sum::sum_batch(values, &self.sum_datatype)?, + )?; + Ok(()) + } + + fn merge_batch( + &mut self, + states: &[ArrayRef], + accessor: &mut RowAccessor, + ) -> Result<()> { + let counts = states[0].as_any().downcast_ref::().unwrap(); + // count + let delta = compute::sum(counts).unwrap_or(0); + accessor.add_u64(self.state_index(), delta); + + // sum + sum::add_to_row( + &self.sum_datatype, + self.state_index() + 1, + accessor, + &sum::sum_batch(&states[1], &self.sum_datatype)?, + )?; + Ok(()) + } + + fn evaluate(&self, accessor: &RowAccessor) -> Result { + assert_eq!(self.sum_datatype, DataType::Float64); + Ok(match accessor.get_u64_opt(self.state_index()) { + None => ScalarValue::Float64(None), + Some(0) => ScalarValue::Float64(Some(0.0)), + Some(n) => ScalarValue::Float64( + accessor + .get_f64_opt(self.state_index() + 1) + .map(|f| f / n as f64), + ), + }) + } + + #[inline(always)] + fn state_index(&self) -> usize { + self.state_index + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs index 9e8485e928c2..54bec05d72f0 100644 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ b/datafusion/physical-expr/src/aggregate/count.rs @@ -18,8 +18,10 @@ //! Defines physical expressions that can evaluated at runtime during query execution use std::any::Any; +use std::fmt::Debug; use std::sync::Arc; +use crate::aggregate::row_accumulator::RowAccumulator; use crate::{AggregateExpr, PhysicalExpr}; use arrow::compute; use arrow::datatypes::DataType; @@ -30,6 +32,7 @@ use arrow::{ use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::Accumulator; +use datafusion_row::accessor::RowAccessor; use crate::expressions::format_state_name; @@ -92,6 +95,17 @@ impl AggregateExpr for Count { fn name(&self) -> &str { &self.name } + + fn row_accumulator_supported(&self) -> bool { + true + } + + fn create_row_accumulator( + &self, + start_index: usize, + ) -> Result> { + Ok(Box::new(CountRowAccumulator::new(start_index))) + } } #[derive(Debug)] @@ -131,6 +145,52 @@ impl Accumulator for CountAccumulator { } } +#[derive(Debug)] +struct CountRowAccumulator { + state_index: usize, +} + +impl CountRowAccumulator { + pub fn new(index: usize) -> Self { + Self { state_index: index } + } +} + +impl RowAccumulator for CountRowAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + accessor: &mut RowAccessor, + ) -> Result<()> { + let array = &values[0]; + let delta = (array.len() - array.data().null_count()) as u64; + accessor.add_u64(self.state_index, delta); + Ok(()) + } + + fn merge_batch( + &mut self, + states: &[ArrayRef], + accessor: &mut RowAccessor, + ) -> Result<()> { + let counts = states[0].as_any().downcast_ref::().unwrap(); + let delta = &compute::sum(counts); + if let Some(d) = delta { + accessor.add_u64(self.state_index, *d); + } + Ok(()) + } + + fn evaluate(&self, accessor: &RowAccessor) -> Result { + Ok(accessor.get_as_scalar(&DataType::UInt64, self.state_index)) + } + + #[inline(always)] + fn state_index(&self) -> usize { + self.state_index + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs index 7de10e4b8a7e..dd2f44b22c07 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/physical-expr/src/aggregate/min_max.rs @@ -37,9 +37,11 @@ use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Accumulator; +use crate::aggregate::row_accumulator::RowAccumulator; use crate::expressions::format_state_name; use arrow::array::Array; use arrow::array::DecimalArray; +use datafusion_row::accessor::RowAccessor; // Min/max aggregation can take Dictionary encode input but always produces unpacked // (aka non Dictionary) output. We need to adjust the output data type to reflect this. @@ -111,6 +113,32 @@ impl AggregateExpr for Max { fn name(&self) -> &str { &self.name } + + fn row_accumulator_supported(&self) -> bool { + matches!( + self.data_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ) + } + + fn create_row_accumulator( + &self, + start_index: usize, + ) -> Result> { + Ok(Box::new(MaxRowAccumulator::new( + start_index, + self.data_type.clone(), + ))) + } } // Statically-typed version of min/max(array) -> ScalarValue for string types. @@ -303,6 +331,18 @@ macro_rules! typed_min_max { }}; } +// min/max of two non-string scalar values. +macro_rules! typed_min_max_v2 { + ($INDEX:ident, $ACC:ident, $SCALAR:expr, $TYPE:ident, $OP:ident) => {{ + paste::item! { + match $SCALAR { + None => {} + Some(v) => $ACC.[<$OP _ $TYPE>]($INDEX, *v as $TYPE) + } + } + }}; +} + // min/max of two scalar string values. macro_rules! typed_min_max_string { ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ @@ -408,16 +448,68 @@ macro_rules! min_max { }}; } +// min/max of two scalar values of the same type +macro_rules! min_max_v2 { + ($INDEX:ident, $ACC:ident, $SCALAR:expr, $OP:ident) => {{ + Ok(match $SCALAR { + ScalarValue::Float64(rhs) => { + typed_min_max_v2!($INDEX, $ACC, rhs, f64, $OP) + } + ScalarValue::Float32(rhs) => { + typed_min_max_v2!($INDEX, $ACC, rhs, f32, $OP) + } + ScalarValue::UInt64(rhs) => { + typed_min_max_v2!($INDEX, $ACC, rhs, u64, $OP) + } + ScalarValue::UInt32(rhs) => { + typed_min_max_v2!($INDEX, $ACC, rhs, u32, $OP) + } + ScalarValue::UInt16(rhs) => { + typed_min_max_v2!($INDEX, $ACC, rhs, u16, $OP) + } + ScalarValue::UInt8(rhs) => { + typed_min_max_v2!($INDEX, $ACC, rhs, u8, $OP) + } + ScalarValue::Int64(rhs) => { + typed_min_max_v2!($INDEX, $ACC, rhs, i64, $OP) + } + ScalarValue::Int32(rhs) => { + typed_min_max_v2!($INDEX, $ACC, rhs, i32, $OP) + } + ScalarValue::Int16(rhs) => { + typed_min_max_v2!($INDEX, $ACC, rhs, i16, $OP) + } + ScalarValue::Int8(rhs) => { + typed_min_max_v2!($INDEX, $ACC, rhs, i8, $OP) + } + e => { + return Err(DataFusionError::Internal(format!( + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", + e + ))) + } + }) + }}; +} + /// the minimum of two scalar values pub fn min(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { min_max!(lhs, rhs, min) } +pub fn min_row(index: usize, accessor: &mut RowAccessor, s: &ScalarValue) -> Result<()> { + min_max_v2!(index, accessor, s, min) +} + /// the maximum of two scalar values pub fn max(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { min_max!(lhs, rhs, max) } +pub fn max_row(index: usize, accessor: &mut RowAccessor, s: &ScalarValue) -> Result<()> { + min_max_v2!(index, accessor, s, max) +} + /// An accumulator to compute the maximum value #[derive(Debug)] pub struct MaxAccumulator { @@ -454,6 +546,48 @@ impl Accumulator for MaxAccumulator { } } +#[derive(Debug)] +struct MaxRowAccumulator { + index: usize, + data_type: DataType, +} + +impl MaxRowAccumulator { + pub fn new(index: usize, data_type: DataType) -> Self { + Self { index, data_type } + } +} + +impl RowAccumulator for MaxRowAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + accessor: &mut RowAccessor, + ) -> Result<()> { + let values = &values[0]; + let delta = &max_batch(values)?; + max_row(self.index, accessor, delta)?; + Ok(()) + } + + fn merge_batch( + &mut self, + states: &[ArrayRef], + accessor: &mut RowAccessor, + ) -> Result<()> { + self.update_batch(states, accessor) + } + + fn evaluate(&self, accessor: &RowAccessor) -> Result { + Ok(accessor.get_as_scalar(&self.data_type, self.index)) + } + + #[inline(always)] + fn state_index(&self) -> usize { + self.index + } +} + /// MIN aggregate expression #[derive(Debug)] pub struct Min { @@ -512,6 +646,32 @@ impl AggregateExpr for Min { fn name(&self) -> &str { &self.name } + + fn row_accumulator_supported(&self) -> bool { + matches!( + self.data_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ) + } + + fn create_row_accumulator( + &self, + start_index: usize, + ) -> Result> { + Ok(Box::new(MinRowAccumulator::new( + start_index, + self.data_type.clone(), + ))) + } } /// An accumulator to compute the minimum value @@ -550,6 +710,48 @@ impl Accumulator for MinAccumulator { } } +#[derive(Debug)] +struct MinRowAccumulator { + index: usize, + data_type: DataType, +} + +impl MinRowAccumulator { + pub fn new(index: usize, data_type: DataType) -> Self { + Self { index, data_type } + } +} + +impl RowAccumulator for MinRowAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + accessor: &mut RowAccessor, + ) -> Result<()> { + let values = &values[0]; + let delta = &min_batch(values)?; + min_row(self.index, accessor, delta)?; + Ok(()) + } + + fn merge_batch( + &mut self, + states: &[ArrayRef], + accessor: &mut RowAccessor, + ) -> Result<()> { + self.update_batch(states, accessor) + } + + fn evaluate(&self, accessor: &RowAccessor) -> Result { + Ok(accessor.get_as_scalar(&self.data_type, self.index)) + } + + #[inline(always)] + fn state_index(&self) -> usize { + self.index + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 37719369d3e2..0db35d109c2d 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::aggregate::row_accumulator::RowAccumulator; use crate::PhysicalExpr; use arrow::datatypes::Field; use datafusion_common::Result; @@ -39,6 +40,7 @@ pub(crate) mod covariance; pub(crate) mod min_max; pub mod build_in; mod hyperloglog; +pub mod row_accumulator; pub(crate) mod stats; pub(crate) mod stddev; pub(crate) mod sum; @@ -77,4 +79,21 @@ pub trait AggregateExpr: Send + Sync + Debug { fn name(&self) -> &str { "AggregateExpr: default name" } + + /// If the aggregate expression is supported by row format + fn row_accumulator_supported(&self) -> bool { + false + } + + /// RowAccumulator to access/update row-based aggregation state in-place. + /// Currently, row accumulator only supports states of fixed-sized type. + /// + /// We recommend implementing `RowAccumulator` along with the standard `Accumulator`, + /// when its state is of fixed size, as RowAccumulator is more memory efficient and CPU-friendly. + fn create_row_accumulator( + &self, + _start_index: usize, + ) -> Result> { + unreachable!() + } } diff --git a/datafusion/physical-expr/src/aggregate/row_accumulator.rs b/datafusion/physical-expr/src/aggregate/row_accumulator.rs new file mode 100644 index 000000000000..386787454f85 --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/row_accumulator.rs @@ -0,0 +1,65 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Accumulator over row format + +use arrow::array::ArrayRef; +use datafusion_common::{Result, ScalarValue}; +use datafusion_row::accessor::RowAccessor; +use std::fmt::Debug; + +/// Row-based accumulator where the internal aggregate state(s) are stored using row format. +/// +/// Unlike the [`datafusion_expr::Accumulator`], the [`RowAccumulator`] does not store the state internally. +/// Instead, it knows how to access/update the state stored in a row via the the provided accessor and +/// its state's starting field index in the row. +/// +/// For example, we are evaluating `SELECT a, sum(b), avg(c), count(d) from GROUP BY a;`, we would have one row used as +/// aggregation state for each distinct `a` value, the index of the first and the only state of `sum(b)` would be 0, +/// the index of the first state of `avg(c)` would be 1, and the index of the first and only state of `cound(d)` would be 3: +/// +/// sum(b) state_index = 0 count(d) state_index = 3 +/// | | +/// v v +/// +--------+----------+--------+----------+ +/// | sum(b) | count(c) | sum(c) | count(d) | +/// +--------+----------+--------+----------+ +/// ^ +/// | +/// avg(c) state_index = 1 +/// +pub trait RowAccumulator: Send + Sync + Debug { + /// updates the accumulator's state from a vector of arrays. + fn update_batch( + &mut self, + values: &[ArrayRef], + accessor: &mut RowAccessor, + ) -> Result<()>; + + /// updates the accumulator's state from a vector of states. + fn merge_batch( + &mut self, + states: &[ArrayRef], + accessor: &mut RowAccessor, + ) -> Result<()>; + + /// returns its value based on its current state. + fn evaluate(&self, accessor: &RowAccessor) -> Result; + + /// State's starting field index in the row. + fn state_index(&self) -> usize; +} diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index cca54733d246..c369e7af0081 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -34,9 +34,12 @@ use arrow::{ use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; +use crate::aggregate::row_accumulator::RowAccumulator; use crate::expressions::format_state_name; use arrow::array::Array; use arrow::array::DecimalArray; +use arrow::compute::cast; +use datafusion_row::accessor::RowAccessor; /// SUM aggregate expression #[derive(Debug)] @@ -96,6 +99,32 @@ impl AggregateExpr for Sum { fn name(&self) -> &str { &self.name } + + fn row_accumulator_supported(&self) -> bool { + matches!( + self.data_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ) + } + + fn create_row_accumulator( + &self, + start_index: usize, + ) -> Result> { + Ok(Box::new(SumRowAccumulator::new( + start_index, + self.data_type.clone(), + ))) + } } #[derive(Debug)] @@ -144,7 +173,8 @@ fn sum_decimal_batch( } // sums the array and returns a ScalarValue of its corresponding type. -pub(crate) fn sum_batch(values: &ArrayRef) -> Result { +pub(crate) fn sum_batch(values: &ArrayRef, sum_type: &DataType) -> Result { + let values = &cast(values, sum_type)?; Ok(match values.data_type() { DataType::Decimal(precision, scale) => { sum_decimal_batch(values, precision, scale)? @@ -180,6 +210,17 @@ macro_rules! typed_sum { }}; } +macro_rules! sum_row { + ($INDEX:ident, $ACC:ident, $DELTA:expr, $TYPE:ident) => {{ + paste::item! { + match $DELTA { + None => {} + Some(v) => $ACC.[]($INDEX, *v as $TYPE) + } + } + }}; +} + // TODO implement this in arrow-rs with simd // https://github.com/apache/arrow-rs/issues/1010 fn sum_decimal( @@ -284,7 +325,7 @@ pub(crate) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { (ScalarValue::UInt64(lhs), ScalarValue::UInt8(rhs)) => { typed_sum!(lhs, rhs, UInt64, u64) } - // i64 coerces i* to u64 + // i64 coerces i* to i64 (ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => { typed_sum!(lhs, rhs, Int64, i64) } @@ -315,6 +356,84 @@ pub(crate) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { }) } +pub(crate) fn add_to_row( + dt: &DataType, + index: usize, + accessor: &mut RowAccessor, + s: &ScalarValue, +) -> Result<()> { + match (dt, s) { + // float64 coerces everything to f64 + (DataType::Float64, ScalarValue::Float64(rhs)) => { + sum_row!(index, accessor, rhs, f64) + } + (DataType::Float64, ScalarValue::Float32(rhs)) => { + sum_row!(index, accessor, rhs, f64) + } + (DataType::Float64, ScalarValue::Int64(rhs)) => { + sum_row!(index, accessor, rhs, f64) + } + (DataType::Float64, ScalarValue::Int32(rhs)) => { + sum_row!(index, accessor, rhs, f64) + } + (DataType::Float64, ScalarValue::Int16(rhs)) => { + sum_row!(index, accessor, rhs, f64) + } + (DataType::Float64, ScalarValue::Int8(rhs)) => { + sum_row!(index, accessor, rhs, f64) + } + (DataType::Float64, ScalarValue::UInt64(rhs)) => { + sum_row!(index, accessor, rhs, f64) + } + (DataType::Float64, ScalarValue::UInt32(rhs)) => { + sum_row!(index, accessor, rhs, f64) + } + (DataType::Float64, ScalarValue::UInt16(rhs)) => { + sum_row!(index, accessor, rhs, f64) + } + (DataType::Float64, ScalarValue::UInt8(rhs)) => { + sum_row!(index, accessor, rhs, f64) + } + // float32 has no cast + (DataType::Float32, ScalarValue::Float32(rhs)) => { + sum_row!(index, accessor, rhs, f32) + } + // u64 coerces u* to u64 + (DataType::UInt64, ScalarValue::UInt64(rhs)) => { + sum_row!(index, accessor, rhs, u64) + } + (DataType::UInt64, ScalarValue::UInt32(rhs)) => { + sum_row!(index, accessor, rhs, u64) + } + (DataType::UInt64, ScalarValue::UInt16(rhs)) => { + sum_row!(index, accessor, rhs, u64) + } + (DataType::UInt64, ScalarValue::UInt8(rhs)) => { + sum_row!(index, accessor, rhs, u64) + } + // i64 coerces i* to i64 + (DataType::Int64, ScalarValue::Int64(rhs)) => { + sum_row!(index, accessor, rhs, i64) + } + (DataType::Int64, ScalarValue::Int32(rhs)) => { + sum_row!(index, accessor, rhs, i64) + } + (DataType::Int64, ScalarValue::Int16(rhs)) => { + sum_row!(index, accessor, rhs, i64) + } + (DataType::Int64, ScalarValue::Int8(rhs)) => { + sum_row!(index, accessor, rhs, i64) + } + e => { + return Err(DataFusionError::Internal(format!( + "Row sum updater is not expected to receive a scalar {:?}", + e + ))); + } + } + Ok(()) +} + impl Accumulator for SumAccumulator { fn state(&self) -> Result> { Ok(vec![self.sum.clone()]) @@ -322,7 +441,7 @@ impl Accumulator for SumAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = &values[0]; - self.sum = sum(&self.sum, &sum_batch(values)?)?; + self.sum = sum(&self.sum, &sum_batch(values, &self.sum.get_datatype())?)?; Ok(()) } @@ -338,6 +457,52 @@ impl Accumulator for SumAccumulator { } } +#[derive(Debug)] +struct SumRowAccumulator { + index: usize, + datatype: DataType, +} + +impl SumRowAccumulator { + pub fn new(index: usize, datatype: DataType) -> Self { + Self { index, datatype } + } +} + +impl RowAccumulator for SumRowAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + accessor: &mut RowAccessor, + ) -> Result<()> { + let values = &values[0]; + add_to_row( + &self.datatype, + self.index, + accessor, + &sum_batch(values, &self.datatype)?, + )?; + Ok(()) + } + + fn merge_batch( + &mut self, + states: &[ArrayRef], + accessor: &mut RowAccessor, + ) -> Result<()> { + self.update_batch(states, accessor) + } + + fn evaluate(&self, accessor: &RowAccessor) -> Result { + Ok(accessor.get_as_scalar(&self.datatype, self.index)) + } + + #[inline(always)] + fn state_index(&self) -> usize { + self.index + } +} + #[cfg(test)] mod tests { use super::*; @@ -379,7 +544,7 @@ mod tests { .collect::() .with_precision_and_scale(10, 0)?, ); - let result = sum_batch(&array)?; + let result = sum_batch(&array, &DataType::Decimal(10, 0))?; assert_eq!(ScalarValue::Decimal128(Some(15), 10, 0), result); // test agg @@ -414,7 +579,7 @@ mod tests { .collect::() .with_precision_and_scale(10, 0)?, ); - let result = sum_batch(&array)?; + let result = sum_batch(&array, &DataType::Decimal(10, 0))?; assert_eq!(ScalarValue::Decimal128(Some(13), 10, 0), result); // test agg @@ -448,7 +613,7 @@ mod tests { .collect::() .with_precision_and_scale(10, 0)?, ); - let result = sum_batch(&array)?; + let result = sum_batch(&array, &DataType::Decimal(10, 0))?; assert_eq!(ScalarValue::Decimal128(None, 10, 0), result); // test agg diff --git a/datafusion/row/src/accessor.rs b/datafusion/row/src/accessor.rs new file mode 100644 index 000000000000..b6ec41d3345b --- /dev/null +++ b/datafusion/row/src/accessor.rs @@ -0,0 +1,296 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Setter/Getter for row with all fixed-sized fields. + +use crate::layout::{RowLayout, RowType}; +use crate::validity::NullBitsFormatter; +use crate::{fn_get_idx, fn_get_idx_opt, fn_set_idx}; +use arrow::datatypes::{DataType, Schema}; +use arrow::util::bit_util::{get_bit_raw, set_bit_raw}; +use datafusion_common::ScalarValue; +use std::sync::Arc; + +//TODO: DRY with reader and writer + +/// Read the tuple `data[base_offset..]` we are currently pointing to +pub struct RowAccessor<'a> { + /// Layout on how to read each field + layout: Arc, + /// Raw bytes slice where the tuple stores + data: &'a mut [u8], + /// Start position for the current tuple in the raw bytes slice. + base_offset: usize, +} + +impl<'a> std::fmt::Debug for RowAccessor<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.null_free() { + write!(f, "null_free") + } else { + let null_bits = self.null_bits(); + write!( + f, + "{:?}", + NullBitsFormatter::new(null_bits, self.layout.field_count) + ) + } + } +} + +#[macro_export] +macro_rules! fn_add_idx { + ($NATIVE: ident) => { + paste::item! { + /// add field at `idx` with `value` + pub fn [](&mut self, idx: usize, value: $NATIVE) { + if self.is_valid_at(idx) { + self.[](idx, value + self.[](idx)); + } else { + self.set_non_null_at(idx); + self.[](idx, value); + } + } + } + }; +} + +macro_rules! fn_max_min_idx { + ($NATIVE: ident, $OP: ident) => { + paste::item! { + /// check max then update + pub fn [<$OP _ $NATIVE>](&mut self, idx: usize, value: $NATIVE) { + if self.is_valid_at(idx) { + let v = value.$OP(self.[](idx)); + self.[](idx, v); + } else { + self.set_non_null_at(idx); + self.[](idx, value); + } + } + } + }; +} + +macro_rules! fn_get_idx_scalar { + ($NATIVE: ident, $SCALAR:ident) => { + paste::item! { + pub fn [](&self, idx: usize) -> ScalarValue { + if self.is_valid_at(idx) { + ScalarValue::$SCALAR(Some(self.[](idx))) + } else { + ScalarValue::$SCALAR(None) + } + } + } + }; +} + +impl<'a> RowAccessor<'a> { + /// new + pub fn new(schema: &Schema, row_type: RowType) -> Self { + Self { + layout: Arc::new(RowLayout::new(schema, row_type)), + data: &mut [], + base_offset: 0, + } + } + + pub fn new_from_layout(layout: Arc) -> Self { + Self { + layout, + data: &mut [], + base_offset: 0, + } + } + + /// Update this row to point to position `offset` in `base` + pub fn point_to(&mut self, offset: usize, data: &'a mut [u8]) { + self.base_offset = offset; + self.data = data; + } + + #[inline] + fn assert_index_valid(&self, idx: usize) { + assert!(idx < self.layout.field_count); + } + + #[inline(always)] + fn field_offsets(&self) -> &[usize] { + &self.layout.field_offsets + } + + #[inline(always)] + fn null_free(&self) -> bool { + self.layout.null_free + } + + #[inline(always)] + fn null_bits(&self) -> &[u8] { + if self.null_free() { + &[] + } else { + let start = self.base_offset; + &self.data[start..start + self.layout.null_width] + } + } + + fn is_valid_at(&self, idx: usize) -> bool { + unsafe { get_bit_raw(self.null_bits().as_ptr(), idx) } + } + + // ------------------------------ + // ----- Fixed Sized getters ---- + // ------------------------------ + + fn get_bool(&self, idx: usize) -> bool { + self.assert_index_valid(idx); + let offset = self.field_offsets()[idx]; + let value = &self.data[self.base_offset + offset..]; + value[0] != 0 + } + + fn get_u8(&self, idx: usize) -> u8 { + self.assert_index_valid(idx); + let offset = self.field_offsets()[idx]; + self.data[self.base_offset + offset] + } + + fn_get_idx!(u16, 2); + fn_get_idx!(u32, 4); + fn_get_idx!(u64, 8); + fn_get_idx!(i8, 1); + fn_get_idx!(i16, 2); + fn_get_idx!(i32, 4); + fn_get_idx!(i64, 8); + fn_get_idx!(f32, 4); + fn_get_idx!(f64, 8); + + fn_get_idx_opt!(bool); + fn_get_idx_opt!(u8); + fn_get_idx_opt!(u16); + fn_get_idx_opt!(u32); + fn_get_idx_opt!(u64); + fn_get_idx_opt!(i8); + fn_get_idx_opt!(i16); + fn_get_idx_opt!(i32); + fn_get_idx_opt!(i64); + fn_get_idx_opt!(f32); + fn_get_idx_opt!(f64); + + fn_get_idx_scalar!(bool, Boolean); + fn_get_idx_scalar!(u8, UInt8); + fn_get_idx_scalar!(u16, UInt16); + fn_get_idx_scalar!(u32, UInt32); + fn_get_idx_scalar!(u64, UInt64); + fn_get_idx_scalar!(i8, Int8); + fn_get_idx_scalar!(i16, Int16); + fn_get_idx_scalar!(i32, Int32); + fn_get_idx_scalar!(i64, Int64); + fn_get_idx_scalar!(f32, Float32); + fn_get_idx_scalar!(f64, Float64); + + pub fn get_as_scalar(&self, dt: &DataType, index: usize) -> ScalarValue { + match dt { + DataType::Boolean => self.get_bool_scalar(index), + DataType::Int8 => self.get_i8_scalar(index), + DataType::Int16 => self.get_i16_scalar(index), + DataType::Int32 => self.get_i32_scalar(index), + DataType::Int64 => self.get_i64_scalar(index), + DataType::UInt8 => self.get_u8_scalar(index), + DataType::UInt16 => self.get_u16_scalar(index), + DataType::UInt32 => self.get_u32_scalar(index), + DataType::UInt64 => self.get_u64_scalar(index), + DataType::Float32 => self.get_f32_scalar(index), + DataType::Float64 => self.get_f64_scalar(index), + _ => unreachable!(), + } + } + + // ------------------------------ + // ----- Fixed Sized setters ---- + // ------------------------------ + + pub(crate) fn set_non_null_at(&mut self, idx: usize) { + assert!( + !self.null_free(), + "Unexpected call to set_non_null_at on null-free row writer" + ); + let null_bits = &mut self.data[0..self.layout.null_width]; + unsafe { + set_bit_raw(null_bits.as_mut_ptr(), idx); + } + } + + fn set_u8(&mut self, idx: usize, value: u8) { + self.assert_index_valid(idx); + let offset = self.field_offsets()[idx]; + self.data[offset] = value; + } + + fn_set_idx!(u16, 2); + fn_set_idx!(u32, 4); + fn_set_idx!(u64, 8); + fn_set_idx!(i16, 2); + fn_set_idx!(i32, 4); + fn_set_idx!(i64, 8); + fn_set_idx!(f32, 4); + fn_set_idx!(f64, 8); + + fn set_i8(&mut self, idx: usize, value: i8) { + self.assert_index_valid(idx); + let offset = self.field_offsets()[idx]; + self.data[offset] = value.to_le_bytes()[0]; + } + + // ------------------------------ + // ---- Fixed sized updaters ---- + // ------------------------------ + + fn_add_idx!(u8); + fn_add_idx!(u16); + fn_add_idx!(u32); + fn_add_idx!(u64); + fn_add_idx!(i8); + fn_add_idx!(i16); + fn_add_idx!(i32); + fn_add_idx!(i64); + fn_add_idx!(f32); + fn_add_idx!(f64); + + fn_max_min_idx!(u8, max); + fn_max_min_idx!(u16, max); + fn_max_min_idx!(u32, max); + fn_max_min_idx!(u64, max); + fn_max_min_idx!(i8, max); + fn_max_min_idx!(i16, max); + fn_max_min_idx!(i32, max); + fn_max_min_idx!(i64, max); + fn_max_min_idx!(f32, max); + fn_max_min_idx!(f64, max); + + fn_max_min_idx!(u8, min); + fn_max_min_idx!(u16, min); + fn_max_min_idx!(u32, min); + fn_max_min_idx!(u64, min); + fn_max_min_idx!(i8, min); + fn_max_min_idx!(i16, min); + fn_max_min_idx!(i32, min); + fn_max_min_idx!(i64, min); + fn_max_min_idx!(f32, min); + fn_max_min_idx!(f64, min); +} diff --git a/datafusion/row/src/layout.rs b/datafusion/row/src/layout.rs index b017d195836d..0c92025a74f4 100644 --- a/datafusion/row/src/layout.rs +++ b/datafusion/row/src/layout.rs @@ -38,8 +38,8 @@ pub enum RowType { } /// Reveals how the fields of a record are stored in the raw-bytes format -#[derive(Debug)] -pub(crate) struct RowLayout { +#[derive(Debug, Clone)] +pub struct RowLayout { /// Type of the layout row_type: RowType, /// If a row is null free according to its schema @@ -55,8 +55,14 @@ pub(crate) struct RowLayout { } impl RowLayout { - pub(crate) fn new(schema: &Schema, row_type: RowType) -> Self { - assert!(row_supported(schema, row_type)); + /// new + pub fn new(schema: &Schema, row_type: RowType) -> Self { + assert!( + row_supported(schema, row_type), + "{:?}Row with {:?} not supported yet.", + row_type, + schema, + ); let null_free = schema_null_free(schema); let field_count = schema.fields().len(); let null_width = if null_free { @@ -81,8 +87,9 @@ impl RowLayout { } } + /// Get fixed part width for this layout #[inline(always)] - pub(crate) fn fixed_part_width(&self) -> usize { + pub fn fixed_part_width(&self) -> usize { self.null_width + self.values_width } } @@ -149,7 +156,7 @@ pub(crate) fn estimate_row_width(schema: &Schema, layout: &RowLayout) -> usize { /// Tell if we can create raw-bytes based rows since we currently /// has limited data type supports in the row format -fn row_supported(schema: &Schema, row_type: RowType) -> bool { +pub fn row_supported(schema: &Schema, row_type: RowType) -> bool { schema .fields() .iter() diff --git a/datafusion/row/src/lib.rs b/datafusion/row/src/lib.rs index 54c112dd5e06..c05cbcd0ef1c 100644 --- a/datafusion/row/src/lib.rs +++ b/datafusion/row/src/lib.rs @@ -47,13 +47,15 @@ //! 0 1 2 10 14 22 31 32 //! -use arrow::array::{make_builder, ArrayBuilder}; +use arrow::array::{make_builder, ArrayBuilder, ArrayRef}; use arrow::datatypes::Schema; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; +pub use layout::row_supported; pub use layout::RowType; use std::sync::Arc; +pub mod accessor; #[cfg(feature = "jit")] pub mod jit; pub mod layout; @@ -84,6 +86,10 @@ impl MutableRecordBatch { let result = make_batch(self.schema.clone(), self.arrays.drain(..).collect()); result } + + pub fn output_as_columns(&mut self) -> Vec { + get_columns(self.arrays.drain(..).collect()) + } } fn new_arrays(schema: &Schema, batch_size: usize) -> Vec> { @@ -105,6 +111,10 @@ fn make_batch( RecordBatch::try_new(schema, columns) } +fn get_columns(mut arrays: Vec>) -> Vec { + arrays.iter_mut().map(|array| array.finish()).collect() +} + #[cfg(test)] mod tests { use super::*; @@ -341,7 +351,7 @@ mod tests { ); #[test] - #[should_panic(expected = "row_supported(schema, row_type)")] + #[should_panic(expected = "not supported yet")] fn test_unsupported_word_aligned_type() { let a: ArrayRef = Arc::new(StringArray::from(vec!["hello", "world"])); let batch = RecordBatch::try_from_iter(vec![("a", a)]).unwrap(); @@ -380,7 +390,7 @@ mod tests { } #[test] - #[should_panic(expected = "row_supported(schema, row_type)")] + #[should_panic(expected = "not supported yet")] fn test_unsupported_type_write() { let a: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8])); let batch = RecordBatch::try_from_iter(vec![("a", a)]).unwrap(); @@ -390,7 +400,7 @@ mod tests { } #[test] - #[should_panic(expected = "row_supported(schema, row_type)")] + #[should_panic(expected = "not supported yet")] fn test_unsupported_type_read() { let schema = Arc::new(Schema::new(vec![Field::new( "a", diff --git a/datafusion/row/src/reader.rs b/datafusion/row/src/reader.rs index e7ee004b0076..1bf6e102a9f2 100644 --- a/datafusion/row/src/reader.rs +++ b/datafusion/row/src/reader.rs @@ -46,6 +46,7 @@ pub fn read_as_batch( output.output().map_err(DataFusionError::ArrowError) } +#[macro_export] macro_rules! get_idx { ($NATIVE: ident, $SELF: ident, $IDX: ident, $WIDTH: literal) => {{ $SELF.assert_index_valid($IDX); @@ -56,6 +57,7 @@ macro_rules! get_idx { }}; } +#[macro_export] macro_rules! fn_get_idx { ($NATIVE: ident, $WIDTH: literal) => { paste::item! { @@ -70,10 +72,11 @@ macro_rules! fn_get_idx { }; } +#[macro_export] macro_rules! fn_get_idx_opt { ($NATIVE: ident) => { paste::item! { - fn [](&self, idx: usize) -> Option<$NATIVE> { + pub fn [](&self, idx: usize) -> Option<$NATIVE> { if self.is_valid_at(idx) { Some(self.[](idx)) } else { diff --git a/datafusion/row/src/writer.rs b/datafusion/row/src/writer.rs index 6b9ffdc0e31d..d71e1dbc073c 100644 --- a/datafusion/row/src/writer.rs +++ b/datafusion/row/src/writer.rs @@ -75,6 +75,7 @@ pub fn bench_write_batch( Ok(lengths) } +#[macro_export] macro_rules! set_idx { ($WIDTH: literal, $SELF: ident, $IDX: ident, $VALUE: ident) => {{ $SELF.assert_index_valid($IDX); @@ -83,6 +84,7 @@ macro_rules! set_idx { }}; } +#[macro_export] macro_rules! fn_set_idx { ($NATIVE: ident, $WIDTH: literal) => { paste::item! {