From 1de1108d26466096ff0cf3faca0482c773c731cf Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Mon, 11 Oct 2021 13:29:35 +0200 Subject: [PATCH] Clean up spawned task on `SortStream` drop Ref #1103. --- datafusion/Cargo.toml | 2 +- datafusion/src/physical_plan/sort.rs | 150 ++++++++++++++++++++++++++- 2 files changed, 149 insertions(+), 3 deletions(-) diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index ea9ca218b017c..ecc434a5af5cb 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -58,7 +58,7 @@ num_cpus = "1.13.0" chrono = "0.4" async-trait = "0.1.41" futures = "0.3" -pin-project-lite= "^0.2.0" +pin-project-lite= "^0.2.7" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs"] } tokio-stream = "0.1" log = "^0.4" diff --git a/datafusion/src/physical_plan/sort.rs b/datafusion/src/physical_plan/sort.rs index b732797c1d26b..fc3d362932374 100644 --- a/datafusion/src/physical_plan/sort.rs +++ b/datafusion/src/physical_plan/sort.rs @@ -40,6 +40,7 @@ use std::any::Any; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +use tokio::task::JoinHandle; /// Sort execution plan #[derive(Debug)] @@ -228,6 +229,13 @@ pin_project! { output: futures::channel::oneshot::Receiver>>, finished: bool, schema: SchemaRef, + join_handle: JoinHandle<()>, + } + + impl PinnedDrop for SortStream { + fn drop(this: Pin<&mut Self>) { + this.join_handle.abort(); + } } } @@ -239,7 +247,7 @@ impl SortStream { ) -> Self { let (tx, rx) = futures::channel::oneshot::channel(); let schema = input.schema(); - tokio::spawn(async move { + let join_handle = tokio::spawn(async move { let schema = input.schema(); let sorted_batch = common::collect(input) .await @@ -257,13 +265,15 @@ impl SortStream { Ok(result) }); - tx.send(sorted_batch) + // failing here is OK, the receiver is gone and does not care about the result + tx.send(sorted_batch).ok(); }); Self { output: rx, finished: false, schema, + join_handle, } } } @@ -305,6 +315,8 @@ impl RecordBatchStream for SortStream { #[cfg(test)] mod tests { + use std::sync::Weak; + use super::*; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::expressions::col; @@ -316,6 +328,7 @@ mod tests { use crate::test; use arrow::array::*; use arrow::datatypes::*; + use futures::FutureExt; #[tokio::test] async fn test_sort() -> Result<()> { @@ -474,4 +487,137 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_drop_cancel() -> Result<()> { + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); + + let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema))); + let refs = blocking_exec.refs(); + let sort_exec = Arc::new(SortExec::try_new( + vec![PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions::default(), + }], + blocking_exec, + )?); + + let fut = collect(sort_exec); + let mut fut = fut.boxed(); + + let waker = futures::task::noop_waker(); + let mut cx = futures::task::Context::from_waker(&waker); + let poll = fut.poll_unpin(&mut cx); + + assert!(poll.is_pending()); + drop(fut); + tokio::time::timeout(std::time::Duration::from_secs(10), async { + loop { + if dbg!(Weak::strong_count(&refs)) == 0 { + break; + } + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + } + }) + .await + .unwrap(); + + Ok(()) + } + + #[derive(Debug)] + struct BlockingExec { + schema: SchemaRef, + refs: Arc<()>, + } + + impl BlockingExec { + fn new(schema: SchemaRef) -> Self { + Self { + schema, + refs: Default::default(), + } + } + + fn refs(&self) -> Weak<()> { + Arc::downgrade(&self.refs) + } + } + + #[async_trait] + impl ExecutionPlan for BlockingExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn children(&self) -> Vec> { + // this is a leaf node and has no children + vec![] + } + + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(1) + } + + fn with_new_children( + &self, + _: Vec>, + ) -> Result> { + Err(DataFusionError::Internal(format!( + "Children cannot be replaced in {:?}", + self + ))) + } + + async fn execute(&self, _partition: usize) -> Result { + Ok(Box::pin(BlockingStream { + schema: Arc::clone(&self.schema), + refs: Arc::clone(&self.refs), + })) + } + + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default => { + write!(f, "BlockingExec",) + } + } + } + + fn statistics(&self) -> Statistics { + unimplemented!() + } + } + + #[derive(Debug)] + struct BlockingStream { + schema: SchemaRef, + refs: Arc<()>, + } + + impl Stream for BlockingStream { + type Item = ArrowResult; + + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Pending + } + } + + impl RecordBatchStream for BlockingStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + } }