Skip to content

Commit

Permalink
Add serde for plans with tables from TableProviderFactorys (#3907)
Browse files Browse the repository at this point in the history
* Can compile and run test

Failing on scheduler due to no factories

Tests pass

Back to "no object store available for delta-rs://home-bgardner-workspace"

Switch back to git refs

CI fixes

Add roundtrip test

Passing deltalake test

Passing serde test

Remove unrelated refactor

Formatting

Fix typo that was hard to debug

CI fixes

delta & ballista tests pass

* Take Andy's advice and turn it async

* Fix CI

* No suitable object store on executor

* Fix test

* Fix test

* Bump CI

* Update datafusion/core/src/datasource/datasource.rs

Co-authored-by: xudong.w <[email protected]>

* Update datafusion/proto/src/bytes/mod.rs

Co-authored-by: Andrew Lamb <[email protected]>

* Update datafusion/proto/src/bytes/mod.rs

Co-authored-by: Andrew Lamb <[email protected]>

Co-authored-by: xudong.w <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
3 people authored Oct 24, 2022
1 parent e1f866e commit 9595b8d
Show file tree
Hide file tree
Showing 11 changed files with 372 additions and 89 deletions.
4 changes: 2 additions & 2 deletions datafusion/core/src/datasource/datasource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,6 @@ pub trait TableProvider: Sync + Send {
/// from a directory of files only when that name is referenced.
#[async_trait]
pub trait TableProviderFactory: Sync + Send {
/// Create a TableProvider given name and url
async fn create(&self, name: &str, url: &str) -> Result<Arc<dyn TableProvider>>;
/// Create a TableProvider with the given url
async fn create(&self, url: &str) -> Result<Arc<dyn TableProvider>>;
}
7 changes: 3 additions & 4 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -418,19 +418,18 @@ impl SessionContext {
cmd: &CreateExternalTable,
) -> Result<Arc<DataFrame>> {
let state = self.state.read().clone();
let file_type = cmd.file_type.to_lowercase();
let factory = &state
.runtime_env
.table_factories
.get(&cmd.file_type)
.get(file_type.as_str())
.ok_or_else(|| {
DataFusionError::Execution(format!(
"Unable to find factory for {}",
cmd.file_type
))
})?;
let table = (*factory)
.create(cmd.name.as_str(), cmd.location.as_str())
.await?;
let table = (*factory).create(cmd.location.as_str()).await?;
self.register_table(cmd.name.as_str(), table)?;
let plan = LogicalPlanBuilder::empty(false).build()?;
Ok(Arc::new(DataFrame::new(self.state.clone(), &plan)))
Expand Down
60 changes: 59 additions & 1 deletion datafusion/core/src/test_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,19 @@

//! Utility functions to make testing DataFusion based crates easier
use std::any::Any;
use std::collections::BTreeMap;
use std::{env, error::Error, path::PathBuf, sync::Arc};

use crate::datasource::{empty::EmptyTable, provider_as_source};
use crate::datasource::datasource::TableProviderFactory;
use crate::datasource::{empty::EmptyTable, provider_as_source, TableProvider};
use crate::execution::context::SessionState;
use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE};
use crate::physical_plan::ExecutionPlan;
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use async_trait::async_trait;
use datafusion_common::DataFusionError;
use datafusion_expr::{Expr, TableType};

/// Compares formatted output of a record batch with an expected
/// vector of strings, with the result of pretty formatting record
Expand Down Expand Up @@ -317,6 +323,58 @@ pub fn aggr_test_schema_with_missing_col() -> SchemaRef {
Arc::new(schema)
}

/// TableFactory for tests
pub struct TestTableFactory {}

#[async_trait]
impl TableProviderFactory for TestTableFactory {
async fn create(
&self,
url: &str,
) -> datafusion_common::Result<Arc<dyn TableProvider>> {
Ok(Arc::new(TestTableProvider {
url: url.to_string(),
}))
}
}

/// TableProvider for testing purposes
pub struct TestTableProvider {
/// URL of table files or folder
pub url: String,
}

impl TestTableProvider {}

#[async_trait]
impl TableProvider for TestTableProvider {
fn as_any(&self) -> &dyn Any {
self
}

fn schema(&self) -> SchemaRef {
let schema = Schema::new(vec![
Field::new("a", DataType::Int64, true),
Field::new("b", DataType::Decimal128(15, 2), true),
]);
Arc::new(schema)
}

fn table_type(&self) -> TableType {
unimplemented!("TestTableProvider is a stub for testing.")
}

async fn scan(
&self,
_ctx: &SessionState,
_projection: &Option<Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
unimplemented!("TestTableProvider is a stub for testing.")
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
45 changes: 2 additions & 43 deletions datafusion/core/tests/sql/create_drop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,12 @@
// specific language governing permissions and limitations
// under the License.

use async_trait::async_trait;
use std::any::Any;
use std::collections::HashMap;
use std::io::Write;

use datafusion::datasource::datasource::TableProviderFactory;
use datafusion::execution::context::SessionState;
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion_expr::TableType;
use datafusion::test_util::TestTableFactory;
use tempfile::TempDir;

use super::*;
Expand Down Expand Up @@ -369,49 +366,11 @@ async fn create_pipe_delimited_csv_table() -> Result<()> {
Ok(())
}

struct TestTableProvider {}

impl TestTableProvider {}

#[async_trait]
impl TableProvider for TestTableProvider {
fn as_any(&self) -> &dyn Any {
unimplemented!("TestTableProvider is a stub for testing.")
}

fn schema(&self) -> SchemaRef {
unimplemented!("TestTableProvider is a stub for testing.")
}

fn table_type(&self) -> TableType {
unimplemented!("TestTableProvider is a stub for testing.")
}

async fn scan(
&self,
_ctx: &SessionState,
_projection: &Option<Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
unimplemented!("TestTableProvider is a stub for testing.")
}
}

struct TestTableFactory {}

#[async_trait]
impl TableProviderFactory for TestTableFactory {
async fn create(&self, _name: &str, _url: &str) -> Result<Arc<dyn TableProvider>> {
Ok(Arc::new(TestTableProvider {}))
}
}

#[tokio::test]
async fn create_custom_table() -> Result<()> {
let mut table_factories: HashMap<String, Arc<dyn TableProviderFactory>> =
HashMap::new();
table_factories.insert("DELTATABLE".to_string(), Arc::new(TestTableFactory {}));
table_factories.insert("deltatable".to_string(), Arc::new(TestTableFactory {}));
let cfg = RuntimeConfig::new().with_table_factories(table_factories);
let env = RuntimeEnv::new(cfg).unwrap();
let ses = SessionConfig::new();
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ json = ["pbjson", "pbjson-build", "serde", "serde_json"]

[dependencies]
arrow = "25.0.0"
async-trait = "0.1.41"
datafusion = { path = "../core", version = "13.0.0" }
datafusion-common = { path = "../common", version = "13.0.0" }
datafusion-expr = { path = "../expr", version = "13.0.0" }
Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async fn main() -> Result<()> {
?;
let plan = ctx.table("t1")?.to_logical_plan()?;
let bytes = logical_plan_to_bytes(&plan)?;
let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?;
let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx).await?;
assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip));
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/examples/plan_serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ async fn main() -> Result<()> {
.await?;
let plan = ctx.table("t1")?.to_logical_plan()?;
let bytes = logical_plan_to_bytes(&plan)?;
let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?;
let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx).await?;
assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip));
Ok(())
}
10 changes: 10 additions & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ message LogicalPlanNode {
CreateViewNode create_view = 22;
DistinctNode distinct = 23;
ViewTableScanNode view_scan = 24;
CustomTableScanNode custom_scan = 25;
}
}

Expand Down Expand Up @@ -118,6 +119,15 @@ message ViewTableScanNode {
string definition = 5;
}

// Logical Plan to Scan a CustomTableProvider registered at runtime
message CustomTableScanNode {
string table_name = 1;
ProjectionColumns projection = 2;
datafusion.Schema schema = 3;
repeated datafusion.LogicalExprNode filters = 4;
bytes custom_table_data = 5;
}

message ProjectionNode {
LogicalPlanNode input = 1;
repeated datafusion.LogicalExprNode expr = 2;
Expand Down
47 changes: 38 additions & 9 deletions datafusion/proto/src/bytes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@
//! Serialization / Deserialization to Bytes
use crate::logical_plan::{AsLogicalPlan, LogicalExtensionCodec};
use crate::{from_proto::parse_expr, protobuf};
use arrow::datatypes::SchemaRef;
use async_trait::async_trait;
use datafusion::datasource::TableProvider;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::{Expr, Extension, LogicalPlan};
use prost::{
bytes::{Bytes, BytesMut},
Message,
};
use std::sync::Arc;

// Reexport Bytes which appears in the API
use datafusion::execution::registry::FunctionRegistry;
Expand Down Expand Up @@ -132,37 +136,41 @@ pub fn logical_plan_to_bytes_with_extension_codec(

/// Deserialize a LogicalPlan from json
#[cfg(feature = "json")]
pub fn logical_plan_from_json(json: &str, ctx: &SessionContext) -> Result<LogicalPlan> {
pub async fn logical_plan_from_json(
json: &str,
ctx: &SessionContext,
) -> Result<LogicalPlan> {
let back: protobuf::LogicalPlanNode = serde_json::from_str(json)
.map_err(|e| DataFusionError::Plan(format!("Error serializing plan: {}", e)))?;
let extension_codec = DefaultExtensionCodec {};
back.try_into_logical_plan(ctx, &extension_codec)
back.try_into_logical_plan(ctx, &extension_codec).await
}

/// Deserialize a LogicalPlan from bytes
pub fn logical_plan_from_bytes(
pub async fn logical_plan_from_bytes(
bytes: &[u8],
ctx: &SessionContext,
) -> Result<LogicalPlan> {
let extension_codec = DefaultExtensionCodec {};
logical_plan_from_bytes_with_extension_codec(bytes, ctx, &extension_codec)
logical_plan_from_bytes_with_extension_codec(bytes, ctx, &extension_codec).await
}

/// Deserialize a LogicalPlan from bytes
pub fn logical_plan_from_bytes_with_extension_codec(
pub async fn logical_plan_from_bytes_with_extension_codec(
bytes: &[u8],
ctx: &SessionContext,
extension_codec: &dyn LogicalExtensionCodec,
) -> Result<LogicalPlan> {
let protobuf = protobuf::LogicalPlanNode::decode(bytes).map_err(|e| {
DataFusionError::Plan(format!("Error decoding expr as protobuf: {}", e))
})?;
protobuf.try_into_logical_plan(ctx, extension_codec)
protobuf.try_into_logical_plan(ctx, extension_codec).await
}

#[derive(Debug)]
struct DefaultExtensionCodec {}

#[async_trait]
impl LogicalExtensionCodec for DefaultExtensionCodec {
fn try_decode(
&self,
Expand All @@ -180,6 +188,27 @@ impl LogicalExtensionCodec for DefaultExtensionCodec {
"No extension codec provided".to_string(),
))
}

async fn try_decode_table_provider(
&self,
_buf: &[u8],
_schema: SchemaRef,
_ctx: &SessionContext,
) -> std::result::Result<Arc<dyn TableProvider>, DataFusionError> {
Err(DataFusionError::NotImplemented(
"No codec provided to for TableProviders".to_string(),
))
}

fn try_encode_table_provider(
&self,
_node: Arc<dyn TableProvider>,
_buf: &mut Vec<u8>,
) -> std::result::Result<(), DataFusionError> {
Err(DataFusionError::NotImplemented(
"No codec provided to for TableProviders".to_string(),
))
}
}

#[cfg(test)]
Expand Down Expand Up @@ -214,12 +243,12 @@ mod test {
assert_eq!(actual, expected);
}

#[test]
#[tokio::test]
#[cfg(feature = "json")]
fn json_to_plan() {
async fn json_to_plan() {
let input = r#"{"emptyRelation":{}}"#.to_string();
let ctx = SessionContext::new();
let actual = logical_plan_from_json(&input, &ctx).unwrap();
let actual = logical_plan_from_json(&input, &ctx).await.unwrap();
let result = matches!(actual, LogicalPlan::EmptyRelation(_));
assert!(result, "Should parse empty relation");
}
Expand Down
Loading

0 comments on commit 9595b8d

Please sign in to comment.