diff --git a/ballista/client/Cargo.toml b/ballista/client/Cargo.toml index da61dab94..038c62c82 100644 --- a/ballista/client/Cargo.toml +++ b/ballista/client/Cargo.toml @@ -47,9 +47,12 @@ ballista-executor = { path = "../executor", version = "0.12.0" } ballista-scheduler = { path = "../scheduler", version = "0.12.0" } ctor = { version = "0.2" } env_logger = { workspace = true } +object_store = { workspace = true, features = ["aws"] } +testcontainers-modules = { version = "0.11", features = ["minio"] } [features] azure = ["ballista-core/azure"] default = [] s3 = ["ballista-core/s3"] standalone = ["ballista-executor", "ballista-scheduler"] +testcontainers = [] diff --git a/ballista/client/src/extension.rs b/ballista/client/src/extension.rs index 99c8a88fc..38931e280 100644 --- a/ballista/client/src/extension.rs +++ b/ballista/client/src/extension.rs @@ -15,13 +15,13 @@ // specific language governing permissions and limitations // under the License. -pub use ballista_core::utils::BallistaSessionConfigExt; +pub use ballista_core::utils::SessionConfigExt; use ballista_core::{ config::BallistaConfig, serde::protobuf::{ scheduler_grpc_client::SchedulerGrpcClient, CreateSessionParams, KeyValuePair, }, - utils::{create_grpc_client_connection, BallistaSessionStateExt}, + utils::{create_grpc_client_connection, SessionStateExt}, }; use datafusion::{ error::DataFusionError, execution::SessionState, prelude::SessionContext, @@ -65,6 +65,7 @@ const DEFAULT_SCHEDULER_PORT: u16 = 50050; /// There are still few limitations on query distribution, thus not all /// [SessionContext] functionalities are supported. /// + #[async_trait::async_trait] pub trait SessionContextExt { /// Creates a context for executing queries against a standalone Ballista scheduler instance @@ -144,14 +145,8 @@ impl SessionContextExt for SessionContext { ) -> datafusion::error::Result { let config = state.ballista_config(); - let codec_logical = state.config().ballista_logical_extension_codec(); - let codec_physical = state.config().ballista_physical_extension_codec(); - - let ballista_codec = - ballista_core::serde::BallistaCodec::new(codec_logical, codec_physical); - let (remote_session_id, scheduler_url) = - Extension::setup_standalone(config, ballista_codec).await?; + Extension::setup_standalone(config, Some(&state)).await?; let session_state = state.upgrade_for_ballista(scheduler_url, remote_session_id.clone())?; @@ -170,10 +165,8 @@ impl SessionContextExt for SessionContext { let config = BallistaConfig::new() .map_err(|e| DataFusionError::Configuration(e.to_string()))?; - let ballista_codec = ballista_core::serde::BallistaCodec::default(); - let (remote_session_id, scheduler_url) = - Extension::setup_standalone(config, ballista_codec).await?; + Extension::setup_standalone(config, None).await?; let session_state = SessionState::new_ballista_state(scheduler_url, remote_session_id.clone())?; @@ -205,14 +198,22 @@ impl Extension { #[cfg(feature = "standalone")] async fn setup_standalone( config: BallistaConfig, - ballista_codec: ballista_core::serde::BallistaCodec< - datafusion_proto::protobuf::LogicalPlanNode, - datafusion_proto::protobuf::PhysicalPlanNode, - >, + session_state: Option<&SessionState>, ) -> datafusion::error::Result<(String, String)> { - let addr = ballista_scheduler::standalone::new_standalone_scheduler() - .await - .map_err(|e| DataFusionError::Configuration(e.to_string()))?; + use ballista_core::serde::BallistaCodec; + + let addr = match session_state { + None => ballista_scheduler::standalone::new_standalone_scheduler() + .await + .map_err(|e| DataFusionError::Configuration(e.to_string()))?, + Some(session_state) => { + ballista_scheduler::standalone::new_standalone_scheduler_from_state( + session_state, + ) + .await + .map_err(|e| DataFusionError::Configuration(e.to_string()))? + } + }; let scheduler_url = format!("http://localhost:{}", addr.port()); @@ -243,13 +244,26 @@ impl Extension { .session_id; let concurrent_tasks = config.default_standalone_parallelism(); - ballista_executor::new_standalone_executor( - scheduler, - concurrent_tasks, - ballista_codec, - ) - .await - .map_err(|e| DataFusionError::Configuration(e.to_string()))?; + + match session_state { + None => { + ballista_executor::new_standalone_executor( + scheduler, + concurrent_tasks, + BallistaCodec::default(), + ) + .await + .map_err(|e| DataFusionError::Configuration(e.to_string()))?; + } + Some(session_state) => { + ballista_executor::new_standalone_executor_from_state::< + datafusion_proto::protobuf::LogicalPlanNode, + datafusion_proto::protobuf::PhysicalPlanNode, + >(scheduler, concurrent_tasks, session_state) + .await + .map_err(|e| DataFusionError::Configuration(e.to_string()))?; + } + } Ok((remote_session_id, scheduler_url)) } diff --git a/ballista/client/tests/common/mod.rs b/ballista/client/tests/common/mod.rs index 02f25d7be..afc32aeaa 100644 --- a/ballista/client/tests/common/mod.rs +++ b/ballista/client/tests/common/mod.rs @@ -23,6 +23,53 @@ use ballista::prelude::BallistaConfig; use ballista_core::serde::{ protobuf::scheduler_grpc_client::SchedulerGrpcClient, BallistaCodec, }; +use datafusion::execution::SessionState; +use object_store::aws::AmazonS3Builder; +use testcontainers_modules::minio::MinIO; +use testcontainers_modules::testcontainers::core::{CmdWaitFor, ExecCommand}; +use testcontainers_modules::testcontainers::ContainerRequest; +use testcontainers_modules::{minio, testcontainers::ImageExt}; + +pub const REGION: &str = "eu-west-1"; +pub const BUCKET: &str = "ballista"; +pub const ACCESS_KEY_ID: &str = "MINIO"; +pub const SECRET_KEY: &str = "MINIOMINIO"; + +#[allow(dead_code)] +pub fn create_s3_store( + port: u16, +) -> std::result::Result { + AmazonS3Builder::new() + .with_endpoint(format!("http://localhost:{port}")) + .with_region(REGION) + .with_bucket_name(BUCKET) + .with_access_key_id(ACCESS_KEY_ID) + .with_secret_access_key(SECRET_KEY) + .with_allow_http(true) + .build() +} + +#[allow(dead_code)] +pub fn create_minio_container() -> ContainerRequest { + MinIO::default() + .with_env_var("MINIO_ACCESS_KEY", ACCESS_KEY_ID) + .with_env_var("MINIO_SECRET_KEY", SECRET_KEY) +} + +#[allow(dead_code)] +pub fn create_bucket_command() -> ExecCommand { + // this is hack to create a bucket without creating s3 client. + // this works with current testcontainer (and image) version 'RELEASE.2022-02-07T08-17-33Z'. + // (testcontainer does not await properly on latest image version) + // + // if testcontainer image version change to something newer we should use "mc mb /data/ballista" + // to crate a bucket. + ExecCommand::new(vec![ + "mkdir".to_string(), + format!("/data/{}", crate::common::BUCKET), + ]) + .with_cmd_ready_condition(CmdWaitFor::seconds(1)) +} // /// Remote ballista cluster to be used for local testing. // static BALLISTA_CLUSTER: tokio::sync::OnceCell<(String, u16)> = @@ -136,6 +183,48 @@ pub async fn setup_test_cluster() -> (String, u16) { (host, addr.port()) } +/// starts a ballista cluster for integration tests +#[allow(dead_code)] +pub async fn setup_test_cluster_with_state(session_state: SessionState) -> (String, u16) { + let config = BallistaConfig::builder().build().unwrap(); + //let default_codec = BallistaCodec::default(); + + let addr = ballista_scheduler::standalone::new_standalone_scheduler_from_state( + &session_state, + ) + .await + .expect("scheduler to be created"); + + let host = "localhost".to_string(); + + let scheduler_url = format!("http://{}:{}", host, addr.port()); + + let scheduler = loop { + match SchedulerGrpcClient::connect(scheduler_url.clone()).await { + Err(_) => { + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + log::info!("Attempting to connect to test scheduler..."); + } + Ok(scheduler) => break scheduler, + } + }; + + ballista_executor::new_standalone_executor_from_state::< + datafusion_proto::protobuf::LogicalPlanNode, + datafusion_proto::protobuf::PhysicalPlanNode, + >( + scheduler, + config.default_standalone_parallelism(), + &session_state, + ) + .await + .expect("executor to be created"); + + log::info!("test scheduler created at: {}:{}", host, addr.port()); + + (host, addr.port()) +} + #[ctor::ctor] fn init() { // Enable RUST_LOG logging configuration for test diff --git a/ballista/client/tests/object_store.rs b/ballista/client/tests/object_store.rs new file mode 100644 index 000000000..b58bcb905 --- /dev/null +++ b/ballista/client/tests/object_store.rs @@ -0,0 +1,201 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # Object Store Support +//! +//! Tests demonstrate how to setup object stores with ballista. +//! +//! Test depend on Minio testcontainer acting as S3 object +//! store. +//! +//! Tesctoncainers require docker to run. + +mod common; + +#[cfg(test)] +#[cfg(feature = "standalone")] +#[cfg(feature = "testcontainers")] +mod standalone { + + use ballista::extension::SessionContextExt; + use datafusion::{assert_batches_eq, prelude::SessionContext}; + use datafusion::{ + error::DataFusionError, + execution::{ + runtime_env::{RuntimeConfig, RuntimeEnv}, + SessionStateBuilder, + }, + }; + use std::sync::Arc; + use testcontainers_modules::testcontainers::runners::AsyncRunner; + + #[tokio::test] + async fn should_execute_sql_write() -> datafusion::error::Result<()> { + let container = crate::common::create_minio_container(); + let node = container.start().await.unwrap(); + + node.exec(crate::common::create_bucket_command()) + .await + .unwrap(); + + let port = node.get_host_port_ipv4(9000).await.unwrap(); + + let object_store = crate::common::create_s3_store(port) + .map_err(|e| DataFusionError::External(e.into()))?; + + let test_data = crate::common::example_test_data(); + let config = RuntimeConfig::new(); + let runtime_env = RuntimeEnv::new(config)?; + + runtime_env.register_object_store( + &format!("s3://{}", crate::common::BUCKET) + .as_str() + .try_into() + .unwrap(), + Arc::new(object_store), + ); + let state = SessionStateBuilder::new() + .with_runtime_env(runtime_env.into()) + .build(); + + let ctx: SessionContext = SessionContext::standalone_with_state(state).await?; + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await?; + + let write_dir_path = + &format!("s3://{}/write_test.parquet", crate::common::BUCKET); + + ctx.sql("select * from test") + .await? + .write_parquet(write_dir_path, Default::default(), Default::default()) + .await?; + + ctx.register_parquet("written_table", write_dir_path, Default::default()) + .await?; + + let result = ctx + .sql("select id, string_col, timestamp_col from written_table where id > 4") + .await? + .collect() + .await?; + let expected = [ + "+----+------------+---------------------+", + "| id | string_col | timestamp_col |", + "+----+------------+---------------------+", + "| 5 | 31 | 2009-03-01T00:01:00 |", + "| 6 | 30 | 2009-04-01T00:00:00 |", + "| 7 | 31 | 2009-04-01T00:01:00 |", + "+----+------------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + Ok(()) + } +} + +#[cfg(test)] +#[cfg(feature = "testcontainers")] +mod remote { + + use ballista::extension::SessionContextExt; + use datafusion::{assert_batches_eq, prelude::SessionContext}; + use datafusion::{ + error::DataFusionError, + execution::{ + runtime_env::{RuntimeConfig, RuntimeEnv}, + SessionStateBuilder, + }, + }; + use std::sync::Arc; + use testcontainers_modules::testcontainers::runners::AsyncRunner; + + #[tokio::test] + async fn should_execute_sql_write() -> datafusion::error::Result<()> { + let test_data = crate::common::example_test_data(); + + let container = crate::common::create_minio_container(); + let node = container.start().await.unwrap(); + + node.exec(crate::common::create_bucket_command()) + .await + .unwrap(); + + let port = node.get_host_port_ipv4(9000).await.unwrap(); + + let object_store = crate::common::create_s3_store(port) + .map_err(|e| DataFusionError::External(e.into()))?; + + let config = RuntimeConfig::new(); + let runtime_env = RuntimeEnv::new(config)?; + + runtime_env.register_object_store( + &format!("s3://{}", crate::common::BUCKET) + .as_str() + .try_into() + .unwrap(), + Arc::new(object_store), + ); + let state = SessionStateBuilder::new() + .with_runtime_env(runtime_env.into()) + .build(); + + let (host, port) = + crate::common::setup_test_cluster_with_state(state.clone()).await; + let url = format!("df://{host}:{port}"); + + let ctx: SessionContext = SessionContext::remote_with_state(&url, state).await?; + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await?; + + let write_dir_path = + &format!("s3://{}/write_test.parquet", crate::common::BUCKET); + + ctx.sql("select * from test") + .await? + .write_parquet(write_dir_path, Default::default(), Default::default()) + .await?; + + ctx.register_parquet("written_table", write_dir_path, Default::default()) + .await?; + + let result = ctx + .sql("select id, string_col, timestamp_col from written_table where id > 4") + .await? + .collect() + .await?; + let expected = [ + "+----+------------+---------------------+", + "| id | string_col | timestamp_col |", + "+----+------------+---------------------+", + "| 5 | 31 | 2009-03-01T00:01:00 |", + "| 6 | 30 | 2009-04-01T00:00:00 |", + "| 7 | 31 | 2009-04-01T00:01:00 |", + "+----+------------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + Ok(()) + } +} diff --git a/ballista/client/tests/remote.rs b/ballista/client/tests/remote.rs index 619c4cd62..b0184b265 100644 --- a/ballista/client/tests/remote.rs +++ b/ballista/client/tests/remote.rs @@ -142,4 +142,44 @@ mod remote { Ok(()) } + + #[tokio::test] + async fn should_execute_sql_app_name_show() -> datafusion::error::Result<()> { + let (host, port) = crate::common::setup_test_cluster().await; + let url = format!("df://{host}:{port}"); + + let test_data = crate::common::example_test_data(); + let ctx: SessionContext = SessionContext::remote(&url).await?; + + ctx.sql("SET ballista.job.name = 'Super Cool Ballista App'") + .await? + .show() + .await?; + + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await?; + + let result = ctx + .sql("select string_col, timestamp_col from test where id > 4") + .await? + .collect() + .await?; + let expected = [ + "+------------+---------------------+", + "| string_col | timestamp_col |", + "+------------+---------------------+", + "| 31 | 2009-03-01T00:01:00 |", + "| 30 | 2009-04-01T00:00:00 |", + "| 31 | 2009-04-01T00:01:00 |", + "+------------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } } diff --git a/ballista/client/tests/setup.rs b/ballista/client/tests/setup.rs index 30a6df84a..10b482906 100644 --- a/ballista/client/tests/setup.rs +++ b/ballista/client/tests/setup.rs @@ -20,7 +20,7 @@ mod common; #[cfg(test)] mod remote { use ballista::{ - extension::{BallistaSessionConfigExt, SessionContextExt}, + extension::{SessionConfigExt, SessionContextExt}, prelude::BALLISTA_JOB_NAME, }; use datafusion::{ @@ -109,12 +109,10 @@ mod standalone { use std::sync::{atomic::AtomicBool, Arc}; use ballista::{ - extension::{BallistaSessionConfigExt, SessionContextExt}, + extension::{SessionConfigExt, SessionContextExt}, prelude::BALLISTA_JOB_NAME, }; - use ballista_core::{ - config::BALLISTA_PLANNER_OVERRIDE, serde::BallistaPhysicalExtensionCodec, - }; + use ballista_core::serde::BallistaPhysicalExtensionCodec; use datafusion::{ assert_batches_eq, common::exec_err, @@ -243,12 +241,11 @@ mod standalone { async fn should_override_planner() -> datafusion::error::Result<()> { let session_config = SessionConfig::new_with_ballista() .with_information_schema(true) - .set_str(BALLISTA_PLANNER_OVERRIDE, "false"); + .with_ballista_query_planner(Arc::new(BadPlanner::default())); let state = SessionStateBuilder::new() .with_default_features() .with_config(session_config) - .with_query_planner(Arc::new(BadPlanner::default())) .build(); let ctx: SessionContext = SessionContext::standalone_with_state(state).await?; @@ -257,14 +254,12 @@ mod standalone { assert!(result.is_err()); - let session_config = SessionConfig::new_with_ballista() - .with_information_schema(true) - .set_str(BALLISTA_PLANNER_OVERRIDE, "true"); + let session_config = + SessionConfig::new_with_ballista().with_information_schema(true); let state = SessionStateBuilder::new() .with_default_features() .with_config(session_config) - .with_query_planner(Arc::new(BadPlanner::default())) .build(); let ctx: SessionContext = SessionContext::standalone_with_state(state).await?; diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs index 782b8b9d0..88cba1d9a 100644 --- a/ballista/core/src/config.rs +++ b/ballista/core/src/config.rs @@ -43,11 +43,6 @@ pub const BALLISTA_REPARTITION_WINDOWS: &str = "ballista.repartition.windows"; pub const BALLISTA_PARQUET_PRUNING: &str = "ballista.parquet.pruning"; pub const BALLISTA_COLLECT_STATISTICS: &str = "ballista.collect_statistics"; pub const BALLISTA_STANDALONE_PARALLELISM: &str = "ballista.standalone.parallelism"; -/// If set to false, planner will not be overridden by ballista. -/// This allows user to replace ballista planner -// this is a bit of a hack, as we can't detect if there is a -// custom planner provided -pub const BALLISTA_PLANNER_OVERRIDE: &str = "ballista.planner.override"; pub const BALLISTA_WITH_INFORMATION_SCHEMA: &str = "ballista.with_information_schema"; @@ -221,10 +216,6 @@ impl BallistaConfig { "Configuration for max message size in gRPC clients".to_string(), DataType::UInt64, Some((16 * 1024 * 1024).to_string())), - ConfigEntry::new(BALLISTA_PLANNER_OVERRIDE.to_string(), - "Disable overriding provided planner".to_string(), - DataType::Boolean, - Some((true).to_string())), ]; entries .iter() @@ -280,10 +271,6 @@ impl BallistaConfig { self.get_bool_setting(BALLISTA_WITH_INFORMATION_SCHEMA) } - pub fn planner_override(&self) -> bool { - self.get_bool_setting(BALLISTA_PLANNER_OVERRIDE) - } - fn get_usize_setting(&self, key: &str) -> usize { if let Some(v) = self.settings.get(key) { // infallible because we validate all configs in the constructor diff --git a/ballista/core/src/execution_plans/distributed_query.rs b/ballista/core/src/execution_plans/distributed_query.rs index 050ba877a..dae4bb8ee 100644 --- a/ballista/core/src/execution_plans/distributed_query.rs +++ b/ballista/core/src/execution_plans/distributed_query.rs @@ -194,7 +194,7 @@ impl ExecutionPlan for DistributedQueryExec { fn execute( &self, partition: usize, - _context: Arc, + context: Arc, ) -> Result { assert_eq!(0, partition); @@ -210,17 +210,22 @@ impl ExecutionPlan for DistributedQueryExec { DataFusionError::Execution(format!("failed to encode logical plan: {e:?}")) })?; + let settings = context + .session_config() + .options() + .entries() + .iter() + .map( + |datafusion::config::ConfigEntry { key, value, .. }| KeyValuePair { + key: key.to_owned(), + value: value.clone().unwrap_or_else(|| String::from("")), + }, + ) + .collect(); + let query = ExecuteQueryParams { query: Some(Query::LogicalPlan(buf)), - settings: self - .config - .settings() - .iter() - .map(|(k, v)| KeyValuePair { - key: k.to_owned(), - value: v.to_owned(), - }) - .collect::>(), + settings, optional_session_id: Some(OptionalSessionId::SessionId( self.session_id.clone(), )), diff --git a/ballista/core/src/lib.rs b/ballista/core/src/lib.rs index c52d2ef4e..8ae5dfb59 100644 --- a/ballista/core/src/lib.rs +++ b/ballista/core/src/lib.rs @@ -16,6 +16,10 @@ // under the License. #![doc = include_str!("../README.md")] + +use std::sync::Arc; + +use datafusion::{execution::runtime_env::RuntimeEnv, prelude::SessionConfig}; pub const BALLISTA_VERSION: &str = env!("CARGO_PKG_VERSION"); pub fn print_version() { @@ -33,3 +37,23 @@ pub mod utils; #[macro_use] pub mod serde; + +/// +/// [RuntimeProducer] is a factory which creates runtime [RuntimeEnv] +/// from [SessionConfig]. As [SessionConfig] will be propagated +/// from client to executors, this provides possibility to +/// create [RuntimeEnv] components and configure them according to +/// [SessionConfig] or some of its config extension +/// +/// It is intended to be used with executor configuration +/// +pub type RuntimeProducer = Arc< + dyn Fn(&SessionConfig) -> datafusion::error::Result> + Send + Sync, +>; +/// +/// [ConfigProducer] is a factory which can create [SessionConfig], with +/// additional extension or configuration codecs +/// +/// It is intended to be used with executor configuration +/// +pub type ConfigProducer = Arc SessionConfig + Send + Sync>; diff --git a/ballista/core/src/object_store_registry/mod.rs b/ballista/core/src/object_store_registry/mod.rs index aedccc5e9..e7fbee216 100644 --- a/ballista/core/src/object_store_registry/mod.rs +++ b/ballista/core/src/object_store_registry/mod.rs @@ -31,6 +31,7 @@ use std::sync::Arc; use url::Url; /// Get a RuntimeConfig with specific ObjectStoreRegistry +// TODO: #[deprecated] this method pub fn with_object_store_registry(config: RuntimeConfig) -> RuntimeConfig { let registry = Arc::new(BallistaObjectStoreRegistry::default()); config.with_object_store_registry(registry) diff --git a/ballista/core/src/serde/mod.rs b/ballista/core/src/serde/mod.rs index 7464fe686..5400b00ca 100644 --- a/ballista/core/src/serde/mod.rs +++ b/ballista/core/src/serde/mod.rs @@ -89,7 +89,7 @@ impl Default for BallistaCodec { fn default() -> Self { Self { logical_extension_codec: Arc::new(BallistaLogicalExtensionCodec::default()), - physical_extension_codec: Arc::new(BallistaPhysicalExtensionCodec {}), + physical_extension_codec: Arc::new(BallistaPhysicalExtensionCodec::default()), logical_plan_repr: PhantomData, physical_plan_repr: PhantomData, } diff --git a/ballista/core/src/serde/scheduler/from_proto.rs b/ballista/core/src/serde/scheduler/from_proto.rs index 4821eab26..28a1e8a59 100644 --- a/ballista/core/src/serde/scheduler/from_proto.rs +++ b/ballista/core/src/serde/scheduler/from_proto.rs @@ -17,12 +17,13 @@ use chrono::{TimeZone, Utc}; use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion::execution::runtime_env::RuntimeEnv; + use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF}; use datafusion::physical_plan::metrics::{ Count, Gauge, MetricValue, MetricsSet, Time, Timestamp, }; use datafusion::physical_plan::{ExecutionPlan, Metric}; +use datafusion::prelude::SessionConfig; use datafusion_proto::logical_plan::AsLogicalPlan; use datafusion_proto::physical_plan::AsExecutionPlan; use std::collections::HashMap; @@ -32,11 +33,13 @@ use std::time::Duration; use crate::error::BallistaError; use crate::serde::scheduler::{ - Action, ExecutorData, ExecutorMetadata, ExecutorSpecification, PartitionId, - PartitionLocation, PartitionStats, SimpleFunctionRegistry, TaskDefinition, + Action, BallistaFunctionRegistry, ExecutorData, ExecutorMetadata, + ExecutorSpecification, PartitionId, PartitionLocation, PartitionStats, + TaskDefinition, }; use crate::serde::{protobuf, BallistaCodec}; +use crate::RuntimeProducer; use protobuf::{operator_metric, NamedCount, NamedGauge, NamedTime}; impl TryInto for protobuf::Action { @@ -281,17 +284,17 @@ impl Into for protobuf::ExecutorData { pub fn get_task_definition( task: protobuf::TaskDefinition, - runtime: Arc, + produce_runtime: RuntimeProducer, + session_config: SessionConfig, scalar_functions: HashMap>, aggregate_functions: HashMap>, window_functions: HashMap>, codec: BallistaCodec, ) -> Result { - let mut props = HashMap::new(); + let mut session_config = session_config; for kv_pair in task.props { - props.insert(kv_pair.key, kv_pair.value); + session_config = session_config.set_str(&kv_pair.key, &kv_pair.value); } - let props = Arc::new(props); let mut task_scalar_functions = HashMap::new(); let mut task_aggregate_functions = HashMap::new(); @@ -306,12 +309,12 @@ pub fn get_task_definition = U::try_decode(encoded_plan).and_then(|proto| { proto.try_into_physical_plan( @@ -340,7 +343,7 @@ pub fn get_task_definition( multi_task: protobuf::MultiTaskDefinition, - runtime: Arc, + runtime_producer: RuntimeProducer, + session_config: SessionConfig, scalar_functions: HashMap>, aggregate_functions: HashMap>, window_functions: HashMap>, codec: BallistaCodec, ) -> Result, BallistaError> { - let mut props = HashMap::new(); + let mut session_config = session_config; for kv_pair in multi_task.props { - props.insert(kv_pair.key, kv_pair.value); + session_config = session_config.set_str(&kv_pair.key, &kv_pair.value); } - let props = Arc::new(props); let mut task_scalar_functions = HashMap::new(); let mut task_aggregate_functions = HashMap::new(); @@ -375,12 +378,14 @@ pub fn get_task_definition_vec< for agg_func in window_functions { task_window_functions.insert(agg_func.0, agg_func.1); } - let function_registry = Arc::new(SimpleFunctionRegistry { + let function_registry = Arc::new(BallistaFunctionRegistry { scalar_functions: task_scalar_functions, aggregate_functions: task_aggregate_functions, window_functions: task_window_functions, }); + let runtime = runtime_producer(&session_config)?; + let encoded_plan = multi_task.plan.as_slice(); let plan: Arc = U::try_decode(encoded_plan).and_then(|proto| { proto.try_into_physical_plan( @@ -410,7 +415,7 @@ pub fn get_task_definition_vec< plan: reset_metrics_for_execution_plan(plan.clone())?, launch_time, session_id: session_id.clone(), - props: props.clone(), + session_config: session_config.clone(), function_registry: function_registry.clone(), }) }) diff --git a/ballista/core/src/serde/scheduler/mod.rs b/ballista/core/src/serde/scheduler/mod.rs index 23c9c4256..2905455eb 100644 --- a/ballista/core/src/serde/scheduler/mod.rs +++ b/ballista/core/src/serde/scheduler/mod.rs @@ -24,11 +24,15 @@ use datafusion::arrow::array::{ }; use datafusion::arrow::datatypes::{DataType, Field}; use datafusion::common::DataFusionError; -use datafusion::execution::FunctionRegistry; +use datafusion::execution::{FunctionRegistry, SessionState}; +use datafusion::functions::all_default_functions; +use datafusion::functions_aggregate::all_default_aggregate_functions; +use datafusion::functions_window::all_default_window_functions; use datafusion::logical_expr::planner::ExprPlanner; use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF}; use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_plan::Partitioning; +use datafusion::prelude::SessionConfig; use serde::Serialize; use crate::error::BallistaError; @@ -288,18 +292,43 @@ pub struct TaskDefinition { pub plan: Arc, pub launch_time: u64, pub session_id: String, - pub props: Arc>, - pub function_registry: Arc, + pub session_config: SessionConfig, + pub function_registry: Arc, } #[derive(Debug)] -pub struct SimpleFunctionRegistry { +pub struct BallistaFunctionRegistry { pub scalar_functions: HashMap>, pub aggregate_functions: HashMap>, pub window_functions: HashMap>, } -impl FunctionRegistry for SimpleFunctionRegistry { +impl Default for BallistaFunctionRegistry { + fn default() -> Self { + let scalar_functions = all_default_functions() + .into_iter() + .map(|f| (f.name().to_string(), f)) + .collect(); + + let aggregate_functions = all_default_aggregate_functions() + .into_iter() + .map(|f| (f.name().to_string(), f)) + .collect(); + + let window_functions = all_default_window_functions() + .into_iter() + .map(|f| (f.name().to_string(), f)) + .collect(); + + Self { + scalar_functions, + aggregate_functions, + window_functions, + } + } +} + +impl FunctionRegistry for BallistaFunctionRegistry { fn expr_planners(&self) -> Vec> { vec![] } @@ -338,3 +367,17 @@ impl FunctionRegistry for SimpleFunctionRegistry { }) } } + +impl From<&SessionState> for BallistaFunctionRegistry { + fn from(state: &SessionState) -> Self { + let scalar_functions = state.scalar_functions().clone(); + let aggregate_functions = state.aggregate_functions().clone(); + let window_functions = state.window_functions().clone(); + + Self { + scalar_functions, + aggregate_functions, + window_functions, + } + } +} diff --git a/ballista/core/src/utils.rs b/ballista/core/src/utils.rs index 8be32c402..3f8f6bfea 100644 --- a/ballista/core/src/utils.rs +++ b/ballista/core/src/utils.rs @@ -277,7 +277,7 @@ pub fn create_df_ctx_with_ballista_query_planner( SessionContext::new_with_state(session_state) } -pub trait BallistaSessionStateExt { +pub trait SessionStateExt { fn new_ballista_state( scheduler_url: String, session_id: String, @@ -291,7 +291,7 @@ pub trait BallistaSessionStateExt { fn ballista_config(&self) -> BallistaConfig; } -impl BallistaSessionStateExt for SessionState { +impl SessionStateExt for SessionState { fn ballista_config(&self) -> BallistaConfig { self.config() .options() @@ -313,7 +313,9 @@ impl BallistaSessionStateExt for SessionState { let session_config = SessionConfig::new() .with_information_schema(true) - .with_option_extension(config.clone()); + .with_option_extension(config.clone()) + // Ballista disables this option + .with_round_robin_repartition(false); let runtime_config = RuntimeConfig::default(); let runtime_env = RuntimeEnv::new(runtime_config)?; @@ -334,6 +336,7 @@ impl BallistaSessionStateExt for SessionState { session_id: String, ) -> datafusion::error::Result { let codec_logical = self.config().ballista_logical_extension_codec(); + let planner_override = self.config().ballista_query_planner(); let new_config = self .config() @@ -346,39 +349,31 @@ impl BallistaSessionStateExt for SessionState { let session_config = self .config() .clone() - .with_option_extension(new_config.clone()); - - // at the moment we don't have a way to detect if - // user set planner so we provide a configuration to - // user to disable planner override - let planner_override = self - .config() - .options() - .extensions - .get::() - .map(|config| config.planner_override()) - .unwrap_or(true); + .with_option_extension(new_config.clone()) + // Ballista disables this option + .with_round_robin_repartition(false); let builder = SessionStateBuilder::new_from_existing(self) .with_config(session_config) .with_session_id(session_id); - let builder = if planner_override { - let query_planner = BallistaQueryPlanner::::with_extension( - scheduler_url, - new_config, - codec_logical, - ); - builder.with_query_planner(Arc::new(query_planner)) - } else { - builder + let builder = match planner_override { + Some(planner) => builder.with_query_planner(planner), + None => { + let planner = BallistaQueryPlanner::::with_extension( + scheduler_url, + new_config, + codec_logical, + ); + builder.with_query_planner(Arc::new(planner)) + } }; Ok(builder.build()) } } -pub trait BallistaSessionConfigExt { +pub trait SessionConfigExt { /// Creates session config which has /// ballista configuration initialized fn new_with_ballista() -> SessionConfig; @@ -402,9 +397,20 @@ pub trait BallistaSessionConfigExt { /// returns [PhysicalExtensionCodec] if set /// or default ballista codec if not fn ballista_physical_extension_codec(&self) -> Arc; + + /// Overrides ballista's [QueryPlanner] + fn with_ballista_query_planner( + self, + planner: Arc, + ) -> SessionConfig; + + /// Returns ballista's [QueryPlanner] if overriden + fn ballista_query_planner( + &self, + ) -> Option>; } -impl BallistaSessionConfigExt for SessionConfig { +impl SessionConfigExt for SessionConfig { fn new_with_ballista() -> SessionConfig { SessionConfig::new().with_option_extension(BallistaConfig::new().unwrap()) } @@ -433,6 +439,21 @@ impl BallistaSessionConfigExt for SessionConfig { .map(|c| c.codec()) .unwrap_or_else(|| Arc::new(BallistaPhysicalExtensionCodec::default())) } + + fn with_ballista_query_planner( + self, + planner: Arc, + ) -> SessionConfig { + let extension = BallistaQueryPlannerExtension::new(planner); + self.with_extension(Arc::new(extension)) + } + + fn ballista_query_planner( + &self, + ) -> Option> { + self.get_extension::() + .map(|c| c.planner()) + } } /// Wrapper for [SessionConfig] extension @@ -465,6 +486,21 @@ impl BallistaConfigExtensionPhysicalCodec { } } +/// Wrapper for [SessionConfig] extension +/// holding overridden [QueryPlanner] +struct BallistaQueryPlannerExtension { + planner: Arc, +} + +impl BallistaQueryPlannerExtension { + fn new(planner: Arc) -> Self { + Self { planner } + } + fn planner(&self) -> Arc { + self.planner.clone() + } +} + pub struct BallistaQueryPlanner { scheduler_url: String, config: BallistaConfig, @@ -656,12 +692,12 @@ mod test { error::Result, execution::{ runtime_env::{RuntimeConfig, RuntimeEnv}, - SessionStateBuilder, + SessionState, SessionStateBuilder, }, prelude::{SessionConfig, SessionContext}, }; - use crate::utils::LocalRun; + use crate::utils::{LocalRun, SessionStateExt}; fn context() -> SessionContext { let runtime_environment = RuntimeEnv::new(RuntimeConfig::new()).unwrap(); @@ -738,4 +774,25 @@ mod test { Ok(()) } + + // Ballista disables round robin repatriations + #[tokio::test] + async fn should_disable_round_robin_repartition() { + let state = SessionState::new_ballista_state( + "scheduler_url".to_string(), + "session_id".to_string(), + ) + .unwrap(); + + assert!(!state.config().round_robin_repartition()); + + let state = SessionStateBuilder::new().build(); + + assert!(state.config().round_robin_repartition()); + let state = state + .upgrade_for_ballista("scheduler_url".to_string(), "session_id".to_string()) + .unwrap(); + + assert!(!state.config().round_robin_repartition()); + } } diff --git a/ballista/executor/Cargo.toml b/ballista/executor/Cargo.toml index ed7c43186..b04abd9d5 100644 --- a/ballista/executor/Cargo.toml +++ b/ballista/executor/Cargo.toml @@ -41,7 +41,7 @@ anyhow = "1" arrow = { workspace = true } arrow-flight = { workspace = true } async-trait = { workspace = true } -ballista-core = { path = "../core", version = "0.12.0", features = ["s3"] } +ballista-core = { path = "../core", version = "0.12.0" } configure_me = { workspace = true } dashmap = { workspace = true } datafusion = { workspace = true } diff --git a/ballista/executor/src/bin/main.rs b/ballista/executor/src/bin/main.rs index ba56b3335..9f5ed12f1 100644 --- a/ballista/executor/src/bin/main.rs +++ b/ballista/executor/src/bin/main.rs @@ -87,6 +87,11 @@ async fn main() -> Result<()> { cache_capacity: opt.cache_capacity, cache_io_concurrency: opt.cache_io_concurrency, execution_engine: None, + function_registry: None, + config_producer: None, + runtime_producer: None, + logical_codec: None, + physical_codec: None, }; start_executor_process(Arc::new(config)).await diff --git a/ballista/executor/src/execution_loop.rs b/ballista/executor/src/execution_loop.rs index 591c5c453..8056d6c52 100644 --- a/ballista/executor/src/execution_loop.rs +++ b/ballista/executor/src/execution_loop.rs @@ -15,40 +15,30 @@ // specific language governing permissions and limitations // under the License. -use datafusion::config::ConfigOptions; -use datafusion::physical_plan::ExecutionPlan; - -use ballista_core::serde::protobuf::{ - scheduler_grpc_client::SchedulerGrpcClient, PollWorkParams, PollWorkResult, - TaskDefinition, TaskStatus, -}; -use datafusion::prelude::SessionConfig; -use tokio::sync::{OwnedSemaphorePermit, Semaphore}; - use crate::cpu_bound_executor::DedicatedExecutor; use crate::executor::Executor; use crate::{as_task_status, TaskExecutionTimes}; use ballista_core::error::BallistaError; +use ballista_core::serde::protobuf::{ + scheduler_grpc_client::SchedulerGrpcClient, PollWorkParams, PollWorkResult, + TaskDefinition, TaskStatus, +}; use ballista_core::serde::scheduler::{ExecutorSpecification, PartitionId}; use ballista_core::serde::BallistaCodec; use datafusion::execution::context::TaskContext; -use datafusion::functions::datetime::date_part; -use datafusion::functions::unicode::substr; -use datafusion::functions_aggregate::covariance::{covar_pop_udaf, covar_samp_udaf}; -use datafusion::functions_aggregate::sum::sum_udaf; -use datafusion::functions_aggregate::variance::var_samp_udaf; +use datafusion::physical_plan::ExecutionPlan; use datafusion_proto::logical_plan::AsLogicalPlan; use datafusion_proto::physical_plan::AsExecutionPlan; use futures::FutureExt; use log::{debug, error, info, warn}; use std::any::Any; -use std::collections::HashMap; use std::convert::TryInto; use std::error::Error; use std::ops::Deref; use std::sync::mpsc::{Receiver, Sender, TryRecvError}; use std::time::{SystemTime, UNIX_EPOCH}; use std::{sync::Arc, time::Duration}; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use tonic::transport::Channel; pub async fn poll_loop( @@ -172,43 +162,20 @@ async fn run_received_task>, + /// Function registry + pub function_registry: Arc, - /// Aggregate functions registered in the Executor - pub aggregate_functions: HashMap>, + /// Creates [RuntimeEnv] based on [SessionConfig] + pub runtime_producer: RuntimeProducer, - /// Window functions registered in the Executor - pub window_functions: HashMap>, - - /// Runtime environment for Executor - runtime: Arc, + /// Creates default [SessionConfig] + pub config_producer: ConfigProducer, /// Collector for runtime execution metrics pub metrics_collector: Arc, @@ -90,33 +88,47 @@ pub struct Executor { } impl Executor { - /// Create a new executor instance + /// Create a new executor instance with given [RuntimeEnv] + /// It will use default scalar, aggregate and window functions + pub fn new_basic( + metadata: ExecutorRegistration, + work_dir: &str, + runtime_producer: RuntimeProducer, + config_producer: ConfigProducer, + concurrent_tasks: usize, + ) -> Self { + Self::new( + metadata, + work_dir, + runtime_producer, + config_producer, + Arc::new(BallistaFunctionRegistry::default()), + Arc::new(LoggingMetricsCollector::default()), + concurrent_tasks, + None, + ) + } + + /// Create a new executor instance with given [RuntimeEnv], + /// [ScalarUDF], [AggregateUDF] and [WindowUDF] + + #[allow(clippy::too_many_arguments)] pub fn new( metadata: ExecutorRegistration, work_dir: &str, - runtime: Arc, + runtime_producer: RuntimeProducer, + config_producer: ConfigProducer, + function_registry: Arc, metrics_collector: Arc, concurrent_tasks: usize, execution_engine: Option>, ) -> Self { - let scalar_functions = all_default_functions() - .into_iter() - .map(|f| (f.name().to_string(), f)) - .collect(); - - let aggregate_functions = all_default_aggregate_functions() - .into_iter() - .map(|f| (f.name().to_string(), f)) - .collect(); - Self { metadata, work_dir: work_dir.to_owned(), - scalar_functions, - aggregate_functions, - // TODO: set to default window functions when they are moved to udwf - window_functions: HashMap::new(), - runtime, + function_registry, + runtime_producer, + config_producer, metrics_collector, concurrent_tasks, abort_handles: Default::default(), @@ -127,8 +139,15 @@ impl Executor { } impl Executor { - pub fn get_runtime(&self) -> Arc { - self.runtime.clone() + pub fn produce_runtime( + &self, + config: &SessionConfig, + ) -> datafusion::error::Result> { + (self.runtime_producer)(config) + } + + pub fn produce_config(&self) -> SessionConfig { + (self.config_producer)() } /// Execute one partition of a query stage and persist the result to disk in IPC format. On @@ -197,12 +216,13 @@ impl Executor { mod test { use crate::execution_engine::DefaultQueryStageExec; use crate::executor::Executor; - use crate::metrics::LoggingMetricsCollector; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; + use ballista_core::config::BallistaConfig; use ballista_core::execution_plans::ShuffleWriterExec; use ballista_core::serde::protobuf::ExecutorRegistration; use ballista_core::serde::scheduler::PartitionId; + use ballista_core::RuntimeProducer; use datafusion::error::{DataFusionError, Result}; use datafusion::execution::context::TaskContext; @@ -210,7 +230,7 @@ mod test { DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, }; - use datafusion::prelude::SessionContext; + use datafusion::prelude::{SessionConfig, SessionContext}; use futures::Stream; use std::any::Any; use std::pin::Pin; @@ -341,16 +361,20 @@ mod test { specification: None, optional_host: None, }; - + let config_producer = Arc::new(|| { + SessionConfig::new().with_option_extension(BallistaConfig::new().unwrap()) + }); let ctx = SessionContext::new(); + let runtime_env = ctx.runtime_env().clone(); + let runtime_producer: RuntimeProducer = + Arc::new(move |_| Ok(runtime_env.clone())); - let executor = Executor::new( + let executor = Executor::new_basic( executor_registration, &work_dir, - ctx.runtime_env(), - Arc::new(LoggingMetricsCollector {}), + runtime_producer, + config_producer, 2, - None, ); let (sender, receiver) = tokio::sync::oneshot::channel(); diff --git a/ballista/executor/src/executor_process.rs b/ballista/executor/src/executor_process.rs index c19f0656a..a15bfadbd 100644 --- a/ballista/executor/src/executor_process.rs +++ b/ballista/executor/src/executor_process.rs @@ -25,6 +25,10 @@ use std::{env, io}; use anyhow::{Context, Result}; use arrow_flight::flight_service_server::FlightServiceServer; +use ballista_core::serde::scheduler::BallistaFunctionRegistry; +use datafusion::prelude::SessionConfig; +use datafusion_proto::logical_plan::LogicalExtensionCodec; +use datafusion_proto::physical_plan::PhysicalExtensionCodec; use futures::stream::FuturesUnordered; use futures::StreamExt; use log::{error, info, warn}; @@ -38,11 +42,11 @@ use tracing_subscriber::EnvFilter; use uuid::Uuid; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; -use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode}; -use ballista_core::config::{DataCachePolicy, LogRotationPolicy, TaskSchedulingPolicy}; +use ballista_core::config::{ + BallistaConfig, DataCachePolicy, LogRotationPolicy, TaskSchedulingPolicy, +}; use ballista_core::error::BallistaError; -use ballista_core::object_store_registry::with_object_store_registry; use ballista_core::serde::protobuf::executor_resource::Resource; use ballista_core::serde::protobuf::executor_status::Status; use ballista_core::serde::protobuf::{ @@ -50,11 +54,13 @@ use ballista_core::serde::protobuf::{ ExecutorRegistration, ExecutorResource, ExecutorSpecification, ExecutorStatus, ExecutorStoppedParams, HeartBeatParams, }; -use ballista_core::serde::BallistaCodec; +use ballista_core::serde::{ + BallistaCodec, BallistaLogicalExtensionCodec, BallistaPhysicalExtensionCodec, +}; use ballista_core::utils::{ create_grpc_client_connection, create_grpc_server, get_time_before, }; -use ballista_core::BALLISTA_VERSION; +use ballista_core::{ConfigProducer, RuntimeProducer, BALLISTA_VERSION}; use crate::execution_engine::ExecutionEngine; use crate::executor::{Executor, TasksDrainedFuture}; @@ -96,6 +102,16 @@ pub struct ExecutorProcessConfig { /// Optional execution engine to use to execute physical plans, will default to /// DataFusion if none is provided. pub execution_engine: Option>, + /// Overrides default function registry + pub function_registry: Option>, + /// [RuntimeProducer] override option + pub runtime_producer: Option, + /// [ConfigProducer] override option + pub config_producer: Option, + /// [PhysicalExtensionCodec] override option + pub logical_codec: Option>, + /// [PhysicalExtensionCodec] override option + pub physical_codec: Option>, } pub async fn start_executor_process(opt: Arc) -> Result<()> { @@ -181,20 +197,40 @@ pub async fn start_executor_process(opt: Arc) -> Result<( }), }; - let config = RuntimeConfig::new().with_temp_file_path(work_dir.clone()); - let runtime = { - let config = with_object_store_registry(config.clone()); - Arc::new(RuntimeEnv::new(config).map_err(|_| { - BallistaError::Internal("Failed to init Executor RuntimeEnv".to_owned()) - })?) - }; - + // put them to session config let metrics_collector = Arc::new(LoggingMetricsCollector::default()); + let config_producer = opt.config_producer.clone().unwrap_or_else(|| { + Arc::new(|| { + SessionConfig::new().with_option_extension(BallistaConfig::new().unwrap()) + }) + }); + let wd = work_dir.clone(); + let runtime_producer: RuntimeProducer = Arc::new(move |_| { + let config = RuntimeConfig::new().with_temp_file_path(wd.clone()); + Ok(Arc::new(RuntimeEnv::new(config)?)) + }); + + let logical = opt + .logical_codec + .clone() + .unwrap_or_else(|| Arc::new(BallistaLogicalExtensionCodec::default())); + + let physical = opt + .physical_codec + .clone() + .unwrap_or_else(|| Arc::new(BallistaPhysicalExtensionCodec::default())); + + let default_codec: BallistaCodec< + datafusion_proto::protobuf::LogicalPlanNode, + datafusion_proto::protobuf::PhysicalPlanNode, + > = BallistaCodec::new(logical, physical); let executor = Arc::new(Executor::new( executor_meta, &work_dir, - runtime, + runtime_producer, + config_producer, + opt.function_registry.clone().unwrap_or_default(), metrics_collector, concurrent_tasks, opt.execution_engine.clone(), @@ -244,9 +280,6 @@ pub async fn start_executor_process(opt: Arc) -> Result<( .max_encoding_message_size(opt.grpc_max_encoding_message_size as usize) .max_decoding_message_size(opt.grpc_max_decoding_message_size as usize); - let default_codec: BallistaCodec = - BallistaCodec::default(); - let scheduler_policy = opt.task_scheduling_policy; let job_data_ttl_seconds = opt.job_data_ttl_seconds; diff --git a/ballista/executor/src/executor_server.rs b/ballista/executor/src/executor_server.rs index 6e3d5589b..cfbc2bd4c 100644 --- a/ballista/executor/src/executor_server.rs +++ b/ballista/executor/src/executor_server.rs @@ -46,9 +46,7 @@ use ballista_core::serde::scheduler::TaskDefinition; use ballista_core::serde::BallistaCodec; use ballista_core::utils::{create_grpc_client_connection, create_grpc_server}; use dashmap::DashMap; -use datafusion::config::ConfigOptions; use datafusion::execution::TaskContext; -use datafusion::prelude::SessionConfig; use datafusion_proto::{logical_plan::AsLogicalPlan, physical_plan::AsExecutionPlan}; use tokio::sync::mpsc::error::TryRecvError; use tokio::task::JoinHandle; @@ -342,22 +340,13 @@ impl ExecutorServer ExecutorGrpc scheduler_id: scheduler_id.clone(), task: get_task_definition( task, - self.executor.get_runtime(), - self.executor.scalar_functions.clone(), - self.executor.aggregate_functions.clone(), - self.executor.window_functions.clone(), + self.executor.runtime_producer.clone(), + self.executor.produce_config(), + self.executor.function_registry.scalar_functions.clone(), + self.executor.function_registry.aggregate_functions.clone(), + self.executor.function_registry.window_functions.clone(), self.codec.clone(), ) .map_err(|e| Status::invalid_argument(format!("{e}")))?, @@ -669,10 +659,11 @@ impl ExecutorGrpc for multi_task in multi_tasks { let multi_task: Vec = get_task_definition_vec( multi_task, - self.executor.get_runtime(), - self.executor.scalar_functions.clone(), - self.executor.aggregate_functions.clone(), - self.executor.window_functions.clone(), + self.executor.runtime_producer.clone(), + self.executor.produce_config(), + self.executor.function_registry.scalar_functions.clone(), + self.executor.function_registry.aggregate_functions.clone(), + self.executor.function_registry.window_functions.clone(), self.codec.clone(), ) .map_err(|e| Status::invalid_argument(format!("{e}")))?; diff --git a/ballista/executor/src/lib.rs b/ballista/executor/src/lib.rs index beb9faac2..b7219225a 100644 --- a/ballista/executor/src/lib.rs +++ b/ballista/executor/src/lib.rs @@ -32,6 +32,7 @@ mod cpu_bound_executor; mod standalone; pub use standalone::new_standalone_executor; +pub use standalone::new_standalone_executor_from_state; use log::info; diff --git a/ballista/executor/src/standalone.rs b/ballista/executor/src/standalone.rs index 38e277134..628de96f4 100644 --- a/ballista/executor/src/standalone.rs +++ b/ballista/executor/src/standalone.rs @@ -18,6 +18,8 @@ use crate::metrics::LoggingMetricsCollector; use crate::{execution_loop, executor::Executor, flight_service::BallistaFlightService}; use arrow_flight::flight_service_server::FlightServiceServer; +use ballista_core::config::BallistaConfig; +use ballista_core::utils::SessionConfigExt; use ballista_core::{ error::Result, object_store_registry::with_object_store_registry, @@ -28,7 +30,10 @@ use ballista_core::{ utils::create_grpc_server, BALLISTA_VERSION, }; +use ballista_core::{ConfigProducer, RuntimeProducer}; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::execution::SessionState; +use datafusion::prelude::SessionConfig; use datafusion_proto::logical_plan::AsLogicalPlan; use datafusion_proto::physical_plan::AsExecutionPlan; use log::info; @@ -38,14 +43,26 @@ use tokio::net::TcpListener; use tonic::transport::Channel; use uuid::Uuid; -pub async fn new_standalone_executor< +/// Creates new standalone executor based on +/// session_state provided. +/// +/// This provides flexible way of configuring underlying +/// components. +pub async fn new_standalone_executor_from_state< T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan, >( scheduler: SchedulerGrpcClient, concurrent_tasks: usize, - codec: BallistaCodec, + session_state: &SessionState, ) -> Result<()> { + let logical = session_state.config().ballista_logical_extension_codec(); + let physical = session_state.config().ballista_physical_extension_codec(); + let codec: BallistaCodec< + datafusion_proto::protobuf::LogicalPlanNode, + datafusion_proto::protobuf::PhysicalPlanNode, + > = BallistaCodec::new(logical, physical); + // Let the OS assign a random, free port let listener = TcpListener::bind("localhost:0").await?; let addr = listener.local_addr()?; @@ -74,14 +91,21 @@ pub async fn new_standalone_executor< .unwrap(); info!("work_dir: {}", work_dir); - let config = with_object_store_registry( - RuntimeConfig::new().with_temp_file_path(work_dir.clone()), - ); + let config = session_state + .config() + .clone() + .with_option_extension(BallistaConfig::new().unwrap()); + let runtime = session_state.runtime_env().clone(); + + let config_producer: ConfigProducer = Arc::new(move || config.clone()); + let runtime_producer: RuntimeProducer = Arc::new(move |_| Ok(runtime.clone())); let executor = Arc::new(Executor::new( executor_meta, &work_dir, - Arc::new(RuntimeEnv::new(config).unwrap()), + runtime_producer, + config_producer, + Arc::new(session_state.into()), Arc::new(LoggingMetricsCollector::default()), concurrent_tasks, None, @@ -100,3 +124,74 @@ pub async fn new_standalone_executor< tokio::spawn(execution_loop::poll_loop(scheduler, executor, codec)); Ok(()) } + +/// Creates standalone executor with most values +/// set as default. +pub async fn new_standalone_executor< + T: 'static + AsLogicalPlan, + U: 'static + AsExecutionPlan, +>( + scheduler: SchedulerGrpcClient, + concurrent_tasks: usize, + codec: BallistaCodec, +) -> Result<()> { + // Let the OS assign a random, free port + let listener = TcpListener::bind("localhost:0").await?; + let addr = listener.local_addr()?; + info!( + "Ballista v{} Rust Executor listening on {:?}", + BALLISTA_VERSION, addr + ); + + let executor_meta = ExecutorRegistration { + id: Uuid::new_v4().to_string(), // assign this executor a unique ID + optional_host: Some(OptionalHost::Host("localhost".to_string())), + port: addr.port() as u32, + // TODO Make it configurable + grpc_port: 50020, + specification: Some( + ExecutorSpecification { + task_slots: concurrent_tasks as u32, + } + .into(), + ), + }; + let work_dir = TempDir::new()? + .into_path() + .into_os_string() + .into_string() + .unwrap(); + info!("work_dir: {}", work_dir); + + let config_producer = Arc::new(|| { + SessionConfig::new().with_option_extension(BallistaConfig::new().unwrap()) + }); + let wd = work_dir.clone(); + let runtime_producer: RuntimeProducer = Arc::new(move |_: &SessionConfig| { + let config = with_object_store_registry( + RuntimeConfig::new().with_temp_file_path(wd.clone()), + ); + Ok(Arc::new(RuntimeEnv::new(config)?)) + }); + + let executor = Arc::new(Executor::new_basic( + executor_meta, + &work_dir, + runtime_producer, + config_producer, + concurrent_tasks, + )); + + let service = BallistaFlightService::new(); + let server = FlightServiceServer::new(service); + tokio::spawn( + create_grpc_server() + .add_service(server) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new( + listener, + )), + ); + + tokio::spawn(execution_loop::poll_loop(scheduler, executor, codec)); + Ok(()) +} diff --git a/ballista/scheduler/Cargo.toml b/ballista/scheduler/Cargo.toml index a1d597352..642e63d48 100644 --- a/ballista/scheduler/Cargo.toml +++ b/ballista/scheduler/Cargo.toml @@ -45,7 +45,7 @@ anyhow = "1" arrow-flight = { workspace = true } async-trait = { workspace = true } axum = "0.7.7" -ballista-core = { path = "../core", version = "0.12.0", features = ["s3"] } +ballista-core = { path = "../core", version = "0.12.0" } base64 = { version = "0.22" } clap = { workspace = true } configure_me = { workspace = true } diff --git a/ballista/scheduler/src/cluster/memory.rs b/ballista/scheduler/src/cluster/memory.rs index f2fe589a8..861b86578 100644 --- a/ballista/scheduler/src/cluster/memory.rs +++ b/ballista/scheduler/src/cluster/memory.rs @@ -401,7 +401,7 @@ impl JobState for InMemoryJobState { &self, config: &BallistaConfig, ) -> Result> { - let session = create_datafusion_context(config, self.session_builder); + let session = create_datafusion_context(config, self.session_builder.clone()); self.sessions.insert(session.session_id(), session.clone()); Ok(session) @@ -412,7 +412,7 @@ impl JobState for InMemoryJobState { session_id: &str, config: &BallistaConfig, ) -> Result> { - let session = create_datafusion_context(config, self.session_builder); + let session = create_datafusion_context(config, self.session_builder.clone()); self.sessions .insert(session_id.to_string(), session.clone()); @@ -486,6 +486,8 @@ impl JobState for InMemoryJobState { #[cfg(test)] mod test { + use std::sync::Arc; + use crate::cluster::memory::InMemoryJobState; use crate::cluster::test_util::{test_job_lifecycle, test_job_planning_failure}; use crate::test_utils::{ @@ -497,17 +499,17 @@ mod test { #[tokio::test] async fn test_in_memory_job_lifecycle() -> Result<()> { test_job_lifecycle( - InMemoryJobState::new("", default_session_builder), + InMemoryJobState::new("", Arc::new(default_session_builder)), test_aggregation_plan(4).await, ) .await?; test_job_lifecycle( - InMemoryJobState::new("", default_session_builder), + InMemoryJobState::new("", Arc::new(default_session_builder)), test_two_aggregations_plan(4).await, ) .await?; test_job_lifecycle( - InMemoryJobState::new("", default_session_builder), + InMemoryJobState::new("", Arc::new(default_session_builder)), test_join_plan(4).await, ) .await?; @@ -518,17 +520,17 @@ mod test { #[tokio::test] async fn test_in_memory_job_planning_failure() -> Result<()> { test_job_planning_failure( - InMemoryJobState::new("", default_session_builder), + InMemoryJobState::new("", Arc::new(default_session_builder)), test_aggregation_plan(4).await, ) .await?; test_job_planning_failure( - InMemoryJobState::new("", default_session_builder), + InMemoryJobState::new("", Arc::new(default_session_builder)), test_two_aggregations_plan(4).await, ) .await?; test_job_planning_failure( - InMemoryJobState::new("", default_session_builder), + InMemoryJobState::new("", Arc::new(default_session_builder)), test_join_plan(4).await, ) .await?; diff --git a/ballista/scheduler/src/cluster/mod.rs b/ballista/scheduler/src/cluster/mod.rs index 81432056a..450c8018c 100644 --- a/ballista/scheduler/src/cluster/mod.rs +++ b/ballista/scheduler/src/cluster/mod.rs @@ -109,7 +109,7 @@ impl BallistaCluster { match &config.cluster_storage { ClusterStorageConfig::Memory => Ok(BallistaCluster::new_memory( scheduler, - default_session_builder, + Arc::new(default_session_builder), )), } } diff --git a/ballista/scheduler/src/scheduler_server/grpc.rs b/ballista/scheduler/src/scheduler_server/grpc.rs index 653bda834..e475e438a 100644 --- a/ballista/scheduler/src/scheduler_server/grpc.rs +++ b/ballista/scheduler/src/scheduler_server/grpc.rs @@ -424,6 +424,7 @@ impl SchedulerGrpc } = query_params { let mut query_settings = HashMap::new(); + log::trace!("received query settings: {:?}", settings); for kv_pair in settings { query_settings.insert(kv_pair.key, kv_pair.value); } @@ -523,6 +524,7 @@ impl SchedulerGrpc .cloned() .unwrap_or_else(|| "None".to_string()); + log::trace!("setting job name: {}", job_name); self.submit_job(&job_id, &job_name, session_ctx, &plan) .await .map_err(|e| { diff --git a/ballista/scheduler/src/scheduler_server/mod.rs b/ballista/scheduler/src/scheduler_server/mod.rs index 3e2da13b1..7ec0e63e3 100644 --- a/ballista/scheduler/src/scheduler_server/mod.rs +++ b/ballista/scheduler/src/scheduler_server/mod.rs @@ -56,7 +56,7 @@ mod external_scaler; mod grpc; pub(crate) mod query_stage_scheduler; -pub(crate) type SessionBuilder = fn(SessionConfig) -> SessionState; +pub(crate) type SessionBuilder = Arc SessionState + Send + Sync>; #[derive(Clone)] pub struct SchedulerServer { diff --git a/ballista/scheduler/src/standalone.rs b/ballista/scheduler/src/standalone.rs index bb6d70064..5ff4d6111 100644 --- a/ballista/scheduler/src/standalone.rs +++ b/ballista/scheduler/src/standalone.rs @@ -20,11 +20,15 @@ use crate::config::SchedulerConfig; use crate::metrics::default_metrics_collector; use crate::scheduler_server::SchedulerServer; use ballista_core::serde::BallistaCodec; -use ballista_core::utils::{create_grpc_server, default_session_builder}; +use ballista_core::utils::{ + create_grpc_server, default_session_builder, SessionConfigExt, +}; use ballista_core::{ error::Result, serde::protobuf::scheduler_grpc_server::SchedulerGrpcServer, BALLISTA_VERSION, }; +use datafusion::execution::{SessionState, SessionStateBuilder}; +use datafusion::prelude::SessionConfig; use datafusion_proto::protobuf::LogicalPlanNode; use datafusion_proto::protobuf::PhysicalPlanNode; use log::info; @@ -33,15 +37,39 @@ use std::sync::Arc; use tokio::net::TcpListener; pub async fn new_standalone_scheduler() -> Result { - let metrics_collector = default_metrics_collector()?; + let codec = BallistaCodec::default(); + new_standalone_scheduler_with_builder(Arc::new(default_session_builder), codec).await +} + +pub async fn new_standalone_scheduler_from_state( + session_state: &SessionState, +) -> Result { + let logical = session_state.config().ballista_logical_extension_codec(); + let physical = session_state.config().ballista_physical_extension_codec(); + let codec = BallistaCodec::new(logical, physical); - let cluster = BallistaCluster::new_memory("localhost:50050", default_session_builder); + let session_state = session_state.clone(); + let session_builder = Arc::new(move |c: SessionConfig| { + SessionStateBuilder::new_from_existing(session_state.clone()) + .with_config(c) + .build() + }); + + new_standalone_scheduler_with_builder(session_builder, codec).await +} + +async fn new_standalone_scheduler_with_builder( + session_builder: crate::scheduler_server::SessionBuilder, + codec: BallistaCodec, +) -> Result { + let cluster = BallistaCluster::new_memory("localhost:50050", session_builder); + let metrics_collector = default_metrics_collector()?; let mut scheduler_server: SchedulerServer = SchedulerServer::new( "localhost:50050".to_owned(), cluster, - BallistaCodec::default(), + codec, Arc::new(SchedulerConfig::default()), metrics_collector, ); diff --git a/ballista/scheduler/src/test_utils.rs b/ballista/scheduler/src/test_utils.rs index 27bc0ec8b..f9eae3156 100644 --- a/ballista/scheduler/src/test_utils.rs +++ b/ballista/scheduler/src/test_utils.rs @@ -124,7 +124,7 @@ pub async fn await_condition>, F: Fn() -> Fut> } pub fn test_cluster_context() -> BallistaCluster { - BallistaCluster::new_memory(TEST_SCHEDULER_NAME, default_session_builder) + BallistaCluster::new_memory(TEST_SCHEDULER_NAME, Arc::new(default_session_builder)) } pub async fn datafusion_test_context(path: &str) -> Result {