From 7f6a99eded4f3426119dbc609cbf8da1c0d94d66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marko=20Milenkovi=C4=87?= Date: Tue, 18 Feb 2025 20:40:01 +0000 Subject: [PATCH] use configuration to propagate grpc settings --- ballista/core/src/client.rs | 41 ++++++-- ballista/core/src/config.rs | 9 ++ .../src/execution_plans/distributed_query.rs | 17 ++-- .../src/execution_plans/shuffle_reader.rs | 60 ++++++++---- ballista/core/src/extension.rs | 49 +++++++--- ballista/executor/src/executor_process.rs | 18 +++- ballista/executor/src/standalone.rs | 98 ++++++------------- ballista/scheduler/src/scheduler_process.rs | 12 ++- ballista/scheduler/src/standalone.rs | 8 +- 9 files changed, 188 insertions(+), 124 deletions(-) diff --git a/ballista/core/src/client.rs b/ballista/core/src/client.rs index 2c99569b9..06d862764 100644 --- a/ballista/core/src/client.rs +++ b/ballista/core/src/client.rs @@ -61,7 +61,11 @@ const IO_RETRY_WAIT_TIME_MS: u64 = 3000; impl BallistaClient { /// Create a new BallistaClient to connect to the executor listening on the specified /// host and port - pub async fn try_new(host: &str, port: u16) -> Result { + pub async fn try_new( + host: &str, + port: u16, + grpc_client_max_message_size: usize, + ) -> Result { let addr = format!("http://{host}:{port}"); debug!("BallistaClient connecting to {}", addr); let connection = @@ -72,8 +76,11 @@ impl BallistaClient { "Error connecting to Ballista scheduler or executor at {addr}: {e:?}" )) })?; - let flight_client = FlightServiceClient::new(connection); - debug!("BallistaClient connected OK"); + let flight_client = FlightServiceClient::new(connection) + .max_decoding_message_size(grpc_client_max_message_size) + .max_encoding_message_size(grpc_client_max_message_size); + + debug!("BallistaClient connected OK: {:?}", flight_client); Ok(Self { flight_client }) } @@ -99,13 +106,27 @@ impl BallistaClient { .await .map_err(|error| match error { // map grpc connection error to partition fetch error. - BallistaError::GrpcActionError(msg) => BallistaError::FetchFailed( - executor_id.to_owned(), - partition_id.stage_id, - partition_id.partition_id, - msg, - ), - other => other, + BallistaError::GrpcActionError(msg) => { + log::warn!( + "grpc client failed to fetch partition: {:?} , message: {:?}", + partition_id, + msg + ); + BallistaError::FetchFailed( + executor_id.to_owned(), + partition_id.stage_id, + partition_id.partition_id, + msg, + ) + } + error => { + log::warn!( + "grpc client failed to fetch partition: {:?} , error: {:?}", + partition_id, + error + ); + error + } }) } diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs index 628821447..18a298253 100644 --- a/ballista/core/src/config.rs +++ b/ballista/core/src/config.rs @@ -32,6 +32,7 @@ pub const BALLISTA_STANDALONE_PARALLELISM: &str = "ballista.standalone.paralleli /// max message size for gRPC clients pub const BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE: &str = "ballista.grpc_client_max_message_size"; +pub const BALLISTA_SHUFFLE_READER_MAX_REQUESTS: &str = "ballista.shuffle.max_requests"; pub type ParseResult = result::Result; use std::sync::LazyLock; @@ -48,6 +49,10 @@ static CONFIG_ENTRIES: LazyLock> = LazyLock::new(|| "Configuration for max message size in gRPC clients".to_string(), DataType::UInt64, Some((16 * 1024 * 1024).to_string())), + ConfigEntry::new(BALLISTA_SHUFFLE_READER_MAX_REQUESTS.to_string(), + "Maximin concurrent requests shuffle reader can serve".to_string(), + DataType::UInt64, + Some((64).to_string())), ]; entries .into_iter() @@ -165,6 +170,10 @@ impl BallistaConfig { self.get_usize_setting(BALLISTA_STANDALONE_PARALLELISM) } + pub fn shuffle_reader_maximum_in_flight_requests(&self) -> usize { + self.get_usize_setting(BALLISTA_SHUFFLE_READER_MAX_REQUESTS) + } + fn get_usize_setting(&self, key: &str) -> usize { if let Some(v) = self.settings.get(key) { // infallible because we validate all configs in the constructor diff --git a/ballista/core/src/execution_plans/distributed_query.rs b/ballista/core/src/execution_plans/distributed_query.rs index e9596ef19..595038140 100644 --- a/ballista/core/src/execution_plans/distributed_query.rs +++ b/ballista/core/src/execution_plans/distributed_query.rs @@ -310,12 +310,16 @@ async fn execute_query( break Err(DataFusionError::Execution(msg)); } Some(job_status::Status::Successful(successful)) => { - let streams = successful.partition_location.into_iter().map(|p| { - let f = fetch_partition(p) - .map_err(|e| ArrowError::ExternalError(Box::new(e))); + let streams = + successful + .partition_location + .into_iter() + .map(move |partition| { + let f = fetch_partition(partition, max_message_size) + .map_err(|e| ArrowError::ExternalError(Box::new(e))); - futures::stream::once(f).try_flatten() - }); + futures::stream::once(f).try_flatten() + }); break Ok(futures::stream::iter(streams).flatten()); } @@ -325,6 +329,7 @@ async fn execute_query( async fn fetch_partition( location: PartitionLocation, + max_message_size: usize, ) -> Result { let metadata = location.executor_meta.ok_or_else(|| { DataFusionError::Internal("Received empty executor metadata".to_owned()) @@ -334,7 +339,7 @@ async fn fetch_partition( })?; let host = metadata.host.as_str(); let port = metadata.port as u16; - let mut ballista_client = BallistaClient::try_new(host, port) + let mut ballista_client = BallistaClient::try_new(host, port, max_message_size) .await .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; ballista_client diff --git a/ballista/core/src/execution_plans/shuffle_reader.rs b/ballista/core/src/execution_plans/shuffle_reader.rs index 7a20f1215..2ad389283 100644 --- a/ballista/core/src/execution_plans/shuffle_reader.rs +++ b/ballista/core/src/execution_plans/shuffle_reader.rs @@ -29,6 +29,7 @@ use std::sync::Arc; use std::task::{Context, Poll}; use crate::client::BallistaClient; +use crate::extension::SessionConfigExt; use crate::serde::scheduler::{PartitionLocation, PartitionStats}; use datafusion::arrow::datatypes::SchemaRef; @@ -146,8 +147,18 @@ impl ExecutionPlan for ShuffleReaderExec { let task_id = context.task_id().unwrap_or_else(|| partition.to_string()); info!("ShuffleReaderExec::execute({})", task_id); - // TODO make the maximum size configurable, or make it depends on global memory control - let max_request_num = 50usize; + let config = context.session_config(); + + let max_request_num = config.ballista_grpc_client_max_message_size(); + let max_message_size = + config.ballista_shuffle_reader_maximum_in_flight_requests(); + + log::debug!( + "ShuffleReaderExec::execute({}) max_request_num: {}, max_message_size: {}", + task_id, + max_request_num, + max_message_size + ); let mut partition_locations = HashMap::new(); for p in &self.partition[partition] { partition_locations @@ -166,7 +177,7 @@ impl ExecutionPlan for ShuffleReaderExec { partition_locations.shuffle(&mut thread_rng()); let response_receiver = - send_fetch_partitions(partition_locations, max_request_num); + send_fetch_partitions(partition_locations, max_request_num, max_message_size); let result = RecordBatchStreamAdapter::new( Arc::new(self.schema.as_ref().clone()), @@ -284,6 +295,7 @@ impl Stream for AbortableReceiverStream { fn send_fetch_partitions( partition_locations: Vec, max_request_num: usize, + max_message_size: usize, ) -> AbortableReceiverStream { let (response_sender, response_receiver) = mpsc::channel(max_request_num); let semaphore = Arc::new(Semaphore::new(max_request_num)); @@ -302,7 +314,9 @@ fn send_fetch_partitions( let response_sender_c = response_sender.clone(); spawned_tasks.push(SpawnedTask::spawn(async move { for p in local_locations { - let r = PartitionReaderEnum::Local.fetch_partition(&p).await; + let r = PartitionReaderEnum::Local + .fetch_partition(&p, max_message_size) + .await; if let Err(e) = response_sender_c.send(r).await { error!("Fail to send response event to the channel due to {}", e); } @@ -315,7 +329,9 @@ fn send_fetch_partitions( spawned_tasks.push(SpawnedTask::spawn(async move { // Block if exceeds max request number. let permit = semaphore.acquire_owned().await.unwrap(); - let r = PartitionReaderEnum::FlightRemote.fetch_partition(&p).await; + let r = PartitionReaderEnum::FlightRemote + .fetch_partition(&p, max_message_size) + .await; // Block if the channel buffer is full. if let Err(e) = response_sender.send(r).await { error!("Fail to send response event to the channel due to {}", e); @@ -339,6 +355,7 @@ trait PartitionReader: Send + Sync + Clone { async fn fetch_partition( &self, location: &PartitionLocation, + max_message_size: usize, ) -> result::Result; } @@ -356,9 +373,12 @@ impl PartitionReader for PartitionReaderEnum { async fn fetch_partition( &self, location: &PartitionLocation, + max_message_size: usize, ) -> result::Result { match self { - PartitionReaderEnum::FlightRemote => fetch_partition_remote(location).await, + PartitionReaderEnum::FlightRemote => { + fetch_partition_remote(location, max_message_size).await + } PartitionReaderEnum::Local => fetch_partition_local(location).await, PartitionReaderEnum::ObjectStoreRemote => { fetch_partition_object_store(location).await @@ -369,6 +389,7 @@ impl PartitionReader for PartitionReaderEnum { async fn fetch_partition_remote( location: &PartitionLocation, + max_message_size: usize, ) -> result::Result { let metadata = &location.executor_meta; let partition_id = &location.partition_id; @@ -376,19 +397,18 @@ async fn fetch_partition_remote( // And we should also avoid to keep alive too many connections for long time. let host = metadata.host.as_str(); let port = metadata.port; - let mut ballista_client = - BallistaClient::try_new(host, port) - .await - .map_err(|error| match error { - // map grpc connection error to partition fetch error. - BallistaError::GrpcConnectionError(msg) => BallistaError::FetchFailed( - metadata.id.clone(), - partition_id.stage_id, - partition_id.partition_id, - msg, - ), - other => other, - })?; + let mut ballista_client = BallistaClient::try_new(host, port, max_message_size) + .await + .map_err(|error| match error { + // map grpc connection error to partition fetch error. + BallistaError::GrpcConnectionError(msg) => BallistaError::FetchFailed( + metadata.id.clone(), + partition_id.stage_id, + partition_id.partition_id, + msg, + ), + other => other, + })?; ballista_client .fetch_partition(&metadata.id, partition_id, &location.path, host, port) @@ -644,7 +664,7 @@ mod tests { ); let response_receiver = - send_fetch_partitions(partition_locations, max_request_num); + send_fetch_partitions(partition_locations, max_request_num, 4 * 1024 * 1024); let stream = RecordBatchStreamAdapter::new( Arc::new(schema), diff --git a/ballista/core/src/extension.rs b/ballista/core/src/extension.rs index 182113b8f..1d6062858 100644 --- a/ballista/core/src/extension.rs +++ b/ballista/core/src/extension.rs @@ -17,7 +17,7 @@ use crate::config::{ BallistaConfig, BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE, BALLISTA_JOB_NAME, - BALLISTA_STANDALONE_PARALLELISM, + BALLISTA_SHUFFLE_READER_MAX_REQUESTS, BALLISTA_STANDALONE_PARALLELISM, }; use crate::planner::BallistaQueryPlanner; use crate::serde::protobuf::KeyValuePair; @@ -103,6 +103,15 @@ pub trait SessionConfigExt { /// Sets ballista job name fn with_ballista_job_name(self, job_name: &str) -> Self; + + /// get maximum in flight requests for shuffle reader + fn ballista_shuffle_reader_maximum_in_flight_requests(&self) -> usize; + + /// Sets maximum in flight requests for shuffle reader + fn with_ballista_shuffle_reader_maximum_in_flight_requests( + self, + max_requests: usize, + ) -> Self; } /// [SessionConfigHelperExt] is set of [SessionConfig] extension methods @@ -121,16 +130,11 @@ impl SessionStateExt for SessionState { scheduler_url: String, session_id: String, ) -> datafusion::error::Result { - let config = BallistaConfig::default(); - - let planner = - BallistaQueryPlanner::::new(scheduler_url, config.clone()); - - let session_config = SessionConfig::new() - .with_information_schema(true) - .with_option_extension(config.clone()) - // Ballista disables this option - .with_round_robin_repartition(false); + let session_config = SessionConfig::new_with_ballista(); + let planner = BallistaQueryPlanner::::new( + scheduler_url, + BallistaConfig::default(), + ); let runtime_env = RuntimeEnvBuilder::new().build()?; let session_state = SessionStateBuilder::new() @@ -191,6 +195,7 @@ impl SessionConfigExt for SessionConfig { fn new_with_ballista() -> SessionConfig { SessionConfig::new() .with_option_extension(BallistaConfig::default()) + .with_information_schema(true) .with_target_partitions(16) .with_round_robin_repartition(false) } @@ -279,6 +284,28 @@ impl SessionConfigExt for SessionConfig { .set_usize(BALLISTA_STANDALONE_PARALLELISM, parallelism) } } + + fn ballista_shuffle_reader_maximum_in_flight_requests(&self) -> usize { + self.options() + .extensions + .get::() + .map(|c| c.shuffle_reader_maximum_in_flight_requests()) + .unwrap_or_else(|| { + BallistaConfig::default().shuffle_reader_maximum_in_flight_requests() + }) + } + + fn with_ballista_shuffle_reader_maximum_in_flight_requests( + self, + max_requests: usize, + ) -> Self { + if self.options().extensions.get::().is_some() { + self.set_usize(BALLISTA_SHUFFLE_READER_MAX_REQUESTS, max_requests) + } else { + self.with_option_extension(BallistaConfig::default()) + .set_usize(BALLISTA_SHUFFLE_READER_MAX_REQUESTS, max_requests) + } + } } impl SessionConfigHelperExt for SessionConfig { diff --git a/ballista/executor/src/executor_process.rs b/ballista/executor/src/executor_process.rs index b8edc5969..b3aa629ac 100644 --- a/ballista/executor/src/executor_process.rs +++ b/ballista/executor/src/executor_process.rs @@ -365,7 +365,13 @@ pub async fn start_executor_process( service_handlers.push(match override_flight { None => { info!("Starting built-in arrow flight service"); - flight_server_task(address, shutdown).await + flight_server_task( + address, + shutdown, + opt.grpc_max_encoding_message_size as usize, + opt.grpc_max_decoding_message_size as usize, + ) + .await } Some(flight_provider) => { info!("Starting custom, user provided, arrow flight service"); @@ -471,12 +477,18 @@ pub async fn start_executor_process( async fn flight_server_task( address: SocketAddr, mut grpc_shutdown: Shutdown, + max_encoding_message_size: usize, + max_decoding_message_size: usize, ) -> JoinHandle> { tokio::spawn(async move { - info!("Built-in arrow flight server listening on: {:?}", address); + info!("Built-in arrow flight server listening on: {:?} max_encoding_size: {} max_decoding_size: {}", address, max_encoding_message_size, max_decoding_message_size); let server_future = create_grpc_server() - .add_service(FlightServiceServer::new(BallistaFlightService::new())) + .add_service( + FlightServiceServer::new(BallistaFlightService::new()) + .max_decoding_message_size(max_decoding_message_size) + .max_encoding_message_size(max_encoding_message_size), + ) .serve_with_shutdown(address, grpc_shutdown.recv()); server_future.await.map_err(|e| { diff --git a/ballista/executor/src/standalone.rs b/ballista/executor/src/standalone.rs index b16e4a3a0..38c46d028 100644 --- a/ballista/executor/src/standalone.rs +++ b/ballista/executor/src/standalone.rs @@ -31,11 +31,7 @@ use ballista_core::{ BALLISTA_VERSION, }; use ballista_core::{ConfigProducer, RuntimeProducer}; -use datafusion::execution::runtime_env::RuntimeEnvBuilder; -use datafusion::execution::SessionState; -use datafusion::prelude::SessionConfig; -use datafusion_proto::logical_plan::AsLogicalPlan; -use datafusion_proto::physical_plan::AsExecutionPlan; +use datafusion::execution::{SessionState, SessionStateBuilder}; use log::info; use std::sync::Arc; use tempfile::TempDir; @@ -55,6 +51,7 @@ pub async fn new_standalone_executor_from_state( ) -> Result<()> { let logical = session_state.config().ballista_logical_extension_codec(); let physical = session_state.config().ballista_physical_extension_codec(); + let codec: BallistaCodec< datafusion_proto::protobuf::LogicalPlanNode, datafusion_proto::protobuf::PhysicalPlanNode, @@ -63,7 +60,9 @@ pub async fn new_standalone_executor_from_state( let config = session_state .config() .clone() - .with_option_extension(BallistaConfig::default()); + .with_option_extension(BallistaConfig::default()) // TODO: do we need this statement + ; + let runtime = session_state.runtime_env().clone(); let config_producer: ConfigProducer = Arc::new(move || config.clone()); @@ -90,16 +89,16 @@ pub async fn new_standalone_executor_from_builder( ) -> Result<()> { // Let the OS assign a random, free port let listener = TcpListener::bind("localhost:0").await?; - let addr = listener.local_addr()?; + let address = listener.local_addr()?; info!( "Ballista v{} Rust Executor listening on {:?}", - BALLISTA_VERSION, addr + BALLISTA_VERSION, address ); let executor_meta = ExecutorRegistration { id: Uuid::new_v4().to_string(), // assign this executor a unique ID host: Some("localhost".to_string()), - port: addr.port() as u32, + port: address.port() as u32, // TODO Make it configurable grpc_port: 50020, specification: Some( @@ -110,6 +109,9 @@ pub async fn new_standalone_executor_from_builder( ), }; + let config = config_producer(); + let max_message_size = config.ballista_grpc_client_max_message_size(); + let work_dir = TempDir::new()? .into_path() .into_os_string() @@ -130,7 +132,10 @@ pub async fn new_standalone_executor_from_builder( )); let service = BallistaFlightService::new(); - let server = FlightServiceServer::new(service); + let server = FlightServiceServer::new(service) + .max_decoding_message_size(max_message_size) + .max_encoding_message_size(max_message_size); + tokio::spawn( create_grpc_server() .add_service(server) @@ -145,69 +150,22 @@ pub async fn new_standalone_executor_from_builder( /// Creates standalone executor with most values /// set as default. -pub async fn new_standalone_executor< - T: 'static + AsLogicalPlan, - U: 'static + AsExecutionPlan, ->( +pub async fn new_standalone_executor( scheduler: SchedulerGrpcClient, concurrent_tasks: usize, - codec: BallistaCodec, + codec: BallistaCodec, ) -> Result<()> { - // Let the OS assign a random, free port - let listener = TcpListener::bind("localhost:0").await?; - let addr = listener.local_addr()?; - info!( - "Ballista v{} Rust Executor listening on {:?}", - BALLISTA_VERSION, addr - ); - - let executor_meta = ExecutorRegistration { - id: Uuid::new_v4().to_string(), // assign this executor a unique ID - host: Some("localhost".to_string()), - port: addr.port() as u32, - // TODO Make it configurable - grpc_port: 50020, - specification: Some( - ExecutorSpecification { - task_slots: concurrent_tasks as u32, - } - .into(), - ), - }; - let work_dir = TempDir::new()? - .into_path() - .into_os_string() - .into_string() - .unwrap(); - info!("work_dir: {}", work_dir); - - let config_producer = Arc::new(default_config_producer); - let wd = work_dir.clone(); - let runtime_producer: RuntimeProducer = Arc::new(move |_: &SessionConfig| { - let runtime_env = RuntimeEnvBuilder::new() - .with_temp_file_path(wd.clone()) - .build()?; - Ok(Arc::new(runtime_env)) - }); + let session_state = SessionStateBuilder::new().with_default_features().build(); + let runtime = session_state.runtime_env().clone(); + let runtime_producer: RuntimeProducer = Arc::new(move |_| Ok(runtime.clone())); - let executor = Arc::new(Executor::new_basic( - executor_meta, - &work_dir, - runtime_producer, - config_producer, + new_standalone_executor_from_builder( + scheduler, concurrent_tasks, - )); - - let service = BallistaFlightService::new(); - let server = FlightServiceServer::new(service); - tokio::spawn( - create_grpc_server() - .add_service(server) - .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new( - listener, - )), - ); - - tokio::spawn(execution_loop::poll_loop(scheduler, executor, codec)); - Ok(()) + Arc::new(default_config_producer), + runtime_producer, + codec, + (&session_state).into(), + ) + .await } diff --git a/ballista/scheduler/src/scheduler_process.rs b/ballista/scheduler/src/scheduler_process.rs index bf6d484f0..f22fe658c 100644 --- a/ballista/scheduler/src/scheduler_process.rs +++ b/ballista/scheduler/src/scheduler_process.rs @@ -91,9 +91,15 @@ pub async fn start_server( tonic_builder.add_service(ExternalScalerServer::new(scheduler_server.clone())); #[cfg(feature = "flight-sql")] - let tonic_builder = tonic_builder.add_service(FlightServiceServer::new( - FlightSqlServiceImpl::new(scheduler_server.clone()), - )); + let tonic_builder = tonic_builder.add_service( + FlightServiceServer::new(FlightSqlServiceImpl::new(scheduler_server.clone())) + .max_encoding_message_size( + config.grpc_server_max_encoding_message_size as usize, + ) + .max_decoding_message_size( + config.grpc_server_max_decoding_message_size as usize, + ), + ); let tonic = tonic_builder.into_service().into_axum_router(); diff --git a/ballista/scheduler/src/standalone.rs b/ballista/scheduler/src/standalone.rs index e9c483456..f7d121521 100644 --- a/ballista/scheduler/src/standalone.rs +++ b/ballista/scheduler/src/standalone.rs @@ -74,8 +74,11 @@ pub async fn new_standalone_scheduler_with_builder( config_producer: ConfigProducer, codec: BallistaCodec, ) -> Result { + let config = config_producer(); + let cluster = BallistaCluster::new_memory("localhost:50050", session_builder, config_producer); + let metrics_collector = default_metrics_collector()?; let mut scheduler_server: SchedulerServer = @@ -88,7 +91,10 @@ pub async fn new_standalone_scheduler_with_builder( ); scheduler_server.init().await?; - let server = SchedulerGrpcServer::new(scheduler_server.clone()); + let server = SchedulerGrpcServer::new(scheduler_server.clone()) + .max_decoding_message_size(config.ballista_grpc_client_max_message_size()) + .max_encoding_message_size(config.ballista_grpc_client_max_message_size()); + // Let the OS assign a random, free port let listener = TcpListener::bind("localhost:0").await?; let addr = listener.local_addr()?;