diff --git a/kernel/src/engine/default/json.rs b/kernel/src/engine/default/json.rs index 62b7846e8..5b373dc20 100644 --- a/kernel/src/engine/default/json.rs +++ b/kernel/src/engine/default/json.rs @@ -33,7 +33,7 @@ pub struct DefaultJsonHandler { /// The executor to run async tasks on task_executor: Arc, /// The maximum number of batches to read ahead - readahead: usize, + buffer_size: usize, /// The number of rows to read per batch batch_size: usize, } @@ -43,7 +43,7 @@ impl DefaultJsonHandler { Self { store, task_executor, - readahead: 1000, + buffer_size: 1000, batch_size: 1024 * 128, } } @@ -52,7 +52,7 @@ impl DefaultJsonHandler { /// /// Defaults to 10. pub fn with_readahead(mut self, readahead: usize) -> Self { - self.readahead = readahead; + self.buffer_size = readahead; self } @@ -87,9 +87,9 @@ impl JsonHandler for DefaultJsonHandler { let schema: ArrowSchemaRef = Arc::new(physical_schema.as_ref().try_into()?); let file_opener = JsonOpener::new(self.batch_size, schema.clone(), self.store.clone()); - let (tx, rx) = mpsc::sync_channel(self.readahead); + let (tx, rx) = mpsc::sync_channel(self.buffer_size); let files = files.to_vec(); - let readahead = self.readahead; + let buffer_size = self.buffer_size; self.task_executor.spawn(async move { // an iterator of futures that open each file @@ -97,7 +97,7 @@ impl JsonHandler for DefaultJsonHandler { // create a stream from that iterator which buffers up to `readahead` futures at a time let mut stream = stream::iter(file_futures) - .buffered(readahead) + .buffered(buffer_size) .try_flatten() .map_ok(|record_batch| { Box::new(ArrowEngineData::new(record_batch)) as Box @@ -228,8 +228,13 @@ mod tests { use std::sync::Mutex; use std::time::Duration; + use crate::actions::get_log_schema; use crate::arrow::array::{AsArray, RecordBatch, StringArray}; use crate::arrow::datatypes::{DataType, Field, Schema as ArrowSchema}; + use crate::engine::arrow_data::ArrowEngineData; + use crate::engine::default::executor::tokio::{ + TokioBackgroundExecutor, TokioMultiThreadExecutor, + }; use futures::future; use itertools::Itertools; use object_store::{local::LocalFileSystem, ObjectStore}; @@ -238,11 +243,14 @@ mod tests { PutOptions, PutPayload, PutResult, Result, }; + // TODO: should just use the one from test_utils, but running into dependency issues + fn into_record_batch(engine_data: Box) -> RecordBatch { + ArrowEngineData::try_from_engine_data(engine_data) + .unwrap() + .into() + } + use super::*; - use crate::{ - actions::get_log_schema, engine::arrow_data::ArrowEngineData, - engine::default::executor::tokio::TokioBackgroundExecutor, - }; /// Store wrapper that wraps an inner store to purposefully delay GET requests of certain keys. #[derive(Debug)] @@ -435,15 +443,7 @@ mod tests { let data: Vec = handler .read_json_files(files, Arc::new(physical_schema.try_into().unwrap()), None) .unwrap() - .map(|ed_res| { - // TODO(nick) make this easier - ed_res.and_then(|ed| { - ed.into_any() - .downcast::() - .map_err(|_| Error::engine_data_type("ArrowEngineData")) - .map(|sd| sd.into()) - }) - }) + .map_ok(into_record_batch) .try_collect() .unwrap(); @@ -451,7 +451,7 @@ mod tests { assert_eq!(data[0].num_rows(), 4); } - #[tokio::test] + #[tokio::test(flavor = "multi_thread", worker_threads = 3)] async fn test_read_json_files_ordering() { let paths = [ "./tests/data/table-with-dv-small/_delta_log/00000000000000000000.json", @@ -485,22 +485,21 @@ mod tests { } }) .collect(); + + // note: join_all is ordered let files = future::join_all(file_futures).await; - let handler = DefaultJsonHandler::new(store, Arc::new(TokioBackgroundExecutor::new())); + let handler = DefaultJsonHandler::new( + store, + Arc::new(TokioMultiThreadExecutor::new( + tokio::runtime::Handle::current(), + )), + ); let physical_schema = Arc::new(ArrowSchema::try_from(get_log_schema().as_ref()).unwrap()); let data: Vec = handler .read_json_files(&files, Arc::new(physical_schema.try_into().unwrap()), None) .unwrap() - .map(|ed_res| { - // TODO(nick) make this easier - ed_res.and_then(|ed| { - ed.into_any() - .downcast::() - .map_err(|_| Error::engine_data_type("ArrowEngineData")) - .map(|sd| sd.into()) - }) - }) + .map_ok(into_record_batch) .try_collect() .unwrap(); assert_eq!(data.len(), 2);