From 84e27425c8d3bb78c23c24c6e21bb53a6361ed80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marko=20Milenkovi=C4=87?= Date: Mon, 18 Nov 2024 10:24:23 +0000 Subject: [PATCH] Make easier to create schedulers and executors --- Cargo.toml | 1 + ballista/executor/Cargo.toml | 2 +- ballista/executor/src/bin/main.rs | 58 +--- ballista/executor/src/config.rs | 71 +++++ ballista/executor/src/executor_process.rs | 22 +- ballista/executor/src/lib.rs | 1 + ballista/scheduler/Cargo.toml | 2 +- ballista/scheduler/scheduler_config_spec.toml | 4 +- ballista/scheduler/src/bin/main.rs | 111 ++----- ballista/scheduler/src/cluster/memory.rs | 2 +- ballista/scheduler/src/cluster/mod.rs | 14 +- ballista/scheduler/src/config.rs | 129 +++++++- ballista/scheduler/src/scheduler_process.rs | 18 +- examples/Cargo.toml | 10 +- examples/examples/custom_client.rs | 96 ++++++ examples/examples/custom_executor.rs | 48 +++ examples/examples/custom_scheduler.rs | 58 ++++ examples/src/lib.rs | 1 + examples/src/object_store.rs | 294 ++++++++++++++++++ 19 files changed, 774 insertions(+), 168 deletions(-) create mode 100644 ballista/executor/src/config.rs create mode 100644 examples/examples/custom_client.rs create mode 100644 examples/examples/custom_executor.rs create mode 100644 examples/examples/custom_scheduler.rs create mode 100644 examples/src/object_store.rs diff --git a/Cargo.toml b/Cargo.toml index 4e88716dc4..0467d8ab67 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ members = ["ballista-cli", "ballista/client", "ballista/core", "ballista/executo resolver = "2" [workspace.dependencies] +anyhow = "1" arrow = { version = "53", features = ["ipc_compression"] } arrow-flight = { version = "53", features = ["flight-sql-experimental"] } clap = { version = "3", features = ["derive", "cargo"] } diff --git a/ballista/executor/Cargo.toml b/ballista/executor/Cargo.toml index e1822e9c1b..a7c5c65cc9 100644 --- a/ballista/executor/Cargo.toml +++ b/ballista/executor/Cargo.toml @@ -37,7 +37,7 @@ path = "src/bin/main.rs" default = ["mimalloc"] [dependencies] -anyhow = "1" +anyhow = { workspace = true } arrow = { workspace = true } arrow-flight = { workspace = true } async-trait = { workspace = true } diff --git a/ballista/executor/src/bin/main.rs b/ballista/executor/src/bin/main.rs index 9f5ed12f15..5833930e01 100644 --- a/ballista/executor/src/bin/main.rs +++ b/ballista/executor/src/bin/main.rs @@ -18,24 +18,12 @@ //! Ballista Rust executor binary. use anyhow::Result; -use std::sync::Arc; - use ballista_core::print_version; +use ballista_executor::config::prelude::*; use ballista_executor::executor_process::{ start_executor_process, ExecutorProcessConfig, }; -use config::prelude::*; - -#[allow(unused_imports)] -#[macro_use] -extern crate configure_me; - -#[allow(clippy::all, warnings)] -mod config { - // Ideally we would use the include_config macro from configure_me, but then we cannot use - // #[allow(clippy::all)] to silence clippy warnings from the generated code - include!(concat!(env!("OUT_DIR"), "/executor_configure_me_config.rs")); -} +use std::sync::Arc; #[cfg(feature = "mimalloc")] #[global_allocator] @@ -53,46 +41,6 @@ async fn main() -> Result<()> { std::process::exit(0); } - let log_file_name_prefix = format!( - "executor_{}_{}", - opt.external_host - .clone() - .unwrap_or_else(|| "localhost".to_string()), - opt.bind_port - ); - - let config = ExecutorProcessConfig { - special_mod_log_level: opt.log_level_setting, - external_host: opt.external_host, - bind_host: opt.bind_host, - port: opt.bind_port, - grpc_port: opt.bind_grpc_port, - scheduler_host: opt.scheduler_host, - scheduler_port: opt.scheduler_port, - scheduler_connect_timeout_seconds: opt.scheduler_connect_timeout_seconds, - concurrent_tasks: opt.concurrent_tasks, - task_scheduling_policy: opt.task_scheduling_policy, - work_dir: opt.work_dir, - log_dir: opt.log_dir, - log_file_name_prefix, - log_rotation_policy: opt.log_rotation_policy, - print_thread_info: opt.print_thread_info, - job_data_ttl_seconds: opt.job_data_ttl_seconds, - job_data_clean_up_interval_seconds: opt.job_data_clean_up_interval_seconds, - grpc_max_decoding_message_size: opt.grpc_server_max_decoding_message_size, - grpc_max_encoding_message_size: opt.grpc_server_max_encoding_message_size, - executor_heartbeat_interval_seconds: opt.executor_heartbeat_interval_seconds, - data_cache_policy: opt.data_cache_policy, - cache_dir: opt.cache_dir, - cache_capacity: opt.cache_capacity, - cache_io_concurrency: opt.cache_io_concurrency, - execution_engine: None, - function_registry: None, - config_producer: None, - runtime_producer: None, - logical_codec: None, - physical_codec: None, - }; - + let config: ExecutorProcessConfig = opt.try_into()?; start_executor_process(Arc::new(config)).await } diff --git a/ballista/executor/src/config.rs b/ballista/executor/src/config.rs new file mode 100644 index 0000000000..78db477f9b --- /dev/null +++ b/ballista/executor/src/config.rs @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use ballista_core::error::BallistaError; + +use crate::executor_process::ExecutorProcessConfig; + +// Ideally we would use the include_config macro from configure_me, but then we cannot use +// #[allow(clippy::all)] to silence clippy warnings from the generated code +include!(concat!(env!("OUT_DIR"), "/executor_configure_me_config.rs")); + +impl TryFrom for ExecutorProcessConfig { + type Error = BallistaError; + + fn try_from(opt: Config) -> Result { + let log_file_name_prefix = format!( + "executor_{}_{}", + opt.external_host + .clone() + .unwrap_or_else(|| "localhost".to_string()), + opt.bind_port + ); + + Ok(ExecutorProcessConfig { + special_mod_log_level: opt.log_level_setting, + external_host: opt.external_host, + bind_host: opt.bind_host, + port: opt.bind_port, + grpc_port: opt.bind_grpc_port, + scheduler_host: opt.scheduler_host, + scheduler_port: opt.scheduler_port, + scheduler_connect_timeout_seconds: opt.scheduler_connect_timeout_seconds, + concurrent_tasks: opt.concurrent_tasks, + task_scheduling_policy: opt.task_scheduling_policy, + work_dir: opt.work_dir, + log_dir: opt.log_dir, + log_file_name_prefix, + log_rotation_policy: opt.log_rotation_policy, + print_thread_info: opt.print_thread_info, + job_data_ttl_seconds: opt.job_data_ttl_seconds, + job_data_clean_up_interval_seconds: opt.job_data_clean_up_interval_seconds, + grpc_max_decoding_message_size: opt.grpc_server_max_decoding_message_size, + grpc_max_encoding_message_size: opt.grpc_server_max_encoding_message_size, + executor_heartbeat_interval_seconds: opt.executor_heartbeat_interval_seconds, + data_cache_policy: opt.data_cache_policy, + cache_dir: opt.cache_dir, + cache_capacity: opt.cache_capacity, + cache_io_concurrency: opt.cache_io_concurrency, + override_execution_engine: None, + override_function_registry: None, + override_config_producer: None, + override_runtime_producer: None, + override_logical_codec: None, + override_physical_codec: None, + }) + } +} diff --git a/ballista/executor/src/executor_process.rs b/ballista/executor/src/executor_process.rs index 9a6187bdaa..41e4c029e1 100644 --- a/ballista/executor/src/executor_process.rs +++ b/ballista/executor/src/executor_process.rs @@ -98,17 +98,17 @@ pub struct ExecutorProcessConfig { pub executor_heartbeat_interval_seconds: u64, /// Optional execution engine to use to execute physical plans, will default to /// DataFusion if none is provided. - pub execution_engine: Option>, + pub override_execution_engine: Option>, /// Overrides default function registry - pub function_registry: Option>, + pub override_function_registry: Option>, /// [RuntimeProducer] override option - pub runtime_producer: Option, + pub override_runtime_producer: Option, /// [ConfigProducer] override option - pub config_producer: Option, + pub override_config_producer: Option, /// [PhysicalExtensionCodec] override option - pub logical_codec: Option>, + pub override_logical_codec: Option>, /// [PhysicalExtensionCodec] override option - pub physical_codec: Option>, + pub override_physical_codec: Option>, } pub async fn start_executor_process(opt: Arc) -> Result<()> { @@ -194,7 +194,7 @@ pub async fn start_executor_process(opt: Arc) -> Result<( // put them to session config let metrics_collector = Arc::new(LoggingMetricsCollector::default()); let config_producer = opt - .config_producer + .override_config_producer .clone() .unwrap_or_else(|| Arc::new(default_config_producer)); @@ -205,12 +205,12 @@ pub async fn start_executor_process(opt: Arc) -> Result<( }); let logical = opt - .logical_codec + .override_logical_codec .clone() .unwrap_or_else(|| Arc::new(BallistaLogicalExtensionCodec::default())); let physical = opt - .physical_codec + .override_physical_codec .clone() .unwrap_or_else(|| Arc::new(BallistaPhysicalExtensionCodec::default())); @@ -224,10 +224,10 @@ pub async fn start_executor_process(opt: Arc) -> Result<( &work_dir, runtime_producer, config_producer, - opt.function_registry.clone().unwrap_or_default(), + opt.override_function_registry.clone().unwrap_or_default(), metrics_collector, concurrent_tasks, - opt.execution_engine.clone(), + opt.override_execution_engine.clone(), )); let connect_timeout = opt.scheduler_connect_timeout_seconds as u64; diff --git a/ballista/executor/src/lib.rs b/ballista/executor/src/lib.rs index bc9d23e87d..f0284cbdb3 100644 --- a/ballista/executor/src/lib.rs +++ b/ballista/executor/src/lib.rs @@ -18,6 +18,7 @@ #![doc = include_str!("../README.md")] pub mod collect; +pub mod config; pub mod execution_engine; pub mod execution_loop; pub mod executor; diff --git a/ballista/scheduler/Cargo.toml b/ballista/scheduler/Cargo.toml index 642e63d480..ad3e09636f 100644 --- a/ballista/scheduler/Cargo.toml +++ b/ballista/scheduler/Cargo.toml @@ -41,7 +41,7 @@ prometheus-metrics = ["prometheus", "once_cell"] rest-api = [] [dependencies] -anyhow = "1" +anyhow = { workspace = true } arrow-flight = { workspace = true } async-trait = { workspace = true } axum = "0.7.7" diff --git a/ballista/scheduler/scheduler_config_spec.toml b/ballista/scheduler/scheduler_config_spec.toml index 804987d9af..20bceb5f2f 100644 --- a/ballista/scheduler/scheduler_config_spec.toml +++ b/ballista/scheduler/scheduler_config_spec.toml @@ -82,9 +82,9 @@ doc = "Delayed interval for cleaning up finished job state. Default: 3600" [[param]] name = "task_distribution" -type = "ballista_scheduler::config::TaskDistribution" +type = "crate::config::TaskDistribution" doc = "The policy of distributing tasks to available executor slots, possible values: bias, round-robin, consistent-hash. Default: bias" -default = "ballista_scheduler::config::TaskDistribution::Bias" +default = "crate::config::TaskDistribution::Bias" [[param]] name = "consistent_hash_num_replicas" diff --git a/ballista/scheduler/src/bin/main.rs b/ballista/scheduler/src/bin/main.rs index 7d8b4b1b09..f6a0632840 100644 --- a/ballista/scheduler/src/bin/main.rs +++ b/ballista/scheduler/src/bin/main.rs @@ -17,35 +17,16 @@ //! Ballista Rust scheduler binary. -use std::sync::Arc; -use std::{env, io}; - use anyhow::Result; - -use crate::config::{Config, ResultExt}; use ballista_core::config::LogRotationPolicy; use ballista_core::print_version; use ballista_scheduler::cluster::BallistaCluster; -use ballista_scheduler::config::{ - ClusterStorageConfig, SchedulerConfig, TaskDistribution, TaskDistributionPolicy, -}; +use ballista_scheduler::config::{Config, ResultExt}; use ballista_scheduler::scheduler_process::start_server; +use std::sync::Arc; +use std::{env, io}; use tracing_subscriber::EnvFilter; -#[allow(unused_imports)] -#[macro_use] -extern crate configure_me; - -#[allow(clippy::all, warnings)] -mod config { - // Ideally we would use the include_config macro from configure_me, but then we cannot use - // #[allow(clippy::all)] to silence clippy warnings from the generated code - include!(concat!( - env!("OUT_DIR"), - "/scheduler_configure_me_config.rs" - )); -} - fn main() -> Result<()> { let runtime = tokio::runtime::Builder::new_multi_thread() .enable_io() @@ -67,19 +48,23 @@ async fn inner() -> Result<()> { std::process::exit(0); } - let special_mod_log_level = opt.log_level_setting; - let log_dir = opt.log_dir; - let print_thread_info = opt.print_thread_info; + let rust_log = env::var(EnvFilter::DEFAULT_ENV); + let log_filter = EnvFilter::new(rust_log.unwrap_or(opt.log_level_setting.clone())); - let log_file_name_prefix = format!( - "scheduler_{}_{}_{}", - opt.namespace, opt.external_host, opt.bind_port - ); + let tracing = tracing_subscriber::fmt() + .with_ansi(false) + .with_thread_names(opt.print_thread_info) + .with_thread_ids(opt.print_thread_info) + .with_writer(io::stdout) + .with_env_filter(log_filter); - let rust_log = env::var(EnvFilter::DEFAULT_ENV); - let log_filter = EnvFilter::new(rust_log.unwrap_or(special_mod_log_level)); // File layer - if let Some(log_dir) = log_dir { + if let Some(log_dir) = &opt.log_dir { + let log_file_name_prefix = format!( + "scheduler_{}_{}_{}", + opt.namespace, opt.external_host, opt.bind_port + ); + let log_file = match opt.log_rotation_policy { LogRotationPolicy::Minutely => { tracing_appender::rolling::minutely(log_dir, &log_file_name_prefix) @@ -94,68 +79,16 @@ async fn inner() -> Result<()> { tracing_appender::rolling::never(log_dir, &log_file_name_prefix) } }; - tracing_subscriber::fmt() - .with_ansi(false) - .with_thread_names(print_thread_info) - .with_thread_ids(print_thread_info) - .with_writer(log_file) - .with_env_filter(log_filter) - .init(); + + tracing.with_writer(log_file).init(); } else { - // Console layer - tracing_subscriber::fmt() - .with_ansi(false) - .with_thread_names(print_thread_info) - .with_thread_ids(print_thread_info) - .with_writer(io::stdout) - .with_env_filter(log_filter) - .init(); + tracing.init(); } - let addr = format!("{}:{}", opt.bind_host, opt.bind_port); let addr = addr.parse()?; - - let cluster_storage_config = ClusterStorageConfig::Memory; - - let task_distribution = match opt.task_distribution { - TaskDistribution::Bias => TaskDistributionPolicy::Bias, - TaskDistribution::RoundRobin => TaskDistributionPolicy::RoundRobin, - TaskDistribution::ConsistentHash => { - let num_replicas = opt.consistent_hash_num_replicas as usize; - let tolerance = opt.consistent_hash_tolerance as usize; - TaskDistributionPolicy::ConsistentHash { - num_replicas, - tolerance, - } - } - }; - - let config = SchedulerConfig { - namespace: opt.namespace, - external_host: opt.external_host, - bind_port: opt.bind_port, - scheduling_policy: opt.scheduler_policy, - event_loop_buffer_size: opt.event_loop_buffer_size, - task_distribution, - finished_job_data_clean_up_interval_seconds: opt - .finished_job_data_clean_up_interval_seconds, - finished_job_state_clean_up_interval_seconds: opt - .finished_job_state_clean_up_interval_seconds, - advertise_flight_sql_endpoint: opt.advertise_flight_sql_endpoint, - cluster_storage: cluster_storage_config, - job_resubmit_interval_ms: (opt.job_resubmit_interval_ms > 0) - .then_some(opt.job_resubmit_interval_ms), - executor_termination_grace_period: opt.executor_termination_grace_period, - scheduler_event_expected_processing_duration: opt - .scheduler_event_expected_processing_duration, - grpc_server_max_decoding_message_size: opt.grpc_server_max_decoding_message_size, - grpc_server_max_encoding_message_size: opt.grpc_server_max_encoding_message_size, - executor_timeout_seconds: opt.executor_timeout_seconds, - expire_dead_executor_interval_seconds: opt.expire_dead_executor_interval_seconds, - }; - + let config = opt.try_into()?; let cluster = BallistaCluster::new_from_config(&config).await?; - start_server(cluster, addr, Arc::new(config)).await?; + Ok(()) } diff --git a/ballista/scheduler/src/cluster/memory.rs b/ballista/scheduler/src/cluster/memory.rs index 6e32510a0a..6df0440357 100644 --- a/ballista/scheduler/src/cluster/memory.rs +++ b/ballista/scheduler/src/cluster/memory.rs @@ -290,7 +290,7 @@ pub struct InMemoryJobState { session_builder: SessionBuilder, /// Sender of job events job_event_sender: ClusterEventSender, - + /// Config producer config_producer: ConfigProducer, } diff --git a/ballista/scheduler/src/cluster/mod.rs b/ballista/scheduler/src/cluster/mod.rs index 2869c8876e..94f86969e2 100644 --- a/ballista/scheduler/src/cluster/mod.rs +++ b/ballista/scheduler/src/cluster/mod.rs @@ -111,11 +111,21 @@ impl BallistaCluster { pub async fn new_from_config(config: &SchedulerConfig) -> Result { let scheduler = config.scheduler_name(); + let session_builder = config + .override_session_builder + .clone() + .unwrap_or_else(|| Arc::new(default_session_builder)); + + let config_producer = config + .override_config_producer + .clone() + .unwrap_or_else(|| Arc::new(default_config_producer)); + match &config.cluster_storage { ClusterStorageConfig::Memory => Ok(BallistaCluster::new_memory( scheduler, - Arc::new(default_session_builder), - Arc::new(default_config_producer), + session_builder, + config_producer, )), } } diff --git a/ballista/scheduler/src/config.rs b/ballista/scheduler/src/config.rs index ce542e5194..7bb85bd48f 100644 --- a/ballista/scheduler/src/config.rs +++ b/ballista/scheduler/src/config.rs @@ -18,12 +18,20 @@ //! Ballista scheduler specific configuration -use ballista_core::config::TaskSchedulingPolicy; +use crate::SessionBuilder; +use ballista_core::{config::TaskSchedulingPolicy, error::BallistaError, ConfigProducer}; use clap::ValueEnum; -use std::fmt; +use datafusion_proto::logical_plan::LogicalExtensionCodec; +use datafusion_proto::physical_plan::PhysicalExtensionCodec; +use std::{fmt, sync::Arc}; + +include!(concat!( + env!("OUT_DIR"), + "/scheduler_configure_me_config.rs" +)); /// Configurations for the ballista scheduler of scheduling jobs and tasks -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct SchedulerConfig { /// Namespace of this scheduler. Schedulers using the same cluster storage and namespace /// will share global cluster state. @@ -62,6 +70,65 @@ pub struct SchedulerConfig { pub executor_timeout_seconds: u64, /// The interval to check expired or dead executors pub expire_dead_executor_interval_seconds: u64, + + /// [ConfigProducer] override option + pub override_config_producer: Option, + /// [SessionBuilder] override option + pub override_session_builder: Option, + /// [PhysicalExtensionCodec] override option + pub override_logical_codec: Option>, + /// [PhysicalExtensionCodec] override option + pub override_physical_codec: Option>, +} + +impl std::fmt::Debug for SchedulerConfig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SchedulerConfig") + .field("namespace", &self.namespace) + .field("external_host", &self.external_host) + .field("bind_port", &self.bind_port) + .field("scheduling_policy", &self.scheduling_policy) + .field("event_loop_buffer_size", &self.event_loop_buffer_size) + .field("task_distribution", &self.task_distribution) + .field( + "finished_job_data_clean_up_interval_seconds", + &self.finished_job_data_clean_up_interval_seconds, + ) + .field( + "finished_job_state_clean_up_interval_seconds", + &self.finished_job_state_clean_up_interval_seconds, + ) + .field( + "advertise_flight_sql_endpoint", + &self.advertise_flight_sql_endpoint, + ) + .field("job_resubmit_interval_ms", &self.job_resubmit_interval_ms) + .field("cluster_storage", &self.cluster_storage) + .field( + "executor_termination_grace_period", + &self.executor_termination_grace_period, + ) + .field( + "scheduler_event_expected_processing_duration", + &self.scheduler_event_expected_processing_duration, + ) + .field( + "grpc_server_max_decoding_message_size", + &self.grpc_server_max_decoding_message_size, + ) + .field( + "grpc_server_max_encoding_message_size", + &self.grpc_server_max_encoding_message_size, + ) + .field("executor_timeout_seconds", &self.executor_timeout_seconds) + .field( + "expire_dead_executor_interval_seconds", + &self.expire_dead_executor_interval_seconds, + ) + .field("override_logical_codec", &self.override_logical_codec) + .field("override_physical_codec", &self.override_physical_codec) + .finish() + } } impl Default for SchedulerConfig { @@ -84,6 +151,10 @@ impl Default for SchedulerConfig { grpc_server_max_encoding_message_size: 16777216, executor_timeout_seconds: 180, expire_dead_executor_interval_seconds: 15, + override_config_producer: None, + override_session_builder: None, + override_logical_codec: None, + override_physical_codec: None, } } } @@ -231,3 +302,55 @@ pub enum TaskDistributionPolicy { tolerance: usize, }, } + +impl TryFrom for SchedulerConfig { + type Error = BallistaError; + + fn try_from(opt: Config) -> Result { + let task_distribution = match opt.task_distribution { + TaskDistribution::Bias => TaskDistributionPolicy::Bias, + TaskDistribution::RoundRobin => TaskDistributionPolicy::RoundRobin, + TaskDistribution::ConsistentHash => { + let num_replicas = opt.consistent_hash_num_replicas as usize; + let tolerance = opt.consistent_hash_tolerance as usize; + TaskDistributionPolicy::ConsistentHash { + num_replicas, + tolerance, + } + } + }; + + let config = SchedulerConfig { + namespace: opt.namespace, + external_host: opt.external_host, + bind_port: opt.bind_port, + scheduling_policy: opt.scheduler_policy, + event_loop_buffer_size: opt.event_loop_buffer_size, + task_distribution, + finished_job_data_clean_up_interval_seconds: opt + .finished_job_data_clean_up_interval_seconds, + finished_job_state_clean_up_interval_seconds: opt + .finished_job_state_clean_up_interval_seconds, + advertise_flight_sql_endpoint: opt.advertise_flight_sql_endpoint, + cluster_storage: ClusterStorageConfig::Memory, + job_resubmit_interval_ms: (opt.job_resubmit_interval_ms > 0) + .then_some(opt.job_resubmit_interval_ms), + executor_termination_grace_period: opt.executor_termination_grace_period, + scheduler_event_expected_processing_duration: opt + .scheduler_event_expected_processing_duration, + grpc_server_max_decoding_message_size: opt + .grpc_server_max_decoding_message_size, + grpc_server_max_encoding_message_size: opt + .grpc_server_max_encoding_message_size, + executor_timeout_seconds: opt.executor_timeout_seconds, + expire_dead_executor_interval_seconds: opt + .expire_dead_executor_interval_seconds, + override_config_producer: None, + override_logical_codec: None, + override_physical_codec: None, + override_session_builder: None, + }; + + Ok(config) + } +} diff --git a/ballista/scheduler/src/scheduler_process.rs b/ballista/scheduler/src/scheduler_process.rs index 4b97060797..393b03b624 100644 --- a/ballista/scheduler/src/scheduler_process.rs +++ b/ballista/scheduler/src/scheduler_process.rs @@ -19,7 +19,9 @@ use anyhow::{Error, Result}; #[cfg(feature = "flight-sql")] use arrow_flight::flight_service_server::FlightServiceServer; use ballista_core::serde::protobuf::scheduler_grpc_server::SchedulerGrpcServer; -use ballista_core::serde::BallistaCodec; +use ballista_core::serde::{ + BallistaCodec, BallistaLogicalExtensionCodec, BallistaPhysicalExtensionCodec, +}; use ballista_core::utils::create_grpc_server; use ballista_core::BALLISTA_VERSION; use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode}; @@ -54,11 +56,23 @@ pub async fn start_server( let metrics_collector = default_metrics_collector()?; + let codec_logical = config + .override_logical_codec + .clone() + .unwrap_or_else(|| Arc::new(BallistaLogicalExtensionCodec::default())); + + let codec_physical = config + .override_physical_codec + .clone() + .unwrap_or_else(|| Arc::new(BallistaPhysicalExtensionCodec::default())); + + let codec = BallistaCodec::new(codec_logical, codec_physical); + let mut scheduler_server: SchedulerServer = SchedulerServer::new( config.scheduler_name(), cluster, - BallistaCodec::default(), + codec, config, metrics_collector, ); diff --git a/examples/Cargo.toml b/examples/Cargo.toml index c87c039cf0..97b9f441ba 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -34,8 +34,16 @@ path = "examples/standalone-sql.rs" required-features = ["ballista/standalone"] [dependencies] +anyhow = { workspace = true } ballista = { path = "../ballista/client", version = "0.12.0" } +ballista-core = { path = "../ballista/core", version = "0.12.0" } +ballista-executor = { path = "../ballista/executor", version = "0.12.0" } +ballista-scheduler = { path = "../ballista/scheduler", version = "0.12.0" } datafusion = { workspace = true } +env_logger = { workspace = true } +log = { workspace = true } +object_store = { workspace = true, features = ["aws"] } +parking_lot = { workspace = true } tokio = { workspace = true, features = [ "macros", "rt", @@ -43,4 +51,4 @@ tokio = { workspace = true, features = [ "sync", "parking_lot" ] } - +url = { workspace = true } diff --git a/examples/examples/custom_client.rs b/examples/examples/custom_client.rs new file mode 100644 index 0000000000..f42ad8b21b --- /dev/null +++ b/examples/examples/custom_client.rs @@ -0,0 +1,96 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use ballista::extension::SessionContextExt; +use ballista_examples::object_store::{ + custom_config_with_s3_options, custom_state_with_s3_support, +}; +use datafusion::error::Result; +use datafusion::{assert_batches_eq, prelude::SessionContext}; + +const BUCKET: &str = "ballista"; +const ACCESS_KEY_ID: &str = "minio"; +const SECRET_KEY: &str = "miniominio"; + +#[tokio::main] +async fn main() -> Result<()> { + let test_data = ballista_examples::test_util::examples_test_data(); + + let state = custom_state_with_s3_support(custom_config_with_s3_options()); + + let ctx: SessionContext = + SessionContext::remote_with_state("df://localhost:50050", state).await?; + + // setting up relevant S3 options + ctx.sql("SET s3.allow_http = true").await?.show().await?; + ctx.sql(&format!("SET s3.access_key_id = '{}'", ACCESS_KEY_ID)) + .await? + .show() + .await?; + ctx.sql(&format!("SET s3.secret_access_key = '{}'", SECRET_KEY)) + .await? + .show() + .await?; + ctx.sql("SET s3.endpoint = 'http://localhost:9000'") + .await? + .show() + .await?; + ctx.sql("SET s3.allow_http = true").await?.show().await?; + + // verifying that we have set S3Options + ctx.sql( + "select name, value from information_schema.df_settings where name like 's3.%'", + ) + .await? + .show() + .await?; + + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await?; + + let write_dir_path = &format!("s3://{}/write_test.parquet", BUCKET); + + ctx.sql("select * from test") + .await? + .write_parquet(write_dir_path, Default::default(), Default::default()) + .await?; + + ctx.register_parquet("written_table", write_dir_path, Default::default()) + .await?; + + let result = ctx + .sql("select id, string_col, timestamp_col from written_table where id > 4") + .await? + .collect() + .await?; + let expected = [ + "+----+------------+---------------------+", + "| id | string_col | timestamp_col |", + "+----+------------+---------------------+", + "| 5 | 31 | 2009-03-01T00:01:00 |", + "| 6 | 30 | 2009-04-01T00:00:00 |", + "| 7 | 31 | 2009-04-01T00:01:00 |", + "+----+------------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + Ok(()) +} diff --git a/examples/examples/custom_executor.rs b/examples/examples/custom_executor.rs new file mode 100644 index 0000000000..77cab97f7c --- /dev/null +++ b/examples/examples/custom_executor.rs @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use anyhow::Result; +use ballista_examples::object_store::{ + custom_config_with_s3_options, custom_runtime_with_s3_support, +}; +use ballista_executor::config::prelude::*; +use ballista_executor::executor_process::{ + start_executor_process, ExecutorProcessConfig, +}; +use datafusion::prelude::SessionConfig; +use std::sync::Arc; + +#[tokio::main] +async fn main() -> Result<()> { + let (opt, _remaining_args) = + Config::including_optional_config_files(&["/etc/ballista/executor.toml"]) + .unwrap_or_exit(); + + if opt.version { + ballista_core::print_version(); + std::process::exit(0); + } + + let mut config: ExecutorProcessConfig = opt.try_into().unwrap(); + config.override_config_producer = Some(Arc::new(|| custom_config_with_s3_options())); + config.override_runtime_producer = + Some(Arc::new(|session_config: &SessionConfig| { + custom_runtime_with_s3_support(session_config) + })); + + start_executor_process(Arc::new(config)).await +} diff --git a/examples/examples/custom_scheduler.rs b/examples/examples/custom_scheduler.rs new file mode 100644 index 0000000000..e23b2385dc --- /dev/null +++ b/examples/examples/custom_scheduler.rs @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use anyhow::Result; +use ballista_core::print_version; +use ballista_examples::object_store::{ + custom_config_with_s3_options, custom_state_with_s3_support, +}; +use ballista_scheduler::cluster::BallistaCluster; +use ballista_scheduler::config::{Config, ResultExt, SchedulerConfig}; +use ballista_scheduler::scheduler_process::start_server; +use datafusion::prelude::SessionConfig; +use std::sync::Arc; + +#[tokio::main] +async fn main() -> Result<()> { + let _ = env_logger::builder() + .filter_level(log::LevelFilter::Info) + .is_test(true) + .try_init(); + + // parse options + let (opt, _remaining_args) = + Config::including_optional_config_files(&["/etc/ballista/scheduler.toml"]) + .unwrap_or_exit(); + + if opt.version { + print_version(); + std::process::exit(0); + } + + let addr = format!("{}:{}", opt.bind_host, opt.bind_port); + let addr = addr.parse()?; + let mut config: SchedulerConfig = opt.try_into()?; + + config.override_config_producer = Some(Arc::new(|| custom_config_with_s3_options())); + config.override_session_builder = Some(Arc::new(|session_config: SessionConfig| { + custom_state_with_s3_support(session_config) + })); + let cluster = BallistaCluster::new_from_config(&config).await?; + start_server(cluster, addr, Arc::new(config)).await?; + + Ok(()) +} diff --git a/examples/src/lib.rs b/examples/src/lib.rs index 6dc48f6b98..cc83fac5fa 100644 --- a/examples/src/lib.rs +++ b/examples/src/lib.rs @@ -15,4 +15,5 @@ // specific language governing permissions and limitations // under the License. +pub mod object_store; pub mod test_util; diff --git a/examples/src/object_store.rs b/examples/src/object_store.rs new file mode 100644 index 0000000000..b1b070916e --- /dev/null +++ b/examples/src/object_store.rs @@ -0,0 +1,294 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use ballista::prelude::SessionConfigExt; +use datafusion::common::{config_err, exec_err}; +use datafusion::config::{ + ConfigEntry, ConfigExtension, ConfigField, ExtensionOptions, Visit, +}; +use datafusion::error::Result; +use datafusion::execution::object_store::ObjectStoreRegistry; +use datafusion::execution::SessionState; +use datafusion::prelude::SessionConfig; +use datafusion::{ + error::DataFusionError, + execution::{ + runtime_env::{RuntimeConfig, RuntimeEnv}, + SessionStateBuilder, + }, +}; +use object_store::aws::AmazonS3Builder; +use object_store::local::LocalFileSystem; +use object_store::ObjectStore; +use parking_lot::RwLock; +use std::any::Any; +use std::fmt::Display; +use std::sync::Arc; +use url::Url; + +pub fn custom_config_with_s3_options() -> SessionConfig { + SessionConfig::new_with_ballista() + .with_information_schema(true) + .with_option_extension(S3Options::default()) +} + +pub fn custom_runtime_with_s3_support( + session_config: &SessionConfig, +) -> Result> { + let s3options = session_config + .options() + .extensions + .get::() + .ok_or(DataFusionError::Configuration( + "S3 Options not set".to_string(), + ))?; + + let config = RuntimeConfig::new().with_object_store_registry(Arc::new( + CustomObjectStoreRegistry::new(s3options.clone()), + )); + + Ok(Arc::new(RuntimeEnv::new(config)?)) +} + +pub fn custom_state_with_s3_support(session_config: SessionConfig) -> SessionState { + let runtime_env = custom_runtime_with_s3_support(&session_config).unwrap(); + + SessionStateBuilder::new() + .with_runtime_env(runtime_env.into()) + .with_config(session_config) + .build() +} + +#[derive(Debug)] +pub struct CustomObjectStoreRegistry { + local: Arc, + s3options: S3Options, +} + +impl CustomObjectStoreRegistry { + pub fn new(s3options: S3Options) -> Self { + Self { + s3options, + local: Arc::new(LocalFileSystem::new()), + } + } +} + +impl ObjectStoreRegistry for CustomObjectStoreRegistry { + fn register_store( + &self, + _url: &Url, + _store: Arc, + ) -> Option> { + unreachable!("register_store not supported ") + } + + fn get_store(&self, url: &Url) -> Result> { + let scheme = url.scheme(); + log::info!("get_store: {:?}", &self.s3options.config.read()); + match scheme { + "" | "file" => Ok(self.local.clone()), + "s3" => { + let s3store = + Self::s3_object_store_builder(url, &self.s3options.config.read())? + .build()?; + + Ok(Arc::new(s3store)) + } + + _ => exec_err!("get_store - store not supported, url {}", url), + } + } +} + +impl CustomObjectStoreRegistry { + pub fn s3_object_store_builder( + url: &Url, + aws_options: &S3RegistryConfiguration, + ) -> Result { + let S3RegistryConfiguration { + access_key_id, + secret_access_key, + session_token, + region, + endpoint, + allow_http, + } = aws_options; + + let bucket_name = Self::get_bucket_name(url)?; + let mut builder = AmazonS3Builder::from_env().with_bucket_name(bucket_name); + + if let (Some(access_key_id), Some(secret_access_key)) = + (access_key_id, secret_access_key) + { + builder = builder + .with_access_key_id(access_key_id) + .with_secret_access_key(secret_access_key); + + if let Some(session_token) = session_token { + builder = builder.with_token(session_token); + } + } else { + return config_err!( + "'s3.access_key_id' & 's3.secret_access_key' must be configured" + ); + } + + if let Some(region) = region { + builder = builder.with_region(region); + } + + if let Some(endpoint) = endpoint { + if let Ok(endpoint_url) = Url::try_from(endpoint.as_str()) { + if !matches!(allow_http, Some(true)) && endpoint_url.scheme() == "http" { + return config_err!("Invalid endpoint: {endpoint}. HTTP is not allowed for S3 endpoints. To allow HTTP, set 's3.allow_http' to true"); + } + } + + builder = builder.with_endpoint(endpoint); + } + + if let Some(allow_http) = allow_http { + builder = builder.with_allow_http(*allow_http); + } + + Ok(builder) + } + + fn get_bucket_name(url: &Url) -> Result<&str> { + url.host_str().ok_or_else(|| { + DataFusionError::Execution(format!( + "Not able to parse bucket name from url: {}", + url.as_str() + )) + }) + } +} + +#[derive(Debug, Clone, Default)] +pub struct S3Options { + config: Arc>, +} + +impl ExtensionOptions for S3Options { + fn as_any(&self) -> &dyn Any { + self + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + fn cloned(&self) -> Box { + Box::new(self.clone()) + } + + fn set(&mut self, key: &str, value: &str) -> Result<()> { + log::debug!("set config, key:{}, value:{}", key, value); + match key { + "access_key_id" => { + let mut c = self.config.write(); + c.access_key_id.set(key, value)?; + } + "secret_access_key" => { + let mut c = self.config.write(); + c.secret_access_key.set(key, value)?; + } + "session_token" => { + let mut c = self.config.write(); + c.session_token.set(key, value)?; + } + "region" => { + let mut c = self.config.write(); + c.region.set(key, value)?; + } + "endpoint" => { + let mut c = self.config.write(); + c.endpoint.set(key, value)?; + } + "allow_http" => { + let mut c = self.config.write(); + c.allow_http.set(key, value)?; + } + _ => { + log::warn!("Config value {} cant be set to {}", key, value); + return config_err!("Config value \"{}\" not found in S3Options", key); + } + } + Ok(()) + } + + fn entries(&self) -> Vec { + struct Visitor(Vec); + + impl Visit for Visitor { + fn some( + &mut self, + key: &str, + value: V, + description: &'static str, + ) { + self.0.push(ConfigEntry { + key: format!("{}.{}", S3Options::PREFIX, key), + value: Some(value.to_string()), + description, + }) + } + + fn none(&mut self, key: &str, description: &'static str) { + self.0.push(ConfigEntry { + key: format!("{}.{}", S3Options::PREFIX, key), + value: None, + description, + }) + } + } + let c = self.config.read(); + + let mut v = Visitor(vec![]); + c.access_key_id + .visit(&mut v, "access_key_id", "S3 Access Key"); + c.secret_access_key + .visit(&mut v, "secret_access_key", "S3 Secret Key"); + c.session_token + .visit(&mut v, "session_token", "S3 Session token"); + c.region.visit(&mut v, "region", "S3 region"); + c.endpoint.visit(&mut v, "endpoint", "S3 Endpoint"); + c.allow_http.visit(&mut v, "allow_http", "S3 Allow Http"); + + v.0 + } +} + +impl ConfigExtension for S3Options { + const PREFIX: &'static str = "s3"; +} +#[derive(Default, Debug, Clone)] +pub struct S3RegistryConfiguration { + /// Access Key ID + pub access_key_id: Option, + /// Secret Access Key + pub secret_access_key: Option, + /// Session token + pub session_token: Option, + /// AWS Region + pub region: Option, + /// OSS or COS Endpoint + pub endpoint: Option, + /// Allow HTTP (otherwise will always use https) + pub allow_http: Option, +}