Skip to content

Commit

Permalink
expose RuntimeProducer and SessionConfigProducer
Browse files Browse the repository at this point in the history
so executors can configure runtime per task,
and session config they have
  • Loading branch information
milenkovicm committed Oct 27, 2024
1 parent 592bf69 commit 8194b5a
Show file tree
Hide file tree
Showing 9 changed files with 144 additions and 76 deletions.
10 changes: 10 additions & 0 deletions ballista/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
// under the License.

#![doc = include_str!("../README.md")]

use std::sync::Arc;

use datafusion::{execution::runtime_env::RuntimeEnv, prelude::SessionConfig};
pub const BALLISTA_VERSION: &str = env!("CARGO_PKG_VERSION");

pub fn print_version() {
Expand All @@ -33,3 +37,9 @@ pub mod utils;

#[macro_use]
pub mod serde;

pub type RuntimeProducer = Arc<
dyn Fn(&SessionConfig) -> datafusion::error::Result<Arc<RuntimeEnv>> + Send + Sync,
>;

pub type ConfigProducer = Arc<dyn Fn() -> SessionConfig + Send + Sync>;
34 changes: 22 additions & 12 deletions ballista/core/src/serde/scheduler/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

use chrono::{TimeZone, Utc};
use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion::execution::runtime_env::RuntimeEnv;

use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF};
use datafusion::physical_plan::metrics::{
Count, Gauge, MetricValue, MetricsSet, Time, Timestamp,
};
use datafusion::physical_plan::{ExecutionPlan, Metric};
use datafusion::prelude::SessionConfig;
use datafusion_proto::logical_plan::AsLogicalPlan;
use datafusion_proto::physical_plan::AsExecutionPlan;
use std::collections::HashMap;
Expand All @@ -37,6 +38,7 @@ use crate::serde::scheduler::{
};

use crate::serde::{protobuf, BallistaCodec};
use crate::RuntimeProducer;
use protobuf::{operator_metric, NamedCount, NamedGauge, NamedTime};

impl TryInto<Action> for protobuf::Action {
Expand Down Expand Up @@ -281,17 +283,18 @@ impl Into<ExecutorData> for protobuf::ExecutorData {

pub fn get_task_definition<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>(
task: protobuf::TaskDefinition,
runtime: Arc<RuntimeEnv>,
produce_runtime: RuntimeProducer,
session_config: SessionConfig,
//runtime: Arc<RuntimeEnv>,
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
window_functions: HashMap<String, Arc<WindowUDF>>,
codec: BallistaCodec<T, U>,
) -> Result<TaskDefinition, BallistaError> {
let mut props = HashMap::new();
let mut session_config = session_config;
for kv_pair in task.props {
props.insert(kv_pair.key, kv_pair.value);
session_config = session_config.set_str(&kv_pair.key, &kv_pair.value);
}
let props = Arc::new(props);

let mut task_scalar_functions = HashMap::new();
let mut task_aggregate_functions = HashMap::new();
Expand All @@ -311,7 +314,7 @@ pub fn get_task_definition<T: 'static + AsLogicalPlan, U: 'static + AsExecutionP
aggregate_functions: task_aggregate_functions,
window_functions: task_window_functions,
});

let runtime = produce_runtime(&session_config)?;
let encoded_plan = task.plan.as_slice();
let plan: Arc<dyn ExecutionPlan> = U::try_decode(encoded_plan).and_then(|proto| {
proto.try_into_physical_plan(
Expand Down Expand Up @@ -340,7 +343,7 @@ pub fn get_task_definition<T: 'static + AsLogicalPlan, U: 'static + AsExecutionP
plan,
launch_time,
session_id,
props,
session_config,
function_registry,
})
}
Expand All @@ -350,17 +353,22 @@ pub fn get_task_definition_vec<
U: 'static + AsExecutionPlan,
>(
multi_task: protobuf::MultiTaskDefinition,
runtime: Arc<RuntimeEnv>,
runtime_producer: RuntimeProducer,
session_config: SessionConfig,
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
window_functions: HashMap<String, Arc<WindowUDF>>,
codec: BallistaCodec<T, U>,
) -> Result<Vec<TaskDefinition>, BallistaError> {
let mut props = HashMap::new();
//let mut props = HashMap::new();
// for kv_pair in multi_task.props {
// props.insert(kv_pair.key, kv_pair.value);
// }
// let props = Arc::new(props);
let mut session_config = session_config;
for kv_pair in multi_task.props {
props.insert(kv_pair.key, kv_pair.value);
session_config = session_config.set_str(&kv_pair.key, &kv_pair.value);
}
let props = Arc::new(props);

let mut task_scalar_functions = HashMap::new();
let mut task_aggregate_functions = HashMap::new();
Expand All @@ -381,6 +389,8 @@ pub fn get_task_definition_vec<
window_functions: task_window_functions,
});

let runtime = runtime_producer(&session_config)?;

let encoded_plan = multi_task.plan.as_slice();
let plan: Arc<dyn ExecutionPlan> = U::try_decode(encoded_plan).and_then(|proto| {
proto.try_into_physical_plan(
Expand Down Expand Up @@ -410,7 +420,7 @@ pub fn get_task_definition_vec<
plan: reset_metrics_for_execution_plan(plan.clone())?,
launch_time,
session_id: session_id.clone(),
props: props.clone(),
session_config: session_config.clone(),
function_registry: function_registry.clone(),
})
})
Expand Down
4 changes: 3 additions & 1 deletion ballista/core/src/serde/scheduler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use datafusion::logical_expr::planner::ExprPlanner;
use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF};
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::Partitioning;
use datafusion::prelude::SessionConfig;
use serde::Serialize;

use crate::error::BallistaError;
Expand Down Expand Up @@ -288,7 +289,8 @@ pub struct TaskDefinition {
pub plan: Arc<dyn ExecutionPlan>,
pub launch_time: u64,
pub session_id: String,
pub props: Arc<HashMap<String, String>>,
pub session_config: SessionConfig,
//pub props: Arc<HashMap<String, String>>,
pub function_registry: Arc<SimpleFunctionRegistry>,
}

Expand Down
2 changes: 2 additions & 0 deletions ballista/executor/src/bin/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ async fn main() -> Result<()> {
cache_io_concurrency: opt.cache_io_concurrency,
execution_engine: None,
session_state: None,
config_producer: None,
runtime_producer: None,
};

start_executor_process(Arc::new(config)).await
Expand Down
11 changes: 5 additions & 6 deletions ballista/executor/src/execution_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ use ballista_core::serde::scheduler::{ExecutorSpecification, PartitionId};
use ballista_core::serde::BallistaCodec;
use datafusion::execution::context::TaskContext;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::SessionConfig;
use datafusion_proto::logical_plan::AsLogicalPlan;
use datafusion_proto::physical_plan::AsExecutionPlan;
use futures::FutureExt;
Expand All @@ -46,7 +45,7 @@ pub async fn poll_loop<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>
mut scheduler: SchedulerGrpcClient<Channel>,
executor: Arc<Executor>,
codec: BallistaCodec<T, U>,
session_config: SessionConfig,
//session_config: SessionConfig,
) -> Result<(), BallistaError> {
let executor_specification: ExecutorSpecification = executor
.metadata
Expand Down Expand Up @@ -108,7 +107,7 @@ pub async fn poll_loop<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>
task,
&codec,
&dedicated_executor,
session_config.clone(),
// session_config.clone(),
)
.await
{
Expand Down Expand Up @@ -150,7 +149,7 @@ async fn run_received_task<T: 'static + AsLogicalPlan, U: 'static + AsExecutionP
task: TaskDefinition,
codec: &BallistaCodec<T, U>,
dedicated_executor: &DedicatedExecutor,
session_config: SessionConfig,
// session_config: SessionConfig,
) -> Result<(), BallistaError> {
let task_id = task.task_id;
let task_attempt_num = task.task_attempt_num;
Expand All @@ -170,7 +169,7 @@ async fn run_received_task<T: 'static + AsLogicalPlan, U: 'static + AsExecutionP
"Received task: {}, task_properties: {:?}",
task_identity, task.props
);
let mut session_config = session_config;
let mut session_config = executor.produce_config();
for kv_pair in task.props {
session_config = session_config.set_str(&kv_pair.key, &kv_pair.value);
}
Expand All @@ -179,7 +178,7 @@ async fn run_received_task<T: 'static + AsLogicalPlan, U: 'static + AsExecutionP
let task_aggregate_functions = executor.aggregate_functions.clone();
let task_window_functions = executor.window_functions.clone();

let runtime = executor.get_runtime();
let runtime = executor.produce_runtime(&session_config)?;
let session_id = task.session_id.clone();
let task_context = Arc::new(TaskContext::new(
Some(task_identity.clone()),
Expand Down
53 changes: 39 additions & 14 deletions ballista/executor/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ use ballista_core::error::BallistaError;
use ballista_core::serde::protobuf;
use ballista_core::serde::protobuf::ExecutorRegistration;
use ballista_core::serde::scheduler::PartitionId;
use ballista_core::ConfigProducer;
use ballista_core::RuntimeProducer;
use dashmap::DashMap;
use datafusion::execution::context::TaskContext;
use datafusion::execution::runtime_env::RuntimeEnv;
Expand All @@ -33,6 +35,7 @@ use datafusion::functions::all_default_functions;
use datafusion::functions_aggregate::all_default_aggregate_functions;
use datafusion::functions_window::all_default_window_functions;
use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF};
use datafusion::prelude::SessionConfig;
use futures::future::AbortHandle;
use std::collections::HashMap;
use std::future::Future;
Expand Down Expand Up @@ -74,8 +77,9 @@ pub struct Executor {
/// Window functions registered in the Executor
pub window_functions: HashMap<String, Arc<WindowUDF>>,

/// Runtime environment for Executor
runtime: Arc<RuntimeEnv>,
pub runtime_producer: RuntimeProducer,

pub config_producer: ConfigProducer,

/// Collector for runtime execution metrics
pub metrics_collector: Arc<dyn ExecutorMetricsCollector>,
Expand All @@ -97,7 +101,8 @@ impl Executor {
pub fn new_from_runtime(
metadata: ExecutorRegistration,
work_dir: &str,
runtime: Arc<RuntimeEnv>,
runtime_producer: RuntimeProducer,
config_producer: ConfigProducer,
metrics_collector: Arc<dyn ExecutorMetricsCollector>,
concurrent_tasks: usize,
execution_engine: Option<Arc<dyn ExecutionEngine>>,
Expand All @@ -120,7 +125,8 @@ impl Executor {
Self::new(
metadata,
work_dir,
runtime,
runtime_producer,
config_producer,
scalar_functions,
aggregate_functions,
window_functions,
Expand All @@ -136,7 +142,8 @@ impl Executor {
fn new(
metadata: ExecutorRegistration,
work_dir: &str,
runtime: Arc<RuntimeEnv>,
runtime_producer: RuntimeProducer,
config_producer: ConfigProducer,
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
window_functions: HashMap<String, Arc<WindowUDF>>,
Expand All @@ -150,7 +157,8 @@ impl Executor {
scalar_functions,
aggregate_functions,
window_functions,
runtime,
runtime_producer,
config_producer,
metrics_collector,
concurrent_tasks,
abort_handles: Default::default(),
Expand All @@ -163,20 +171,22 @@ impl Executor {
pub fn new_from_state(
metadata: ExecutorRegistration,
work_dir: &str,
state: &SessionState,
runtime_producer: RuntimeProducer,
config_producer: ConfigProducer,
state: &SessionState, // TODO MM narrow state down
metrics_collector: Arc<dyn ExecutorMetricsCollector>,
concurrent_tasks: usize,
execution_engine: Option<Arc<dyn ExecutionEngine>>,
) -> Self {
let scalar_functions = state.scalar_functions().clone();
let aggregate_functions = state.aggregate_functions().clone();
let window_functions = state.window_functions().clone();
let runtime = state.runtime_env().clone();

Self::new(
metadata,
work_dir,
runtime,
runtime_producer,
config_producer,
scalar_functions,
aggregate_functions,
window_functions,
Expand All @@ -188,8 +198,15 @@ impl Executor {
}

impl Executor {
pub fn get_runtime(&self) -> Arc<RuntimeEnv> {
self.runtime.clone()
pub fn produce_runtime(
&self,
config: &SessionConfig,
) -> datafusion::error::Result<Arc<RuntimeEnv>> {
(self.runtime_producer)(config)
}

pub fn produce_config(&self) -> SessionConfig {
(self.config_producer)()
}

/// Execute one partition of a query stage and persist the result to disk in IPC format. On
Expand Down Expand Up @@ -261,17 +278,19 @@ mod test {
use crate::metrics::LoggingMetricsCollector;
use arrow::datatypes::{Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use ballista_core::config::BallistaConfig;
use ballista_core::execution_plans::ShuffleWriterExec;
use ballista_core::serde::protobuf::ExecutorRegistration;
use ballista_core::serde::scheduler::PartitionId;
use ballista_core::RuntimeProducer;
use datafusion::error::{DataFusionError, Result};
use datafusion::execution::context::TaskContext;

use datafusion::physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
RecordBatchStream, SendableRecordBatchStream, Statistics,
};
use datafusion::prelude::SessionContext;
use datafusion::prelude::{SessionConfig, SessionContext};
use futures::Stream;
use std::any::Any;
use std::pin::Pin;
Expand Down Expand Up @@ -402,13 +421,19 @@ mod test {
specification: None,
optional_host: None,
};

let config_producer = Arc::new(|| {
SessionConfig::new().with_option_extension(BallistaConfig::new().unwrap())
});
let ctx = SessionContext::new();
let runtime_env = ctx.runtime_env().clone();
let runtime_producer: RuntimeProducer =
Arc::new(move |_| Ok(runtime_env.clone()));

let executor = Executor::new_from_runtime(
executor_registration,
&work_dir,
ctx.runtime_env(),
runtime_producer,
config_producer,
Arc::new(LoggingMetricsCollector {}),
2,
None,
Expand Down
Loading

0 comments on commit 8194b5a

Please sign in to comment.