Skip to content

Commit

Permalink
Fix: Passed all tests and make sure not roundrobinbatch after setting
Browse files Browse the repository at this point in the history
  • Loading branch information
Weijun-H committed Feb 5, 2025
1 parent b93fa5c commit 1cb85dd
Showing 1 changed file with 46 additions and 21 deletions.
67 changes: 46 additions & 21 deletions datafusion/physical-plan/src/repartition/on_demand_repartition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ use super::{
use crate::common::SharedMemoryReservation;
use crate::execution_plan::CardinalityEffect;
use crate::metrics::{self, BaselineMetrics, MetricBuilder};
use crate::projection::{all_columns, make_with_child, ProjectionExec};
use crate::repartition::distributor_channels::{
DistributionReceiver, DistributionSender,
};
Expand Down Expand Up @@ -202,7 +203,7 @@ impl ExecutionPlan for OnDemandRepartitionExec {
}

fn benefits_from_input_partitioning(&self) -> Vec<bool> {
vec![matches!(self.partitioning(), Partitioning::Hash(_, _))]
vec![false]
}

fn maintains_input_order(&self) -> Vec<bool> {
Expand Down Expand Up @@ -247,10 +248,10 @@ impl ExecutionPlan for OnDemandRepartitionExec {
.get_or_init(|| async move {
let (txs, rxs) = if preserve_order {
(0..num_input_partitions)
.map(|_| async_channel::bounded(2))
.map(|_| async_channel::unbounded())
.unzip::<_, _, Vec<_>, Vec<_>>()
} else {
let (tx, rx) = async_channel::bounded(2);
let (tx, rx) = async_channel::unbounded();
(vec![tx], vec![rx])
};
Mutex::new((txs, rxs))
Expand Down Expand Up @@ -365,6 +366,30 @@ impl ExecutionPlan for OnDemandRepartitionExec {
fn cardinality_effect(&self) -> CardinalityEffect {
CardinalityEffect::Equal
}

fn try_swapping_with_projection(
&self,
projection: &ProjectionExec,
) -> Result<Option<Arc<dyn ExecutionPlan>>> {
// If the projection does not narrow the schema, we should not try to push it down.
if projection.expr().len() >= projection.input().schema().fields().len() {
return Ok(None);
}

// If pushdown is not beneficial or applicable, break it.
if projection.benefits_from_input_partitioning()[0]
|| !all_columns(projection.expr())
{
return Ok(None);
}

let new_projection = make_with_child(projection, self.input())?;

Ok(Some(Arc::new(OnDemandRepartitionExec::try_new(
new_projection,
self.partitioning().clone(),
)?)))
}
}

impl OnDemandRepartitionExec {
Expand Down Expand Up @@ -396,7 +421,7 @@ impl OnDemandRepartitionExec {
async fn process_input(
input: Arc<dyn ExecutionPlan>,
partition: usize,
buffer_tx: tokio::sync::mpsc::Sender<RecordBatch>,
buffer_tx: Sender<RecordBatch>,
context: Arc<TaskContext>,
fetch_time: metrics::Time,
send_buffer_time: metrics::Time,
Expand Down Expand Up @@ -452,7 +477,7 @@ impl OnDemandRepartitionExec {
context: Arc<TaskContext>,
) -> Result<()> {
// execute the child operator in a separate task
let (buffer_tx, mut buffer_rx) = tokio::sync::mpsc::channel::<RecordBatch>(2);
let (buffer_tx, buffer_rx) = async_channel::bounded::<RecordBatch>(2);
let processing_task = SpawnedTask::spawn(Self::process_input(
Arc::clone(&input),
partition,
Expand All @@ -467,8 +492,8 @@ impl OnDemandRepartitionExec {
while !output_channels.is_empty() {
// When the input is done, break the loop
let batch = match buffer_rx.recv().await {
Some(result) => result,
None => break,
Ok(batch) => batch,
_ => break,
};

// Get the partition number from the output partition
Expand Down Expand Up @@ -595,13 +620,13 @@ impl Stream for OnDemandPerPartitionStream {
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
if !self.is_requested {
match self.sender.try_send(self.partition) {
Ok(_) => {}
Err(_) => {
return Poll::Ready(None);
}
}
if !self.is_requested && !self.sender.is_closed() {
self.sender.try_send(self.partition).map_err(|_| {
internal_datafusion_err!(
"Error sending partition number to the receiver for partition {}",
self.partition
)
})?;
self.is_requested = true;
}

Expand Down Expand Up @@ -667,13 +692,13 @@ impl Stream for OnDemandRepartitionStream {
) -> Poll<Option<Self::Item>> {
loop {
// Send partition number to input partitions
if !self.is_requested {
match self.sender.try_send(self.partition) {
Ok(_) => {}
Err(_) => {
return Poll::Ready(None);
}
}
if !self.is_requested && !self.sender.is_closed() {
self.sender.try_send(self.partition).map_err(|_| {
internal_datafusion_err!(
"Error sending partition number to the receiver for partition {}",
self.partition
)
})?;
self.is_requested = true;
}

Expand Down

0 comments on commit 1cb85dd

Please sign in to comment.