Skip to content

Commit

Permalink
use configuration to propagate grpc settings
Browse files Browse the repository at this point in the history
  • Loading branch information
milenkovicm committed Feb 18, 2025
1 parent faa05af commit 7f6a99e
Show file tree
Hide file tree
Showing 9 changed files with 188 additions and 124 deletions.
41 changes: 31 additions & 10 deletions ballista/core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ const IO_RETRY_WAIT_TIME_MS: u64 = 3000;
impl BallistaClient {
/// Create a new BallistaClient to connect to the executor listening on the specified
/// host and port
pub async fn try_new(host: &str, port: u16) -> Result<Self> {
pub async fn try_new(
host: &str,
port: u16,
grpc_client_max_message_size: usize,
) -> Result<Self> {
let addr = format!("http://{host}:{port}");
debug!("BallistaClient connecting to {}", addr);
let connection =
Expand All @@ -72,8 +76,11 @@ impl BallistaClient {
"Error connecting to Ballista scheduler or executor at {addr}: {e:?}"
))
})?;
let flight_client = FlightServiceClient::new(connection);
debug!("BallistaClient connected OK");
let flight_client = FlightServiceClient::new(connection)
.max_decoding_message_size(grpc_client_max_message_size)
.max_encoding_message_size(grpc_client_max_message_size);

debug!("BallistaClient connected OK: {:?}", flight_client);

Ok(Self { flight_client })
}
Expand All @@ -99,13 +106,27 @@ impl BallistaClient {
.await
.map_err(|error| match error {
// map grpc connection error to partition fetch error.
BallistaError::GrpcActionError(msg) => BallistaError::FetchFailed(
executor_id.to_owned(),
partition_id.stage_id,
partition_id.partition_id,
msg,
),
other => other,
BallistaError::GrpcActionError(msg) => {
log::warn!(
"grpc client failed to fetch partition: {:?} , message: {:?}",
partition_id,
msg
);
BallistaError::FetchFailed(
executor_id.to_owned(),
partition_id.stage_id,
partition_id.partition_id,
msg,
)
}
error => {
log::warn!(
"grpc client failed to fetch partition: {:?} , error: {:?}",
partition_id,
error
);
error
}
})
}

Expand Down
9 changes: 9 additions & 0 deletions ballista/core/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pub const BALLISTA_STANDALONE_PARALLELISM: &str = "ballista.standalone.paralleli
/// max message size for gRPC clients
pub const BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE: &str =
"ballista.grpc_client_max_message_size";
pub const BALLISTA_SHUFFLE_READER_MAX_REQUESTS: &str = "ballista.shuffle.max_requests";

pub type ParseResult<T> = result::Result<T, String>;
use std::sync::LazyLock;
Expand All @@ -48,6 +49,10 @@ static CONFIG_ENTRIES: LazyLock<HashMap<String, ConfigEntry>> = LazyLock::new(||
"Configuration for max message size in gRPC clients".to_string(),
DataType::UInt64,
Some((16 * 1024 * 1024).to_string())),
ConfigEntry::new(BALLISTA_SHUFFLE_READER_MAX_REQUESTS.to_string(),
"Maximin concurrent requests shuffle reader can serve".to_string(),
DataType::UInt64,
Some((64).to_string())),
];
entries
.into_iter()
Expand Down Expand Up @@ -165,6 +170,10 @@ impl BallistaConfig {
self.get_usize_setting(BALLISTA_STANDALONE_PARALLELISM)
}

pub fn shuffle_reader_maximum_in_flight_requests(&self) -> usize {
self.get_usize_setting(BALLISTA_SHUFFLE_READER_MAX_REQUESTS)
}

fn get_usize_setting(&self, key: &str) -> usize {
if let Some(v) = self.settings.get(key) {
// infallible because we validate all configs in the constructor
Expand Down
17 changes: 11 additions & 6 deletions ballista/core/src/execution_plans/distributed_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,12 +310,16 @@ async fn execute_query(
break Err(DataFusionError::Execution(msg));
}
Some(job_status::Status::Successful(successful)) => {
let streams = successful.partition_location.into_iter().map(|p| {
let f = fetch_partition(p)
.map_err(|e| ArrowError::ExternalError(Box::new(e)));
let streams =
successful
.partition_location
.into_iter()
.map(move |partition| {
let f = fetch_partition(partition, max_message_size)
.map_err(|e| ArrowError::ExternalError(Box::new(e)));

futures::stream::once(f).try_flatten()
});
futures::stream::once(f).try_flatten()
});

break Ok(futures::stream::iter(streams).flatten());
}
Expand All @@ -325,6 +329,7 @@ async fn execute_query(

async fn fetch_partition(
location: PartitionLocation,
max_message_size: usize,
) -> Result<SendableRecordBatchStream> {
let metadata = location.executor_meta.ok_or_else(|| {
DataFusionError::Internal("Received empty executor metadata".to_owned())
Expand All @@ -334,7 +339,7 @@ async fn fetch_partition(
})?;
let host = metadata.host.as_str();
let port = metadata.port as u16;
let mut ballista_client = BallistaClient::try_new(host, port)
let mut ballista_client = BallistaClient::try_new(host, port, max_message_size)
.await
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
ballista_client
Expand Down
60 changes: 40 additions & 20 deletions ballista/core/src/execution_plans/shuffle_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use std::sync::Arc;
use std::task::{Context, Poll};

use crate::client::BallistaClient;
use crate::extension::SessionConfigExt;
use crate::serde::scheduler::{PartitionLocation, PartitionStats};

use datafusion::arrow::datatypes::SchemaRef;
Expand Down Expand Up @@ -146,8 +147,18 @@ impl ExecutionPlan for ShuffleReaderExec {
let task_id = context.task_id().unwrap_or_else(|| partition.to_string());
info!("ShuffleReaderExec::execute({})", task_id);

// TODO make the maximum size configurable, or make it depends on global memory control
let max_request_num = 50usize;
let config = context.session_config();

let max_request_num = config.ballista_grpc_client_max_message_size();
let max_message_size =
config.ballista_shuffle_reader_maximum_in_flight_requests();

log::debug!(
"ShuffleReaderExec::execute({}) max_request_num: {}, max_message_size: {}",
task_id,
max_request_num,
max_message_size
);
let mut partition_locations = HashMap::new();
for p in &self.partition[partition] {
partition_locations
Expand All @@ -166,7 +177,7 @@ impl ExecutionPlan for ShuffleReaderExec {
partition_locations.shuffle(&mut thread_rng());

let response_receiver =
send_fetch_partitions(partition_locations, max_request_num);
send_fetch_partitions(partition_locations, max_request_num, max_message_size);

let result = RecordBatchStreamAdapter::new(
Arc::new(self.schema.as_ref().clone()),
Expand Down Expand Up @@ -284,6 +295,7 @@ impl Stream for AbortableReceiverStream {
fn send_fetch_partitions(
partition_locations: Vec<PartitionLocation>,
max_request_num: usize,
max_message_size: usize,
) -> AbortableReceiverStream {
let (response_sender, response_receiver) = mpsc::channel(max_request_num);
let semaphore = Arc::new(Semaphore::new(max_request_num));
Expand All @@ -302,7 +314,9 @@ fn send_fetch_partitions(
let response_sender_c = response_sender.clone();
spawned_tasks.push(SpawnedTask::spawn(async move {
for p in local_locations {
let r = PartitionReaderEnum::Local.fetch_partition(&p).await;
let r = PartitionReaderEnum::Local
.fetch_partition(&p, max_message_size)
.await;
if let Err(e) = response_sender_c.send(r).await {
error!("Fail to send response event to the channel due to {}", e);
}
Expand All @@ -315,7 +329,9 @@ fn send_fetch_partitions(
spawned_tasks.push(SpawnedTask::spawn(async move {
// Block if exceeds max request number.
let permit = semaphore.acquire_owned().await.unwrap();
let r = PartitionReaderEnum::FlightRemote.fetch_partition(&p).await;
let r = PartitionReaderEnum::FlightRemote
.fetch_partition(&p, max_message_size)
.await;
// Block if the channel buffer is full.
if let Err(e) = response_sender.send(r).await {
error!("Fail to send response event to the channel due to {}", e);
Expand All @@ -339,6 +355,7 @@ trait PartitionReader: Send + Sync + Clone {
async fn fetch_partition(
&self,
location: &PartitionLocation,
max_message_size: usize,
) -> result::Result<SendableRecordBatchStream, BallistaError>;
}

Expand All @@ -356,9 +373,12 @@ impl PartitionReader for PartitionReaderEnum {
async fn fetch_partition(
&self,
location: &PartitionLocation,
max_message_size: usize,
) -> result::Result<SendableRecordBatchStream, BallistaError> {
match self {
PartitionReaderEnum::FlightRemote => fetch_partition_remote(location).await,
PartitionReaderEnum::FlightRemote => {
fetch_partition_remote(location, max_message_size).await
}
PartitionReaderEnum::Local => fetch_partition_local(location).await,
PartitionReaderEnum::ObjectStoreRemote => {
fetch_partition_object_store(location).await
Expand All @@ -369,26 +389,26 @@ impl PartitionReader for PartitionReaderEnum {

async fn fetch_partition_remote(
location: &PartitionLocation,
max_message_size: usize,
) -> result::Result<SendableRecordBatchStream, BallistaError> {
let metadata = &location.executor_meta;
let partition_id = &location.partition_id;
// TODO for shuffle client connections, we should avoid creating new connections again and again.
// And we should also avoid to keep alive too many connections for long time.
let host = metadata.host.as_str();
let port = metadata.port;
let mut ballista_client =
BallistaClient::try_new(host, port)
.await
.map_err(|error| match error {
// map grpc connection error to partition fetch error.
BallistaError::GrpcConnectionError(msg) => BallistaError::FetchFailed(
metadata.id.clone(),
partition_id.stage_id,
partition_id.partition_id,
msg,
),
other => other,
})?;
let mut ballista_client = BallistaClient::try_new(host, port, max_message_size)
.await
.map_err(|error| match error {
// map grpc connection error to partition fetch error.
BallistaError::GrpcConnectionError(msg) => BallistaError::FetchFailed(
metadata.id.clone(),
partition_id.stage_id,
partition_id.partition_id,
msg,
),
other => other,
})?;

ballista_client
.fetch_partition(&metadata.id, partition_id, &location.path, host, port)
Expand Down Expand Up @@ -644,7 +664,7 @@ mod tests {
);

let response_receiver =
send_fetch_partitions(partition_locations, max_request_num);
send_fetch_partitions(partition_locations, max_request_num, 4 * 1024 * 1024);

let stream = RecordBatchStreamAdapter::new(
Arc::new(schema),
Expand Down
49 changes: 38 additions & 11 deletions ballista/core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use crate::config::{
BallistaConfig, BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE, BALLISTA_JOB_NAME,
BALLISTA_STANDALONE_PARALLELISM,
BALLISTA_SHUFFLE_READER_MAX_REQUESTS, BALLISTA_STANDALONE_PARALLELISM,
};
use crate::planner::BallistaQueryPlanner;
use crate::serde::protobuf::KeyValuePair;
Expand Down Expand Up @@ -103,6 +103,15 @@ pub trait SessionConfigExt {

/// Sets ballista job name
fn with_ballista_job_name(self, job_name: &str) -> Self;

/// get maximum in flight requests for shuffle reader
fn ballista_shuffle_reader_maximum_in_flight_requests(&self) -> usize;

/// Sets maximum in flight requests for shuffle reader
fn with_ballista_shuffle_reader_maximum_in_flight_requests(
self,
max_requests: usize,
) -> Self;
}

/// [SessionConfigHelperExt] is set of [SessionConfig] extension methods
Expand All @@ -121,16 +130,11 @@ impl SessionStateExt for SessionState {
scheduler_url: String,
session_id: String,
) -> datafusion::error::Result<SessionState> {
let config = BallistaConfig::default();

let planner =
BallistaQueryPlanner::<LogicalPlanNode>::new(scheduler_url, config.clone());

let session_config = SessionConfig::new()
.with_information_schema(true)
.with_option_extension(config.clone())
// Ballista disables this option
.with_round_robin_repartition(false);
let session_config = SessionConfig::new_with_ballista();
let planner = BallistaQueryPlanner::<LogicalPlanNode>::new(
scheduler_url,
BallistaConfig::default(),
);

let runtime_env = RuntimeEnvBuilder::new().build()?;
let session_state = SessionStateBuilder::new()
Expand Down Expand Up @@ -191,6 +195,7 @@ impl SessionConfigExt for SessionConfig {
fn new_with_ballista() -> SessionConfig {
SessionConfig::new()
.with_option_extension(BallistaConfig::default())
.with_information_schema(true)
.with_target_partitions(16)
.with_round_robin_repartition(false)
}
Expand Down Expand Up @@ -279,6 +284,28 @@ impl SessionConfigExt for SessionConfig {
.set_usize(BALLISTA_STANDALONE_PARALLELISM, parallelism)
}
}

fn ballista_shuffle_reader_maximum_in_flight_requests(&self) -> usize {
self.options()
.extensions
.get::<BallistaConfig>()
.map(|c| c.shuffle_reader_maximum_in_flight_requests())
.unwrap_or_else(|| {
BallistaConfig::default().shuffle_reader_maximum_in_flight_requests()
})
}

fn with_ballista_shuffle_reader_maximum_in_flight_requests(
self,
max_requests: usize,
) -> Self {
if self.options().extensions.get::<BallistaConfig>().is_some() {
self.set_usize(BALLISTA_SHUFFLE_READER_MAX_REQUESTS, max_requests)
} else {
self.with_option_extension(BallistaConfig::default())
.set_usize(BALLISTA_SHUFFLE_READER_MAX_REQUESTS, max_requests)
}
}
}

impl SessionConfigHelperExt for SessionConfig {
Expand Down
18 changes: 15 additions & 3 deletions ballista/executor/src/executor_process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,13 @@ pub async fn start_executor_process(
service_handlers.push(match override_flight {
None => {
info!("Starting built-in arrow flight service");
flight_server_task(address, shutdown).await
flight_server_task(
address,
shutdown,
opt.grpc_max_encoding_message_size as usize,
opt.grpc_max_decoding_message_size as usize,
)
.await
}
Some(flight_provider) => {
info!("Starting custom, user provided, arrow flight service");
Expand Down Expand Up @@ -471,12 +477,18 @@ pub async fn start_executor_process(
async fn flight_server_task(
address: SocketAddr,
mut grpc_shutdown: Shutdown,
max_encoding_message_size: usize,
max_decoding_message_size: usize,
) -> JoinHandle<Result<(), BallistaError>> {
tokio::spawn(async move {
info!("Built-in arrow flight server listening on: {:?}", address);
info!("Built-in arrow flight server listening on: {:?} max_encoding_size: {} max_decoding_size: {}", address, max_encoding_message_size, max_decoding_message_size);

let server_future = create_grpc_server()
.add_service(FlightServiceServer::new(BallistaFlightService::new()))
.add_service(
FlightServiceServer::new(BallistaFlightService::new())
.max_decoding_message_size(max_decoding_message_size)
.max_encoding_message_size(max_encoding_message_size),
)
.serve_with_shutdown(address, grpc_shutdown.recv());

server_future.await.map_err(|e| {
Expand Down
Loading

0 comments on commit 7f6a99e

Please sign in to comment.