From 5602098f0b0a517431752504db080c0c22192824 Mon Sep 17 00:00:00 2001 From: zhangli20 Date: Sat, 25 Jan 2025 22:13:29 +0800 Subject: [PATCH] remove arrow child allocator introduce spark.blaze.onHeapSpill.memoryFraction do not send empty batches set ArrowWriter vector initial capacity to zero coalesce scan input fix ffi reader hanging simplify logging --- .idea/vcs.xml | 5 +- .../src/spark_udf_wrapper.rs | 7 +- .../datafusion-ext-plans/src/agg/agg_table.rs | 41 ++----- .../src/common/execution_context.rs | 32 ++++- .../src/ffi_reader_exec.rs | 6 +- .../datafusion-ext-plans/src/orc_exec.rs | 36 +++--- .../datafusion-ext-plans/src/parquet_exec.rs | 5 +- .../datafusion-ext-plans/src/project_exec.rs | 2 + .../src/shuffle/rss_sort_repartitioner.rs | 4 +- .../src/shuffle/sort_repartitioner.rs | 29 +++-- .../datafusion-ext-plans/src/sort_exec.rs | 50 ++------ .../org/apache/spark/sql/blaze/BlazeConf.java | 5 +- .../sql/blaze/BlazeCallNativeWrapper.scala | 35 +++--- .../sql/blaze/SparkUDFWrapperContext.scala | 48 ++++---- .../sql/blaze/SparkUDTFWrapperContext.scala | 76 +++++------- .../sql/blaze/memory/OnHeapSpillManager.scala | 4 +- .../blaze/arrowio/ArrowFFIExporter.scala | 111 ++++++++---------- .../blaze/arrowio/util/ArrowUtils.scala | 8 +- .../blaze/arrowio/util/ArrowWriter.scala | 8 +- 19 files changed, 246 insertions(+), 266 deletions(-) diff --git a/.idea/vcs.xml b/.idea/vcs.xml index 6b45010c0..bd37bc65d 100644 --- a/.idea/vcs.xml +++ b/.idea/vcs.xml @@ -12,5 +12,8 @@ + + + - + \ No newline at end of file diff --git a/native-engine/datafusion-ext-exprs/src/spark_udf_wrapper.rs b/native-engine/datafusion-ext-exprs/src/spark_udf_wrapper.rs index 7be741d40..f2f7ec7ba 100644 --- a/native-engine/datafusion-ext-exprs/src/spark_udf_wrapper.rs +++ b/native-engine/datafusion-ext-exprs/src/spark_udf_wrapper.rs @@ -20,7 +20,7 @@ use std::{ }; use arrow::{ - array::{as_struct_array, make_array, Array, ArrayRef, StructArray}, + array::{as_struct_array, make_array, new_empty_array, Array, ArrayRef, StructArray}, datatypes::{DataType, Field, Schema, SchemaRef}, ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}, record_batch::{RecordBatch, RecordBatchOptions}, @@ -123,6 +123,10 @@ impl PhysicalExpr for SparkUDFWrapperExpr { } let batch_schema = batch.schema(); + let num_rows = batch.num_rows(); + if num_rows == 0 { + return Ok(ColumnarValue::Array(new_empty_array(&self.return_type))); + } // init params schema let params_schema = self @@ -140,7 +144,6 @@ impl PhysicalExpr for SparkUDFWrapperExpr { })?; // evaluate params - let num_rows = batch.num_rows(); let params: Vec = self .params .iter() diff --git a/native-engine/datafusion-ext-plans/src/agg/agg_table.rs b/native-engine/datafusion-ext-plans/src/agg/agg_table.rs index 52d824057..d4538545c 100644 --- a/native-engine/datafusion-ext-plans/src/agg/agg_table.rs +++ b/native-engine/datafusion-ext-plans/src/agg/agg_table.rs @@ -30,7 +30,7 @@ use datafusion_ext_commons::{ rdx_queue::{KeyForRadixQueue, RadixQueue}, rdx_sort::radix_sort_by_key, }, - batch_size, compute_suggested_batch_size_for_output, df_execution_err, downcast_any, + batch_size, compute_suggested_batch_size_for_output, df_execution_err, io::{read_bytes_slice, read_len, write_len}, }; use futures::lock::Mutex; @@ -60,7 +60,6 @@ const SPILL_OFFHEAP_MEM_COST: usize = 200000; const NUM_SPILL_BUCKETS: usize = 64000; pub struct AggTable { - name: String, mem_consumer_info: Option>, in_mem: Mutex, spills: Mutex>>, @@ -71,14 +70,12 @@ pub struct AggTable { impl AggTable { pub fn new(agg_ctx: Arc, exec_ctx: Arc) -> Self { - let name = format!("AggTable[partition={}]", exec_ctx.partition_id()); let hashing_time = exec_ctx.register_timer_metric("hashing_time"); let merging_time = exec_ctx.register_timer_metric("merging_time"); let output_time = exec_ctx.register_timer_metric("output_time"); Self { mem_consumer_info: None, in_mem: Mutex::new(InMemTable::new( - name.clone(), 0, agg_ctx.clone(), exec_ctx.clone(), @@ -87,7 +84,6 @@ impl AggTable { merging_time.clone(), )), spills: Mutex::default(), - name, agg_ctx, exec_ctx, output_time, @@ -201,22 +197,24 @@ impl AggTable { return Ok(()); } - // convert all tables into cursors + // write rest data into an in-memory buffer if in-mem data is small + // otherwise write into spill let mut spills = spills; - let mut cursors = vec![]; if in_mem.num_records() > 0 { - let spill = tokio::task::spawn_blocking(|| { - let mut spill: Box = Box::new(vec![]); + let spill_metrics = self.exec_ctx.spill_metrics().clone(); + let spill = tokio::task::spawn_blocking(move || { + let mut spill: Box = try_new_spill(&spill_metrics)?; in_mem.try_into_spill(&mut spill)?; // spill staging records Ok::<_, DataFusionError>(spill) }) .await - .expect("tokio error")?; - let spill_size = downcast_any!(spill, Vec)?.len(); - self.update_mem_used(spill_size + spills.len() * SPILL_OFFHEAP_MEM_COST) + .expect("tokio spawn_blocking error")?; + self.update_mem_used(spills.len() * SPILL_OFFHEAP_MEM_COST) .await?; spills.push(spill); } + + let mut cursors = vec![]; for spill in &mut spills { cursors.push(RecordsSpillCursor::try_from_spill(spill, &self.agg_ctx)?); } @@ -277,7 +275,7 @@ impl AggTable { #[async_trait] impl MemConsumer for AggTable { fn name(&self) -> &str { - &self.name + "AggTable" } fn set_consumer_info(&mut self, consumer_info: Weak) { @@ -336,7 +334,6 @@ pub enum InMemMode { /// Unordered in-mem hash table which can be updated pub struct InMemTable { - name: String, id: usize, agg_ctx: Arc, exec_ctx: Arc, @@ -347,7 +344,6 @@ pub struct InMemTable { impl InMemTable { fn new( - name: String, id: usize, agg_ctx: Arc, exec_ctx: Arc, @@ -356,7 +352,6 @@ impl InMemTable { merging_time: Time, ) -> Self { Self { - name, id, hashing_data: HashingData::new(agg_ctx.clone(), hashing_time), merging_data: MergingData::new(agg_ctx.clone(), merging_time), @@ -367,7 +362,6 @@ impl InMemTable { } fn renew(&mut self, mode: InMemMode) -> Self { - let name = self.name.clone(); let agg_ctx = self.agg_ctx.clone(); let task_ctx = self.exec_ctx.clone(); let id = self.id + 1; @@ -375,15 +369,7 @@ impl InMemTable { let merging_time = self.merging_data.merging_time.clone(); std::mem::replace( self, - Self::new( - name, - id, - agg_ctx, - task_ctx, - mode, - hashing_time, - merging_time, - ), + Self::new(id, agg_ctx, task_ctx, mode, hashing_time, merging_time), ) } @@ -407,8 +393,7 @@ impl InMemTable { let cardinality_ratio = self.hashing_data.cardinality_ratio(); if cardinality_ratio > self.agg_ctx.partial_skipping_ratio { log::warn!( - "{} cardinality ratio = {cardinality_ratio}, will trigger partial skipping", - self.name, + "AggTable cardinality ratio = {cardinality_ratio}, will trigger partial skipping", ); return true; } diff --git a/native-engine/datafusion-ext-plans/src/common/execution_context.rs b/native-engine/datafusion-ext-plans/src/common/execution_context.rs index 1b9516584..b87b872a1 100644 --- a/native-engine/datafusion-ext-plans/src/common/execution_context.rs +++ b/native-engine/datafusion-ext-plans/src/common/execution_context.rs @@ -36,7 +36,7 @@ use datafusion_ext_commons::{ arrow::{array_size::ArraySize, coalesce::coalesce_batches_unchecked}, batch_size, df_execution_err, suggested_output_batch_mem_size, }; -use futures::{Stream, StreamExt}; +use futures::{executor::block_on_stream, Stream, StreamExt}; use futures_util::FutureExt; use once_cell::sync::OnceCell; use parking_lot::Mutex; @@ -122,6 +122,33 @@ impl ExecutionContext { .counter(name.to_owned(), self.partition_id) } + pub fn spawn_worker_thread_on_stream( + self: &Arc, + input: SendableRecordBatchStream, + ) -> SendableRecordBatchStream { + let (batch_sender, mut batch_receiver) = tokio::sync::mpsc::channel(1); + + tokio::task::spawn_blocking(move || { + let mut blocking_stream = block_on_stream(input); + while is_task_running() + && let Some(batch_result) = blocking_stream.next() + { + if batch_sender.blocking_send(batch_result).is_err() { + break; + } + } + }); + + self.output_with_sender("WorkerThreadOnStream", move |sender| async move { + while is_task_running() + && let Some(batch_result) = batch_receiver.recv().await + { + sender.send(batch_result?).await; + } + Ok(()) + }) + } + pub fn coalesce_with_default_batch_size( self: &Arc, input: SendableRecordBatchStream, @@ -393,6 +420,9 @@ impl WrappedRecordBatchSender { } pub async fn send(&self, batch: RecordBatch) { + if batch.num_rows() == 0 { + return; + } let exclude_time = self.exclude_time.get().cloned(); let send_time = exclude_time.as_ref().map(|_| Instant::now()); self.sender diff --git a/native-engine/datafusion-ext-plans/src/ffi_reader_exec.rs b/native-engine/datafusion-ext-plans/src/ffi_reader_exec.rs index f4d385d45..49a924756 100644 --- a/native-engine/datafusion-ext-plans/src/ffi_reader_exec.rs +++ b/native-engine/datafusion-ext-plans/src/ffi_reader_exec.rs @@ -149,14 +149,14 @@ fn read_ffi( Ok(exec_ctx .clone() .output_with_sender("FFIReader", move |sender| async move { - struct AutoCloseableExporer(GlobalRef); - impl Drop for AutoCloseableExporer { + struct AutoCloseableExporter(GlobalRef); + impl Drop for AutoCloseableExporter { fn drop(&mut self) { let _ = jni_call!(JavaAutoCloseable(self.0.as_obj()).close() -> ()); } } + let exporter = AutoCloseableExporter(exporter); - let exporter = AutoCloseableExporer(exporter); loop { let batch = { // load batch from ffi diff --git a/native-engine/datafusion-ext-plans/src/orc_exec.rs b/native-engine/datafusion-ext-plans/src/orc_exec.rs index 3dc7e958e..a2f0ac894 100644 --- a/native-engine/datafusion-ext-plans/src/orc_exec.rs +++ b/native-engine/datafusion-ext-plans/src/orc_exec.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::{any::Any, fmt, fmt::Formatter, sync::Arc}; +use std::{any::Any, fmt, fmt::Formatter, pin::Pin, sync::Arc}; use arrow::{datatypes::SchemaRef, error::ArrowError}; use blaze_jni_bridge::{jni_call_static, jni_new_global_ref, jni_new_string}; @@ -165,24 +165,16 @@ impl ExecutionPlan for OrcExec { fs_provider, }; - let mut file_stream = Box::pin(FileStream::new( + let file_stream = Box::pin(FileStream::new( &self.base_config, partition, opener, exec_ctx.execution_plan_metrics(), )?); - let timed_stream = - exec_ctx - .clone() - .output_with_sender("OrcScan", move |sender| async move { - sender.exclude_time(exec_ctx.baseline_metrics().elapsed_compute()); - let _timer = exec_ctx.baseline_metrics().elapsed_compute().timer(); - while let Some(batch) = file_stream.next().await.transpose()? { - sender.send(batch).await; - } - Ok(()) - }); - Ok(timed_stream) + + let timed_stream = execute_orc_scan(file_stream, exec_ctx.clone())?; + let nonblock_stream = exec_ctx.spawn_worker_thread_on_stream(timed_stream); + Ok(exec_ctx.coalesce_with_default_batch_size(nonblock_stream)) } fn metrics(&self) -> Option { @@ -194,6 +186,22 @@ impl ExecutionPlan for OrcExec { } } +fn execute_orc_scan( + mut stream: Pin>>, + exec_ctx: Arc, +) -> Result { + Ok(exec_ctx + .clone() + .output_with_sender("OrcScan", move |sender| async move { + sender.exclude_time(exec_ctx.baseline_metrics().elapsed_compute()); + let _timer = exec_ctx.baseline_metrics().elapsed_compute().timer(); + while let Some(batch) = stream.next().await.transpose()? { + sender.send(batch).await; + } + Ok(()) + })) +} + struct OrcOpener { projection: Vec, batch_size: usize, diff --git a/native-engine/datafusion-ext-plans/src/parquet_exec.rs b/native-engine/datafusion-ext-plans/src/parquet_exec.rs index 4fd2bd961..b5becb218 100644 --- a/native-engine/datafusion-ext-plans/src/parquet_exec.rs +++ b/native-engine/datafusion-ext-plans/src/parquet_exec.rs @@ -226,8 +226,9 @@ impl ExecutionPlan for ParquetExec { file_stream = file_stream.with_on_error(OnError::Skip); } - let timed_stream = execute_parquet_scan(Box::pin(file_stream), exec_ctx)?; - Ok(timed_stream) + let timed_stream = execute_parquet_scan(Box::pin(file_stream), exec_ctx.clone())?; + let nonblock_stream = exec_ctx.spawn_worker_thread_on_stream(timed_stream); + Ok(exec_ctx.coalesce_with_default_batch_size(nonblock_stream)) } fn metrics(&self) -> Option { diff --git a/native-engine/datafusion-ext-plans/src/project_exec.rs b/native-engine/datafusion-ext-plans/src/project_exec.rs index 63e3e976c..86e96a282 100644 --- a/native-engine/datafusion-ext-plans/src/project_exec.rs +++ b/native-engine/datafusion-ext-plans/src/project_exec.rs @@ -218,6 +218,8 @@ fn execute_project_with_filtering( .transpose()? { let output_batch = cached_expr_evaluator.filter_project(&batch)?; + drop(batch); + exec_ctx .baseline_metrics() .record_output(output_batch.num_rows()); diff --git a/native-engine/datafusion-ext-plans/src/shuffle/rss_sort_repartitioner.rs b/native-engine/datafusion-ext-plans/src/shuffle/rss_sort_repartitioner.rs index fa5a23618..3ff50a57e 100644 --- a/native-engine/datafusion-ext-plans/src/shuffle/rss_sort_repartitioner.rs +++ b/native-engine/datafusion-ext-plans/src/shuffle/rss_sort_repartitioner.rs @@ -27,7 +27,6 @@ use crate::{ }; pub struct RssSortShuffleRepartitioner { - name: String, mem_consumer_info: Option>, data: Mutex, rss: GlobalRef, @@ -41,7 +40,6 @@ impl RssSortShuffleRepartitioner { sort_time: Time, ) -> Self { Self { - name: format!("RssSortShufflePartitioner[partition={}]", partition_id), mem_consumer_info: None, data: Mutex::new(BufferedData::new(partitioning, partition_id, sort_time)), rss: rss_partition_writer, @@ -52,7 +50,7 @@ impl RssSortShuffleRepartitioner { #[async_trait] impl MemConsumer for RssSortShuffleRepartitioner { fn name(&self) -> &str { - &self.name + "RssSortShuffleRepartitioner" } fn set_consumer_info(&mut self, consumer_info: Weak) { diff --git a/native-engine/datafusion-ext-plans/src/shuffle/sort_repartitioner.rs b/native-engine/datafusion-ext-plans/src/shuffle/sort_repartitioner.rs index 88a4a0da3..d23b5c453 100644 --- a/native-engine/datafusion-ext-plans/src/shuffle/sort_repartitioner.rs +++ b/native-engine/datafusion-ext-plans/src/shuffle/sort_repartitioner.rs @@ -43,7 +43,6 @@ use crate::{ pub struct SortShuffleRepartitioner { exec_ctx: Arc, - name: String, mem_consumer_info: Option>, output_data_file: String, output_index_file: String, @@ -66,7 +65,6 @@ impl SortShuffleRepartitioner { let num_output_partitions = partitioning.partition_count(); Self { exec_ctx, - name: format!("SortShufflePartitioner[partition={partition_id}]"), mem_consumer_info: None, output_data_file, output_index_file, @@ -81,7 +79,7 @@ impl SortShuffleRepartitioner { #[async_trait] impl MemConsumer for SortShuffleRepartitioner { fn name(&self) -> &str { - &self.name + "SortShuffleRepartitioner" } fn set_consumer_info(&mut self, consumer_info: Weak) { @@ -199,13 +197,26 @@ impl ShuffleRepartitioner for SortShuffleRepartitioner { return Ok(()); } - // write rest data into an in-memory buffer + // write rest data into a spill if !data.is_empty() { - let mut spill = Box::new(vec![]); - let writer = spill.get_buf_writer(); - let offsets = data.write(writer)?; - self.update_mem_used(spill.len()).await?; - spills.push(Offsetted::new(offsets, spill)); + if self.mem_used_percent() < 0.5 { + let mut spill = Box::new(vec![]); + let writer = spill.get_buf_writer(); + let offsets = data.write(writer)?; + self.update_mem_used(spill.len()).await?; + spills.push(Offsetted::new(offsets, spill)); + } else { + let spill_metrics = self.exec_ctx.spill_metrics().clone(); + let spill = tokio::task::spawn_blocking(move || { + let mut spill = try_new_spill(&spill_metrics)?; + let offsets = data.write(spill.get_buf_writer())?; + Ok::<_, DataFusionError>(Offsetted::new(offsets, spill)) + }) + .await + .expect("tokio spawn_blocking error")?; + self.update_mem_used(0).await?; + spills.push(spill); + } } // append partition in each spills diff --git a/native-engine/datafusion-ext-plans/src/sort_exec.rs b/native-engine/datafusion-ext-plans/src/sort_exec.rs index 2c397fb71..9b58cd4b6 100644 --- a/native-engine/datafusion-ext-plans/src/sort_exec.rs +++ b/native-engine/datafusion-ext-plans/src/sort_exec.rs @@ -53,7 +53,6 @@ use datafusion_ext_commons::{ selection::{create_batch_interleaver, take_batch, BatchInterleaver}, }, compute_suggested_batch_size_for_kway_merge, compute_suggested_batch_size_for_output, - downcast_any, io::{read_len, read_one_batch, write_len, write_one_batch}, }; use futures::{lock::Mutex, StreamExt}; @@ -196,7 +195,6 @@ struct LevelSpill { struct ExternalSorter { exec_ctx: Arc, - name: String, mem_consumer_info: Option>, prune_sort_keys_from_batch: Arc, limit: usize, @@ -209,7 +207,7 @@ struct ExternalSorter { #[async_trait] impl MemConsumer for ExternalSorter { fn name(&self) -> &str { - &self.name + "ExternalSorter" } fn set_consumer_info(&mut self, consumer_info: Weak) { @@ -501,7 +499,6 @@ impl ExecuteWithColumnPruning for SortExec { )?); let sorter = Arc::new(ExternalSorter { exec_ctx: exec_ctx.clone(), - name: format!("ExternalSorter[partition={}]", partition), mem_consumer_info: None, prune_sort_keys_from_batch, limit: self.fetch.unwrap_or(usize::MAX), @@ -609,39 +606,18 @@ impl ExternalSorter { self.num_total_rows(), ); let mut spills: Vec> = spills.into_iter().map(|spill| spill.spill).collect(); - - if self.mem_used_percent() < 0.25 { - // if in-mem data is small, try to spill it into native raw bytes - let limit = self.limit; - let mut spill = tokio::task::spawn_blocking(move || { - let mut spill: Box = Box::new(vec![]); - data.try_into_spill(&mut spill, sub_batch_size, limit)?; - Ok::<_, DataFusionError>(spill) - }) - .await - .expect("tokio error")?; - - let in_mem_spill = downcast_any!(spill, mut Vec)?; - in_mem_spill.shrink_to_fit(); - - let in_mem_spill_size = in_mem_spill.len(); - spills.push(spill); - self.update_mem_used(in_mem_spill_size + spills.len() * SPILL_OFFHEAP_MEM_COST) - .await?; - } else { - let limit = self.limit; - let spill_metrics = self.exec_ctx.spill_metrics().clone(); - let spill = tokio::task::spawn_blocking(move || { - let mut spill = try_new_spill(&spill_metrics)?; - data.try_into_spill(&mut spill, sub_batch_size, limit)?; - Ok::<_, DataFusionError>(spill) - }) - .await - .expect("tokio error")?; - spills.push(spill); - self.update_mem_used(spills.len() * SPILL_OFFHEAP_MEM_COST) - .await?; - } + let limit = self.limit; + let spill_metrics = self.exec_ctx.spill_metrics().clone(); + let spill = tokio::task::spawn_blocking(move || { + let mut spill = try_new_spill(&spill_metrics)?; + data.try_into_spill(&mut spill, sub_batch_size, limit)?; + Ok::<_, DataFusionError>(spill) + }) + .await + .expect("tokio error")?; + spills.push(spill); + self.update_mem_used(spills.len() * SPILL_OFFHEAP_MEM_COST) + .await?; let mut merger = ExternalMerger::::try_new( &mut spills, diff --git a/spark-extension/src/main/java/org/apache/spark/sql/blaze/BlazeConf.java b/spark-extension/src/main/java/org/apache/spark/sql/blaze/BlazeConf.java index fff0a8ca3..6a877b75a 100644 --- a/spark-extension/src/main/java/org/apache/spark/sql/blaze/BlazeConf.java +++ b/spark-extension/src/main/java/org/apache/spark/sql/blaze/BlazeConf.java @@ -77,7 +77,10 @@ public enum BlazeConf { SMJ_FALLBACK_ROWS_THRESHOLD("spark.blaze.smjfallback.rows.threshold", 10000000), // smj fallback threshold - SMJ_FALLBACK_MEM_SIZE_THRESHOLD("spark.blaze.smjfallback.mem.threshold", 134217728); + SMJ_FALLBACK_MEM_SIZE_THRESHOLD("spark.blaze.smjfallback.mem.threshold", 134217728), + + // max memory fraction of on-heap spills + ON_HEAP_SPILL_MEM_FRACTION("spark.blaze.onHeapSpill.memoryFraction", 0.9); public final String key; private final Object defaultValue; diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeCallNativeWrapper.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeCallNativeWrapper.scala index aad2cb703..4b9f6001e 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeCallNativeWrapper.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeCallNativeWrapper.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.blaze.util.Using import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.execution.blaze.arrowio.util.ArrowUtils +import org.apache.spark.sql.execution.blaze.arrowio.util.ArrowUtils.ROOT_ALLOCATOR import org.apache.spark.sql.execution.blaze.arrowio.ColumnarHelper import org.apache.spark.sql.types.StructType import org.apache.spark.util.CompletionIterator @@ -102,12 +103,10 @@ case class BlazeCallNativeWrapper( metrics protected def importSchema(ffiSchemaPtr: Long): Unit = { - Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { schemaAllocator => - Using.resource(ArrowSchema.wrap(ffiSchemaPtr)) { ffiSchema => - arrowSchema = Data.importSchema(schemaAllocator, ffiSchema, dictionaryProvider) - schema = ArrowUtils.fromArrowSchema(arrowSchema) - toUnsafe = UnsafeProjection.create(schema) - } + Using.resource(ArrowSchema.wrap(ffiSchemaPtr)) { ffiSchema => + arrowSchema = Data.importSchema(ROOT_ALLOCATOR, ffiSchema, dictionaryProvider) + schema = ArrowUtils.fromArrowSchema(arrowSchema) + toUnsafe = UnsafeProjection.create(schema) } } @@ -116,19 +115,17 @@ case class BlazeCallNativeWrapper( throw new RuntimeException("Native runtime is finalized") } - Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => - Using.resources( - ArrowArray.wrap(ffiArrayPtr), - VectorSchemaRoot.create(arrowSchema, batchAllocator)) { case (ffiArray, root) => - Data.importIntoVectorSchemaRoot(batchAllocator, ffiArray, root, dictionaryProvider) - val batch = ColumnarHelper.rootAsBatch(root) - - batchRows.append( - ColumnarHelper - .batchAsRowIter(batch) - .map(row => toUnsafe(row).copy().asInstanceOf[InternalRow]) - .toSeq: _*) - } + Using.resources( + ArrowArray.wrap(ffiArrayPtr), + VectorSchemaRoot.create(arrowSchema, ROOT_ALLOCATOR)) { case (ffiArray, root) => + Data.importIntoVectorSchemaRoot(ROOT_ALLOCATOR, ffiArray, root, dictionaryProvider) + val batch = ColumnarHelper.rootAsBatch(root) + + batchRows.append( + ColumnarHelper + .batchAsRowIter(batch) + .map(row => toUnsafe(row).copy().asInstanceOf[InternalRow]) + .toSeq: _*) } } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDFWrapperContext.scala index e9498137c..4c19d9b47 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDFWrapperContext.scala @@ -16,6 +16,7 @@ package org.apache.spark.sql.blaze import java.nio.ByteBuffer + import org.apache.arrow.c.ArrowArray import org.apache.arrow.c.Data import org.apache.arrow.vector.VectorSchemaRoot @@ -30,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.Nondeterministic import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.execution.blaze.arrowio.ColumnarHelper import org.apache.spark.sql.execution.blaze.arrowio.util.ArrowUtils +import org.apache.spark.sql.execution.blaze.arrowio.util.ArrowUtils.ROOT_ALLOCATOR import org.apache.spark.sql.execution.blaze.arrowio.util.ArrowWriter import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.StructType @@ -61,36 +63,26 @@ case class SparkUDFWrapperContext(serialized: ByteBuffer) extends Logging { } def eval(importFFIArrayPtr: Long, exportFFIArrayPtr: Long): Unit = { - Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => - Using.resources( - VectorSchemaRoot.create(outputSchema, batchAllocator), - VectorSchemaRoot.create(paramsSchema, batchAllocator), - ArrowArray.wrap(importFFIArrayPtr), - ArrowArray.wrap(exportFFIArrayPtr)) { - (outputRoot, paramsRoot, importArray, exportArray) => - // import into params root - Data.importIntoVectorSchemaRoot( - batchAllocator, - importArray, - paramsRoot, - dictionaryProvider) - val batch = ColumnarHelper.rootAsBatch(paramsRoot) - - // evaluate expression and write to output root - val outputWriter = ArrowWriter.create(outputRoot) - for (paramsRow <- ColumnarHelper.batchAsRowIter(batch)) { - val outputRow = InternalRow(expr.eval(paramsToUnsafe(paramsRow))) - outputWriter.write(outputRow) - } - outputWriter.finish() + Using.resources( + VectorSchemaRoot.create(outputSchema, ROOT_ALLOCATOR), + VectorSchemaRoot.create(paramsSchema, ROOT_ALLOCATOR), + ArrowArray.wrap(importFFIArrayPtr), + ArrowArray.wrap(exportFFIArrayPtr)) { (outputRoot, paramsRoot, importArray, exportArray) => + // import into params root + Data.importIntoVectorSchemaRoot(ROOT_ALLOCATOR, importArray, paramsRoot, dictionaryProvider) - // export to output using root allocator - Data.exportVectorSchemaRoot( - ArrowUtils.rootAllocator, - outputRoot, - dictionaryProvider, - exportArray) + // evaluate expression and write to output root + Using.resource(ColumnarHelper.rootAsBatch(paramsRoot)) { batch => + val outputWriter = ArrowWriter.create(outputRoot) + for (paramsRow <- ColumnarHelper.batchAsRowIter(batch)) { + val outputRow = InternalRow(expr.eval(paramsToUnsafe(paramsRow))) + outputWriter.write(outputRow) + } + outputWriter.finish() } + + // export to output using root allocator + Data.exportVectorSchemaRoot(ROOT_ALLOCATOR, outputRoot, dictionaryProvider, exportArray) } } } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDTFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDTFWrapperContext.scala index d62771dc8..a73d2f381 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDTFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDTFWrapperContext.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.Nondeterministic import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.execution.blaze.arrowio.ColumnarHelper import org.apache.spark.sql.execution.blaze.arrowio.util.ArrowUtils +import org.apache.spark.sql.execution.blaze.arrowio.util.ArrowUtils.ROOT_ALLOCATOR import org.apache.spark.sql.execution.blaze.arrowio.util.ArrowWriter import org.apache.spark.sql.types.IntegerType import org.apache.spark.sql.types.StructField @@ -64,59 +65,42 @@ case class SparkUDTFWrapperContext(serialized: ByteBuffer) extends Logging { } def eval(importFFIArrayPtr: Long, exportFFIArrayPtr: Long): Unit = { - Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => - Using.resources( - VectorSchemaRoot.create(outputSchema, batchAllocator), - VectorSchemaRoot.create(paramsSchema, batchAllocator), - ArrowArray.wrap(importFFIArrayPtr), - ArrowArray.wrap(exportFFIArrayPtr)) { - (outputRoot, paramsRoot, importArray, exportArray) => - // import into params root - Data.importIntoVectorSchemaRoot( - batchAllocator, - importArray, - paramsRoot, - dictionaryProvider) - val batch = ColumnarHelper.rootAsBatch(paramsRoot) + Using.resources( + VectorSchemaRoot.create(outputSchema, ROOT_ALLOCATOR), + VectorSchemaRoot.create(paramsSchema, ROOT_ALLOCATOR), + ArrowArray.wrap(importFFIArrayPtr), + ArrowArray.wrap(exportFFIArrayPtr)) { (outputRoot, paramsRoot, importArray, exportArray) => + // import into params root + Data.importIntoVectorSchemaRoot(ROOT_ALLOCATOR, importArray, paramsRoot, dictionaryProvider) + val batch = ColumnarHelper.rootAsBatch(paramsRoot) - // evaluate expression and write to output root - val outputWriter = ArrowWriter.create(outputRoot) - for ((paramsRow, rowId) <- ColumnarHelper.batchAsRowIter(batch).zipWithIndex) { - for (outputRow <- expr.eval(paramsToUnsafe(paramsRow))) { - outputWriter.write(InternalRow(rowId, outputRow)) - } - } - outputWriter.finish() - - // export to output using root allocator - Data.exportVectorSchemaRoot( - ArrowUtils.rootAllocator, - outputRoot, - dictionaryProvider, - exportArray) + // evaluate expression and write to output root + val outputWriter = ArrowWriter.create(outputRoot) + for ((paramsRow, rowId) <- ColumnarHelper.batchAsRowIter(batch).zipWithIndex) { + for (outputRow <- expr.eval(paramsToUnsafe(paramsRow))) { + outputWriter.write(InternalRow(rowId, outputRow)) + } } + outputWriter.finish() + + // export to output using root allocator + Data.exportVectorSchemaRoot(ROOT_ALLOCATOR, outputRoot, dictionaryProvider, exportArray) } } def terminate(rowId: Int, exportFFIArrayPtr: Long): Unit = { - Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => - Using.resources( - VectorSchemaRoot.create(outputSchema, batchAllocator), - ArrowArray.wrap(exportFFIArrayPtr)) { (outputRoot, exportArray) => - // evaluate expression and write to output root - val outputWriter = ArrowWriter.create(outputRoot) - for (outputRow <- expr.terminate()) { - outputWriter.write(InternalRow(rowId, outputRow)) - } - outputWriter.finish() - - // export to output using root allocator - Data.exportVectorSchemaRoot( - ArrowUtils.rootAllocator, - outputRoot, - dictionaryProvider, - exportArray) + Using.resources( + VectorSchemaRoot.create(outputSchema, ROOT_ALLOCATOR), + ArrowArray.wrap(exportFFIArrayPtr)) { (outputRoot, exportArray) => + // evaluate expression and write to output root + val outputWriter = ArrowWriter.create(outputRoot) + for (outputRow <- expr.terminate()) { + outputWriter.write(InternalRow(rowId, outputRow)) } + outputWriter.finish() + + // export to output using root allocator + Data.exportVectorSchemaRoot(ROOT_ALLOCATOR, outputRoot, dictionaryProvider, exportArray) } } } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/memory/OnHeapSpillManager.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/memory/OnHeapSpillManager.scala index 3a270df9e..716a585f7 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/memory/OnHeapSpillManager.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/memory/OnHeapSpillManager.scala @@ -26,6 +26,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.memory.MemoryConsumer import org.apache.spark.memory.MemoryMode import org.apache.spark.memory.blaze.OnHeapSpillManagerHelper +import org.apache.spark.sql.blaze.BlazeConf import org.apache.spark.storage.BlockManager import org.apache.spark.util.Utils @@ -80,7 +81,8 @@ class OnHeapSpillManager(taskContext: TaskContext) s" ratio=$jvmMemoryUsedRatio") // we should have at least 10% free memory - memoryUsedRatio < 0.9 && jvmMemoryUsedRatio < 0.9 + val maxRatio = BlazeConf.ON_HEAP_SPILL_MEM_FRACTION.doubleConf() + memoryUsedRatio < maxRatio && jvmMemoryUsedRatio < maxRatio } /** diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/ArrowFFIExporter.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/ArrowFFIExporter.scala index f8819d2eb..c279760ff 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/ArrowFFIExporter.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/ArrowFFIExporter.scala @@ -15,6 +15,8 @@ */ package org.apache.spark.sql.execution.blaze.arrowio +import java.lang.Thread.UncaughtExceptionHandler + import org.apache.arrow.c.ArrowArray import org.apache.arrow.c.ArrowSchema import org.apache.arrow.c.Data @@ -22,18 +24,17 @@ import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.blaze.arrowio.util.ArrowUtils +import org.apache.spark.sql.execution.blaze.arrowio.util.ArrowUtils.ROOT_ALLOCATOR import org.apache.spark.sql.execution.blaze.arrowio.util.ArrowWriter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.blaze.{BlazeConf, NativeHelper} import org.apache.spark.sql.blaze.util.Using import org.apache.spark.TaskContext import java.security.PrivilegedExceptionAction +import java.util.concurrent.ArrayBlockingQueue import java.util.concurrent.BlockingQueue -import java.util.concurrent.SynchronousQueue - -class ArrowFFIExporter(rowIter: Iterator[InternalRow], schema: StructType) - extends AutoCloseable { +class ArrowFFIExporter(rowIter: Iterator[InternalRow], schema: StructType) { private val maxBatchNumRows = BlazeConf.BATCH_SIZE.intConf() private val maxBatchMemorySize = 1 << 24 // 16MB @@ -41,22 +42,19 @@ class ArrowFFIExporter(rowIter: Iterator[InternalRow], schema: StructType) private val emptyDictionaryProvider = new MapDictionaryProvider() private val nativeCurrentUser = NativeHelper.currentUser - private trait QueueElement - private case class Root(root: VectorSchemaRoot) extends QueueElement - private case object RootConsumed extends QueueElement - private case object Finished extends QueueElement + private trait QueueState + private case object NextBatch extends QueueState + private case object Finished extends QueueState private val tc = TaskContext.get() - private val outputQueue: BlockingQueue[QueueElement] = new SynchronousQueue[QueueElement]() + private val outputQueue: BlockingQueue[QueueState] = new ArrayBlockingQueue[QueueState](16) + private val processingQueue: BlockingQueue[Unit] = new ArrayBlockingQueue[Unit](16) private var currentRoot: VectorSchemaRoot = _ - private var finished = false - private val outputThread = startOutputThread() + startOutputThread() def exportSchema(exportArrowSchemaPtr: Long): Unit = { - Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { schemaAllocator => - Using.resource(ArrowSchema.wrap(exportArrowSchemaPtr)) { exportSchema => - Data.exportSchema(schemaAllocator, arrowSchema, emptyDictionaryProvider, exportSchema) - } + Using.resource(ArrowSchema.wrap(exportArrowSchemaPtr)) { exportSchema => + Data.exportSchema(ROOT_ALLOCATOR, arrowSchema, emptyDictionaryProvider, exportSchema) } } @@ -66,37 +64,24 @@ class ArrowFFIExporter(rowIter: Iterator[InternalRow], schema: StructType) } // export using root allocator - val allocator = ArrowUtils.rootAllocator Using.resource(ArrowArray.wrap(exportArrowArrayPtr)) { exportArray => - Data.exportVectorSchemaRoot(allocator, currentRoot, emptyDictionaryProvider, exportArray) + Data.exportVectorSchemaRoot( + ROOT_ALLOCATOR, + currentRoot, + emptyDictionaryProvider, + exportArray) } - // consume RootConsumed state so that we can go to the next batch - outputQueue.take() match { - case RootConsumed => - case other => - throw new IllegalStateException(s"Unexpected queue element: $other, expect RootConsumed") - } + // to continue processing next batch + processingQueue.put(()) true } private def hasNext: Boolean = { if (tc != null && (tc.isCompleted() || tc.isInterrupted())) { - finished = true return false } - - outputQueue.take() match { - case Root(root) => - currentRoot = root - true - case Finished => - currentRoot = null - finished = true - false - case other => - throw new IllegalStateException(s"Unexpected queue element: $other, expect Root/Finished") - } + outputQueue.take() == NextBatch } private def startOutputThread(): Thread = { @@ -108,49 +93,55 @@ class ArrowFFIExporter(rowIter: Iterator[InternalRow], schema: StructType) nativeCurrentUser.doAs(new PrivilegedExceptionAction[Unit] { override def run(): Unit = { - while (!finished && (tc == null || (!tc.isCompleted() && !tc.isInterrupted()))) { + while (tc == null || (!tc.isCompleted() && !tc.isInterrupted())) { if (!rowIter.hasNext) { - finished = true outputQueue.put(Finished) return } - Using.resource(ArrowUtils.newChildAllocator(getClass.getName)) { batchAllocator => - Using.resource(VectorSchemaRoot.create(arrowSchema, batchAllocator)) { root => - val arrowWriter = ArrowWriter.create(root) - var rowCount = 0 - - while (rowIter.hasNext - && rowCount < maxBatchNumRows - && batchAllocator.getAllocatedMemory < maxBatchMemorySize) { - arrowWriter.write(rowIter.next()) - rowCount += 1 - } - arrowWriter.finish() - - // export root - outputQueue.put(Root(root)) - outputQueue.put(RootConsumed) + Using.resource(VectorSchemaRoot.create(arrowSchema, ROOT_ALLOCATOR)) { root => + val arrowWriter = ArrowWriter.create(root) + var rowCount = 0 + while (rowIter.hasNext + && rowCount < maxBatchNumRows + && (rowCount == 0 || ROOT_ALLOCATOR.getAllocatedMemory < maxBatchMemorySize)) { + arrowWriter.write(rowIter.next()) + rowCount += 1 } + arrowWriter.finish() + + // export root + currentRoot = root + outputQueue.put(NextBatch) + + // wait for processing next batch + processingQueue.take() } } + outputQueue.put(Finished) } }) } }) + def close(): Unit = { + thread.interrupt() + outputQueue.put(Finished) // to abort any pending call to exportNextBatch() + } + if (tc != null) { tc.addTaskCompletionListener[Unit]((_: TaskContext) => close()) tc.addTaskFailureListener((_, _) => close()) } + thread.setDaemon(true) + thread.setUncaughtExceptionHandler(new UncaughtExceptionHandler { + override def uncaughtException(t: Thread, e: Throwable): Unit = { + close() + throw e + } + }) thread.start() thread } - - override def close(): Unit = { - finished = true - outputThread.interrupt() - outputQueue.offer(Finished) // to abort any pending call to exportNextBatch() - } } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/util/ArrowUtils.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/util/ArrowUtils.scala index f6ddfc607..64b8ca9bb 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/util/ArrowUtils.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/util/ArrowUtils.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.blaze.arrowio.util import scala.collection.JavaConverters.asScalaBufferConverter import scala.collection.JavaConverters.seqAsJavaListConverter -import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.complex.MapVector import org.apache.arrow.vector.types.DateUnit @@ -32,11 +31,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.ShutdownHookManager object ArrowUtils { - val rootAllocator = new RootAllocator(Long.MaxValue) - ShutdownHookManager.addShutdownHook(() => rootAllocator.close()) - - def newChildAllocator(name: String): BufferAllocator = - rootAllocator.newChildAllocator(name, 0, Long.MaxValue) + val ROOT_ALLOCATOR = new RootAllocator(Long.MaxValue) + ShutdownHookManager.addShutdownHook(() => ROOT_ALLOCATOR.close()) /** Maps data type from Spark to Arrow. NOTE: timeZoneId is always NULL in TimestampTypes */ def toArrowType(dt: DataType): ArrowType = diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/util/ArrowWriter.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/util/ArrowWriter.scala index 44de5945f..5579ded6c 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/util/ArrowWriter.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/util/ArrowWriter.scala @@ -21,18 +21,20 @@ import org.apache.arrow.vector._ import org.apache.arrow.vector.complex._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.execution.blaze.arrowio.util.ArrowUtils.ROOT_ALLOCATOR import org.apache.spark.sql.types._ object ArrowWriter { def create(schema: StructType): ArrowWriter = { val arrowSchema = ArrowUtils.toArrowSchema(schema) - val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator) + val root = VectorSchemaRoot.create(arrowSchema, ROOT_ALLOCATOR) create(root) } def create(root: VectorSchemaRoot): ArrowWriter = { val children = root.getFieldVectors().asScala.map { vector => + vector.setInitialCapacity(0) // don't allocate initial memory for the vector vector.allocateNew() createFieldWriter(vector) } @@ -211,7 +213,6 @@ private[sql] class FloatWriter(val valueVector: Float4Vector) extends ArrowField } private[sql] class DoubleWriter(val valueVector: Float8Vector) extends ArrowFieldWriter { - override def setNull(): Unit = { valueVector.setNull(count) } @@ -223,7 +224,6 @@ private[sql] class DoubleWriter(val valueVector: Float8Vector) extends ArrowFiel private[sql] class DecimalWriter(val valueVector: DecimalVector, precision: Int, scale: Int) extends ArrowFieldWriter { - override def setNull(): Unit = { valueVector.setNull(count) } @@ -239,7 +239,6 @@ private[sql] class DecimalWriter(val valueVector: DecimalVector, precision: Int, } private[sql] class StringWriter(val valueVector: VarCharVector) extends ArrowFieldWriter { - override def setNull(): Unit = { valueVector.setNull(count) } @@ -253,7 +252,6 @@ private[sql] class StringWriter(val valueVector: VarCharVector) extends ArrowFie } private[sql] class BinaryWriter(val valueVector: VarBinaryVector) extends ArrowFieldWriter { - override def setNull(): Unit = { valueVector.setNull(count) }