From c49df1a33d8f8efecdfd4aae17169f11b5cac083 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marko=20Milenkovi=C4=87?= Date: Mon, 21 Oct 2024 22:23:59 +0100 Subject: [PATCH] Initial session store implementation for extensions closes: #1092 --- ballista/client/Cargo.toml | 2 + ballista/client/src/extension.rs | 225 +++++++++++------ ballista/client/tests/setup.rs | 407 +++++++++++++++++++++++++++++++ ballista/core/src/config.rs | 13 + ballista/core/src/serde/mod.rs | 2 +- ballista/core/src/utils.rs | 192 ++++++++++++++- 6 files changed, 769 insertions(+), 72 deletions(-) create mode 100644 ballista/client/tests/setup.rs diff --git a/ballista/client/Cargo.toml b/ballista/client/Cargo.toml index a8de27362..da61dab94 100644 --- a/ballista/client/Cargo.toml +++ b/ballista/client/Cargo.toml @@ -43,6 +43,8 @@ tokio = { workspace = true } url = { version = "2.5" } [dev-dependencies] +ballista-executor = { path = "../executor", version = "0.12.0" } +ballista-scheduler = { path = "../scheduler", version = "0.12.0" } ctor = { version = "0.2" } env_logger = { workspace = true } diff --git a/ballista/client/src/extension.rs b/ballista/client/src/extension.rs index ca104d3b1..99c8a88fc 100644 --- a/ballista/client/src/extension.rs +++ b/ballista/client/src/extension.rs @@ -15,15 +15,17 @@ // specific language governing permissions and limitations // under the License. +pub use ballista_core::utils::BallistaSessionConfigExt; use ballista_core::{ config::BallistaConfig, serde::protobuf::{ scheduler_grpc_client::SchedulerGrpcClient, CreateSessionParams, KeyValuePair, }, - utils::{create_df_ctx_with_ballista_query_planner, create_grpc_client_connection}, + utils::{create_grpc_client_connection, BallistaSessionStateExt}, +}; +use datafusion::{ + error::DataFusionError, execution::SessionState, prelude::SessionContext, }; -use datafusion::{error::DataFusionError, prelude::SessionContext}; -use datafusion_proto::protobuf::LogicalPlanNode; use url::Url; const DEFAULT_SCHEDULER_PORT: u16 = 50050; @@ -65,86 +67,155 @@ const DEFAULT_SCHEDULER_PORT: u16 = 50050; /// #[async_trait::async_trait] pub trait SessionContextExt { - /// Create a context for executing queries against a standalone Ballista scheduler instance + /// Creates a context for executing queries against a standalone Ballista scheduler instance + /// /// It wills start local ballista cluster with scheduler and executor. #[cfg(feature = "standalone")] async fn standalone() -> datafusion::error::Result; - /// Create a context for executing queries against a remote Ballista scheduler instance + /// Creates a context for executing queries against a standalone Ballista scheduler instance + /// with custom session state. + /// + /// It wills start local ballista cluster with scheduler and executor. + #[cfg(feature = "standalone")] + async fn standalone_with_state( + state: SessionState, + ) -> datafusion::error::Result; + + /// Creates a context for executing queries against a remote Ballista scheduler instance async fn remote(url: &str) -> datafusion::error::Result; + + /// Creates a context for executing queries against a remote Ballista scheduler instance + /// with custom session state + async fn remote_with_state( + url: &str, + state: SessionState, + ) -> datafusion::error::Result; } #[async_trait::async_trait] impl SessionContextExt for SessionContext { - async fn remote(url: &str) -> datafusion::error::Result { - let url = - Url::parse(url).map_err(|e| DataFusionError::Configuration(e.to_string()))?; - let host = url.host().ok_or(DataFusionError::Configuration( - "hostname should be provided".to_string(), - ))?; - let port = url.port().unwrap_or(DEFAULT_SCHEDULER_PORT); - let scheduler_url = format!("http://{}:{}", &host, port); + async fn remote_with_state( + url: &str, + state: SessionState, + ) -> datafusion::error::Result { + let config = state.ballista_config(); + + let scheduler_url = Extension::parse_url(url)?; log::info!( "Connecting to Ballista scheduler at {}", scheduler_url.clone() ); - let connection = create_grpc_client_connection(scheduler_url.clone()) - .await - .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; + let remote_session_id = + Extension::setup_remote(config, scheduler_url.clone()).await?; + log::info!( + "Server side SessionContext created with session id: {}", + remote_session_id + ); + let session_state = + state.upgrade_for_ballista(scheduler_url, remote_session_id)?; + + Ok(SessionContext::new_with_state(session_state)) + } - let config = BallistaConfig::builder() - .build() + async fn remote(url: &str) -> datafusion::error::Result { + let config = BallistaConfig::new() .map_err(|e| DataFusionError::Configuration(e.to_string()))?; + let scheduler_url = Extension::parse_url(url)?; + log::info!( + "Connecting to Ballista scheduler at {}", + scheduler_url.clone() + ); + let remote_session_id = + Extension::setup_remote(config, scheduler_url.clone()).await?; + log::info!( + "Server side SessionContext created with session id: {}", + remote_session_id + ); + let session_state = + SessionState::new_ballista_state(scheduler_url, remote_session_id)?; - let limit = config.default_grpc_client_max_message_size(); - let mut scheduler = SchedulerGrpcClient::new(connection) - .max_encoding_message_size(limit) - .max_decoding_message_size(limit); + Ok(SessionContext::new_with_state(session_state)) + } - let remote_session_id = scheduler - .create_session(CreateSessionParams { - settings: config - .settings() - .iter() - .map(|(k, v)| KeyValuePair { - key: k.to_owned(), - value: v.to_owned(), - }) - .collect::>(), - }) - .await - .map_err(|e| DataFusionError::Execution(format!("{e:?}")))? - .into_inner() - .session_id; + #[cfg(feature = "standalone")] + async fn standalone_with_state( + state: SessionState, + ) -> datafusion::error::Result { + let config = state.ballista_config(); + + let codec_logical = state.config().ballista_logical_extension_codec(); + let codec_physical = state.config().ballista_physical_extension_codec(); + + let ballista_codec = + ballista_core::serde::BallistaCodec::new(codec_logical, codec_physical); + + let (remote_session_id, scheduler_url) = + Extension::setup_standalone(config, ballista_codec).await?; + + let session_state = + state.upgrade_for_ballista(scheduler_url, remote_session_id.clone())?; log::info!( "Server side SessionContext created with session id: {}", remote_session_id ); - let ctx = { - create_df_ctx_with_ballista_query_planner::( - scheduler_url, - remote_session_id, - &config, - ) - }; - - Ok(ctx) + Ok(SessionContext::new_with_state(session_state)) } #[cfg(feature = "standalone")] async fn standalone() -> datafusion::error::Result { - use ballista_core::serde::BallistaCodec; - use datafusion_proto::protobuf::PhysicalPlanNode; - log::info!("Running in local mode. Scheduler will be run in-proc"); + let config = BallistaConfig::new() + .map_err(|e| DataFusionError::Configuration(e.to_string()))?; + let ballista_codec = ballista_core::serde::BallistaCodec::default(); + + let (remote_session_id, scheduler_url) = + Extension::setup_standalone(config, ballista_codec).await?; + + let session_state = + SessionState::new_ballista_state(scheduler_url, remote_session_id.clone())?; + + log::info!( + "Server side SessionContext created with session id: {}", + remote_session_id + ); + + Ok(SessionContext::new_with_state(session_state)) + } +} + +struct Extension {} + +impl Extension { + fn parse_url(url: &str) -> datafusion::error::Result { + let url = + Url::parse(url).map_err(|e| DataFusionError::Configuration(e.to_string()))?; + let host = url.host().ok_or(DataFusionError::Configuration( + "hostname should be provided".to_string(), + ))?; + let port = url.port().unwrap_or(DEFAULT_SCHEDULER_PORT); + let scheduler_url = format!("http://{}:{}", &host, port); + + Ok(scheduler_url) + } + + #[cfg(feature = "standalone")] + async fn setup_standalone( + config: BallistaConfig, + ballista_codec: ballista_core::serde::BallistaCodec< + datafusion_proto::protobuf::LogicalPlanNode, + datafusion_proto::protobuf::PhysicalPlanNode, + >, + ) -> datafusion::error::Result<(String, String)> { let addr = ballista_scheduler::standalone::new_standalone_scheduler() .await .map_err(|e| DataFusionError::Configuration(e.to_string()))?; let scheduler_url = format!("http://localhost:{}", addr.port()); + let mut scheduler = loop { match SchedulerGrpcClient::connect(scheduler_url.clone()).await { Err(_) => { @@ -154,9 +225,7 @@ impl SessionContextExt for SessionContext { Ok(scheduler) => break scheduler, } }; - let config = BallistaConfig::builder() - .build() - .map_err(|e| DataFusionError::Configuration(e.to_string()))?; + let remote_session_id = scheduler .create_session(CreateSessionParams { settings: config @@ -173,31 +242,47 @@ impl SessionContextExt for SessionContext { .into_inner() .session_id; - log::info!( - "Server side SessionContext created with session id: {}", - remote_session_id - ); - - let ctx = { - create_df_ctx_with_ballista_query_planner::( - scheduler_url, - remote_session_id, - &config, - ) - }; - - let default_codec: BallistaCodec = - BallistaCodec::default(); - let concurrent_tasks = config.default_standalone_parallelism(); ballista_executor::new_standalone_executor( scheduler, concurrent_tasks, - default_codec, + ballista_codec, ) .await .map_err(|e| DataFusionError::Configuration(e.to_string()))?; - Ok(ctx) + Ok((remote_session_id, scheduler_url)) + } + + async fn setup_remote( + config: BallistaConfig, + scheduler_url: String, + ) -> datafusion::error::Result { + let connection = create_grpc_client_connection(scheduler_url.clone()) + .await + .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; + + let limit = config.default_grpc_client_max_message_size(); + let mut scheduler = SchedulerGrpcClient::new(connection) + .max_encoding_message_size(limit) + .max_decoding_message_size(limit); + + let remote_session_id = scheduler + .create_session(CreateSessionParams { + settings: config + .settings() + .iter() + .map(|(k, v)| KeyValuePair { + key: k.to_owned(), + value: v.to_owned(), + }) + .collect::>(), + }) + .await + .map_err(|e| DataFusionError::Execution(format!("{e:?}")))? + .into_inner() + .session_id; + + Ok(remote_session_id) } } diff --git a/ballista/client/tests/setup.rs b/ballista/client/tests/setup.rs new file mode 100644 index 000000000..30a6df84a --- /dev/null +++ b/ballista/client/tests/setup.rs @@ -0,0 +1,407 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod common; + +#[cfg(test)] +mod remote { + use ballista::{ + extension::{BallistaSessionConfigExt, SessionContextExt}, + prelude::BALLISTA_JOB_NAME, + }; + use datafusion::{ + assert_batches_eq, + execution::SessionStateBuilder, + prelude::{SessionConfig, SessionContext}, + }; + + #[tokio::test] + async fn should_execute_sql_show_with_custom_state() -> datafusion::error::Result<()> + { + let (host, port) = crate::common::setup_test_cluster().await; + let url = format!("df://{host}:{port}"); + let state = SessionStateBuilder::new().with_default_features().build(); + + let test_data = crate::common::example_test_data(); + let ctx: SessionContext = SessionContext::remote_with_state(&url, state).await?; + + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await?; + + let result = ctx + .sql("select string_col, timestamp_col from test where id > 4") + .await? + .collect() + .await?; + let expected = [ + "+------------+---------------------+", + "| string_col | timestamp_col |", + "+------------+---------------------+", + "| 31 | 2009-03-01T00:01:00 |", + "| 30 | 2009-04-01T00:00:00 |", + "| 31 | 2009-04-01T00:01:00 |", + "+------------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + + #[tokio::test] + async fn should_execute_sql_set_configs() -> datafusion::error::Result<()> { + let (host, port) = crate::common::setup_test_cluster().await; + let url = format!("df://{host}:{port}"); + + let session_config = SessionConfig::new_with_ballista() + .with_information_schema(true) + .set_str(BALLISTA_JOB_NAME, "Super Cool Ballista App"); + + let state = SessionStateBuilder::new() + .with_default_features() + .with_config(session_config) + .build(); + + let ctx: SessionContext = SessionContext::remote_with_state(&url, state).await?; + + let result = ctx + .sql("select name, value from information_schema.df_settings where name like 'ballista.job.name' order by name limit 1") + .await? + .collect() + .await?; + + let expected = [ + "+-------------------+-------------------------+", + "| name | value |", + "+-------------------+-------------------------+", + "| ballista.job.name | Super Cool Ballista App |", + "+-------------------+-------------------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } +} + +#[cfg(test)] +#[cfg(feature = "standalone")] +mod standalone { + + use std::sync::{atomic::AtomicBool, Arc}; + + use ballista::{ + extension::{BallistaSessionConfigExt, SessionContextExt}, + prelude::BALLISTA_JOB_NAME, + }; + use ballista_core::{ + config::BALLISTA_PLANNER_OVERRIDE, serde::BallistaPhysicalExtensionCodec, + }; + use datafusion::{ + assert_batches_eq, + common::exec_err, + execution::{context::QueryPlanner, SessionState, SessionStateBuilder}, + logical_expr::LogicalPlan, + physical_plan::ExecutionPlan, + prelude::{SessionConfig, SessionContext}, + }; + use datafusion_proto::{ + logical_plan::LogicalExtensionCodec, physical_plan::PhysicalExtensionCodec, + }; + + #[tokio::test] + async fn should_execute_sql_set_configs() -> datafusion::error::Result<()> { + let session_config = SessionConfig::new_with_ballista() + .with_information_schema(true) + .set_str(BALLISTA_JOB_NAME, "Super Cool Ballista App"); + + let state = SessionStateBuilder::new() + .with_default_features() + .with_config(session_config) + .build(); + + let ctx: SessionContext = SessionContext::standalone_with_state(state).await?; + + let result = ctx + .sql("select name, value from information_schema.df_settings where name like 'ballista.job.name' order by name limit 1") + .await? + .collect() + .await?; + + let expected = [ + "+-------------------+-------------------------+", + "| name | value |", + "+-------------------+-------------------------+", + "| ballista.job.name | Super Cool Ballista App |", + "+-------------------+-------------------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + + // we testing if we can override default logical codec + // in this specific test codec will throw exception which will + // fail the query. + #[tokio::test] + async fn should_set_logical_codec() -> datafusion::error::Result<()> { + let test_data = crate::common::example_test_data(); + let codec = Arc::new(BadLogicalCodec::default()); + + let session_config = SessionConfig::new_with_ballista() + .with_information_schema(true) + .with_ballista_logical_extension_codec(codec.clone()); + + let state = SessionStateBuilder::new() + .with_default_features() + .with_config(session_config) + .build(); + + let ctx: SessionContext = SessionContext::standalone_with_state(state).await?; + + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await?; + + let write_dir = tempfile::tempdir().expect("temporary directory to be created"); + let write_dir_path = write_dir + .path() + .to_str() + .expect("path to be converted to str"); + + let result = ctx + .sql("select * from test") + .await? + .write_parquet(write_dir_path, Default::default(), Default::default()) + .await; + + // this codec should query fail + assert!(result.is_err()); + assert!(codec.invoked.load(std::sync::atomic::Ordering::Relaxed)); + Ok(()) + } + + // tests if we can correctly set physical codec + #[tokio::test] + async fn should_set_physical_codec() -> datafusion::error::Result<()> { + let test_data = crate::common::example_test_data(); + let physical_codec = Arc::new(MockPhysicalCodec::default()); + let session_config = SessionConfig::new_with_ballista() + .with_information_schema(true) + .with_ballista_physical_extension_codec(physical_codec.clone()); + + let state = SessionStateBuilder::new() + .with_default_features() + .with_config(session_config) + .build(); + + let ctx: SessionContext = SessionContext::standalone_with_state(state).await?; + + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await?; + + let _result = ctx + .sql("select string_col, timestamp_col from test where id > 4") + .await? + .collect() + .await; + + assert!(physical_codec + .invoked + .load(std::sync::atomic::Ordering::Relaxed)); + Ok(()) + } + + // check + #[tokio::test] + async fn should_override_planner() -> datafusion::error::Result<()> { + let session_config = SessionConfig::new_with_ballista() + .with_information_schema(true) + .set_str(BALLISTA_PLANNER_OVERRIDE, "false"); + + let state = SessionStateBuilder::new() + .with_default_features() + .with_config(session_config) + .with_query_planner(Arc::new(BadPlanner::default())) + .build(); + + let ctx: SessionContext = SessionContext::standalone_with_state(state).await?; + + let result = ctx.sql("SELECT 1").await?.collect().await; + + assert!(result.is_err()); + + let session_config = SessionConfig::new_with_ballista() + .with_information_schema(true) + .set_str(BALLISTA_PLANNER_OVERRIDE, "true"); + + let state = SessionStateBuilder::new() + .with_default_features() + .with_config(session_config) + .with_query_planner(Arc::new(BadPlanner::default())) + .build(); + + let ctx: SessionContext = SessionContext::standalone_with_state(state).await?; + + let result = ctx.sql("SELECT 1").await?.collect().await; + + assert!(result.is_ok()); + + Ok(()) + } + + #[derive(Debug, Default)] + struct BadLogicalCodec { + invoked: AtomicBool, + } + + impl LogicalExtensionCodec for BadLogicalCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[datafusion::logical_expr::LogicalPlan], + _ctx: &SessionContext, + ) -> datafusion::error::Result { + self.invoked + .store(true, std::sync::atomic::Ordering::Relaxed); + exec_err!("this codec does not work") + } + + fn try_encode( + &self, + _node: &datafusion::logical_expr::Extension, + _buf: &mut Vec, + ) -> datafusion::error::Result<()> { + self.invoked + .store(true, std::sync::atomic::Ordering::Relaxed); + exec_err!("this codec does not work") + } + + fn try_decode_table_provider( + &self, + _buf: &[u8], + _table_ref: &datafusion::sql::TableReference, + _schema: datafusion::arrow::datatypes::SchemaRef, + _ctx: &SessionContext, + ) -> datafusion::error::Result< + std::sync::Arc, + > { + self.invoked + .store(true, std::sync::atomic::Ordering::Relaxed); + exec_err!("this codec does not work") + } + + fn try_encode_table_provider( + &self, + _table_ref: &datafusion::sql::TableReference, + _node: std::sync::Arc, + _buf: &mut Vec, + ) -> datafusion::error::Result<()> { + self.invoked + .store(true, std::sync::atomic::Ordering::Relaxed); + exec_err!("this codec does not work") + } + + fn try_decode_file_format( + &self, + _buf: &[u8], + _ctx: &SessionContext, + ) -> datafusion::error::Result< + Arc, + > { + self.invoked + .store(true, std::sync::atomic::Ordering::Relaxed); + exec_err!("this codec does not work") + } + + fn try_encode_file_format( + &self, + _buf: &mut Vec, + _node: Arc, + ) -> datafusion::error::Result<()> { + self.invoked + .store(true, std::sync::atomic::Ordering::Relaxed); + //Ok(()) + exec_err!("this codec does not work") + } + } + + #[derive(Debug)] + struct MockPhysicalCodec { + invoked: AtomicBool, + codec: Arc, + } + + impl Default for MockPhysicalCodec { + fn default() -> Self { + Self { + invoked: AtomicBool::new(false), + codec: Arc::new(BallistaPhysicalExtensionCodec::default()), + } + } + } + + impl PhysicalExtensionCodec for MockPhysicalCodec { + fn try_decode( + &self, + buf: &[u8], + inputs: &[Arc], + registry: &dyn datafusion::execution::FunctionRegistry, + ) -> datafusion::error::Result> + { + self.invoked + .store(true, std::sync::atomic::Ordering::Relaxed); + self.codec.try_decode(buf, inputs, registry) + } + + fn try_encode( + &self, + node: Arc, + buf: &mut Vec, + ) -> datafusion::error::Result<()> { + self.invoked + .store(true, std::sync::atomic::Ordering::Relaxed); + self.codec.try_encode(node, buf) + } + } + + #[derive(Default)] + struct BadPlanner {} + + #[async_trait::async_trait] + impl QueryPlanner for BadPlanner { + async fn create_physical_plan( + &self, + _logical_plan: &LogicalPlan, + _session_state: &SessionState, + ) -> datafusion::error::Result> { + exec_err!("does not work") + } + } +} diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs index 88cba1d9a..782b8b9d0 100644 --- a/ballista/core/src/config.rs +++ b/ballista/core/src/config.rs @@ -43,6 +43,11 @@ pub const BALLISTA_REPARTITION_WINDOWS: &str = "ballista.repartition.windows"; pub const BALLISTA_PARQUET_PRUNING: &str = "ballista.parquet.pruning"; pub const BALLISTA_COLLECT_STATISTICS: &str = "ballista.collect_statistics"; pub const BALLISTA_STANDALONE_PARALLELISM: &str = "ballista.standalone.parallelism"; +/// If set to false, planner will not be overridden by ballista. +/// This allows user to replace ballista planner +// this is a bit of a hack, as we can't detect if there is a +// custom planner provided +pub const BALLISTA_PLANNER_OVERRIDE: &str = "ballista.planner.override"; pub const BALLISTA_WITH_INFORMATION_SCHEMA: &str = "ballista.with_information_schema"; @@ -216,6 +221,10 @@ impl BallistaConfig { "Configuration for max message size in gRPC clients".to_string(), DataType::UInt64, Some((16 * 1024 * 1024).to_string())), + ConfigEntry::new(BALLISTA_PLANNER_OVERRIDE.to_string(), + "Disable overriding provided planner".to_string(), + DataType::Boolean, + Some((true).to_string())), ]; entries .iter() @@ -271,6 +280,10 @@ impl BallistaConfig { self.get_bool_setting(BALLISTA_WITH_INFORMATION_SCHEMA) } + pub fn planner_override(&self) -> bool { + self.get_bool_setting(BALLISTA_PLANNER_OVERRIDE) + } + fn get_usize_setting(&self, key: &str) -> usize { if let Some(v) = self.settings.get(key) { // infallible because we validate all configs in the constructor diff --git a/ballista/core/src/serde/mod.rs b/ballista/core/src/serde/mod.rs index 0381ba6f7..7464fe686 100644 --- a/ballista/core/src/serde/mod.rs +++ b/ballista/core/src/serde/mod.rs @@ -245,7 +245,7 @@ impl LogicalExtensionCodec for BallistaLogicalExtensionCodec { } } -#[derive(Debug)] +#[derive(Debug, Default)] pub struct BallistaPhysicalExtensionCodec {} impl PhysicalExtensionCodec for BallistaPhysicalExtensionCodec { diff --git a/ballista/core/src/utils.rs b/ballista/core/src/utils.rs index eceb9d447..8be32c402 100644 --- a/ballista/core/src/utils.rs +++ b/ballista/core/src/utils.rs @@ -22,7 +22,7 @@ use crate::execution_plans::{ }; use crate::object_store_registry::with_object_store_registry; use crate::serde::scheduler::PartitionStats; -use crate::serde::BallistaLogicalExtensionCodec; +use crate::serde::{BallistaLogicalExtensionCodec, BallistaPhysicalExtensionCodec}; use async_trait::async_trait; use datafusion::arrow::datatypes::Schema; @@ -51,6 +51,8 @@ use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::{metrics, ExecutionPlan, RecordBatchStream}; use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}; use datafusion_proto::logical_plan::{AsLogicalPlan, LogicalExtensionCodec}; +use datafusion_proto::physical_plan::PhysicalExtensionCodec; +use datafusion_proto::protobuf::LogicalPlanNode; use futures::StreamExt; use log::error; use std::io::{BufWriter, Write}; @@ -275,6 +277,194 @@ pub fn create_df_ctx_with_ballista_query_planner( SessionContext::new_with_state(session_state) } +pub trait BallistaSessionStateExt { + fn new_ballista_state( + scheduler_url: String, + session_id: String, + ) -> datafusion::error::Result; + fn upgrade_for_ballista( + self, + scheduler_url: String, + session_id: String, + ) -> datafusion::error::Result; + + fn ballista_config(&self) -> BallistaConfig; +} + +impl BallistaSessionStateExt for SessionState { + fn ballista_config(&self) -> BallistaConfig { + self.config() + .options() + .extensions + .get::() + .cloned() + .unwrap_or_else(|| BallistaConfig::new().unwrap()) + } + + fn new_ballista_state( + scheduler_url: String, + session_id: String, + ) -> datafusion::error::Result { + let config = BallistaConfig::new() + .map_err(|e| DataFusionError::Configuration(e.to_string()))?; + + let planner = + BallistaQueryPlanner::::new(scheduler_url, config.clone()); + + let session_config = SessionConfig::new() + .with_information_schema(true) + .with_option_extension(config.clone()); + + let runtime_config = RuntimeConfig::default(); + let runtime_env = RuntimeEnv::new(runtime_config)?; + let session_state = SessionStateBuilder::new() + .with_default_features() + .with_config(session_config) + .with_runtime_env(Arc::new(runtime_env)) + .with_query_planner(Arc::new(planner)) + .with_session_id(session_id) + .build(); + + Ok(session_state) + } + + fn upgrade_for_ballista( + self, + scheduler_url: String, + session_id: String, + ) -> datafusion::error::Result { + let codec_logical = self.config().ballista_logical_extension_codec(); + + let new_config = self + .config() + .options() + .extensions + .get::() + .cloned() + .unwrap_or_else(|| BallistaConfig::new().unwrap()); + + let session_config = self + .config() + .clone() + .with_option_extension(new_config.clone()); + + // at the moment we don't have a way to detect if + // user set planner so we provide a configuration to + // user to disable planner override + let planner_override = self + .config() + .options() + .extensions + .get::() + .map(|config| config.planner_override()) + .unwrap_or(true); + + let builder = SessionStateBuilder::new_from_existing(self) + .with_config(session_config) + .with_session_id(session_id); + + let builder = if planner_override { + let query_planner = BallistaQueryPlanner::::with_extension( + scheduler_url, + new_config, + codec_logical, + ); + builder.with_query_planner(Arc::new(query_planner)) + } else { + builder + }; + + Ok(builder.build()) + } +} + +pub trait BallistaSessionConfigExt { + /// Creates session config which has + /// ballista configuration initialized + fn new_with_ballista() -> SessionConfig; + + /// Overrides ballista's [LogicalExtensionCodec] + fn with_ballista_logical_extension_codec( + self, + codec: Arc, + ) -> SessionConfig; + + /// Overrides ballista's [PhysicalExtensionCodec] + fn with_ballista_physical_extension_codec( + self, + codec: Arc, + ) -> SessionConfig; + + /// returns [LogicalExtensionCodec] if set + /// or default ballista codec if not + fn ballista_logical_extension_codec(&self) -> Arc; + + /// returns [PhysicalExtensionCodec] if set + /// or default ballista codec if not + fn ballista_physical_extension_codec(&self) -> Arc; +} + +impl BallistaSessionConfigExt for SessionConfig { + fn new_with_ballista() -> SessionConfig { + SessionConfig::new().with_option_extension(BallistaConfig::new().unwrap()) + } + fn with_ballista_logical_extension_codec( + self, + codec: Arc, + ) -> SessionConfig { + let extension = BallistaConfigExtensionLogicalCodec::new(codec); + self.with_extension(Arc::new(extension)) + } + fn with_ballista_physical_extension_codec( + self, + codec: Arc, + ) -> SessionConfig { + let extension = BallistaConfigExtensionPhysicalCodec::new(codec); + self.with_extension(Arc::new(extension)) + } + + fn ballista_logical_extension_codec(&self) -> Arc { + self.get_extension::() + .map(|c| c.codec()) + .unwrap_or_else(|| Arc::new(BallistaLogicalExtensionCodec::default())) + } + fn ballista_physical_extension_codec(&self) -> Arc { + self.get_extension::() + .map(|c| c.codec()) + .unwrap_or_else(|| Arc::new(BallistaPhysicalExtensionCodec::default())) + } +} + +/// Wrapper for [SessionConfig] extension +/// holding [LogicalExtensionCodec] if overridden +struct BallistaConfigExtensionLogicalCodec { + codec: Arc, +} + +impl BallistaConfigExtensionLogicalCodec { + fn new(codec: Arc) -> Self { + Self { codec } + } + fn codec(&self) -> Arc { + self.codec.clone() + } +} + +/// Wrapper for [SessionConfig] extension +/// holding [PhysicalExtensionCodec] if overridden +struct BallistaConfigExtensionPhysicalCodec { + codec: Arc, +} + +impl BallistaConfigExtensionPhysicalCodec { + fn new(codec: Arc) -> Self { + Self { codec } + } + fn codec(&self) -> Arc { + self.codec.clone() + } +} + pub struct BallistaQueryPlanner { scheduler_url: String, config: BallistaConfig,