From d518dd8e38140c1505b33da1ac7605a379b6e3bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marko=20Milenkovi=C4=87?= Date: Sun, 9 Feb 2025 08:41:06 +0000 Subject: [PATCH] feat: add support for insert into --- ballista/client/tests/context_checks.rs | 65 +++++++++++++ ballista/client/tests/context_unsupported.rs | 65 ------------- ballista/core/src/planner.rs | 36 ++++++- ballista/core/src/serde/mod.rs | 94 ++++++++++++++++++- .../scheduler/src/scheduler_server/grpc.rs | 27 ++++++ 5 files changed, 218 insertions(+), 69 deletions(-) diff --git a/ballista/client/tests/context_checks.rs b/ballista/client/tests/context_checks.rs index ac83e3bdf..2601293de 100644 --- a/ballista/client/tests/context_checks.rs +++ b/ballista/client/tests/context_checks.rs @@ -399,4 +399,69 @@ mod supported { assert_batches_eq!(expected, &result); } + + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + #[cfg(not(windows))] // test is failing at windows, can't debug it + async fn should_support_sql_insert_into( + #[future(awt)] + #[case] + ctx: SessionContext, + test_data: String, + ) { + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await + .unwrap(); + + let write_dir = tempfile::tempdir().expect("temporary directory to be created"); + let write_dir_path = write_dir + .path() + .to_str() + .expect("path to be converted to str"); + + ctx.sql("select * from test") + .await + .unwrap() + .write_parquet(write_dir_path, Default::default(), Default::default()) + .await + .unwrap(); + + ctx.register_parquet("written_table", write_dir_path, Default::default()) + .await + .unwrap(); + + ctx.sql("INSERT INTO written_table select * from test") + .await + .unwrap() + .show() + .await + .unwrap(); + + let result = ctx + .sql("select id, string_col, timestamp_col from written_table where id > 4 order by id") + .await.unwrap() + .collect() + .await.unwrap(); + + let expected = [ + "+----+------------+---------------------+", + "| id | string_col | timestamp_col |", + "+----+------------+---------------------+", + "| 5 | 31 | 2009-03-01T00:01:00 |", + "| 5 | 31 | 2009-03-01T00:01:00 |", + "| 6 | 30 | 2009-04-01T00:00:00 |", + "| 6 | 30 | 2009-04-01T00:00:00 |", + "| 7 | 31 | 2009-04-01T00:01:00 |", + "| 7 | 31 | 2009-04-01T00:01:00 |", + "+----+------------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + } } diff --git a/ballista/client/tests/context_unsupported.rs b/ballista/client/tests/context_unsupported.rs index 805e81325..aa9827993 100644 --- a/ballista/client/tests/context_unsupported.rs +++ b/ballista/client/tests/context_unsupported.rs @@ -144,71 +144,6 @@ mod unsupported { "+----+----------+---------------------+", ]; - assert_batches_eq!(expected, &result); - } - #[rstest] - #[case::standalone(standalone_context())] - #[case::remote(remote_context())] - #[tokio::test] - #[should_panic] - // "Error: Internal(failed to serialize logical plan: Internal(LogicalPlan serde is not yet implemented for Dml))" - async fn should_support_sql_insert_into( - #[future(awt)] - #[case] - ctx: SessionContext, - test_data: String, - ) { - ctx.register_parquet( - "test", - &format!("{test_data}/alltypes_plain.parquet"), - Default::default(), - ) - .await - .unwrap(); - let write_dir = tempfile::tempdir().expect("temporary directory to be created"); - let write_dir_path = write_dir - .path() - .to_str() - .expect("path to be converted to str"); - - ctx.sql("select * from test") - .await - .unwrap() - .write_parquet(write_dir_path, Default::default(), Default::default()) - .await - .unwrap(); - - ctx.register_parquet("written_table", write_dir_path, Default::default()) - .await - .unwrap(); - - let _ = ctx - .sql("INSERT INTO written_table select * from written_table") - .await - .unwrap() - .collect() - .await - .unwrap(); - - let result = ctx - .sql("select id, string_col, timestamp_col from written_table where id > 4 order by id") - .await.unwrap() - .collect() - .await.unwrap(); - - let expected = [ - "+----+------------+---------------------+", - "| id | string_col | timestamp_col |", - "+----+------------+---------------------+", - "| 5 | 31 | 2009-03-01T00:01:00 |", - "| 5 | 31 | 2009-03-01T00:01:00 |", - "| 6 | 30 | 2009-04-01T00:00:00 |", - "| 6 | 30 | 2009-04-01T00:00:00 |", - "| 7 | 31 | 2009-04-01T00:01:00 |", - "| 7 | 31 | 2009-04-01T00:01:00 |", - "+----+------------+---------------------+", - ]; - assert_batches_eq!(expected, &result); } } diff --git a/ballista/core/src/planner.rs b/ballista/core/src/planner.rs index 266da3c6e..f3f0ad975 100644 --- a/ballista/core/src/planner.rs +++ b/ballista/core/src/planner.rs @@ -17,14 +17,16 @@ use crate::config::BallistaConfig; use crate::execution_plans::DistributedQueryExec; -use crate::serde::BallistaLogicalExtensionCodec; +use crate::serde::{BallistaDml, BallistaLogicalExtensionCodec}; use async_trait::async_trait; use datafusion::arrow::datatypes::Schema; +use datafusion::common::plan_err; use datafusion::common::tree_node::{TreeNode, TreeNodeVisitor}; +use datafusion::datasource::DefaultTableSource; use datafusion::error::DataFusionError; use datafusion::execution::context::{QueryPlanner, SessionState}; -use datafusion::logical_expr::{LogicalPlan, TableScan}; +use datafusion::logical_expr::{DmlStatement, Extension, LogicalPlan, TableScan}; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}; @@ -125,6 +127,36 @@ impl QueryPlanner for BallistaQueryPlanner { log::debug!("create_physical_plan - handling empty exec"); Ok(Arc::new(EmptyExec::new(Arc::new(Schema::empty())))) } + // TODO: MM make configuration to disable this translation + // in case we have shared schema registry + LogicalPlan::Dml(DmlStatement { table_name, .. }) => { + let table_name = table_name.to_owned(); + let table = table_name.table().to_string(); + let schema = session_state.schema_for_ref(table_name.clone())?; + let table_provider = match schema.table(&table).await? { + Some(ref provider) => Ok(Arc::clone(provider)), + _ => plan_err!("No table named '{table}'"), + }?; + + let table_source = Arc::new(DefaultTableSource::new(table_provider)); + let table = + TableScan::try_new(table_name, table_source, None, vec![], None)?; + + let node = Arc::new(BallistaDml { + input: logical_plan.clone(), + table, + }); + let plan = LogicalPlan::Extension(Extension { node }); + log::debug!("create_physical_plan - handling DML statement"); + + Ok(Arc::new(DistributedQueryExec::::with_extension( + self.scheduler_url.clone(), + self.config.clone(), + plan.clone(), + self.extension_codec.clone(), + session_state.session_id().to_string(), + ))) + } _ => { log::debug!("create_physical_plan - handling general statement"); diff --git a/ballista/core/src/serde/mod.rs b/ballista/core/src/serde/mod.rs index 84cf80684..9f51d5fac 100644 --- a/ballista/core/src/serde/mod.rs +++ b/ballista/core/src/serde/mod.rs @@ -24,6 +24,9 @@ use arrow_flight::sql::ProstMessageExt; use datafusion::arrow::datatypes::SchemaRef; use datafusion::common::{DataFusionError, Result}; use datafusion::execution::FunctionRegistry; +use datafusion::logical_expr::{ + Extension, LogicalPlan, TableScan, UserDefinedLogicalNodeCore, +}; use datafusion::physical_plan::{ExecutionPlan, Partitioning}; use datafusion_proto::logical_plan::file_formats::{ ArrowLogicalExtensionCodec, AvroLogicalExtensionCodec, CsvLogicalExtensionCodec, @@ -179,7 +182,22 @@ impl LogicalExtensionCodec for BallistaLogicalExtensionCodec { inputs: &[datafusion::logical_expr::LogicalPlan], ctx: &datafusion::prelude::SessionContext, ) -> Result { - self.default_codec.try_decode(buf, inputs, ctx) + let e = ExtensionBallistaDml::decode(buf) + .map_err(|e| DataFusionError::Execution(e.to_string()))?; + + let input = e.input.unwrap(); + let table = e.table.unwrap(); + let input = input.try_into_logical_plan(ctx, self)?; + + let table = table.try_into_logical_plan(ctx, self)?; + + if let LogicalPlan::TableScan(table) = table { + Ok(Extension { + node: Arc::new(BallistaDml { input, table }), + }) + } else { + self.default_codec.try_decode(buf, inputs, ctx) + } } fn try_encode( @@ -187,7 +205,28 @@ impl LogicalExtensionCodec for BallistaLogicalExtensionCodec { node: &datafusion::logical_expr::Extension, buf: &mut Vec, ) -> Result<()> { - self.default_codec.try_encode(node, buf) + if let Some(BallistaDml { input, table }) = + node.node.as_any().downcast_ref::() + { + let input = LogicalPlanNode::try_from_logical_plan(input, self)?; + + let table = LogicalPlanNode::try_from_logical_plan( + &LogicalPlan::TableScan(table.clone()), + self, + )?; + let extension = ExtensionBallistaDml { + input: Some(input), + table: Some(table), + }; + + extension + .encode(buf) + .map_err(|e| DataFusionError::Execution(e.to_string()))?; + + Ok(()) + } else { + self.default_codec.try_encode(node, buf) + } } fn try_decode_table_provider( @@ -486,6 +525,57 @@ struct FileFormatProto { #[prost(bytes, tag = 2)] pub blob: Vec, } +#[derive(Clone, PartialEq, prost::Message)] +struct ExtensionBallistaDml { + #[prost(message, tag = 1)] + pub input: Option, + #[prost(message, tag = 2)] + pub table: Option, +} + +#[derive(Debug, Hash, PartialEq, Eq)] +pub struct BallistaDml { + pub input: LogicalPlan, + pub table: TableScan, +} + +impl std::cmp::PartialOrd for BallistaDml { + fn partial_cmp(&self, other: &Self) -> Option { + self.input.partial_cmp(&other.input) + } +} +impl UserDefinedLogicalNodeCore for BallistaDml { + fn name(&self) -> &str { + "BallistaDml" + } + + fn inputs(&self) -> Vec<&datafusion::logical_expr::LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &datafusion::common::DFSchemaRef { + self.input.schema() + } + + fn expressions(&self) -> Vec { + self.input.expressions() + } + + fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + self.input.fmt(f) + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + inputs: Vec, + ) -> Result { + Ok(Self { + input: inputs[0].clone(), + table: self.table.clone(), + }) + } +} #[cfg(test)] mod test { diff --git a/ballista/scheduler/src/scheduler_server/grpc.rs b/ballista/scheduler/src/scheduler_server/grpc.rs index 02c21a884..4302b78f1 100644 --- a/ballista/scheduler/src/scheduler_server/grpc.rs +++ b/ballista/scheduler/src/scheduler_server/grpc.rs @@ -32,6 +32,9 @@ use ballista_core::serde::protobuf::{ UpdateTaskStatusParams, UpdateTaskStatusResult, }; use ballista_core::serde::scheduler::ExecutorMetadata; +use ballista_core::serde::BallistaDml; +use datafusion::datasource::DefaultTableSource; +use datafusion::logical_expr::LogicalPlan; use datafusion_proto::logical_plan::AsLogicalPlan; use datafusion_proto::physical_plan::AsExecutionPlan; use log::{debug, error, info, trace, warn}; @@ -446,6 +449,30 @@ impl SchedulerGrpc } }; + let plan = match plan { + LogicalPlan::Extension(e) + if e.node.as_any().downcast_ref::().is_some() => + { + let plan = e.node.as_any().downcast_ref::().unwrap(); + let table_provider = &plan + .table + .source + .as_any() + .downcast_ref::() + .expect("dts") + .table_provider; + + let _ = session_ctx.deregister_table(plan.table.table_name.clone()); + let _ = session_ctx.register_table( + plan.table.table_name.clone(), + table_provider.clone(), + ); + + plan.input.clone() + } + _ => plan, + }; + debug!( "Decoded logical plan for execution:\n{}", plan.display_indent()