From 1774aac8965edc41e19cd9467812c947502b5243 Mon Sep 17 00:00:00 2001 From: Matthew Turner Date: Fri, 31 May 2024 09:31:08 -0400 Subject: [PATCH] Feedback --- datafusion-examples/Cargo.toml | 1 - .../examples/file_stream_provider.rs | 19 +++++++++-------- datafusion/core/src/datasource/stream.rs | 21 +++++++------------ 3 files changed, 17 insertions(+), 24 deletions(-) diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index dd24a85fdb15..0bcf7c1afc15 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -83,4 +83,3 @@ uuid = "1.7" [target.'cfg(not(target_os = "windows"))'.dev-dependencies] nix = { version = "0.28.0", features = ["fs"] } - diff --git a/datafusion-examples/examples/file_stream_provider.rs b/datafusion-examples/examples/file_stream_provider.rs index c0590efc9c3e..4e79f9afc2ca 100644 --- a/datafusion-examples/examples/file_stream_provider.rs +++ b/datafusion-examples/examples/file_stream_provider.rs @@ -31,7 +31,7 @@ use futures::StreamExt; use nix::sys::stat; use nix::unistd; use tempfile::TempDir; -use tokio::task::{spawn_blocking, JoinHandle}; +use tokio::task::JoinSet; use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; use datafusion::datasource::TableProvider; @@ -94,13 +94,14 @@ fn create_writing_thread( lines: Vec, waiting_lock: Arc, wait_until: usize, -) -> JoinHandle<()> { + tasks: &mut JoinSet<()>, +) { // Timeout for a long period of BrokenPipe error let broken_pipe_timeout = Duration::from_secs(10); let sa = file_path.clone(); // Spawn a new thread to write to the FIFO file #[allow(clippy::disallowed_methods)] // spawn allowed only in tests - spawn_blocking(move || { + tasks.spawn_blocking(move || { let file = OpenOptions::new().write(true).open(sa).unwrap(); // Reference time to use when deciding to fail the test let execution_start = Instant::now(); @@ -114,7 +115,7 @@ fn create_writing_thread( write_to_fifo(&file, line, execution_start, broken_pipe_timeout).unwrap(); } drop(file); - }) + }); } /// This example demonstrates a scanning against an Arrow data source (JSON) and @@ -130,21 +131,22 @@ async fn main() -> Result<()> { let tmp_dir = TempDir::new()?; let fifo_path = create_fifo_file(&tmp_dir, "fifo_unbounded.csv")?; - let mut tasks: Vec> = vec![]; + let mut tasks: JoinSet<()> = JoinSet::new(); let waiting = Arc::new(AtomicBool::new(true)); let data_iter = 0..TEST_DATA_SIZE; let lines = data_iter .map(|i| format!("{},{}\n", i, i + 1)) .collect::>(); - // Create writing threads for the left and right FIFO files - tasks.push(create_writing_thread( + + create_writing_thread( fifo_path.clone(), Some("a1,a2\n".to_owned()), lines.clone(), waiting.clone(), TEST_DATA_SIZE, - )); + &mut tasks, + ); // Create schema let schema = Arc::new(Schema::new(vec![ @@ -161,7 +163,6 @@ async fn main() -> Result<()> { let df = ctx.sql("SELECT * FROM fifo").await.unwrap(); let mut stream = df.execute_stream().await.unwrap(); - futures::future::join_all(tasks).await; let mut batches = Vec::new(); if let Some(Ok(batch)) = stream.next().await { batches.push(batch) diff --git a/datafusion/core/src/datasource/stream.rs b/datafusion/core/src/datasource/stream.rs index e0fc5b01a4fe..9cfdb7bb1168 100644 --- a/datafusion/core/src/datasource/stream.rs +++ b/datafusion/core/src/datasource/stream.rs @@ -110,9 +110,7 @@ impl FromStr for StreamEncoding { /// responsible for providing a `RecordBatchReader` and optionally a `RecordBatchWriter`. pub trait StreamProvider: std::fmt::Debug + Send + Sync { /// Get a reference to the schema for this stream - fn schema(&self) -> SchemaRef; - /// Needed for `PartitionStream` - maybe there is a better way to do this. - fn schema_ref(&self) -> &SchemaRef; + fn schema(&self) -> &SchemaRef; /// Provide `RecordBatchReader` fn reader(&self) -> Result>; /// Provide `RecordBatchWriter` @@ -182,12 +180,7 @@ impl FileStreamProvider { } impl StreamProvider for FileStreamProvider { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } - - // Needed for `PartitionStream` - fn schema_ref(&self) -> &SchemaRef { + fn schema(&self) -> &SchemaRef { &self.schema } @@ -339,11 +332,11 @@ impl TableProvider for StreamTable { let projected = self.0.source.schema().project(p)?; create_ordering(&projected, &self.0.order)? } - None => create_ordering(self.0.source.schema_ref(), &self.0.order)?, + None => create_ordering(self.0.source.schema(), &self.0.order)?, }; Ok(Arc::new(StreamingTableExec::try_new( - self.0.source.schema(), + self.0.source.schema().clone(), vec![Arc::new(StreamRead(self.0.clone())) as _], projection, projected_schema, @@ -360,7 +353,7 @@ impl TableProvider for StreamTable { ) -> Result> { let ordering = match self.0.order.first() { Some(x) => { - let schema = self.0.source.schema_ref(); + let schema = self.0.source.schema(); let orders = create_ordering(schema, std::slice::from_ref(x))?; let ordering = orders.into_iter().next().unwrap(); Some(ordering.into_iter().map(Into::into).collect()) @@ -371,7 +364,7 @@ impl TableProvider for StreamTable { Ok(Arc::new(DataSinkExec::new( input, Arc::new(StreamWrite(self.0.clone())), - self.0.source.schema(), + self.0.source.schema().clone(), ordering, ))) } @@ -381,7 +374,7 @@ struct StreamRead(Arc); impl PartitionStream for StreamRead { fn schema(&self) -> &SchemaRef { - self.0.source.schema_ref() + self.0.source.schema() } fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream {