Skip to content

Commit

Permalink
feat: add support for insert into
Browse files Browse the repository at this point in the history
  • Loading branch information
milenkovicm committed Feb 9, 2025
1 parent 0823267 commit d518dd8
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 69 deletions.
65 changes: 65 additions & 0 deletions ballista/client/tests/context_checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,4 +399,69 @@ mod supported {

assert_batches_eq!(expected, &result);
}

#[rstest]
#[case::standalone(standalone_context())]
#[case::remote(remote_context())]
#[tokio::test]
#[cfg(not(windows))] // test is failing at windows, can't debug it
async fn should_support_sql_insert_into(
#[future(awt)]
#[case]
ctx: SessionContext,
test_data: String,
) {
ctx.register_parquet(
"test",
&format!("{test_data}/alltypes_plain.parquet"),
Default::default(),
)
.await
.unwrap();

let write_dir = tempfile::tempdir().expect("temporary directory to be created");
let write_dir_path = write_dir
.path()
.to_str()
.expect("path to be converted to str");

ctx.sql("select * from test")
.await
.unwrap()
.write_parquet(write_dir_path, Default::default(), Default::default())
.await
.unwrap();

ctx.register_parquet("written_table", write_dir_path, Default::default())
.await
.unwrap();

ctx.sql("INSERT INTO written_table select * from test")
.await
.unwrap()
.show()
.await
.unwrap();

let result = ctx
.sql("select id, string_col, timestamp_col from written_table where id > 4 order by id")
.await.unwrap()
.collect()
.await.unwrap();

let expected = [
"+----+------------+---------------------+",
"| id | string_col | timestamp_col |",
"+----+------------+---------------------+",
"| 5 | 31 | 2009-03-01T00:01:00 |",
"| 5 | 31 | 2009-03-01T00:01:00 |",
"| 6 | 30 | 2009-04-01T00:00:00 |",
"| 6 | 30 | 2009-04-01T00:00:00 |",
"| 7 | 31 | 2009-04-01T00:01:00 |",
"| 7 | 31 | 2009-04-01T00:01:00 |",
"+----+------------+---------------------+",
];

assert_batches_eq!(expected, &result);
}
}
65 changes: 0 additions & 65 deletions ballista/client/tests/context_unsupported.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,71 +144,6 @@ mod unsupported {
"+----+----------+---------------------+",
];

assert_batches_eq!(expected, &result);
}
#[rstest]
#[case::standalone(standalone_context())]
#[case::remote(remote_context())]
#[tokio::test]
#[should_panic]
// "Error: Internal(failed to serialize logical plan: Internal(LogicalPlan serde is not yet implemented for Dml))"
async fn should_support_sql_insert_into(
#[future(awt)]
#[case]
ctx: SessionContext,
test_data: String,
) {
ctx.register_parquet(
"test",
&format!("{test_data}/alltypes_plain.parquet"),
Default::default(),
)
.await
.unwrap();
let write_dir = tempfile::tempdir().expect("temporary directory to be created");
let write_dir_path = write_dir
.path()
.to_str()
.expect("path to be converted to str");

ctx.sql("select * from test")
.await
.unwrap()
.write_parquet(write_dir_path, Default::default(), Default::default())
.await
.unwrap();

ctx.register_parquet("written_table", write_dir_path, Default::default())
.await
.unwrap();

let _ = ctx
.sql("INSERT INTO written_table select * from written_table")
.await
.unwrap()
.collect()
.await
.unwrap();

let result = ctx
.sql("select id, string_col, timestamp_col from written_table where id > 4 order by id")
.await.unwrap()
.collect()
.await.unwrap();

let expected = [
"+----+------------+---------------------+",
"| id | string_col | timestamp_col |",
"+----+------------+---------------------+",
"| 5 | 31 | 2009-03-01T00:01:00 |",
"| 5 | 31 | 2009-03-01T00:01:00 |",
"| 6 | 30 | 2009-04-01T00:00:00 |",
"| 6 | 30 | 2009-04-01T00:00:00 |",
"| 7 | 31 | 2009-04-01T00:01:00 |",
"| 7 | 31 | 2009-04-01T00:01:00 |",
"+----+------------+---------------------+",
];

assert_batches_eq!(expected, &result);
}
}
36 changes: 34 additions & 2 deletions ballista/core/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@

use crate::config::BallistaConfig;
use crate::execution_plans::DistributedQueryExec;
use crate::serde::BallistaLogicalExtensionCodec;
use crate::serde::{BallistaDml, BallistaLogicalExtensionCodec};

use async_trait::async_trait;
use datafusion::arrow::datatypes::Schema;
use datafusion::common::plan_err;
use datafusion::common::tree_node::{TreeNode, TreeNodeVisitor};
use datafusion::datasource::DefaultTableSource;
use datafusion::error::DataFusionError;
use datafusion::execution::context::{QueryPlanner, SessionState};
use datafusion::logical_expr::{LogicalPlan, TableScan};
use datafusion::logical_expr::{DmlStatement, Extension, LogicalPlan, TableScan};
use datafusion::physical_plan::empty::EmptyExec;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner};
Expand Down Expand Up @@ -125,6 +127,36 @@ impl<T: 'static + AsLogicalPlan> QueryPlanner for BallistaQueryPlanner<T> {
log::debug!("create_physical_plan - handling empty exec");
Ok(Arc::new(EmptyExec::new(Arc::new(Schema::empty()))))
}
// TODO: MM make configuration to disable this translation
// in case we have shared schema registry
LogicalPlan::Dml(DmlStatement { table_name, .. }) => {
let table_name = table_name.to_owned();
let table = table_name.table().to_string();
let schema = session_state.schema_for_ref(table_name.clone())?;
let table_provider = match schema.table(&table).await? {
Some(ref provider) => Ok(Arc::clone(provider)),
_ => plan_err!("No table named '{table}'"),
}?;

let table_source = Arc::new(DefaultTableSource::new(table_provider));
let table =
TableScan::try_new(table_name, table_source, None, vec![], None)?;

let node = Arc::new(BallistaDml {
input: logical_plan.clone(),
table,
});
let plan = LogicalPlan::Extension(Extension { node });
log::debug!("create_physical_plan - handling DML statement");

Ok(Arc::new(DistributedQueryExec::<T>::with_extension(
self.scheduler_url.clone(),
self.config.clone(),
plan.clone(),
self.extension_codec.clone(),
session_state.session_id().to_string(),
)))
}
_ => {
log::debug!("create_physical_plan - handling general statement");

Expand Down
94 changes: 92 additions & 2 deletions ballista/core/src/serde/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ use arrow_flight::sql::ProstMessageExt;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::common::{DataFusionError, Result};
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::{
Extension, LogicalPlan, TableScan, UserDefinedLogicalNodeCore,
};
use datafusion::physical_plan::{ExecutionPlan, Partitioning};
use datafusion_proto::logical_plan::file_formats::{
ArrowLogicalExtensionCodec, AvroLogicalExtensionCodec, CsvLogicalExtensionCodec,
Expand Down Expand Up @@ -179,15 +182,51 @@ impl LogicalExtensionCodec for BallistaLogicalExtensionCodec {
inputs: &[datafusion::logical_expr::LogicalPlan],
ctx: &datafusion::prelude::SessionContext,
) -> Result<datafusion::logical_expr::Extension> {
self.default_codec.try_decode(buf, inputs, ctx)
let e = ExtensionBallistaDml::decode(buf)
.map_err(|e| DataFusionError::Execution(e.to_string()))?;

let input = e.input.unwrap();
let table = e.table.unwrap();
let input = input.try_into_logical_plan(ctx, self)?;

let table = table.try_into_logical_plan(ctx, self)?;

if let LogicalPlan::TableScan(table) = table {
Ok(Extension {
node: Arc::new(BallistaDml { input, table }),
})
} else {
self.default_codec.try_decode(buf, inputs, ctx)
}
}

fn try_encode(
&self,
node: &datafusion::logical_expr::Extension,
buf: &mut Vec<u8>,
) -> Result<()> {
self.default_codec.try_encode(node, buf)
if let Some(BallistaDml { input, table }) =
node.node.as_any().downcast_ref::<BallistaDml>()
{
let input = LogicalPlanNode::try_from_logical_plan(input, self)?;

let table = LogicalPlanNode::try_from_logical_plan(
&LogicalPlan::TableScan(table.clone()),
self,
)?;
let extension = ExtensionBallistaDml {
input: Some(input),
table: Some(table),
};

extension
.encode(buf)
.map_err(|e| DataFusionError::Execution(e.to_string()))?;

Ok(())
} else {
self.default_codec.try_encode(node, buf)
}
}

fn try_decode_table_provider(
Expand Down Expand Up @@ -486,6 +525,57 @@ struct FileFormatProto {
#[prost(bytes, tag = 2)]
pub blob: Vec<u8>,
}
#[derive(Clone, PartialEq, prost::Message)]
struct ExtensionBallistaDml {
#[prost(message, tag = 1)]
pub input: Option<LogicalPlanNode>,
#[prost(message, tag = 2)]
pub table: Option<LogicalPlanNode>,
}

#[derive(Debug, Hash, PartialEq, Eq)]
pub struct BallistaDml {
pub input: LogicalPlan,
pub table: TableScan,
}

impl std::cmp::PartialOrd for BallistaDml {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.input.partial_cmp(&other.input)
}
}
impl UserDefinedLogicalNodeCore for BallistaDml {
fn name(&self) -> &str {
"BallistaDml"
}

fn inputs(&self) -> Vec<&datafusion::logical_expr::LogicalPlan> {
vec![&self.input]
}

fn schema(&self) -> &datafusion::common::DFSchemaRef {
self.input.schema()
}

fn expressions(&self) -> Vec<datafusion::prelude::Expr> {
self.input.expressions()
}

fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
self.input.fmt(f)
}

fn with_exprs_and_inputs(
&self,
_exprs: Vec<datafusion::prelude::Expr>,
inputs: Vec<datafusion::logical_expr::LogicalPlan>,
) -> Result<Self> {
Ok(Self {
input: inputs[0].clone(),
table: self.table.clone(),
})
}
}

#[cfg(test)]
mod test {
Expand Down
27 changes: 27 additions & 0 deletions ballista/scheduler/src/scheduler_server/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ use ballista_core::serde::protobuf::{
UpdateTaskStatusParams, UpdateTaskStatusResult,
};
use ballista_core::serde::scheduler::ExecutorMetadata;
use ballista_core::serde::BallistaDml;
use datafusion::datasource::DefaultTableSource;
use datafusion::logical_expr::LogicalPlan;
use datafusion_proto::logical_plan::AsLogicalPlan;
use datafusion_proto::physical_plan::AsExecutionPlan;
use log::{debug, error, info, trace, warn};
Expand Down Expand Up @@ -446,6 +449,30 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerGrpc
}
};

let plan = match plan {
LogicalPlan::Extension(e)
if e.node.as_any().downcast_ref::<BallistaDml>().is_some() =>
{
let plan = e.node.as_any().downcast_ref::<BallistaDml>().unwrap();
let table_provider = &plan
.table
.source
.as_any()
.downcast_ref::<DefaultTableSource>()
.expect("dts")
.table_provider;

let _ = session_ctx.deregister_table(plan.table.table_name.clone());
let _ = session_ctx.register_table(
plan.table.table_name.clone(),
table_provider.clone(),
);

plan.input.clone()
}
_ => plan,
};

debug!(
"Decoded logical plan for execution:\n{}",
plan.display_indent()
Expand Down

0 comments on commit d518dd8

Please sign in to comment.