Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate hash aggregation output in smaller record batches #3461

Merged
merged 3 commits into from
Oct 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions datafusion/core/src/physical_plan/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ impl ExecutionPlan for AggregateExec {
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let batch_size = context.session_config().batch_size();
let input = self.input.execute(partition, context)?;

let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
Expand All @@ -318,6 +319,7 @@ impl ExecutionPlan for AggregateExec {
self.aggr_expr.clone(),
input,
baseline_metrics,
batch_size,
)?))
} else {
Ok(Box::pin(GroupedHashAggregateStream::new(
Expand Down
121 changes: 72 additions & 49 deletions datafusion/core/src/physical_plan/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@ pub(crate) struct GroupedHashAggregateStreamV2 {

baseline_metrics: BaselineMetrics,
random_state: RandomState,
finished: bool,
/// size to be used for resulting RecordBatches
batch_size: usize,
/// if the result is chunked into batches,
/// last offset is preserved for continuation.
row_group_skip_position: usize,
}

fn aggr_state_schema(aggr_expr: &[Arc<dyn AggregateExpr>]) -> Result<SchemaRef> {
Expand All @@ -105,6 +109,7 @@ impl GroupedHashAggregateStreamV2 {
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
input: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
batch_size: usize,
) -> Result<Self> {
let timer = baseline_metrics.elapsed_compute().timer();

Expand Down Expand Up @@ -135,7 +140,8 @@ impl GroupedHashAggregateStreamV2 {
aggregate_expressions,
aggr_state: Default::default(),
random_state: Default::default(),
finished: false,
batch_size,
row_group_skip_position: 0,
})
}
}
Expand All @@ -148,56 +154,62 @@ impl Stream for GroupedHashAggregateStreamV2 {
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let this = &mut *self;
if this.finished {
return Poll::Ready(None);
}

let elapsed_compute = this.baseline_metrics.elapsed_compute();

loop {
let result = match ready!(this.input.poll_next_unpin(cx)) {
Some(Ok(batch)) => {
let timer = elapsed_compute.timer();
let result = group_aggregate_batch(
&this.mode,
&this.random_state,
&this.group_by,
&mut this.accumulators,
&this.group_schema,
this.aggr_layout.clone(),
batch,
&mut this.aggr_state,
&this.aggregate_expressions,
);

timer.done();

match result {
Ok(_) => continue,
Err(e) => Err(ArrowError::ExternalError(Box::new(e))),
let result: ArrowResult<Option<RecordBatch>> =
match ready!(this.input.poll_next_unpin(cx)) {
Some(Ok(batch)) => {
let timer = elapsed_compute.timer();
let result = group_aggregate_batch(
&this.mode,
&this.random_state,
&this.group_by,
&mut this.accumulators,
&this.group_schema,
this.aggr_layout.clone(),
batch,
&mut this.aggr_state,
&this.aggregate_expressions,
);

timer.done();

match result {
Ok(_) => continue,
Err(e) => Err(ArrowError::ExternalError(Box::new(e))),
}
}
Some(Err(e)) => Err(e),
None => {
let timer = this.baseline_metrics.elapsed_compute().timer();
let result = create_batch_from_map(
&this.mode,
&this.group_schema,
&this.aggr_schema,
this.batch_size,
this.row_group_skip_position,
&mut this.aggr_state,
&mut this.accumulators,
&this.schema,
);

timer.done();
result
}
};

this.row_group_skip_position += this.batch_size;
match result {
Ok(Some(result)) => {
return Poll::Ready(Some(Ok(
result.record_output(&this.baseline_metrics)
)))
}
Some(Err(e)) => Err(e),
None => {
this.finished = true;
let timer = this.baseline_metrics.elapsed_compute().timer();
let result = create_batch_from_map(
&this.mode,
&this.group_schema,
&this.aggr_schema,
&mut this.aggr_state,
&mut this.accumulators,
&this.schema,
)
.record_output(&this.baseline_metrics);

timer.done();
result
}
};

this.finished = true;
return Poll::Ready(Some(result));
Ok(None) => return Poll::Ready(None),
Err(error) => return Poll::Ready(Some(Err(error))),
}
}
}
}
Expand Down Expand Up @@ -419,23 +431,34 @@ fn create_group_rows(arrays: Vec<ArrayRef>, schema: &Schema) -> Vec<Vec<u8>> {
}

/// Create a RecordBatch with all group keys and accumulator' states or values.
#[allow(clippy::too_many_arguments)]
fn create_batch_from_map(
mode: &AggregateMode,
group_schema: &Schema,
aggr_schema: &Schema,
batch_size: usize,
skip_items: usize,
aggr_state: &mut AggregationState,
accumulators: &mut [AccumulatorItemV2],
output_schema: &Schema,
) -> ArrowResult<RecordBatch> {
) -> ArrowResult<Option<RecordBatch>> {
if skip_items > aggr_state.group_states.len() {
return Ok(None);
}

if aggr_state.group_states.is_empty() {
return Ok(RecordBatch::new_empty(Arc::new(output_schema.to_owned())));
return Ok(Some(RecordBatch::new_empty(Arc::new(
output_schema.to_owned(),
))));
}

let mut state_accessor = RowAccessor::new(aggr_schema, RowType::WordAligned);

let (group_buffers, mut state_buffers): (Vec<_>, Vec<_>) = aggr_state
.group_states
.iter()
.skip(skip_items)
.take(batch_size)
.map(|gs| (gs.group_by_values.clone(), gs.aggregation_buffer.clone()))
.unzip();

Expand Down Expand Up @@ -471,7 +494,7 @@ fn create_batch_from_map(
.map(|(col, desired_field)| cast(col, desired_field.data_type()))
.collect::<ArrowResult<Vec<_>>>()?;

RecordBatch::try_new(Arc::new(output_schema.to_owned()), columns)
RecordBatch::try_new(Arc::new(output_schema.to_owned()), columns).map(Some)
}

fn read_as_batch(rows: &[Vec<u8>], schema: &Schema, row_type: RowType) -> Vec<ArrayRef> {
Expand Down