Skip to content

Commit

Permalink
feat: add support for insert into
Browse files Browse the repository at this point in the history
  • Loading branch information
milenkovicm committed Feb 9, 2025
1 parent 0823267 commit 1f160d6
Show file tree
Hide file tree
Showing 6 changed files with 273 additions and 70 deletions.
65 changes: 65 additions & 0 deletions ballista/client/tests/context_checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,4 +399,69 @@ mod supported {

assert_batches_eq!(expected, &result);
}

#[rstest]
#[case::standalone(standalone_context())]
#[case::remote(remote_context())]
#[tokio::test]
#[cfg(not(windows))] // test is failing at windows, can't debug it
async fn should_support_sql_insert_into(
#[future(awt)]
#[case]
ctx: SessionContext,
test_data: String,
) {
ctx.register_parquet(
"test",
&format!("{test_data}/alltypes_plain.parquet"),
Default::default(),
)
.await
.unwrap();

let write_dir = tempfile::tempdir().expect("temporary directory to be created");
let write_dir_path = write_dir
.path()
.to_str()
.expect("path to be converted to str");

ctx.sql("select * from test")
.await
.unwrap()
.write_parquet(write_dir_path, Default::default(), Default::default())
.await
.unwrap();

ctx.register_parquet("written_table", write_dir_path, Default::default())
.await
.unwrap();

ctx.sql("INSERT INTO written_table select * from test")
.await
.unwrap()
.show()
.await
.unwrap();

let result = ctx
.sql("select id, string_col, timestamp_col from written_table where id > 4 order by id")
.await.unwrap()
.collect()
.await.unwrap();

let expected = [
"+----+------------+---------------------+",
"| id | string_col | timestamp_col |",
"+----+------------+---------------------+",
"| 5 | 31 | 2009-03-01T00:01:00 |",
"| 5 | 31 | 2009-03-01T00:01:00 |",
"| 6 | 30 | 2009-04-01T00:00:00 |",
"| 6 | 30 | 2009-04-01T00:00:00 |",
"| 7 | 31 | 2009-04-01T00:01:00 |",
"| 7 | 31 | 2009-04-01T00:01:00 |",
"+----+------------+---------------------+",
];

assert_batches_eq!(expected, &result);
}
}
65 changes: 0 additions & 65 deletions ballista/client/tests/context_unsupported.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,71 +144,6 @@ mod unsupported {
"+----+----------+---------------------+",
];

assert_batches_eq!(expected, &result);
}
#[rstest]
#[case::standalone(standalone_context())]
#[case::remote(remote_context())]
#[tokio::test]
#[should_panic]
// "Error: Internal(failed to serialize logical plan: Internal(LogicalPlan serde is not yet implemented for Dml))"
async fn should_support_sql_insert_into(
#[future(awt)]
#[case]
ctx: SessionContext,
test_data: String,
) {
ctx.register_parquet(
"test",
&format!("{test_data}/alltypes_plain.parquet"),
Default::default(),
)
.await
.unwrap();
let write_dir = tempfile::tempdir().expect("temporary directory to be created");
let write_dir_path = write_dir
.path()
.to_str()
.expect("path to be converted to str");

ctx.sql("select * from test")
.await
.unwrap()
.write_parquet(write_dir_path, Default::default(), Default::default())
.await
.unwrap();

ctx.register_parquet("written_table", write_dir_path, Default::default())
.await
.unwrap();

let _ = ctx
.sql("INSERT INTO written_table select * from written_table")
.await
.unwrap()
.collect()
.await
.unwrap();

let result = ctx
.sql("select id, string_col, timestamp_col from written_table where id > 4 order by id")
.await.unwrap()
.collect()
.await.unwrap();

let expected = [
"+----+------------+---------------------+",
"| id | string_col | timestamp_col |",
"+----+------------+---------------------+",
"| 5 | 31 | 2009-03-01T00:01:00 |",
"| 5 | 31 | 2009-03-01T00:01:00 |",
"| 6 | 30 | 2009-04-01T00:00:00 |",
"| 6 | 30 | 2009-04-01T00:00:00 |",
"| 7 | 31 | 2009-04-01T00:01:00 |",
"| 7 | 31 | 2009-04-01T00:01:00 |",
"+----+------------+---------------------+",
];

assert_batches_eq!(expected, &result);
}
}
15 changes: 15 additions & 0 deletions ballista/core/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ pub const BALLISTA_STANDALONE_PARALLELISM: &str = "ballista.standalone.paralleli
/// max message size for gRPC clients
pub const BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE: &str =
"ballista.grpc_client_max_message_size";
/// enable or disable ballista dml planner extension.
/// when enabled planner will use custom logical planner DML
/// extension which will serialize table provider used in DML
///
/// this configuration should be disabled if using remote schema
/// registries.
pub const BALLISTA_PLANNER_DML_EXTENSION: &str = "ballista.planner.dml_extension";

pub type ParseResult<T> = result::Result<T, String>;
use std::sync::LazyLock;
Expand All @@ -48,6 +55,10 @@ static CONFIG_ENTRIES: LazyLock<HashMap<String, ConfigEntry>> = LazyLock::new(||
"Configuration for max message size in gRPC clients".to_string(),
DataType::UInt64,
Some((16 * 1024 * 1024).to_string())),
ConfigEntry::new(BALLISTA_PLANNER_DML_EXTENSION.to_string(),
"Enable ballista planner DML extension".to_string(),
DataType::Boolean,
Some((true).to_string())),
];
entries
.into_iter()
Expand Down Expand Up @@ -165,6 +176,10 @@ impl BallistaConfig {
self.get_usize_setting(BALLISTA_STANDALONE_PARALLELISM)
}

pub fn planner_dml_extension(&self) -> bool {
self.get_bool_setting(BALLISTA_PLANNER_DML_EXTENSION)
}

fn get_usize_setting(&self, key: &str) -> usize {
if let Some(v) = self.settings.get(key) {
// infallible because we validate all configs in the constructor
Expand Down
41 changes: 39 additions & 2 deletions ballista/core/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@

use crate::config::BallistaConfig;
use crate::execution_plans::DistributedQueryExec;
use crate::serde::BallistaLogicalExtensionCodec;
use crate::serde::{BallistaDmlExtension, BallistaLogicalExtensionCodec};

use async_trait::async_trait;
use datafusion::arrow::datatypes::Schema;
use datafusion::common::plan_err;
use datafusion::common::tree_node::{TreeNode, TreeNodeVisitor};
use datafusion::datasource::DefaultTableSource;
use datafusion::error::DataFusionError;
use datafusion::execution::context::{QueryPlanner, SessionState};
use datafusion::logical_expr::{LogicalPlan, TableScan};
use datafusion::logical_expr::{DmlStatement, Extension, LogicalPlan, TableScan};
use datafusion::physical_plan::empty::EmptyExec;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner};
Expand Down Expand Up @@ -125,6 +127,41 @@ impl<T: 'static + AsLogicalPlan> QueryPlanner for BallistaQueryPlanner<T> {
log::debug!("create_physical_plan - handling empty exec");
Ok(Arc::new(EmptyExec::new(Arc::new(Schema::empty()))))
}
// At the moment DML statement uses TableReference instead of TableProvider.
// As ballista has two contexts (client and scheduler) scheduler context may not
// know table provider for given table reference, thus we need to attach
// table provider to this DML statement.
LogicalPlan::Dml(DmlStatement { table_name, .. })
if self.config.planner_dml_extension() =>
{
let table_name = table_name.to_owned();
let table = table_name.table().to_string();
let schema = session_state.schema_for_ref(table_name.clone())?;
let table_provider = match schema.table(&table).await? {
Some(ref provider) => Ok(Arc::clone(provider)),
_ => plan_err!("No table named '{table}'"),
}?;

let table_source = Arc::new(DefaultTableSource::new(table_provider));
let table =
TableScan::try_new(table_name, table_source, None, vec![], None)?;

// custom made logical extension node is used to attach table reference
let node = Arc::new(BallistaDmlExtension {
dml: logical_plan.clone(),
table,
});
let plan = LogicalPlan::Extension(Extension { node });
log::debug!("create_physical_plan - handling DML statement");

Ok(Arc::new(DistributedQueryExec::<T>::with_extension(
self.scheduler_url.clone(),
self.config.clone(),
plan.clone(),
self.extension_codec.clone(),
session_state.session_id().to_string(),
)))
}
_ => {
log::debug!("create_physical_plan - handling general statement");

Expand Down
126 changes: 123 additions & 3 deletions ballista/core/src/serde/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ use crate::{error::BallistaError, serde::scheduler::Action as BallistaAction};

use arrow_flight::sql::ProstMessageExt;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::common::{DataFusionError, Result};
use datafusion::common::{plan_err, DataFusionError, Result};
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::{
Extension, LogicalPlan, TableScan, UserDefinedLogicalNodeCore,
};
use datafusion::physical_plan::{ExecutionPlan, Partitioning};
use datafusion_proto::logical_plan::file_formats::{
ArrowLogicalExtensionCodec, AvroLogicalExtensionCodec, CsvLogicalExtensionCodec,
Expand Down Expand Up @@ -179,15 +182,64 @@ impl LogicalExtensionCodec for BallistaLogicalExtensionCodec {
inputs: &[datafusion::logical_expr::LogicalPlan],
ctx: &datafusion::prelude::SessionContext,
) -> Result<datafusion::logical_expr::Extension> {
self.default_codec.try_decode(buf, inputs, ctx)
match BallistaExtensionProto::decode(buf) {
Ok(extension) => match extension.extension {
Some(BallistaExtensionType::Dml(BallistaDmlExtensionProto {
dml: Some(dml),
table: Some(table),
})) => {
let table = table.try_into_logical_plan(ctx, self)?;
match table {
LogicalPlan::TableScan(scan) => {
let dml = dml.try_into_logical_plan(ctx, self)?;
Ok(Extension {
node: Arc::new(BallistaDmlExtension { dml, table: scan }),
})
}
_ => plan_err!(
"TableScan expected in ballista DML extension definition"
),
}
}
None => plan_err!("Ballista extension can't be None"),
_ => plan_err!("Ballista extension not supported"),
},

Err(_e) => self.default_codec.try_decode(buf, inputs, ctx),
}
}

fn try_encode(
&self,
node: &datafusion::logical_expr::Extension,
buf: &mut Vec<u8>,
) -> Result<()> {
self.default_codec.try_encode(node, buf)
if let Some(BallistaDmlExtension { dml: input, table }) =
node.node.as_any().downcast_ref::<BallistaDmlExtension>()
{
let input = LogicalPlanNode::try_from_logical_plan(input, self)?;

let table = LogicalPlanNode::try_from_logical_plan(
&LogicalPlan::TableScan(table.clone()),
self,
)?;
let extension = BallistaDmlExtensionProto {
dml: Some(input),
table: Some(table),
};

let extension = BallistaExtensionProto {
extension: Some(BallistaExtensionType::Dml(extension)),
};

extension
.encode(buf)
.map_err(|e| DataFusionError::Execution(e.to_string()))?;

Ok(())
} else {
self.default_codec.try_encode(node, buf)
}
}

fn try_decode_table_provider(
Expand Down Expand Up @@ -487,6 +539,74 @@ struct FileFormatProto {
pub blob: Vec<u8>,
}

#[derive(Clone, PartialEq, prost::Message)]
struct BallistaExtensionProto {
#[prost(oneof = "BallistaExtensionType", tags = "1")]
extension: Option<BallistaExtensionType>,
}

#[derive(Clone, PartialEq, ::prost::Oneof)]
enum BallistaExtensionType {
#[prost(message, tag = "1")]
Dml(BallistaDmlExtensionProto),
}

#[derive(Clone, PartialEq, prost::Message)]
struct BallistaDmlExtensionProto {
#[prost(message, tag = 1)]
pub dml: Option<LogicalPlanNode>,
#[prost(message, tag = 2)]
pub table: Option<LogicalPlanNode>,
}

#[derive(Debug, Hash, PartialEq, Eq, Clone)]
pub struct BallistaDmlExtension {
/// LogicalPlan::DML
/// DMLStatement is expected
pub dml: LogicalPlan,
/// Table provider which is referenced
/// from LogicalPlan::DML
pub table: TableScan,
}

impl std::cmp::PartialOrd for BallistaDmlExtension {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.dml.partial_cmp(&other.dml)
}
}
impl UserDefinedLogicalNodeCore for BallistaDmlExtension {
fn name(&self) -> &str {
"BallistaDmlExtension"
}

fn inputs(&self) -> Vec<&datafusion::logical_expr::LogicalPlan> {
vec![&self.dml]
}

fn schema(&self) -> &datafusion::common::DFSchemaRef {
self.dml.schema()
}

fn expressions(&self) -> Vec<datafusion::prelude::Expr> {
self.dml.expressions()
}

fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
self.dml.fmt(f)
}

fn with_exprs_and_inputs(
&self,
_exprs: Vec<datafusion::prelude::Expr>,
inputs: Vec<datafusion::logical_expr::LogicalPlan>,
) -> Result<Self> {
Ok(Self {
dml: inputs[0].clone(),
table: self.table.clone(),
})
}
}

#[cfg(test)]
mod test {
use super::*;
Expand Down
Loading

0 comments on commit 1f160d6

Please sign in to comment.