Skip to content

Commit

Permalink
try DML patch
Browse files Browse the repository at this point in the history
  • Loading branch information
milenkovicm committed Jan 11, 2025
1 parent ae54d70 commit db30f68
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 111 deletions.
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ arrow-flight = { version = "53", features = ["flight-sql-experimental"] }
clap = { version = "4.5", features = ["derive", "cargo"] }
configure_me = { version = "0.4.0" }
configure_me_codegen = { version = "0.4.4" }
datafusion = "44.0.0"
datafusion-cli = "44.0.0"
datafusion-proto = "44.0.0"
datafusion-proto-common = "44.0.0"
datafusion = { path = "../arrow-datafusion-fork/datafusion/core" }
datafusion-cli = "44"
datafusion-proto = { path = "../arrow-datafusion-fork/datafusion/proto" }
datafusion-proto-common = { path = "../arrow-datafusion-fork/datafusion/proto-common" }
object_store = "0.11"
prost = "0.13"
prost-types = "0.13"
Expand Down
33 changes: 33 additions & 0 deletions ballista/client/tests/context_checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,4 +365,37 @@ mod supported {

Ok(())
}

#[rstest]
#[case::standalone(standalone_context())]
#[case::remote(remote_context())]
#[tokio::test]
async fn should_execute_sql_show_with_url_table(
#[future(awt)]
#[case]
ctx: SessionContext,
test_data: String,
) {
let ctx = ctx.enable_url_table();

let result = ctx
.sql(&format!("select string_col, timestamp_col from '{test_data}/alltypes_plain.parquet' where id > 4"))
.await
.unwrap()
.collect()
.await
.unwrap();

let expected = [
"+------------+---------------------+",
"| string_col | timestamp_col |",
"+------------+---------------------+",
"| 31 | 2009-03-01T00:01:00 |",
"| 30 | 2009-04-01T00:00:00 |",
"| 31 | 2009-04-01T00:01:00 |",
"+------------+---------------------+",
];

assert_batches_eq!(expected, &result);
}
}
144 changes: 144 additions & 0 deletions ballista/client/tests/context_setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,80 @@ mod remote {

Ok(())
}

#[tokio::test]
async fn should_support_sql_insert_into() -> datafusion::error::Result<()> {
let test_data = crate::common::example_test_data();

let session_config = SessionConfig::new_with_ballista()
.with_information_schema(true)
.with_ballista_job_name("Insert INTO Test");

let state = SessionStateBuilder::new()
.with_default_features()
.with_config(session_config)
.build();

let (host, port) =
crate::common::setup_test_cluster_with_state(state.clone()).await;
let url = format!("df://{host}:{port}");

let ctx: SessionContext = SessionContext::remote_with_state(&url, state).await?;

ctx.register_parquet(
"test",
&format!("{test_data}/alltypes_plain.parquet"),
Default::default(),
)
.await
.unwrap();
let write_dir = tempfile::tempdir().expect("temporary directory to be created");
let write_dir_path = write_dir
.path()
.to_str()
.expect("path to be converted to str");

ctx.sql("select * from test")
.await
.unwrap()
.write_parquet(write_dir_path, Default::default(), Default::default())
.await
.unwrap();

ctx.register_parquet("written_table", write_dir_path, Default::default())
.await
.unwrap();

let _ = ctx
.sql("INSERT INTO written_table select * from written_table")
.await
.unwrap()
.collect()
.await
.unwrap();

let result = ctx
.sql("select id, string_col, timestamp_col from written_table where id > 4 order by id")
.await.unwrap()
.collect()
.await.unwrap();

let expected = [
"+----+------------+---------------------+",
"| id | string_col | timestamp_col |",
"+----+------------+---------------------+",
"| 5 | 31 | 2009-03-01T00:01:00 |",
"| 5 | 31 | 2009-03-01T00:01:00 |",
"| 6 | 30 | 2009-04-01T00:00:00 |",
"| 6 | 30 | 2009-04-01T00:00:00 |",
"| 7 | 31 | 2009-04-01T00:01:00 |",
"| 7 | 31 | 2009-04-01T00:01:00 |",
"+----+------------+---------------------+",
];

assert_batches_eq!(expected, &result);
Ok(())
}
}

#[cfg(test)]
Expand Down Expand Up @@ -265,6 +339,76 @@ mod standalone {
Ok(())
}

#[tokio::test]
async fn should_support_sql_insert_into() -> datafusion::error::Result<()> {
let test_data = crate::common::example_test_data();

let session_config = SessionConfig::new_with_ballista()
.with_information_schema(true)
.with_ballista_job_name("Insert INTO Test");

let state = SessionStateBuilder::new()
.with_default_features()
.with_config(session_config)
.build();

let ctx: SessionContext = SessionContext::standalone_with_state(state).await?;

ctx.register_parquet(
"test",
&format!("{test_data}/alltypes_plain.parquet"),
Default::default(),
)
.await
.unwrap();
let write_dir = tempfile::tempdir().expect("temporary directory to be created");
let write_dir_path = write_dir
.path()
.to_str()
.expect("path to be converted to str");

ctx.sql("select * from test")
.await
.unwrap()
.write_parquet(write_dir_path, Default::default(), Default::default())
.await
.unwrap();

ctx.register_parquet("written_table", write_dir_path, Default::default())
.await
.unwrap();

let _ = ctx
.sql("INSERT INTO written_table select * from written_table")
.await
.unwrap()
.collect()
.await
.unwrap();

let result = ctx
.sql("select id, string_col, timestamp_col from written_table where id > 4 order by id")
.await.unwrap()
.collect()
.await.unwrap();

let expected = [
"+----+------------+---------------------+",
"| id | string_col | timestamp_col |",
"+----+------------+---------------------+",
"| 5 | 31 | 2009-03-01T00:01:00 |",
"| 5 | 31 | 2009-03-01T00:01:00 |",
"| 6 | 30 | 2009-04-01T00:00:00 |",
"| 6 | 30 | 2009-04-01T00:00:00 |",
"| 7 | 31 | 2009-04-01T00:01:00 |",
"| 7 | 31 | 2009-04-01T00:01:00 |",
"+----+------------+---------------------+",
];

assert_batches_eq!(expected, &result);
Ok(())
}

#[derive(Debug, Default)]
struct BadLogicalCodec {
invoked: AtomicBool,
Expand Down
106 changes: 0 additions & 106 deletions ballista/client/tests/context_unsupported.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,112 +144,6 @@ mod unsupported {
"+----+----------+---------------------+",
];

assert_batches_eq!(expected, &result);
}
#[rstest]
#[case::standalone(standalone_context())]
#[case::remote(remote_context())]
#[tokio::test]
#[should_panic]
// "Error: Internal(failed to serialize logical plan: Internal(LogicalPlan serde is not yet implemented for Dml))"
async fn should_support_sql_insert_into(
#[future(awt)]
#[case]
ctx: SessionContext,
test_data: String,
) {
ctx.register_parquet(
"test",
&format!("{test_data}/alltypes_plain.parquet"),
Default::default(),
)
.await
.unwrap();
let write_dir = tempfile::tempdir().expect("temporary directory to be created");
let write_dir_path = write_dir
.path()
.to_str()
.expect("path to be converted to str");

ctx.sql("select * from test")
.await
.unwrap()
.write_parquet(write_dir_path, Default::default(), Default::default())
.await
.unwrap();

ctx.register_parquet("written_table", write_dir_path, Default::default())
.await
.unwrap();

let _ = ctx
.sql("INSERT INTO written_table select * from written_table")
.await
.unwrap()
.collect()
.await
.unwrap();

let result = ctx
.sql("select id, string_col, timestamp_col from written_table where id > 4 order by id")
.await.unwrap()
.collect()
.await.unwrap();

let expected = [
"+----+------------+---------------------+",
"| id | string_col | timestamp_col |",
"+----+------------+---------------------+",
"| 5 | 31 | 2009-03-01T00:01:00 |",
"| 5 | 31 | 2009-03-01T00:01:00 |",
"| 6 | 30 | 2009-04-01T00:00:00 |",
"| 6 | 30 | 2009-04-01T00:00:00 |",
"| 7 | 31 | 2009-04-01T00:01:00 |",
"| 7 | 31 | 2009-04-01T00:01:00 |",
"+----+------------+---------------------+",
];

assert_batches_eq!(expected, &result);
}

/// looks like `ctx.enable_url_table()` changes session context id.
///
/// Error returned:
/// ```
/// Failed to load SessionContext for session ID b5530099-63d1-43b1-9e11-87ac83bb33e5:
/// General error: No session for b5530099-63d1-43b1-9e11-87ac83bb33e5 found
/// ```
#[rstest]
#[case::standalone(standalone_context())]
#[case::remote(remote_context())]
#[tokio::test]
#[should_panic]
async fn should_execute_sql_show_with_url_table(
#[future(awt)]
#[case]
ctx: SessionContext,
test_data: String,
) {
let ctx = ctx.enable_url_table();

let result = ctx
.sql(&format!("select string_col, timestamp_col from '{test_data}/alltypes_plain.parquet' where id > 4"))
.await
.unwrap()
.collect()
.await
.unwrap();

let expected = [
"+------------+---------------------+",
"| string_col | timestamp_col |",
"+------------+---------------------+",
"| 31 | 2009-03-01T00:01:00 |",
"| 30 | 2009-04-01T00:00:00 |",
"| 31 | 2009-04-01T00:01:00 |",
"+------------+---------------------+",
];

assert_batches_eq!(expected, &result);
}
}
18 changes: 17 additions & 1 deletion ballista/core/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ message LogicalPlanNode {
UnnestNode unnest = 30;
RecursiveQueryNode recursive_query = 31;
CteWorkTableScanNode cte_work_table_scan = 32;
DmlNode dml = 33;
}
}

Expand Down Expand Up @@ -264,6 +265,21 @@ message CopyToNode {
repeated string partition_by = 7;
}

message DmlNode{
enum Type {
UPDATE = 0;
DELETE = 1;
INSERT_APPEND = 2;
INSERT_OVERWRITE = 3;
INSERT_REPLACE = 4;
CTAS = 5;
}
Type dml_type = 1;
LogicalPlanNode input = 2;
TableReference table_name = 3;
datafusion_common.DfSchema schema = 4;
}

message UnnestNode {
LogicalPlanNode input = 1;
repeated datafusion_common.Column exec_columns = 2;
Expand Down Expand Up @@ -1255,4 +1271,4 @@ message RecursiveQueryNode {
message CteWorkTableScanNode {
string name = 1;
datafusion_common.Schema schema = 2;
}
}

0 comments on commit db30f68

Please sign in to comment.