From 1f160d600afaab6293bb3414d0d33cd28680ef7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marko=20Milenkovi=C4=87?= Date: Sun, 9 Feb 2025 08:41:06 +0000 Subject: [PATCH] feat: add support for insert into --- ballista/client/tests/context_checks.rs | 65 +++++++++ ballista/client/tests/context_unsupported.rs | 65 --------- ballista/core/src/config.rs | 15 +++ ballista/core/src/planner.rs | 41 +++++- ballista/core/src/serde/mod.rs | 126 +++++++++++++++++- .../scheduler/src/scheduler_server/grpc.rs | 31 +++++ 6 files changed, 273 insertions(+), 70 deletions(-) diff --git a/ballista/client/tests/context_checks.rs b/ballista/client/tests/context_checks.rs index ac83e3bdf..2601293de 100644 --- a/ballista/client/tests/context_checks.rs +++ b/ballista/client/tests/context_checks.rs @@ -399,4 +399,69 @@ mod supported { assert_batches_eq!(expected, &result); } + + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + #[cfg(not(windows))] // test is failing at windows, can't debug it + async fn should_support_sql_insert_into( + #[future(awt)] + #[case] + ctx: SessionContext, + test_data: String, + ) { + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await + .unwrap(); + + let write_dir = tempfile::tempdir().expect("temporary directory to be created"); + let write_dir_path = write_dir + .path() + .to_str() + .expect("path to be converted to str"); + + ctx.sql("select * from test") + .await + .unwrap() + .write_parquet(write_dir_path, Default::default(), Default::default()) + .await + .unwrap(); + + ctx.register_parquet("written_table", write_dir_path, Default::default()) + .await + .unwrap(); + + ctx.sql("INSERT INTO written_table select * from test") + .await + .unwrap() + .show() + .await + .unwrap(); + + let result = ctx + .sql("select id, string_col, timestamp_col from written_table where id > 4 order by id") + .await.unwrap() + .collect() + .await.unwrap(); + + let expected = [ + "+----+------------+---------------------+", + "| id | string_col | timestamp_col |", + "+----+------------+---------------------+", + "| 5 | 31 | 2009-03-01T00:01:00 |", + "| 5 | 31 | 2009-03-01T00:01:00 |", + "| 6 | 30 | 2009-04-01T00:00:00 |", + "| 6 | 30 | 2009-04-01T00:00:00 |", + "| 7 | 31 | 2009-04-01T00:01:00 |", + "| 7 | 31 | 2009-04-01T00:01:00 |", + "+----+------------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + } } diff --git a/ballista/client/tests/context_unsupported.rs b/ballista/client/tests/context_unsupported.rs index 805e81325..aa9827993 100644 --- a/ballista/client/tests/context_unsupported.rs +++ b/ballista/client/tests/context_unsupported.rs @@ -144,71 +144,6 @@ mod unsupported { "+----+----------+---------------------+", ]; - assert_batches_eq!(expected, &result); - } - #[rstest] - #[case::standalone(standalone_context())] - #[case::remote(remote_context())] - #[tokio::test] - #[should_panic] - // "Error: Internal(failed to serialize logical plan: Internal(LogicalPlan serde is not yet implemented for Dml))" - async fn should_support_sql_insert_into( - #[future(awt)] - #[case] - ctx: SessionContext, - test_data: String, - ) { - ctx.register_parquet( - "test", - &format!("{test_data}/alltypes_plain.parquet"), - Default::default(), - ) - .await - .unwrap(); - let write_dir = tempfile::tempdir().expect("temporary directory to be created"); - let write_dir_path = write_dir - .path() - .to_str() - .expect("path to be converted to str"); - - ctx.sql("select * from test") - .await - .unwrap() - .write_parquet(write_dir_path, Default::default(), Default::default()) - .await - .unwrap(); - - ctx.register_parquet("written_table", write_dir_path, Default::default()) - .await - .unwrap(); - - let _ = ctx - .sql("INSERT INTO written_table select * from written_table") - .await - .unwrap() - .collect() - .await - .unwrap(); - - let result = ctx - .sql("select id, string_col, timestamp_col from written_table where id > 4 order by id") - .await.unwrap() - .collect() - .await.unwrap(); - - let expected = [ - "+----+------------+---------------------+", - "| id | string_col | timestamp_col |", - "+----+------------+---------------------+", - "| 5 | 31 | 2009-03-01T00:01:00 |", - "| 5 | 31 | 2009-03-01T00:01:00 |", - "| 6 | 30 | 2009-04-01T00:00:00 |", - "| 6 | 30 | 2009-04-01T00:00:00 |", - "| 7 | 31 | 2009-04-01T00:01:00 |", - "| 7 | 31 | 2009-04-01T00:01:00 |", - "+----+------------+---------------------+", - ]; - assert_batches_eq!(expected, &result); } } diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs index cb7f7c5d7..74256489a 100644 --- a/ballista/core/src/config.rs +++ b/ballista/core/src/config.rs @@ -32,6 +32,13 @@ pub const BALLISTA_STANDALONE_PARALLELISM: &str = "ballista.standalone.paralleli /// max message size for gRPC clients pub const BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE: &str = "ballista.grpc_client_max_message_size"; +/// enable or disable ballista dml planner extension. +/// when enabled planner will use custom logical planner DML +/// extension which will serialize table provider used in DML +/// +/// this configuration should be disabled if using remote schema +/// registries. +pub const BALLISTA_PLANNER_DML_EXTENSION: &str = "ballista.planner.dml_extension"; pub type ParseResult = result::Result; use std::sync::LazyLock; @@ -48,6 +55,10 @@ static CONFIG_ENTRIES: LazyLock> = LazyLock::new(|| "Configuration for max message size in gRPC clients".to_string(), DataType::UInt64, Some((16 * 1024 * 1024).to_string())), + ConfigEntry::new(BALLISTA_PLANNER_DML_EXTENSION.to_string(), + "Enable ballista planner DML extension".to_string(), + DataType::Boolean, + Some((true).to_string())), ]; entries .into_iter() @@ -165,6 +176,10 @@ impl BallistaConfig { self.get_usize_setting(BALLISTA_STANDALONE_PARALLELISM) } + pub fn planner_dml_extension(&self) -> bool { + self.get_bool_setting(BALLISTA_PLANNER_DML_EXTENSION) + } + fn get_usize_setting(&self, key: &str) -> usize { if let Some(v) = self.settings.get(key) { // infallible because we validate all configs in the constructor diff --git a/ballista/core/src/planner.rs b/ballista/core/src/planner.rs index 266da3c6e..77690cd44 100644 --- a/ballista/core/src/planner.rs +++ b/ballista/core/src/planner.rs @@ -17,14 +17,16 @@ use crate::config::BallistaConfig; use crate::execution_plans::DistributedQueryExec; -use crate::serde::BallistaLogicalExtensionCodec; +use crate::serde::{BallistaDmlExtension, BallistaLogicalExtensionCodec}; use async_trait::async_trait; use datafusion::arrow::datatypes::Schema; +use datafusion::common::plan_err; use datafusion::common::tree_node::{TreeNode, TreeNodeVisitor}; +use datafusion::datasource::DefaultTableSource; use datafusion::error::DataFusionError; use datafusion::execution::context::{QueryPlanner, SessionState}; -use datafusion::logical_expr::{LogicalPlan, TableScan}; +use datafusion::logical_expr::{DmlStatement, Extension, LogicalPlan, TableScan}; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}; @@ -125,6 +127,41 @@ impl QueryPlanner for BallistaQueryPlanner { log::debug!("create_physical_plan - handling empty exec"); Ok(Arc::new(EmptyExec::new(Arc::new(Schema::empty())))) } + // At the moment DML statement uses TableReference instead of TableProvider. + // As ballista has two contexts (client and scheduler) scheduler context may not + // know table provider for given table reference, thus we need to attach + // table provider to this DML statement. + LogicalPlan::Dml(DmlStatement { table_name, .. }) + if self.config.planner_dml_extension() => + { + let table_name = table_name.to_owned(); + let table = table_name.table().to_string(); + let schema = session_state.schema_for_ref(table_name.clone())?; + let table_provider = match schema.table(&table).await? { + Some(ref provider) => Ok(Arc::clone(provider)), + _ => plan_err!("No table named '{table}'"), + }?; + + let table_source = Arc::new(DefaultTableSource::new(table_provider)); + let table = + TableScan::try_new(table_name, table_source, None, vec![], None)?; + + // custom made logical extension node is used to attach table reference + let node = Arc::new(BallistaDmlExtension { + dml: logical_plan.clone(), + table, + }); + let plan = LogicalPlan::Extension(Extension { node }); + log::debug!("create_physical_plan - handling DML statement"); + + Ok(Arc::new(DistributedQueryExec::::with_extension( + self.scheduler_url.clone(), + self.config.clone(), + plan.clone(), + self.extension_codec.clone(), + session_state.session_id().to_string(), + ))) + } _ => { log::debug!("create_physical_plan - handling general statement"); diff --git a/ballista/core/src/serde/mod.rs b/ballista/core/src/serde/mod.rs index 84cf80684..95bd6084e 100644 --- a/ballista/core/src/serde/mod.rs +++ b/ballista/core/src/serde/mod.rs @@ -22,8 +22,11 @@ use crate::{error::BallistaError, serde::scheduler::Action as BallistaAction}; use arrow_flight::sql::ProstMessageExt; use datafusion::arrow::datatypes::SchemaRef; -use datafusion::common::{DataFusionError, Result}; +use datafusion::common::{plan_err, DataFusionError, Result}; use datafusion::execution::FunctionRegistry; +use datafusion::logical_expr::{ + Extension, LogicalPlan, TableScan, UserDefinedLogicalNodeCore, +}; use datafusion::physical_plan::{ExecutionPlan, Partitioning}; use datafusion_proto::logical_plan::file_formats::{ ArrowLogicalExtensionCodec, AvroLogicalExtensionCodec, CsvLogicalExtensionCodec, @@ -179,7 +182,31 @@ impl LogicalExtensionCodec for BallistaLogicalExtensionCodec { inputs: &[datafusion::logical_expr::LogicalPlan], ctx: &datafusion::prelude::SessionContext, ) -> Result { - self.default_codec.try_decode(buf, inputs, ctx) + match BallistaExtensionProto::decode(buf) { + Ok(extension) => match extension.extension { + Some(BallistaExtensionType::Dml(BallistaDmlExtensionProto { + dml: Some(dml), + table: Some(table), + })) => { + let table = table.try_into_logical_plan(ctx, self)?; + match table { + LogicalPlan::TableScan(scan) => { + let dml = dml.try_into_logical_plan(ctx, self)?; + Ok(Extension { + node: Arc::new(BallistaDmlExtension { dml, table: scan }), + }) + } + _ => plan_err!( + "TableScan expected in ballista DML extension definition" + ), + } + } + None => plan_err!("Ballista extension can't be None"), + _ => plan_err!("Ballista extension not supported"), + }, + + Err(_e) => self.default_codec.try_decode(buf, inputs, ctx), + } } fn try_encode( @@ -187,7 +214,32 @@ impl LogicalExtensionCodec for BallistaLogicalExtensionCodec { node: &datafusion::logical_expr::Extension, buf: &mut Vec, ) -> Result<()> { - self.default_codec.try_encode(node, buf) + if let Some(BallistaDmlExtension { dml: input, table }) = + node.node.as_any().downcast_ref::() + { + let input = LogicalPlanNode::try_from_logical_plan(input, self)?; + + let table = LogicalPlanNode::try_from_logical_plan( + &LogicalPlan::TableScan(table.clone()), + self, + )?; + let extension = BallistaDmlExtensionProto { + dml: Some(input), + table: Some(table), + }; + + let extension = BallistaExtensionProto { + extension: Some(BallistaExtensionType::Dml(extension)), + }; + + extension + .encode(buf) + .map_err(|e| DataFusionError::Execution(e.to_string()))?; + + Ok(()) + } else { + self.default_codec.try_encode(node, buf) + } } fn try_decode_table_provider( @@ -487,6 +539,74 @@ struct FileFormatProto { pub blob: Vec, } +#[derive(Clone, PartialEq, prost::Message)] +struct BallistaExtensionProto { + #[prost(oneof = "BallistaExtensionType", tags = "1")] + extension: Option, +} + +#[derive(Clone, PartialEq, ::prost::Oneof)] +enum BallistaExtensionType { + #[prost(message, tag = "1")] + Dml(BallistaDmlExtensionProto), +} + +#[derive(Clone, PartialEq, prost::Message)] +struct BallistaDmlExtensionProto { + #[prost(message, tag = 1)] + pub dml: Option, + #[prost(message, tag = 2)] + pub table: Option, +} + +#[derive(Debug, Hash, PartialEq, Eq, Clone)] +pub struct BallistaDmlExtension { + /// LogicalPlan::DML + /// DMLStatement is expected + pub dml: LogicalPlan, + /// Table provider which is referenced + /// from LogicalPlan::DML + pub table: TableScan, +} + +impl std::cmp::PartialOrd for BallistaDmlExtension { + fn partial_cmp(&self, other: &Self) -> Option { + self.dml.partial_cmp(&other.dml) + } +} +impl UserDefinedLogicalNodeCore for BallistaDmlExtension { + fn name(&self) -> &str { + "BallistaDmlExtension" + } + + fn inputs(&self) -> Vec<&datafusion::logical_expr::LogicalPlan> { + vec![&self.dml] + } + + fn schema(&self) -> &datafusion::common::DFSchemaRef { + self.dml.schema() + } + + fn expressions(&self) -> Vec { + self.dml.expressions() + } + + fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + self.dml.fmt(f) + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + inputs: Vec, + ) -> Result { + Ok(Self { + dml: inputs[0].clone(), + table: self.table.clone(), + }) + } +} + #[cfg(test)] mod test { use super::*; diff --git a/ballista/scheduler/src/scheduler_server/grpc.rs b/ballista/scheduler/src/scheduler_server/grpc.rs index 02c21a884..4a65aed3b 100644 --- a/ballista/scheduler/src/scheduler_server/grpc.rs +++ b/ballista/scheduler/src/scheduler_server/grpc.rs @@ -32,6 +32,9 @@ use ballista_core::serde::protobuf::{ UpdateTaskStatusParams, UpdateTaskStatusResult, }; use ballista_core::serde::scheduler::ExecutorMetadata; +use ballista_core::serde::BallistaDmlExtension; +use datafusion::datasource::DefaultTableSource; +use datafusion::logical_expr::{Extension, LogicalPlan}; use datafusion_proto::logical_plan::AsLogicalPlan; use datafusion_proto::physical_plan::AsExecutionPlan; use log::{debug, error, info, trace, warn}; @@ -409,6 +412,34 @@ impl SchedulerGrpc self.state.codec.logical_extension_codec(), ) }) { + Ok(LogicalPlan::Extension(Extension { node })) + if node + .as_any() + .downcast_ref::() + .is_some() => + { + let plan = node + .as_any() + .downcast_ref::() + .unwrap(); + + let table_provider = &plan + .table + .source + .as_any() + .downcast_ref::() + .expect("Default Table Source is expected") + .table_provider; + + let _ = session_ctx + .deregister_table(plan.table.table_name.clone()); + let _ = session_ctx.register_table( + plan.table.table_name.clone(), + table_provider.clone(), + ); + + plan.dml.clone() + } Ok(plan) => plan, Err(e) => { let msg =