From af61919204faa2d97cb9ce137a49420419cd8181 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Mon, 24 Oct 2022 15:24:21 -0600 Subject: [PATCH] Add serde for plans with tables from `TableProviderFactory`s (#3907) * Can compile and run test Failing on scheduler due to no factories Tests pass Back to "no object store available for delta-rs://home-bgardner-workspace" Switch back to git refs CI fixes Add roundtrip test Passing deltalake test Passing serde test Remove unrelated refactor Formatting Fix typo that was hard to debug CI fixes delta & ballista tests pass * Take Andy's advice and turn it async * Fix CI * No suitable object store on executor * Fix test * Fix test * Bump CI * Update datafusion/core/src/datasource/datasource.rs Co-authored-by: xudong.w * Update datafusion/proto/src/bytes/mod.rs Co-authored-by: Andrew Lamb * Update datafusion/proto/src/bytes/mod.rs Co-authored-by: Andrew Lamb Co-authored-by: xudong.w Co-authored-by: Andrew Lamb --- datafusion/core/src/datasource/datasource.rs | 4 +- datafusion/core/src/execution/context.rs | 7 +- datafusion/core/src/test_util.rs | 60 +++++++- datafusion/core/tests/sql/create_drop.rs | 45 +----- datafusion/proto/Cargo.toml | 1 + datafusion/proto/README.md | 2 +- datafusion/proto/examples/plan_serde.rs | 2 +- datafusion/proto/proto/datafusion.proto | 10 ++ datafusion/proto/src/bytes/mod.rs | 47 ++++-- datafusion/proto/src/lib.rs | 138 +++++++++++++++++- datafusion/proto/src/logical_plan.rs | 145 ++++++++++++++++--- 11 files changed, 372 insertions(+), 89 deletions(-) diff --git a/datafusion/core/src/datasource/datasource.rs b/datafusion/core/src/datasource/datasource.rs index 84111fed06ca7..38e3193c4935f 100644 --- a/datafusion/core/src/datasource/datasource.rs +++ b/datafusion/core/src/datasource/datasource.rs @@ -85,6 +85,6 @@ pub trait TableProvider: Sync + Send { /// from a directory of files only when that name is referenced. #[async_trait] pub trait TableProviderFactory: Sync + Send { - /// Create a TableProvider given name and url - async fn create(&self, name: &str, url: &str) -> Result>; + /// Create a TableProvider with the given url + async fn create(&self, url: &str) -> Result>; } diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index f140ce1c3b98d..c66199e073b71 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -418,19 +418,18 @@ impl SessionContext { cmd: &CreateExternalTable, ) -> Result> { let state = self.state.read().clone(); + let file_type = cmd.file_type.to_lowercase(); let factory = &state .runtime_env .table_factories - .get(&cmd.file_type) + .get(file_type.as_str()) .ok_or_else(|| { DataFusionError::Execution(format!( "Unable to find factory for {}", cmd.file_type )) })?; - let table = (*factory) - .create(cmd.name.as_str(), cmd.location.as_str()) - .await?; + let table = (*factory).create(cmd.location.as_str()).await?; self.register_table(cmd.name.as_str(), table)?; let plan = LogicalPlanBuilder::empty(false).build()?; Ok(Arc::new(DataFrame::new(self.state.clone(), &plan))) diff --git a/datafusion/core/src/test_util.rs b/datafusion/core/src/test_util.rs index d92b9db6082c2..769ab47feb69f 100644 --- a/datafusion/core/src/test_util.rs +++ b/datafusion/core/src/test_util.rs @@ -17,13 +17,19 @@ //! Utility functions to make testing DataFusion based crates easier +use std::any::Any; use std::collections::BTreeMap; use std::{env, error::Error, path::PathBuf, sync::Arc}; -use crate::datasource::{empty::EmptyTable, provider_as_source}; +use crate::datasource::datasource::TableProviderFactory; +use crate::datasource::{empty::EmptyTable, provider_as_source, TableProvider}; +use crate::execution::context::SessionState; use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE}; +use crate::physical_plan::ExecutionPlan; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use async_trait::async_trait; use datafusion_common::DataFusionError; +use datafusion_expr::{Expr, TableType}; /// Compares formatted output of a record batch with an expected /// vector of strings, with the result of pretty formatting record @@ -317,6 +323,58 @@ pub fn aggr_test_schema_with_missing_col() -> SchemaRef { Arc::new(schema) } +/// TableFactory for tests +pub struct TestTableFactory {} + +#[async_trait] +impl TableProviderFactory for TestTableFactory { + async fn create( + &self, + url: &str, + ) -> datafusion_common::Result> { + Ok(Arc::new(TestTableProvider { + url: url.to_string(), + })) + } +} + +/// TableProvider for testing purposes +pub struct TestTableProvider { + /// URL of table files or folder + pub url: String, +} + +impl TestTableProvider {} + +#[async_trait] +impl TableProvider for TestTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Decimal128(15, 2), true), + ]); + Arc::new(schema) + } + + fn table_type(&self) -> TableType { + unimplemented!("TestTableProvider is a stub for testing.") + } + + async fn scan( + &self, + _ctx: &SessionState, + _projection: &Option>, + _filters: &[Expr], + _limit: Option, + ) -> datafusion_common::Result> { + unimplemented!("TestTableProvider is a stub for testing.") + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/core/tests/sql/create_drop.rs b/datafusion/core/tests/sql/create_drop.rs index 40d61767f9544..567ee022b526e 100644 --- a/datafusion/core/tests/sql/create_drop.rs +++ b/datafusion/core/tests/sql/create_drop.rs @@ -15,15 +15,12 @@ // specific language governing permissions and limitations // under the License. -use async_trait::async_trait; -use std::any::Any; use std::collections::HashMap; use std::io::Write; use datafusion::datasource::datasource::TableProviderFactory; -use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; -use datafusion_expr::TableType; +use datafusion::test_util::TestTableFactory; use tempfile::TempDir; use super::*; @@ -369,49 +366,11 @@ async fn create_pipe_delimited_csv_table() -> Result<()> { Ok(()) } -struct TestTableProvider {} - -impl TestTableProvider {} - -#[async_trait] -impl TableProvider for TestTableProvider { - fn as_any(&self) -> &dyn Any { - unimplemented!("TestTableProvider is a stub for testing.") - } - - fn schema(&self) -> SchemaRef { - unimplemented!("TestTableProvider is a stub for testing.") - } - - fn table_type(&self) -> TableType { - unimplemented!("TestTableProvider is a stub for testing.") - } - - async fn scan( - &self, - _ctx: &SessionState, - _projection: &Option>, - _filters: &[Expr], - _limit: Option, - ) -> Result> { - unimplemented!("TestTableProvider is a stub for testing.") - } -} - -struct TestTableFactory {} - -#[async_trait] -impl TableProviderFactory for TestTableFactory { - async fn create(&self, _name: &str, _url: &str) -> Result> { - Ok(Arc::new(TestTableProvider {})) - } -} - #[tokio::test] async fn create_custom_table() -> Result<()> { let mut table_factories: HashMap> = HashMap::new(); - table_factories.insert("DELTATABLE".to_string(), Arc::new(TestTableFactory {})); + table_factories.insert("deltatable".to_string(), Arc::new(TestTableFactory {})); let cfg = RuntimeConfig::new().with_table_factories(table_factories); let env = RuntimeEnv::new(cfg).unwrap(); let ses = SessionConfig::new(); diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index ff2af33b22d84..ef5d8211248af 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -41,6 +41,7 @@ json = ["pbjson", "pbjson-build", "serde", "serde_json"] [dependencies] arrow = "25.0.0" +async-trait = "0.1.41" datafusion = { path = "../core", version = "13.0.0" } datafusion-common = { path = "../common", version = "13.0.0" } datafusion-expr = { path = "../expr", version = "13.0.0" } diff --git a/datafusion/proto/README.md b/datafusion/proto/README.md index a3878447e042e..8c8962e506a67 100644 --- a/datafusion/proto/README.md +++ b/datafusion/proto/README.md @@ -63,7 +63,7 @@ async fn main() -> Result<()> { ?; let plan = ctx.table("t1")?.to_logical_plan()?; let bytes = logical_plan_to_bytes(&plan)?; - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx).await?; assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip)); Ok(()) } diff --git a/datafusion/proto/examples/plan_serde.rs b/datafusion/proto/examples/plan_serde.rs index d98d88b2a46a6..eed372476fff0 100644 --- a/datafusion/proto/examples/plan_serde.rs +++ b/datafusion/proto/examples/plan_serde.rs @@ -26,7 +26,7 @@ async fn main() -> Result<()> { .await?; let plan = ctx.table("t1")?.to_logical_plan()?; let bytes = logical_plan_to_bytes(&plan)?; - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx).await?; assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip)); Ok(()) } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index d61f52ee7bb27..45c6071259091 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -70,6 +70,7 @@ message LogicalPlanNode { CreateViewNode create_view = 22; DistinctNode distinct = 23; ViewTableScanNode view_scan = 24; + CustomTableScanNode custom_scan = 25; } } @@ -118,6 +119,15 @@ message ViewTableScanNode { string definition = 5; } +// Logical Plan to Scan a CustomTableProvider registered at runtime +message CustomTableScanNode { + string table_name = 1; + ProjectionColumns projection = 2; + datafusion.Schema schema = 3; + repeated datafusion.LogicalExprNode filters = 4; + bytes custom_table_data = 5; +} + message ProjectionNode { LogicalPlanNode input = 1; repeated datafusion.LogicalExprNode expr = 2; diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 8eab5baebe722..5695bf50686a0 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -18,12 +18,16 @@ //! Serialization / Deserialization to Bytes use crate::logical_plan::{AsLogicalPlan, LogicalExtensionCodec}; use crate::{from_proto::parse_expr, protobuf}; +use arrow::datatypes::SchemaRef; +use async_trait::async_trait; +use datafusion::datasource::TableProvider; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::{Expr, Extension, LogicalPlan}; use prost::{ bytes::{Bytes, BytesMut}, Message, }; +use std::sync::Arc; // Reexport Bytes which appears in the API use datafusion::execution::registry::FunctionRegistry; @@ -132,24 +136,27 @@ pub fn logical_plan_to_bytes_with_extension_codec( /// Deserialize a LogicalPlan from json #[cfg(feature = "json")] -pub fn logical_plan_from_json(json: &str, ctx: &SessionContext) -> Result { +pub async fn logical_plan_from_json( + json: &str, + ctx: &SessionContext, +) -> Result { let back: protobuf::LogicalPlanNode = serde_json::from_str(json) .map_err(|e| DataFusionError::Plan(format!("Error serializing plan: {}", e)))?; let extension_codec = DefaultExtensionCodec {}; - back.try_into_logical_plan(ctx, &extension_codec) + back.try_into_logical_plan(ctx, &extension_codec).await } /// Deserialize a LogicalPlan from bytes -pub fn logical_plan_from_bytes( +pub async fn logical_plan_from_bytes( bytes: &[u8], ctx: &SessionContext, ) -> Result { let extension_codec = DefaultExtensionCodec {}; - logical_plan_from_bytes_with_extension_codec(bytes, ctx, &extension_codec) + logical_plan_from_bytes_with_extension_codec(bytes, ctx, &extension_codec).await } /// Deserialize a LogicalPlan from bytes -pub fn logical_plan_from_bytes_with_extension_codec( +pub async fn logical_plan_from_bytes_with_extension_codec( bytes: &[u8], ctx: &SessionContext, extension_codec: &dyn LogicalExtensionCodec, @@ -157,12 +164,13 @@ pub fn logical_plan_from_bytes_with_extension_codec( let protobuf = protobuf::LogicalPlanNode::decode(bytes).map_err(|e| { DataFusionError::Plan(format!("Error decoding expr as protobuf: {}", e)) })?; - protobuf.try_into_logical_plan(ctx, extension_codec) + protobuf.try_into_logical_plan(ctx, extension_codec).await } #[derive(Debug)] struct DefaultExtensionCodec {} +#[async_trait] impl LogicalExtensionCodec for DefaultExtensionCodec { fn try_decode( &self, @@ -180,6 +188,27 @@ impl LogicalExtensionCodec for DefaultExtensionCodec { "No extension codec provided".to_string(), )) } + + async fn try_decode_table_provider( + &self, + _buf: &[u8], + _schema: SchemaRef, + _ctx: &SessionContext, + ) -> std::result::Result, DataFusionError> { + Err(DataFusionError::NotImplemented( + "No codec provided to for TableProviders".to_string(), + )) + } + + fn try_encode_table_provider( + &self, + _node: Arc, + _buf: &mut Vec, + ) -> std::result::Result<(), DataFusionError> { + Err(DataFusionError::NotImplemented( + "No codec provided to for TableProviders".to_string(), + )) + } } #[cfg(test)] @@ -214,12 +243,12 @@ mod test { assert_eq!(actual, expected); } - #[test] + #[tokio::test] #[cfg(feature = "json")] - fn json_to_plan() { + async fn json_to_plan() { let input = r#"{"emptyRelation":{}}"#.to_string(); let ctx = SessionContext::new(); - let actual = logical_plan_from_json(&input, &ctx).unwrap(); + let actual = logical_plan_from_json(&input, &ctx).await.unwrap(); let result = matches!(actual, LogicalPlan::EmptyRelation(_)); assert!(result, "Should parse empty relation"); } diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index 8dd1b55f5baee..21ded84a57ca7 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -51,7 +51,7 @@ mod roundtrip_tests { logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec, }; use crate::logical_plan::LogicalExtensionCodec; - use arrow::datatypes::Schema; + use arrow::datatypes::{Schema, SchemaRef}; use arrow::{ array::ArrayRef, datatypes::{ @@ -59,8 +59,15 @@ mod roundtrip_tests { TimeUnit, UnionMode, }, }; + use async_trait::async_trait; + use datafusion::datasource::datasource::TableProviderFactory; + use datafusion::datasource::TableProvider; + use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::physical_plan::functions::make_scalar_function; - use datafusion::prelude::{create_udf, CsvReadOptions, SessionContext}; + use datafusion::prelude::{ + create_udf, CsvReadOptions, SessionConfig, SessionContext, + }; + use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue}; use datafusion_expr::create_udaf; use datafusion_expr::expr::{Between, BinaryExpr, Case, Cast, GroupingSet, Like}; @@ -72,6 +79,7 @@ mod roundtrip_tests { }; use prost::Message; use std::any::Any; + use std::collections::HashMap; use std::fmt; use std::fmt::Debug; use std::fmt::Formatter; @@ -122,7 +130,8 @@ mod roundtrip_tests { let bytes = logical_plan_to_bytes_with_extension_codec(&topk_plan, &extension_codec)?; let logical_round_trip = - logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &extension_codec)?; + logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &extension_codec) + .await?; assert_eq!( format!("{:?}", topk_plan), format!("{:?}", logical_round_trip) @@ -130,6 +139,101 @@ mod roundtrip_tests { Ok(()) } + #[derive(Clone, PartialEq, Eq, ::prost::Message)] + pub struct TestTableProto { + /// URL of the table root + #[prost(string, tag = "1")] + pub url: String, + } + + #[derive(Debug)] + pub struct TestTableProviderCodec {} + + #[async_trait] + impl LogicalExtensionCodec for TestTableProviderCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[LogicalPlan], + _ctx: &SessionContext, + ) -> Result { + Err(DataFusionError::NotImplemented( + "No extension codec provided".to_string(), + )) + } + + fn try_encode( + &self, + _node: &Extension, + _buf: &mut Vec, + ) -> Result<(), DataFusionError> { + Err(DataFusionError::NotImplemented( + "No extension codec provided".to_string(), + )) + } + + async fn try_decode_table_provider( + &self, + buf: &[u8], + _schema: SchemaRef, + ctx: &SessionContext, + ) -> Result, DataFusionError> { + let msg = TestTableProto::decode(buf).map_err(|_| { + DataFusionError::Internal("Error encoding test table".to_string()) + })?; + let factory = ctx + .state + .read() + .runtime_env + .table_factories + .get("testtable") + .expect("Unable to find testtable factory") + .clone(); + let provider = (*factory).create(msg.url.as_str()).await?; + Ok(provider) + } + + fn try_encode_table_provider( + &self, + node: Arc, + buf: &mut Vec, + ) -> Result<(), DataFusionError> { + let table = node + .as_ref() + .as_any() + .downcast_ref::() + .expect("Can't encode non-test tables"); + let msg = TestTableProto { + url: table.url.clone(), + }; + msg.encode(buf).map_err(|_| { + DataFusionError::Internal("Error encoding test table".to_string()) + }) + } + } + + #[tokio::test] + async fn roundtrip_custom_tables() -> Result<(), DataFusionError> { + let mut table_factories: HashMap> = + HashMap::new(); + table_factories.insert("testtable".to_string(), Arc::new(TestTableFactory {})); + let cfg = RuntimeConfig::new().with_table_factories(table_factories); + let env = RuntimeEnv::new(cfg).unwrap(); + let ses = SessionConfig::new(); + let ctx = SessionContext::with_config_rt(ses, Arc::new(env)); + + let sql = "CREATE EXTERNAL TABLE t STORED AS testtable LOCATION 's3://bucket/schema/table';"; + ctx.sql(sql).await.unwrap(); + + let codec = TestTableProviderCodec {}; + let scan = ctx.table("t")?.to_logical_plan()?; + let bytes = logical_plan_to_bytes_with_extension_codec(&scan, &codec)?; + let logical_round_trip = + logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec).await?; + assert_eq!(format!("{:?}", scan), format!("{:?}", logical_round_trip)); + Ok(()) + } + #[tokio::test] async fn roundtrip_logical_plan_aggregation() -> Result<(), DataFusionError> { let ctx = SessionContext::new(); @@ -153,7 +257,7 @@ mod roundtrip_tests { println!("{:?}", plan); let bytes = logical_plan_to_bytes(&plan)?; - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx).await?; assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip)); Ok(()) @@ -166,7 +270,7 @@ mod roundtrip_tests { .await?; let plan = ctx.table("t1")?.to_logical_plan()?; let bytes = logical_plan_to_bytes(&plan)?; - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx).await?; assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip)); Ok(()) } @@ -180,7 +284,7 @@ mod roundtrip_tests { .await?; let plan = ctx.sql("SELECT * FROM view_t1").await?.to_logical_plan()?; let bytes = logical_plan_to_bytes(&plan)?; - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx).await?; assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip)); Ok(()) } @@ -263,6 +367,7 @@ mod roundtrip_tests { #[derive(Debug)] pub struct TopKExtensionCodec {} + #[async_trait] impl LogicalExtensionCodec for TopKExtensionCodec { fn try_decode( &self, @@ -325,6 +430,27 @@ mod roundtrip_tests { )) } } + + async fn try_decode_table_provider( + &self, + _buf: &[u8], + _schema: SchemaRef, + _ctx: &SessionContext, + ) -> Result, DataFusionError> { + Err(DataFusionError::Internal( + "unsupported plan type".to_string(), + )) + } + + fn try_encode_table_provider( + &self, + _node: Arc, + _buf: &mut Vec, + ) -> Result<(), DataFusionError> { + Err(DataFusionError::Internal( + "unsupported plan type".to_string(), + )) + } } #[test] diff --git a/datafusion/proto/src/logical_plan.rs b/datafusion/proto/src/logical_plan.rs index d61bb2d65baea..b9f34ff02265e 100644 --- a/datafusion/proto/src/logical_plan.rs +++ b/datafusion/proto/src/logical_plan.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use crate::protobuf::logical_plan_node::LogicalPlanType::CustomScan; +use crate::protobuf::CustomTableScanNode; use crate::{ from_proto::{self, parse_expr}, protobuf::{ @@ -23,7 +25,11 @@ use crate::{ }, to_proto, }; -use arrow::datatypes::Schema; +use arrow::datatypes::{Schema, SchemaRef}; +use async_trait::async_trait; +use datafusion::datasource::TableProvider; +use datafusion::execution::FunctionRegistry; +use datafusion::physical_plan::ExecutionPlan; use datafusion::{ datasource::{ file_format::{ @@ -35,7 +41,7 @@ use datafusion::{ datasource::{provider_as_source, source_as_provider}, prelude::SessionContext, }; -use datafusion_common::{Column, DataFusionError}; +use datafusion_common::{context, Column, DataFusionError}; use datafusion_expr::{ logical_plan::{ Aggregate, CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateView, @@ -69,6 +75,7 @@ pub(crate) fn proto_error>(message: S) -> DataFusionError { DataFusionError::Internal(message.into()) } +#[async_trait] pub trait AsLogicalPlan: Debug + Send + Sync + Clone { fn try_decode(buf: &[u8]) -> Result where @@ -79,7 +86,7 @@ pub trait AsLogicalPlan: Debug + Send + Sync + Clone { B: BufMut, Self: Sized; - fn try_into_logical_plan( + async fn try_into_logical_plan( &self, ctx: &SessionContext, extension_codec: &dyn LogicalExtensionCodec, @@ -93,6 +100,22 @@ pub trait AsLogicalPlan: Debug + Send + Sync + Clone { Self: Sized; } +pub trait PhysicalExtensionCodec: Debug + Send + Sync { + fn try_decode( + &self, + buf: &[u8], + inputs: &[Arc], + registry: &dyn FunctionRegistry, + ) -> Result, DataFusionError>; + + fn try_encode( + &self, + node: Arc, + buf: &mut Vec, + ) -> Result<(), DataFusionError>; +} + +#[async_trait] pub trait LogicalExtensionCodec: Debug + Send + Sync { fn try_decode( &self, @@ -106,11 +129,25 @@ pub trait LogicalExtensionCodec: Debug + Send + Sync { node: &Extension, buf: &mut Vec, ) -> Result<(), DataFusionError>; + + async fn try_decode_table_provider( + &self, + buf: &[u8], + schema: SchemaRef, + ctx: &SessionContext, + ) -> Result, DataFusionError>; + + fn try_encode_table_provider( + &self, + node: Arc, + buf: &mut Vec, + ) -> Result<(), DataFusionError>; } #[derive(Debug, Clone)] pub struct DefaultLogicalExtensionCodec {} +#[async_trait] impl LogicalExtensionCodec for DefaultLogicalExtensionCodec { fn try_decode( &self, @@ -132,13 +169,34 @@ impl LogicalExtensionCodec for DefaultLogicalExtensionCodec { "LogicalExtensionCodec is not provided".to_string(), )) } + + async fn try_decode_table_provider( + &self, + _buf: &[u8], + _schema: SchemaRef, + _ctx: &SessionContext, + ) -> Result, DataFusionError> { + Err(DataFusionError::NotImplemented( + "LogicalExtensionCodec is not provided".to_string(), + )) + } + + fn try_encode_table_provider( + &self, + _node: Arc, + _buf: &mut Vec, + ) -> Result<(), DataFusionError> { + Err(DataFusionError::NotImplemented( + "LogicalExtensionCodec is not provided".to_string(), + )) + } } #[macro_export] macro_rules! into_logical_plan { ($PB:expr, $CTX:expr, $CODEC:expr) => {{ if let Some(field) = $PB.as_ref() { - field.as_ref().try_into_logical_plan($CTX, $CODEC) + field.as_ref().try_into_logical_plan($CTX, $CODEC).await } else { Err(proto_error("Missing required field in protobuf")) } @@ -222,6 +280,7 @@ impl From for protobuf::JoinConstraint { } } +#[async_trait] impl AsLogicalPlan for LogicalPlanNode { fn try_decode(buf: &[u8]) -> Result where @@ -242,7 +301,7 @@ impl AsLogicalPlan for LogicalPlanNode { }) } - fn try_into_logical_plan( + async fn try_into_logical_plan( &self, ctx: &SessionContext, extension_codec: &dyn LogicalExtensionCodec, @@ -410,6 +469,36 @@ impl AsLogicalPlan for LogicalPlanNode { )? .build() } + LogicalPlanType::CustomScan(scan) => { + let schema: Schema = convert_required!(scan.schema)?; + let schema = Arc::new(schema); + let mut projection = None; + if let Some(columns) = &scan.projection { + let column_indices = columns + .columns + .iter() + .map(|name| schema.index_of(name)) + .collect::, _>>()?; + projection = Some(column_indices); + } + + let filters = scan + .filters + .iter() + .map(|expr| parse_expr(expr, ctx)) + .collect::, _>>()?; + let provider = extension_codec + .try_decode_table_provider(&scan.custom_table_data, schema, ctx) + .await?; + + LogicalPlanBuilder::scan_with_filters( + &scan.table_name, + provider_as_source(provider), + projection, + filters, + )? + .build() + } LogicalPlanType::Sort(sort) => { let input: LogicalPlan = into_logical_plan!(sort.input, ctx, extension_codec)?; @@ -502,7 +591,7 @@ impl AsLogicalPlan for LogicalPlanNode { .input.clone().ok_or_else(|| DataFusionError::Internal(String::from( "Protobuf deserialization error, CreateViewNode has invalid LogicalPlan input.", )))? - .try_into_logical_plan(ctx, extension_codec)?; + .try_into_logical_plan(ctx, extension_codec).await?; let definition = if !create_view.definition.is_empty() { Some(create_view.definition.clone()) } else { @@ -625,11 +714,11 @@ impl AsLogicalPlan for LogicalPlanNode { builder.build() } LogicalPlanType::Union(union) => { - let mut input_plans: Vec = union - .inputs - .iter() - .map(|i| i.try_into_logical_plan(ctx, extension_codec)) - .collect::>()?; + let mut input_plans: Vec = vec![]; + for i in union.inputs.iter() { + let res = i.try_into_logical_plan(ctx, extension_codec).await?; + input_plans.push(res); + } if input_plans.len() < 2 { return Err( DataFusionError::Internal(String::from( @@ -653,10 +742,11 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlanBuilder::from(left).cross_join(&right)?.build() } LogicalPlanType::Extension(LogicalExtensionNode { node, inputs }) => { - let input_plans: Vec = inputs - .iter() - .map(|i| i.try_into_logical_plan(ctx, extension_codec)) - .collect::>()?; + let mut input_plans: Vec = vec![]; + for i in inputs.iter() { + let res = i.try_into_logical_plan(ctx, extension_codec).await?; + input_plans.push(res); + } let extension_node = extension_codec.try_decode(node, &input_plans, ctx)?; @@ -736,9 +826,9 @@ impl AsLogicalPlan for LogicalPlanNode { projection, .. }) => { - let source = source_as_provider(source)?; - let schema = source.schema(); - let source = source.as_any(); + let provider = source_as_provider(source)?; + let schema = provider.schema(); + let source = provider.as_any(); let projection = match projection { None => None, @@ -830,10 +920,21 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } else { - Err(DataFusionError::Internal(format!( - "logical plan to_proto unsupported table provider {:?}", - source - ))) + let mut bytes = vec![]; + extension_codec + .try_encode_table_provider(provider, &mut bytes) + .map_err(|e| context!("Error serializing custom table", e))?; + let scan = CustomScan(CustomTableScanNode { + table_name: table_name.clone(), + projection, + schema: Some(schema), + filters, + custom_table_data: bytes, + }); + let node = LogicalPlanNode { + logical_plan_type: Some(scan), + }; + Ok(node) } } LogicalPlan::Projection(Projection {