Skip to content

Commit

Permalink
Executor configuration accepts SessionState ..
Browse files Browse the repository at this point in the history
... this way we can configure way more options
  • Loading branch information
milenkovicm committed Oct 25, 2024
1 parent 92ce301 commit 647f426
Show file tree
Hide file tree
Showing 9 changed files with 273 additions and 88 deletions.
2 changes: 2 additions & 0 deletions ballista/client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ tokio = { workspace = true }
url = { version = "2.5" }

[dev-dependencies]
ballista-executor = { path = "../executor", version = "0.12.0" }
ballista-scheduler = { path = "../scheduler", version = "0.12.0" }
ctor = { version = "0.2" }
env_logger = { workspace = true }

Expand Down
40 changes: 40 additions & 0 deletions ballista/client/tests/remote.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,44 @@ mod remote {

Ok(())
}

#[tokio::test]
async fn should_execute_sql_app_name_show() -> datafusion::error::Result<()> {
let (host, port) = crate::common::setup_test_cluster().await;
let url = format!("df://{host}:{port}");

let test_data = crate::common::example_test_data();
let ctx: SessionContext = SessionContext::remote(&url).await?;

ctx.sql("SET ballista.job.name = 'Super Cool Ballista App'")
.await?
.show()
.await?;

ctx.register_parquet(
"test",
&format!("{test_data}/alltypes_plain.parquet"),
Default::default(),
)
.await?;

let result = ctx
.sql("select string_col, timestamp_col from test where id > 4")
.await?
.collect()
.await?;
let expected = [
"+------------+---------------------+",
"| string_col | timestamp_col |",
"+------------+---------------------+",
"| 31 | 2009-03-01T00:01:00 |",
"| 30 | 2009-04-01T00:00:00 |",
"| 31 | 2009-04-01T00:01:00 |",
"+------------+---------------------+",
];

assert_batches_eq!(expected, &result);

Ok(())
}
}
25 changes: 15 additions & 10 deletions ballista/core/src/execution_plans/distributed_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for DistributedQueryExec<T> {
fn execute(
&self,
partition: usize,
_context: Arc<TaskContext>,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
assert_eq!(0, partition);

Expand All @@ -210,17 +210,22 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for DistributedQueryExec<T> {
DataFusionError::Execution(format!("failed to encode logical plan: {e:?}"))
})?;

let settings = context
.session_config()
.options()
.entries()
.iter()
.map(
|datafusion::config::ConfigEntry { key, value, .. }| KeyValuePair {
key: key.to_owned(),
value: value.clone().unwrap_or_else(|| String::from("")),
},
)
.collect();

let query = ExecuteQueryParams {
query: Some(Query::LogicalPlan(buf)),
settings: self
.config
.settings()
.iter()
.map(|(k, v)| KeyValuePair {
key: k.to_owned(),
value: v.to_owned(),
})
.collect::<Vec<_>>(),
settings,
optional_session_id: Some(OptionalSessionId::SessionId(
self.session_id.clone(),
)),
Expand Down
1 change: 1 addition & 0 deletions ballista/executor/src/bin/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ async fn main() -> Result<()> {
cache_capacity: opt.cache_capacity,
cache_io_concurrency: opt.cache_io_concurrency,
execution_engine: None,
session_state: None,
};

start_executor_process(Arc::new(config)).await
Expand Down
67 changes: 19 additions & 48 deletions ballista/executor/src/execution_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,46 +15,38 @@
// specific language governing permissions and limitations
// under the License.

use datafusion::config::ConfigOptions;
use datafusion::physical_plan::ExecutionPlan;

use ballista_core::serde::protobuf::{
scheduler_grpc_client::SchedulerGrpcClient, PollWorkParams, PollWorkResult,
TaskDefinition, TaskStatus,
};
use datafusion::prelude::SessionConfig;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};

use crate::cpu_bound_executor::DedicatedExecutor;
use crate::executor::Executor;
use crate::{as_task_status, TaskExecutionTimes};
use ballista_core::error::BallistaError;
use ballista_core::serde::protobuf::{
scheduler_grpc_client::SchedulerGrpcClient, PollWorkParams, PollWorkResult,
TaskDefinition, TaskStatus,
};
use ballista_core::serde::scheduler::{ExecutorSpecification, PartitionId};
use ballista_core::serde::BallistaCodec;
use datafusion::execution::context::TaskContext;
use datafusion::functions::datetime::date_part;
use datafusion::functions::unicode::substr;
use datafusion::functions_aggregate::covariance::{covar_pop_udaf, covar_samp_udaf};
use datafusion::functions_aggregate::sum::sum_udaf;
use datafusion::functions_aggregate::variance::var_samp_udaf;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::SessionConfig;
use datafusion_proto::logical_plan::AsLogicalPlan;
use datafusion_proto::physical_plan::AsExecutionPlan;
use futures::FutureExt;
use log::{debug, error, info, warn};
use std::any::Any;
use std::collections::HashMap;
use std::convert::TryInto;
use std::error::Error;
use std::ops::Deref;
use std::sync::mpsc::{Receiver, Sender, TryRecvError};
use std::time::{SystemTime, UNIX_EPOCH};
use std::{sync::Arc, time::Duration};
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tonic::transport::Channel;

pub async fn poll_loop<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>(
mut scheduler: SchedulerGrpcClient<Channel>,
executor: Arc<Executor>,
codec: BallistaCodec<T, U>,
session_config: SessionConfig,
) -> Result<(), BallistaError> {
let executor_specification: ExecutorSpecification = executor
.metadata
Expand Down Expand Up @@ -116,6 +108,7 @@ pub async fn poll_loop<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>
task,
&codec,
&dedicated_executor,
session_config.clone(),
)
.await
{
Expand Down Expand Up @@ -157,6 +150,7 @@ async fn run_received_task<T: 'static + AsLogicalPlan, U: 'static + AsExecutionP
task: TaskDefinition,
codec: &BallistaCodec<T, U>,
dedicated_executor: &DedicatedExecutor,
session_config: SessionConfig,
) -> Result<(), BallistaError> {
let task_id = task.task_id;
let task_attempt_num = task.task_attempt_num;
Expand All @@ -172,42 +166,19 @@ async fn run_received_task<T: 'static + AsLogicalPlan, U: 'static + AsExecutionP
let task_identity = format!(
"TID {task_id} {job_id}/{stage_id}.{stage_attempt_num}/{partition_id}.{task_attempt_num}"
);
info!("Received task {}", task_identity);

let mut task_props = HashMap::new();
info!(
"Received task: {}, task_properties: {:?}",
task_identity, task.props
);
let mut session_config = session_config;
for kv_pair in task.props {
task_props.insert(kv_pair.key, kv_pair.value);
session_config = session_config.set_str(&kv_pair.key, &kv_pair.value);
}
let mut config = ConfigOptions::new();
for (k, v) in task_props {
config.set(&k, &v)?;
}
let session_config = SessionConfig::from(config);

let mut task_scalar_functions = HashMap::new();
let mut task_aggregate_functions = HashMap::new();
let mut task_window_functions = HashMap::new();
// TODO combine the functions from Executor's functions and TaskDefintion's function resources
for scalar_func in executor.scalar_functions.clone() {
task_scalar_functions.insert(scalar_func.0.clone(), scalar_func.1);
}
for agg_func in executor.aggregate_functions.clone() {
task_aggregate_functions.insert(agg_func.0, agg_func.1);
}
// since DataFusion 38 some internal functions were converted to UDAF, so
// we have to register them manually
task_aggregate_functions.insert("var".to_string(), var_samp_udaf());
task_aggregate_functions.insert("covar_samp".to_string(), covar_samp_udaf());
task_aggregate_functions.insert("covar_pop".to_string(), covar_pop_udaf());
task_aggregate_functions.insert("SUM".to_string(), sum_udaf());
let task_scalar_functions = executor.scalar_functions.clone();
let task_aggregate_functions = executor.aggregate_functions.clone();
let task_window_functions = executor.window_functions.clone();

// TODO which other functions need adding here?
task_scalar_functions.insert("date_part".to_string(), date_part());
task_scalar_functions.insert("substr".to_string(), substr());

for window_func in executor.window_functions.clone() {
task_window_functions.insert(window_func.0, window_func.1);
}
let runtime = executor.get_runtime();
let session_id = task.session_id.clone();
let task_context = Arc::new(TaskContext::new(
Expand Down
71 changes: 66 additions & 5 deletions ballista/executor/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ use ballista_core::serde::scheduler::PartitionId;
use dashmap::DashMap;
use datafusion::execution::context::TaskContext;
use datafusion::execution::runtime_env::RuntimeEnv;
use datafusion::execution::SessionState;
use datafusion::functions::all_default_functions;
use datafusion::functions_aggregate::all_default_aggregate_functions;
use datafusion::functions_window::all_default_window_functions;
use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF};
use futures::future::AbortHandle;
use std::collections::HashMap;
Expand Down Expand Up @@ -90,8 +92,9 @@ pub struct Executor {
}

impl Executor {
/// Create a new executor instance
pub fn new(
/// Create a new executor instance with given [RuntimeEnv]
/// It will use default scalar, aggregate and window functions
pub fn new_from_runtime(
metadata: ExecutorRegistration,
work_dir: &str,
runtime: Arc<RuntimeEnv>,
Expand All @@ -109,13 +112,44 @@ impl Executor {
.map(|f| (f.name().to_string(), f))
.collect();

let window_functions = all_default_window_functions()
.into_iter()
.map(|f| (f.name().to_string(), f))
.collect();

Self::new(
metadata,
work_dir,
runtime,
scalar_functions,
aggregate_functions,
window_functions,
metrics_collector,
concurrent_tasks,
execution_engine,
)
}

/// Create a new executor instance with given [RuntimeEnv],
/// [ScalarUDF], [AggregateUDF] and [WindowUDF]
#[allow(clippy::too_many_arguments)]
fn new(
metadata: ExecutorRegistration,
work_dir: &str,
runtime: Arc<RuntimeEnv>,
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
window_functions: HashMap<String, Arc<WindowUDF>>,
metrics_collector: Arc<dyn ExecutorMetricsCollector>,
concurrent_tasks: usize,
execution_engine: Option<Arc<dyn ExecutionEngine>>,
) -> Self {
Self {
metadata,
work_dir: work_dir.to_owned(),
scalar_functions,
aggregate_functions,
// TODO: set to default window functions when they are moved to udwf
window_functions: HashMap::new(),
window_functions,
runtime,
metrics_collector,
concurrent_tasks,
Expand All @@ -124,6 +158,33 @@ impl Executor {
.unwrap_or_else(|| Arc::new(DefaultExecutionEngine {})),
}
}
/// Create a new executor instance from [SessionState].
/// [ScalarUDF], [AggregateUDF] and [WindowUDF]
pub fn new_from_state(
metadata: ExecutorRegistration,
work_dir: &str,
state: &SessionState,
metrics_collector: Arc<dyn ExecutorMetricsCollector>,
concurrent_tasks: usize,
execution_engine: Option<Arc<dyn ExecutionEngine>>,
) -> Self {
let scalar_functions = state.scalar_functions().clone();
let aggregate_functions = state.aggregate_functions().clone();
let window_functions = state.window_functions().clone();
let runtime = state.runtime_env().clone();

Self::new(
metadata,
work_dir,
runtime,
scalar_functions,
aggregate_functions,
window_functions,
metrics_collector,
concurrent_tasks,
execution_engine,
)
}
}

impl Executor {
Expand Down Expand Up @@ -344,7 +405,7 @@ mod test {

let ctx = SessionContext::new();

let executor = Executor::new(
let executor = Executor::new_from_runtime(
executor_registration,
&work_dir,
ctx.runtime_env(),
Expand Down
Loading

0 comments on commit 647f426

Please sign in to comment.