diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 53d1f26..e4440c9 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -23,10 +23,10 @@ jobs: - uses: arduino/setup-protoc@v3 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - - run: cargo rustdoc -p datafusion-federation -- --cfg docsrs + - run: cargo rustdoc -p datafusion-flight-sql-server -- --cfg docsrs - run: chmod -c -R +rX "target/doc" - run: touch target/doc/index.html - - run: echo "" > target/doc/index.html + - run: echo "" > target/doc/index.html - if: github.event_name == 'push' && github.ref == 'refs/heads/main' uses: actions/upload-pages-artifact@v3 with: diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0792434..5eb06a7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -70,4 +70,4 @@ jobs: with: repo-token: ${{ secrets.GITHUB_TOKEN }} - run: cargo build --all - - run: cargo package -p datafusion-federation --allow-dirty + - run: cargo package -p datafusion-flight-sql-server --allow-dirty diff --git a/Cargo.toml b/Cargo.toml index 4b2cf85..aa6a0d0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,6 @@ resolver = "2" members = [ - "datafusion-federation", "datafusion-flight-sql-server", "datafusion-flight-sql-table-provider", ] @@ -12,16 +11,15 @@ version = "0.3.5" edition = "2021" license = "Apache-2.0" readme = "README.md" -repository = "https://github.com/datafusion-contrib/datafusion-federation" +repository = "https://github.com/datafusion-contrib/datafusion-flight-sql-server" [workspace.dependencies] arrow = "53.3" arrow-flight = { version = "53.3", features = ["flight-sql-experimental"] } arrow-json = "53.3" -async-stream = "0.3.5" async-trait = "0.1.83" datafusion = "44.0.0" -datafusion-federation = { path = "./datafusion-federation", version = "0.3.5" } +datafusion-federation = { version = "0.3.5" } datafusion-substrait = "44.0.0" futures = "0.3.31" tokio = { version = "1.41", features = ["full"] } diff --git a/README.md b/README.md index 4891b12..6265e42 100644 --- a/README.md +++ b/README.md @@ -1,138 +1,52 @@ -# DataFusion Federation - -[![crates.io](https://img.shields.io/crates/v/datafusion-federation.svg)](https://crates.io/crates/datafusion-federation) -[![docs.rs](https://docs.rs/datafusion-federation/badge.svg)](https://docs.rs/datafusion-federation) - -DataFusion Federation allows -[DataFusion](https://github.com/apache/arrow-datafusion) to execute (part of) a -query plan by a remote execution engine. - - ┌────────────────┐ - ┌────────────┐ │ Remote DBMS(s) │ - SQL Query ───> │ DataFusion │ ───> │ ( execution │ - └────────────┘ │ happens here ) │ - └────────────────┘ - -The goal is to allow resolving queries across remote query engines while -pushing down as much compute as possible to the remote database(s). This allows -execution to happen as close to the storage as possible. This concept is -referred to as 'query federation'. - -> [!TIP] -> This repository implements the federation framework itself. If you want to -> connect to a specific database, check out the compatible providers available -> in -> [datafusion-contrib/datafusion-table-providers](https://github.com/datafusion-contrib/datafusion-table-providers/). - -## Usage - -Check out the [examples](./datafusion-federation/examples/) to get a feel for -how it works. - -For a complete step-by-step example of how federation works, you can check the -example [here](./datafusion-federation/examples/df-csv-advanced.rs). - -## Potential use-cases: - -- Querying across SQLite, MySQL, PostgreSQL, ... -- Pushing down SQL or [Substrait](https://substrait.io/) plans. -- DataFusion -> Flight SQL -> DataFusion -- .. - -## Design concept - -Say you have a query plan as follows: - - ┌────────────┐ - │ Join │ - └────────────┘ - ▲ - ┌───────┴────────┐ - ┌────────────┐ ┌────────────┐ - │ Scan A │ │ Join │ - └────────────┘ └────────────┘ - ▲ - ┌───────┴────────┐ - ┌────────────┐ ┌────────────┐ - │ Scan B │ │ Scan C │ - └────────────┘ └────────────┘ - -DataFusion Federation will identify the largest possible sub-plans that -can be executed by an external database: - - ┌────────────┐ Optimizer recognizes - │ Join │ that B and C are - └────────────┘ available in an - ▲ external database - ┌──────────────┴────────┐ - │ ┌ ─ ─ ─ ─ ─ ─ ┴ ─ ── ─ ─ ─ ─ ─┐ - ┌────────────┐ ┌────────────┐ │ - │ Scan A │ │ │ Join │ - └────────────┘ └────────────┘ │ - │ ▲ - ┌───────┴────────┐ │ - ┌────────────┐ ┌────────────┐ │ - ││ Scan B │ │ Scan C │ - └────────────┘ └────────────┘ │ - ─ ── ─ ─ ── ─ ─ ─ ─ ─ ─ ─ ── ─ ┘ - -The sub-plans are cut out and replaced by an opaque federation node in the plan: - - ┌────────────┐ - │ Join │ - └────────────┘ Rewritten Plan - ▲ - ┌────────┴───────────┐ - │ │ - ┌────────────┐ ┏━━━━━━━━━━━━━━━━━━┓ - │ Scan A │ ┃ Scan B+C ┃ - └────────────┘ ┃ (TableProvider ┃ - ┃ that can execute ┃ - ┃ sub-plan in an ┃ - ┃external database)┃ - ┗━━━━━━━━━━━━━━━━━━┛ - -Different databases may have different query languages and execution -capabilities. To accommodate for this, we allow each 'federation provider' to -self-determine what part of a sub-plan it will actually federate. This is done -by letting each federation provider define its own optimizer rule. When a -sub-plan is 'cut out' of the overall plan, it is first passed the federation -provider's optimizer rule. This optimizer rule determines the part of the plan -that is cut out, based on the execution capabilities of the database it -represents. - -## Implementation - -A remote database is represented by the `FederationProvider` trait. To identify -table scans that are available in the same database, they implement -`FederatedTableSource` trait. This trait allows lookup of the corresponding -`FederationProvider`. - -Identifying sub-plans to federate is done by the `FederationOptimizerRule`. -This rule needs to be registered in your DataFusion SessionState. One easy way -to do this is using `default_session_state`. To do its job, the -`FederationOptimizerRule` currently requires that all TableProviders that need -to be federated are `FederatedTableProviderAdaptor`s. The -`FederatedTableProviderAdaptor` also has a fallback mechanism that allows -implementations to fallback to a 'vanilla' TableProvider in case the -`FederationOptimizerRule` isn't registered. - -The `FederationProvider` can provide a `compute_context`. This allows it to -differentiate between multiple remote execution context of the same type. For -example two different mysql instances, database schemas, access level, etc. The -`FederationProvider` also returns the `Optimizer` that is allows it to -self-determine what part of a sub-plan it can federate. - -The `sql` module implements a generic `FederationProvider` for SQL execution -engines. A specific SQL engine implements the `SQLExecutor` trait for its -engine specific execution. There are a number of compatible providers available -in -[datafusion-contrib/datafusion-table-providers](https://github.com/datafusion-contrib/datafusion-table-providers/). - -## Status - -The project is in alpha status. Contributions welcome; land a PR = commit -access. - -- [Docs (release)](https://docs.rs/datafusion-federation) -- [Docs (main)](https://datafusion-contrib.github.io/datafusion-federation/) +# DataFusion Flight SQL Server + +The `datafusion-flight-sql-server` is a Flight SQL server that implements the +necessary endpoints to use DataFusion as the query engine. + +## Getting Started + +To use `datafusion-flight-sql-server` in your Rust project, run: + +```sh +$ cargo add datafusion-flight-sql-server +``` + +## Example + +Here's a basic example of setting up a Flight SQL server: + +```rust +use datafusion_flight_sql_server::service::FlightSqlService; +use datafusion::{ + execution::{ + context::SessionContext, + options::CsvReadOptions, + }, +}; + +async { + let dsn: String = "0.0.0.0:50051".to_string(); + let remote_ctx = SessionContext::new(); + remote_ctx + .register_csv("test", "./examples/test.csv", CsvReadOptions::new()) + .await.expect("Register csv"); + + FlightSqlService::new(remote_ctx.state()).serve(dsn.clone()) + .await + .expect("Run flight sql service"); + +}; +``` + +This example sets up a Flight SQL server listening on `127.0.0.1:50051`. + + +# Acknowledgments + +This repository was a Rust crate that was first built as a part of +[datafusion-federation](https://github.com/datafusion-contrib/datafusion-federation/) +repository. + +For more details about the original repository, please visit +[datafusion-federation](https://github.com/datafusion-contrib/datafusion-federation/). + diff --git a/datafusion-federation/CHANGELOG.md b/datafusion-federation/CHANGELOG.md deleted file mode 100644 index ba3d0a6..0000000 --- a/datafusion-federation/CHANGELOG.md +++ /dev/null @@ -1,32 +0,0 @@ -# Changelog - -All notable changes to this project will be documented in this file. - -The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), -and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - -## [Unreleased] - -## [0.3.5](https://github.com/datafusion-contrib/datafusion-federation/compare/datafusion-federation-v0.3.4...datafusion-federation-v0.3.5) - 2025-01-20 - -### Other - -- Use the Dialect and Unparser constructor when using the plan_to_sql function. (#105) - -## [0.3.4](https://github.com/datafusion-contrib/datafusion-federation/compare/datafusion-federation-v0.3.3...datafusion-federation-v0.3.4) - 2025-01-12 - -### Other - -- upgrade datafusion to 44 (#103) - -## [0.3.3](https://github.com/datafusion-contrib/datafusion-federation/compare/datafusion-federation-v0.3.2...datafusion-federation-v0.3.3) - 2025-01-04 - -### Fixed - -- handle `LogicalPlan::Limit` separately to preserve skip and offset in `rewrite_table_scans` (#101) - -## [0.3.2](https://github.com/datafusion-contrib/datafusion-federation/compare/datafusion-federation-v0.3.1...datafusion-federation-v0.3.2) - 2024-12-05 - -### Other - -- Release plz action: install required dependencies ([#85](https://github.com/datafusion-contrib/datafusion-federation/pull/85)) diff --git a/datafusion-federation/Cargo.toml b/datafusion-federation/Cargo.toml deleted file mode 100644 index 8844cd2..0000000 --- a/datafusion-federation/Cargo.toml +++ /dev/null @@ -1,43 +0,0 @@ -[package] -name = "datafusion-federation" -version.workspace = true -edition.workspace = true -license.workspace = true -readme.workspace = true -repository.workspace = true -description = "Datafusion federation." - -[lib] -name = "datafusion_federation" -path = "src/lib.rs" - -[package.metadata.docs.rs] -# Whether to pass `--all-features` to Cargo (default: false) -all-features = true -# Whether to pass `--no-default-features` to Cargo (default: false) -no-default-features = true - -[features] -sql = [] - -[dependencies] -futures.workspace = true -async-trait.workspace = true -datafusion.workspace = true -async-stream.workspace = true -arrow-json.workspace = true - -[dev-dependencies] -tokio.workspace = true -tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } -tracing = "0.1.40" - -[[example]] -name = "df-csv" -path = "examples/df-csv.rs" -required-features = ["sql"] - -[[example]] -name = "df-csv-advanced" -path = "examples/df-csv-advanced.rs" -required-features = ["sql"] diff --git a/datafusion-federation/examples/data/test.csv b/datafusion-federation/examples/data/test.csv deleted file mode 100644 index 62d0b11..0000000 --- a/datafusion-federation/examples/data/test.csv +++ /dev/null @@ -1,4 +0,0 @@ -foo,bar -a,1 -b,2 -c,3 diff --git a/datafusion-federation/examples/data/test2.csv b/datafusion-federation/examples/data/test2.csv deleted file mode 100644 index 7196406..0000000 --- a/datafusion-federation/examples/data/test2.csv +++ /dev/null @@ -1,7 +0,0 @@ -foo,bar -a,1 -b,2 -c,3 -d,4 -e,5 -f,6 diff --git a/datafusion-federation/examples/df-csv-advanced.rs b/datafusion-federation/examples/df-csv-advanced.rs deleted file mode 100644 index e7b709f..0000000 --- a/datafusion-federation/examples/df-csv-advanced.rs +++ /dev/null @@ -1,148 +0,0 @@ -mod shared; - -use std::sync::Arc; - -use datafusion::{ - execution::{ - context::SessionContext, options::CsvReadOptions, session_state::SessionStateBuilder, - }, - optimizer::Optimizer, -}; - -use datafusion_federation::{ - sql::{MultiSchemaProvider, SQLFederationProvider, SQLSchemaProvider}, - FederatedQueryPlanner, FederationOptimizerRule, -}; - -use shared::{overwrite_default_schema, MockPostgresExecutor, MockSqliteExecutor}; - -const CSV_PATH_SQLITE: &str = "./examples/data/test.csv"; -const CSV_PATH_POSTGRES: &str = "./examples/data/test2.csv"; -const TABLE_NAME_SQLITE: &str = "test_sqlite"; -const TABLE_NAME_POSTGRES: &str = "test_pg"; - -#[tokio::main] -async fn main() { - // This example demonstrates how DataFusion, with federation enabled, to - // executes a query using two execution engines. - // - // The query used in this example is: - // - // ```sql - // SELECT t.* - // FROM test_pg AS t - // JOIN test_sqlite AS a - // ON t.foo = a.foo - // ``` - // - // In this query, `test_pg` is a table in a PostgreSQL database, and `test_sqlite` is a table - // in an SQLite database. DataFusion Federation will identify the sub-plans that can be - // executed by external databases. In this example, there will be only two sub-plans. - // - // ┌────────────┐ - // │ Join │ - // └────────────┘ - // ▲ - // ┌───────┴──────────┐ - // ┌──────────────┐ ┌────────────┐ - // │ test_sqlite │ │ Join │ - // └──────────────┘ └────────────┘ - // ▲ - // | - // ┌────────────┐ - // │ test_pg │ - // └────────────┘ - // - // Note: For the purpose of this example, both the SQLite and PostgreSQL engines are dummy - // engines that use DataFusion SessionContexts with registered CSV files. However, this setup - // works fine for demonstration purposes. If you'd like to use actual SQLite and PostgreSQL - // engines, you can check out the table-providers repository at - // https://github.com/datafusion-contrib/datafusion-table-providers/. - - ///////////////////// - // Remote sqlite DB - ///////////////////// - // Create a datafusion::SessionContext and register a csv file as a table in that context - // This will be passed to the MockSqliteExecutor and acts as a dummy sqlite engine. - let sqlite_remote_ctx = Arc::new(SessionContext::new()); - // Registers a CSV file - sqlite_remote_ctx - .register_csv(TABLE_NAME_SQLITE, CSV_PATH_SQLITE, CsvReadOptions::new()) - .await - .expect("Register csv file"); - - let sqlite_known_tables: Vec = [TABLE_NAME_SQLITE].iter().map(|&x| x.into()).collect(); - - // Create the federation provider - let sqlite_executor = Arc::new(MockSqliteExecutor::new(sqlite_remote_ctx)); - let sqlite_federation_provider = Arc::new(SQLFederationProvider::new(sqlite_executor)); - // Create the schema provider - let sqlite_schema_provider = Arc::new( - SQLSchemaProvider::new_with_tables(sqlite_federation_provider, sqlite_known_tables) - .await - .expect("Create new schema provider with tables"), - ); - - ///////////////////// - // Remote postgres DB - ///////////////////// - // Create a datafusion::SessionContext and register a csv file as a table in that context - // This will be passed to the MockPostgresExecutor and acts as a dummy postgres engine. - let postgres_remote_ctx = Arc::new(SessionContext::new()); - // Registers a CSV file - postgres_remote_ctx - .register_csv( - TABLE_NAME_POSTGRES, - CSV_PATH_POSTGRES, - CsvReadOptions::new(), - ) - .await - .expect("Register csv file"); - - let postgres_known_tables: Vec = - [TABLE_NAME_POSTGRES].iter().map(|&x| x.into()).collect(); - - // Create the federation provider - let postgres_executor = Arc::new(MockPostgresExecutor::new(postgres_remote_ctx)); - let postgres_federation_provider = Arc::new(SQLFederationProvider::new(postgres_executor)); - // Create the schema provider - let postgres_schema_provider = Arc::new( - SQLSchemaProvider::new_with_tables(postgres_federation_provider, postgres_known_tables) - .await - .expect("Create new schema provider with tables"), - ); - - ///////////////////// - // Main(local) DB - ///////////////////// - // Get the default optimizer rules - let mut rules = Optimizer::new().rules; - - // Create a new federation optimizer rule and add it to the default rules - rules.push(Arc::new(FederationOptimizerRule::new())); - - // Create a new SessionState with the optimizer rule we created above - let state = SessionStateBuilder::new() - .with_optimizer_rules(rules) - .with_query_planner(Arc::new(FederatedQueryPlanner::new())) - .build(); - - // Replace the default schema for the main context with the schema providers - // from the remote DBs - let schema_provider = - MultiSchemaProvider::new(vec![sqlite_schema_provider, postgres_schema_provider]); - overwrite_default_schema(&state, Arc::new(schema_provider)) - .expect("Overwrite the default schema form the main context"); - - // Create the session context for the main db - let ctx = SessionContext::new_with_state(state); - - // Run a query - let query = r#"SELECT t.* FROM test_pg as t join test_sqlite as a ON t.foo = a.foo"#; - let df = ctx - .sql(query) - .await - .expect("Create a dataframe from sql query"); - - df.show().await.expect("Execute the dataframe"); -} diff --git a/datafusion-federation/examples/df-csv.rs b/datafusion-federation/examples/df-csv.rs deleted file mode 100644 index c71c6ab..0000000 --- a/datafusion-federation/examples/df-csv.rs +++ /dev/null @@ -1,45 +0,0 @@ -mod shared; - -use std::sync::Arc; - -use datafusion::{ - error::Result, - execution::{context::SessionContext, options::CsvReadOptions}, -}; -use datafusion_federation::sql::{SQLFederationProvider, SQLSchemaProvider}; - -const CSV_PATH: &str = "./examples/data/test.csv"; -const TABLE_NAME: &str = "test"; - -use shared::{overwrite_default_schema, MockSqliteExecutor}; - -#[tokio::main] -async fn main() -> Result<()> { - // Create a remote context for the mock sqlite DB - let remote_ctx = Arc::new(SessionContext::new()); - - // Registers a CSV file - remote_ctx - .register_csv(TABLE_NAME, CSV_PATH, CsvReadOptions::new()) - .await?; - let known_tables: Vec = [TABLE_NAME].iter().map(|&x| x.into()).collect(); - - // Create the federation provider - let executor = Arc::new(MockSqliteExecutor::new(remote_ctx)); - let provider = Arc::new(SQLFederationProvider::new(executor)); - - // Get the schema - let schema_provider = - Arc::new(SQLSchemaProvider::new_with_tables(provider, known_tables).await?); - - // Main context - let state = datafusion_federation::default_session_state(); - overwrite_default_schema(&state, schema_provider)?; - let ctx = SessionContext::new_with_state(state); - - // Run a query - let query = r#"SELECT * FROM test"#; - let df = ctx.sql(query).await?; - - df.show().await -} diff --git a/datafusion-federation/examples/shared/mod.rs b/datafusion-federation/examples/shared/mod.rs deleted file mode 100644 index 4b19990..0000000 --- a/datafusion-federation/examples/shared/mod.rs +++ /dev/null @@ -1,136 +0,0 @@ -use std::sync::Arc; - -use async_trait::async_trait; -use datafusion::{ - arrow::datatypes::SchemaRef, - catalog::SchemaProvider, - error::{DataFusionError, Result}, - execution::context::{SessionContext, SessionState}, - physical_plan::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream}, - sql::unparser::dialect::{DefaultDialect, Dialect}, -}; -use futures::TryStreamExt; - -use datafusion_federation::sql::SQLExecutor; - -pub fn overwrite_default_schema( - state: &SessionState, - schema: Arc, -) -> Result<()> { - let options = &state.config().options().catalog; - let catalog = state - .catalog_list() - .catalog(options.default_catalog.as_str()) - .unwrap(); - - catalog.register_schema(options.default_schema.as_str(), schema)?; - Ok(()) -} - -pub struct MockSqliteExecutor { - session: Arc, -} - -impl MockSqliteExecutor { - pub fn new(session: Arc) -> Self { - Self { session } - } -} - -#[async_trait] -impl SQLExecutor for MockSqliteExecutor { - fn name(&self) -> &str { - "mock_sqlite_executor" - } - - fn compute_context(&self) -> Option { - // Don't return None here - it will cause incorrect federation with other providers of the - // same name that also have a compute_context of None. - // Instead return a random string that will never match any other provider's context. - Some("sqlite_exec".to_string()) - } - - fn execute(&self, sql: &str, schema: SchemaRef) -> Result { - // Execute it using the remote datafusion session context - let future_stream = _execute(self.session.clone(), sql.to_string()); - let stream = futures::stream::once(future_stream).try_flatten(); - Ok(Box::pin(RecordBatchStreamAdapter::new( - schema.clone(), - stream, - ))) - } - - async fn table_names(&self) -> Result> { - Err(DataFusionError::NotImplemented( - "table inference not implemented".to_string(), - )) - } - - async fn get_table_schema(&self, table_name: &str) -> Result { - let sql = format!("select * from {table_name} limit 1"); - let df = self.session.sql(&sql).await?; - let schema = df.schema().as_arrow().clone(); - Ok(Arc::new(schema)) - } - - fn dialect(&self) -> Arc { - Arc::new(DefaultDialect {}) - } -} - -#[allow(dead_code)] -pub struct MockPostgresExecutor { - session: Arc, -} - -#[allow(dead_code)] -impl MockPostgresExecutor { - pub fn new(session: Arc) -> Self { - Self { session } - } -} - -#[async_trait] -impl SQLExecutor for MockPostgresExecutor { - fn name(&self) -> &str { - "mock_postgres_executor" - } - - fn compute_context(&self) -> Option { - // Don't return None here - it will cause incorrect federation with other providers of the - // same name that also have a compute_context of None. - // Instead return a random string that will never match any other provider's context. - Some("postgres_exec".to_string()) - } - - fn execute(&self, sql: &str, schema: SchemaRef) -> Result { - // Execute it using the remote datafusion session context - let future_stream = _execute(self.session.clone(), sql.to_string()); - let stream = futures::stream::once(future_stream).try_flatten(); - Ok(Box::pin(RecordBatchStreamAdapter::new( - schema.clone(), - stream, - ))) - } - - async fn table_names(&self) -> Result> { - Err(DataFusionError::NotImplemented( - "table inference not implemented".to_string(), - )) - } - - async fn get_table_schema(&self, table_name: &str) -> Result { - let sql = format!("select * from {table_name} limit 1"); - let df = self.session.sql(&sql).await?; - let schema = df.schema().as_arrow().clone(); - Ok(Arc::new(schema)) - } - - fn dialect(&self) -> Arc { - Arc::new(DefaultDialect {}) - } -} - -async fn _execute(ctx: Arc, sql: String) -> Result { - ctx.sql(&sql).await?.execute_stream().await -} diff --git a/datafusion-federation/src/lib.rs b/datafusion-federation/src/lib.rs deleted file mode 100644 index 377b469..0000000 --- a/datafusion-federation/src/lib.rs +++ /dev/null @@ -1,90 +0,0 @@ -mod optimizer; -mod plan_node; -pub mod schema_cast; -#[cfg(feature = "sql")] -pub mod sql; -mod table_provider; - -use std::{ - fmt, - hash::{Hash, Hasher}, - sync::Arc, -}; - -use datafusion::{ - execution::session_state::{SessionState, SessionStateBuilder}, - optimizer::{optimizer::Optimizer, OptimizerRule}, -}; - -pub use optimizer::{get_table_source, FederationOptimizerRule}; -pub use plan_node::{ - FederatedPlanNode, FederatedPlanner, FederatedQueryPlanner, FederationPlanner, -}; -pub use table_provider::{FederatedTableProviderAdaptor, FederatedTableSource}; - -pub fn default_session_state() -> SessionState { - let rules = default_optimizer_rules(); - SessionStateBuilder::new() - .with_optimizer_rules(rules) - .with_query_planner(Arc::new(FederatedQueryPlanner::new())) - .with_default_features() - .build() -} - -pub fn default_optimizer_rules() -> Vec> { - // Get the default optimizer - let df_default = Optimizer::new(); - let mut default_rules = df_default.rules; - - // Insert the FederationOptimizerRule after the ScalarSubqueryToJoin. - // This ensures ScalarSubquery are replaced before we try to federate. - let Some(pos) = default_rules - .iter() - .position(|x| x.name() == "scalar_subquery_to_join") - else { - panic!("Could not locate ScalarSubqueryToJoin"); - }; - - // TODO: check if we should allow other optimizers to run before the federation rule. - - let federation_rule = Arc::new(FederationOptimizerRule::new()); - default_rules.insert(pos + 1, federation_rule); - - default_rules -} - -pub type FederationProviderRef = Arc; -pub trait FederationProvider: Send + Sync { - // Returns the name of the provider, used for comparison. - fn name(&self) -> &str; - - // Returns the compute context in which this federation provider - // will execute a query. For example: database instance & catalog. - fn compute_context(&self) -> Option; - - // Returns an optimizer that can cut out part of the plan - // to federate it. - fn optimizer(&self) -> Option>; -} - -impl fmt::Display for dyn FederationProvider { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{} {:?}", self.name(), self.compute_context()) - } -} - -impl PartialEq for dyn FederationProvider { - /// Comparing name, args and return_type - fn eq(&self, other: &dyn FederationProvider) -> bool { - self.name() == other.name() && self.compute_context() == other.compute_context() - } -} - -impl Hash for dyn FederationProvider { - fn hash(&self, state: &mut H) { - self.name().hash(state); - self.compute_context().hash(state); - } -} - -impl Eq for dyn FederationProvider {} diff --git a/datafusion-federation/src/optimizer/mod.rs b/datafusion-federation/src/optimizer/mod.rs deleted file mode 100644 index 9c16dfc..0000000 --- a/datafusion-federation/src/optimizer/mod.rs +++ /dev/null @@ -1,368 +0,0 @@ -mod scan_result; - -use std::sync::Arc; - -use datafusion::{ - common::not_impl_err, - common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}, - datasource::source_as_provider, - error::Result, - logical_expr::{Expr, Extension, LogicalPlan, Projection, TableScan, TableSource}, - optimizer::optimizer::{Optimizer, OptimizerConfig, OptimizerRule}, -}; - -use crate::{ - FederatedTableProviderAdaptor, FederatedTableSource, FederationProvider, FederationProviderRef, -}; - -use scan_result::ScanResult; - -/// An optimizer rule to identifying sub-plans to federate -/// -/// The optimizer logic walks over the plan, look for the largest subtrees that only have -/// TableScans from the same [`FederationProvider`]. There 'largest sub-trees' are passed to their -/// respective [`FederationProvider::optimizer`]. -#[derive(Default, Debug)] -pub struct FederationOptimizerRule {} - -impl OptimizerRule for FederationOptimizerRule { - /// Try to rewrite `plan` to an optimized form, returning `Transformed::yes` - /// if the plan was rewritten and `Transformed::no` if it was not. - /// - /// Note: this function is only called if [`Self::supports_rewrite`] returns - /// true. Otherwise the Optimizer calls [`Self::try_optimize`] - fn rewrite( - &self, - plan: LogicalPlan, - config: &dyn OptimizerConfig, - ) -> Result> { - match self.optimize_plan_recursively(&plan, true, config)? { - (Some(optimized_plan), _) => Ok(Transformed::yes(optimized_plan)), - (None, _) => Ok(Transformed::no(plan)), - } - } - - /// Does this rule support rewriting owned plans (rather than by reference)? - fn supports_rewrite(&self) -> bool { - true - } - - /// A human readable name for this optimizer rule - fn name(&self) -> &str { - "federation_optimizer_rule" - } -} - -impl FederationOptimizerRule { - /// Creates a new [`FederationOptimizerRule`] - pub fn new() -> Self { - Self::default() - } - - /// Scans a plan to see if it belongs to a single [`FederationProvider`]. - fn scan_plan_recursively(&self, plan: &LogicalPlan) -> Result { - let mut sole_provider: ScanResult = ScanResult::None; - - plan.apply(&mut |p: &LogicalPlan| -> Result { - let exprs_provider = self.scan_plan_exprs(p)?; - sole_provider.merge(exprs_provider); - - if sole_provider.is_ambiguous() { - return Ok(TreeNodeRecursion::Stop); - } - - let sub_provider = get_leaf_provider(p)?; - sole_provider.add(sub_provider); - - Ok(sole_provider.check_recursion()) - })?; - - Ok(sole_provider) - } - - /// Scans a plan's expressions to see if it belongs to a single [`FederationProvider`]. - fn scan_plan_exprs(&self, plan: &LogicalPlan) -> Result { - let mut sole_provider: ScanResult = ScanResult::None; - - let exprs = plan.expressions(); - for expr in &exprs { - let expr_result = self.scan_expr_recursively(expr)?; - sole_provider.merge(expr_result); - - if sole_provider.is_ambiguous() { - return Ok(sole_provider); - } - } - - Ok(sole_provider) - } - - /// scans an expression to see if it belongs to a single [`FederationProvider`] - fn scan_expr_recursively(&self, expr: &Expr) -> Result { - let mut sole_provider: ScanResult = ScanResult::None; - - expr.apply(&mut |e: &Expr| -> Result { - // TODO: Support other types of sub-queries - match e { - Expr::ScalarSubquery(ref subquery) => { - let plan_result = self.scan_plan_recursively(&subquery.subquery)?; - - sole_provider.merge(plan_result); - Ok(sole_provider.check_recursion()) - } - Expr::InSubquery(_) => not_impl_err!("InSubquery"), - Expr::OuterReferenceColumn(..) => { - // Subqueries that reference outer columns are not supported - // for now. We handle this here as ambiguity to force - // federation lower in the plan tree. - sole_provider = ScanResult::Ambiguous; - Ok(TreeNodeRecursion::Stop) - } - _ => Ok(TreeNodeRecursion::Continue), - } - })?; - - Ok(sole_provider) - } - - /// Recursively finds the largest sub-plans that can be federated - /// to a single FederationProvider. - /// - /// Returns a plan if a sub-tree was federated, otherwise None. - /// - /// Returns a ScanResult of all FederationProviders in the subtree. - fn optimize_plan_recursively( - &self, - plan: &LogicalPlan, - is_root: bool, - _config: &dyn OptimizerConfig, - ) -> Result<(Option, ScanResult)> { - let mut sole_provider: ScanResult = ScanResult::None; - - if let LogicalPlan::Extension(Extension { ref node }) = plan { - if node.name() == "Federated" { - // Avoid attempting double federation - return Ok((None, ScanResult::Ambiguous)); - } - } - - // Check if this plan node is a leaf that determines the FederationProvider - let leaf_provider = get_leaf_provider(plan)?; - - // Check if the expressions contain, a potentially different, FederationProvider - let exprs_result = self.scan_plan_exprs(plan)?; - let optimize_expressions = exprs_result.is_some(); - - // Return early if this is a leaf and there is no ambiguity with the expressions. - if leaf_provider.is_some() && (exprs_result.is_none() || exprs_result == leaf_provider) { - return Ok((None, leaf_provider.into())); - } - // Aggregate leaf & expression providers - sole_provider.add(leaf_provider); - sole_provider.merge(exprs_result); - - let inputs = plan.inputs(); - // Return early if there are no sources. - if inputs.is_empty() && sole_provider.is_none() { - return Ok((None, ScanResult::None)); - } - - // Recursively optimize inputs - let input_results = inputs - .iter() - .map(|i| self.optimize_plan_recursively(i, false, _config)) - .collect::>>()?; - - // Aggregate the input providers - input_results.iter().for_each(|(_, scan_result)| { - sole_provider.merge(scan_result.clone()); - }); - - if sole_provider.is_none() { - // No providers found - // TODO: Is/should this be reachable? - return Ok((None, ScanResult::None)); - } - - // If all sources are federated to the same provider - if let ScanResult::Distinct(provider) = sole_provider { - if !is_root { - // The largest sub-plan is higher up. - return Ok((None, ScanResult::Distinct(provider))); - } - - let Some(optimizer) = provider.optimizer() else { - // No optimizer provided - return Ok((None, ScanResult::None)); - }; - - // If this is the root plan node; federate the entire plan - let optimized = optimizer.optimize(plan.clone(), _config, |_, _| {})?; - return Ok((Some(optimized), ScanResult::None)); - } - - // The plan is ambiguous; any input that is not yet optimized and has a - // sole provider represents a largest sub-plan and should be federated. - // - // We loop over the input optimization results, federate where needed and - // return a complete list of new inputs for the optimized plan. - let new_inputs = input_results - .into_iter() - .enumerate() - .map(|(i, (input_plan, input_result))| { - if let Some(federated_plan) = input_plan { - // Already federated deeper in the plan tree - return Ok(federated_plan); - } - - let original_input = (*inputs.get(i).unwrap()).clone(); - if input_result.is_ambiguous() { - // Can happen if the input is already federated, so use - // the original input. - return Ok(original_input); - } - - let provider = input_result.unwrap(); - let Some(provider) = provider else { - // No provider for this input; use the original input. - return Ok(original_input); - }; - - let Some(optimizer) = provider.optimizer() else { - // No optimizer for this input; use the original input. - return Ok(original_input); - }; - - // Replace the input with the federated counterpart - let wrapped = wrap_projection(original_input)?; - let optimized = optimizer.optimize(wrapped, _config, |_, _| {})?; - - Ok(optimized) - }) - .collect::>>()?; - - // Optimize expressions if needed - let new_expressions = if optimize_expressions { - self.optimize_plan_exprs(plan, _config)? - } else { - plan.expressions() - }; - - // Construct the optimized plan - let new_plan = plan.with_new_exprs(new_expressions, new_inputs)?; - - // Return the federated plan - Ok((Some(new_plan), ScanResult::Ambiguous)) - } - - /// Optimizes all exprs of a plan - fn optimize_plan_exprs( - &self, - plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - plan.expressions() - .iter() - .map(|expr| { - let transformed = expr - .clone() - .transform(&|e| self.optimize_expr_recursively(e, _config))?; - Ok(transformed.data) - }) - .collect::>>() - } - - /// recursively optimize expressions - /// Current logic: individually federate every sub-query. - fn optimize_expr_recursively( - &self, - expr: Expr, - _config: &dyn OptimizerConfig, - ) -> Result> { - match expr { - Expr::ScalarSubquery(ref subquery) => { - // Optimize as root to force federating the sub-query - let (new_subquery, _) = - self.optimize_plan_recursively(&subquery.subquery, true, _config)?; - let Some(new_subquery) = new_subquery else { - return Ok(Transformed::no(expr)); - }; - Ok(Transformed::yes(Expr::ScalarSubquery( - subquery.with_plan(new_subquery.into()), - ))) - } - Expr::InSubquery(_) => not_impl_err!("InSubquery"), - _ => Ok(Transformed::no(expr)), - } - } -} - -/// NopFederationProvider is used to represent tables that are not federated, but -/// are resolved by DataFusion. This simplifies the logic of the optimizer rule. -struct NopFederationProvider {} - -impl FederationProvider for NopFederationProvider { - fn name(&self) -> &str { - "nop" - } - - fn compute_context(&self) -> Option { - None - } - - fn optimizer(&self) -> Option> { - None - } -} - -fn get_leaf_provider(plan: &LogicalPlan) -> Result> { - match plan { - LogicalPlan::TableScan(TableScan { ref source, .. }) => { - let Some(federated_source) = get_table_source(source)? else { - // Table is not federated but provided by a standard table provider. - // We use a placeholder federation provider to simplify the logic. - return Ok(Some(Arc::new(NopFederationProvider {}))); - }; - let provider = federated_source.federation_provider(); - Ok(Some(provider)) - } - _ => Ok(None), - } -} - -fn wrap_projection(plan: LogicalPlan) -> Result { - // TODO: minimize requested columns - match plan { - LogicalPlan::Projection(_) => Ok(plan), - _ => { - let expr = plan - .schema() - .columns() - .iter() - .map(|c| Expr::Column(c.clone())) - .collect::>(); - Ok(LogicalPlan::Projection(Projection::try_new( - expr, - Arc::new(plan), - )?)) - } - } -} - -pub fn get_table_source( - source: &Arc, -) -> Result>> { - // Unwrap TableSource - let source = source_as_provider(source)?; - - // Get FederatedTableProviderAdaptor - let Some(wrapper) = source - .as_any() - .downcast_ref::() - else { - return Ok(None); - }; - - // Return original FederatedTableSource - Ok(Some(Arc::clone(&wrapper.source))) -} diff --git a/datafusion-federation/src/optimizer/scan_result.rs b/datafusion-federation/src/optimizer/scan_result.rs deleted file mode 100644 index 639cff4..0000000 --- a/datafusion-federation/src/optimizer/scan_result.rs +++ /dev/null @@ -1,98 +0,0 @@ -use datafusion::common::tree_node::TreeNodeRecursion; - -use crate::FederationProviderRef; - -/// Used to track if all sources, including tableScan, plan inputs and -/// expressions, represents an un-ambiguous, none or a sole' [`crate::FederationProvider`]. -pub enum ScanResult { - None, - Distinct(FederationProviderRef), - Ambiguous, -} - -impl ScanResult { - pub fn merge(&mut self, other: Self) { - match (&self, &other) { - (_, ScanResult::None) => {} - (ScanResult::None, _) => *self = other, - (ScanResult::Ambiguous, _) | (_, ScanResult::Ambiguous) => { - *self = ScanResult::Ambiguous - } - (ScanResult::Distinct(provider), ScanResult::Distinct(other_provider)) => { - if provider != other_provider { - *self = ScanResult::Ambiguous - } - } - } - } - - pub fn add(&mut self, provider: Option) { - self.merge(ScanResult::from(provider)) - } - - pub fn is_ambiguous(&self) -> bool { - matches!(self, ScanResult::Ambiguous) - } - - pub fn is_none(&self) -> bool { - matches!(self, ScanResult::None) - } - pub fn is_some(&self) -> bool { - !self.is_none() - } - - pub fn unwrap(self) -> Option { - match self { - ScanResult::None => None, - ScanResult::Distinct(provider) => Some(provider), - ScanResult::Ambiguous => panic!("called `ScanResult::unwrap()` on a `Ambiguous` value"), - } - } - - pub fn check_recursion(&self) -> TreeNodeRecursion { - if self.is_ambiguous() { - TreeNodeRecursion::Stop - } else { - TreeNodeRecursion::Continue - } - } -} - -impl From> for ScanResult { - fn from(provider: Option) -> Self { - match provider { - Some(provider) => ScanResult::Distinct(provider), - None => ScanResult::None, - } - } -} - -impl PartialEq> for ScanResult { - fn eq(&self, other: &Option) -> bool { - match (self, other) { - (ScanResult::None, None) => true, - (ScanResult::Distinct(provider), Some(other_provider)) => provider == other_provider, - _ => false, - } - } -} - -impl std::fmt::Debug for ScanResult { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::None => write!(f, "ScanResult::None"), - Self::Distinct(provider) => write!(f, "ScanResult::Distinct({})", provider.name()), - Self::Ambiguous => write!(f, "ScanResult::Ambiguous"), - } - } -} - -impl Clone for ScanResult { - fn clone(&self) -> Self { - match self { - ScanResult::None => ScanResult::None, - ScanResult::Distinct(provider) => ScanResult::Distinct(provider.clone()), - ScanResult::Ambiguous => ScanResult::Ambiguous, - } - } -} diff --git a/datafusion-federation/src/plan_node.rs b/datafusion-federation/src/plan_node.rs deleted file mode 100644 index 0647f9d..0000000 --- a/datafusion-federation/src/plan_node.rs +++ /dev/null @@ -1,172 +0,0 @@ -use core::fmt; -use std::{ - fmt::Debug, - hash::{Hash, Hasher}, - sync::Arc, -}; - -use async_trait::async_trait; -use datafusion::{ - common::DFSchemaRef, - error::{DataFusionError, Result}, - execution::context::{QueryPlanner, SessionState}, - logical_expr::{Expr, LogicalPlan, UserDefinedLogicalNode, UserDefinedLogicalNodeCore}, - physical_plan::ExecutionPlan, - physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner}, -}; - -pub struct FederatedPlanNode { - plan: LogicalPlan, - planner: Arc, -} - -impl FederatedPlanNode { - pub fn new(plan: LogicalPlan, planner: Arc) -> Self { - Self { plan, planner } - } - - pub fn plan(&self) -> &LogicalPlan { - &self.plan - } -} - -impl Debug for FederatedPlanNode { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - UserDefinedLogicalNodeCore::fmt_for_explain(self, f) - } -} - -impl UserDefinedLogicalNodeCore for FederatedPlanNode { - fn name(&self) -> &str { - "Federated" - } - - fn inputs(&self) -> Vec<&LogicalPlan> { - Vec::new() - } - - fn schema(&self) -> &DFSchemaRef { - self.plan.schema() - } - - fn expressions(&self) -> Vec { - Vec::new() - } - - fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Federated\n {}", self.plan) - } - - fn with_exprs_and_inputs(&self, exprs: Vec, inputs: Vec) -> Result { - if !inputs.is_empty() { - return Err(DataFusionError::Plan("input size inconsistent".into())); - } - if !exprs.is_empty() { - return Err(DataFusionError::Plan("expression size inconsistent".into())); - } - - Ok(Self { - plan: self.plan.clone(), - planner: self.planner.clone(), - }) - } -} - -#[derive(Default, Debug)] -pub struct FederatedQueryPlanner {} - -impl FederatedQueryPlanner { - pub fn new() -> Self { - Self::default() - } -} - -#[async_trait] -impl QueryPlanner for FederatedQueryPlanner { - async fn create_physical_plan( - &self, - logical_plan: &LogicalPlan, - session_state: &SessionState, - ) -> Result> { - // Get provider here? - - let physical_planner = - DefaultPhysicalPlanner::with_extension_planners(vec![ - Arc::new(FederatedPlanner::new()), - ]); - physical_planner - .create_physical_plan(logical_plan, session_state) - .await - } -} - -#[async_trait] -pub trait FederationPlanner: Send + Sync { - async fn plan_federation( - &self, - node: &FederatedPlanNode, - session_state: &SessionState, - ) -> Result>; -} - -impl std::fmt::Debug for dyn FederationPlanner { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "FederationPlanner") - } -} - -impl PartialEq for FederatedPlanNode { - /// Comparing name, args and return_type - fn eq(&self, other: &FederatedPlanNode) -> bool { - self.plan == other.plan - } -} - -impl PartialOrd for FederatedPlanNode { - fn partial_cmp(&self, other: &FederatedPlanNode) -> Option { - self.plan.partial_cmp(&other.plan) - } -} - -impl Eq for FederatedPlanNode {} - -impl Hash for FederatedPlanNode { - fn hash(&self, state: &mut H) { - self.plan.hash(state); - } -} - -#[derive(Default)] -pub struct FederatedPlanner {} - -impl FederatedPlanner { - pub fn new() -> Self { - Self::default() - } -} - -#[async_trait] -impl ExtensionPlanner for FederatedPlanner { - async fn plan_extension( - &self, - _planner: &dyn PhysicalPlanner, - node: &dyn UserDefinedLogicalNode, - logical_inputs: &[&LogicalPlan], - physical_inputs: &[Arc], - session_state: &SessionState, - ) -> Result>> { - let dc_node = node.as_any().downcast_ref::(); - if let Some(fed_node) = dc_node { - if !logical_inputs.is_empty() || !physical_inputs.is_empty() { - return Err(DataFusionError::Plan( - "Inconsistent number of inputs".into(), - )); - } - - let fed_planner = Arc::clone(&fed_node.planner); - let exec_plan = fed_planner.plan_federation(fed_node, session_state).await?; - return Ok(Some(exec_plan)); - } - Ok(None) - } -} diff --git a/datafusion-federation/src/schema_cast/intervals_cast.rs b/datafusion-federation/src/schema_cast/intervals_cast.rs deleted file mode 100644 index 5fbd806..0000000 --- a/datafusion-federation/src/schema_cast/intervals_cast.rs +++ /dev/null @@ -1,190 +0,0 @@ -use datafusion::arrow::{ - array::{ - Array, ArrayRef, IntervalDayTimeBuilder, IntervalMonthDayNanoArray, - IntervalYearMonthBuilder, - }, - datatypes::{IntervalDayTimeType, IntervalYearMonthType}, - error::ArrowError, -}; -use std::sync::Arc; - -pub(crate) fn cast_interval_monthdaynano_to_yearmonth( - interval_monthdaynano_array: &dyn Array, -) -> Result { - let interval_monthdaynano_array = interval_monthdaynano_array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - ArrowError::CastError("Failed to cast IntervalMonthDayNanoArray: Unable to downcast to IntervalMonthDayNanoArray".to_string()) - })?; - - let mut interval_yearmonth_builder = - IntervalYearMonthBuilder::with_capacity(interval_monthdaynano_array.len()); - - for value in interval_monthdaynano_array { - match value { - None => interval_yearmonth_builder.append_null(), - Some(interval_monthdaynano_value) => { - if interval_monthdaynano_value.days != 0 - || interval_monthdaynano_value.nanoseconds != 0 - { - return Err(ArrowError::CastError( - "Failed to cast IntervalMonthDayNanoArray to IntervalYearMonthArray: Non-zero days or nanoseconds".to_string(), - )); - } - interval_yearmonth_builder.append_value(IntervalYearMonthType::make_value( - 0, - interval_monthdaynano_value.months, - )); - } - } - } - - Ok(Arc::new(interval_yearmonth_builder.finish())) -} - -#[allow(clippy::cast_possible_truncation)] -pub(crate) fn cast_interval_monthdaynano_to_daytime( - interval_monthdaynano_array: &dyn Array, -) -> Result { - let interval_monthdaynano_array = interval_monthdaynano_array - .as_any() - .downcast_ref::() - .ok_or_else(|| ArrowError::CastError("Failed to cast IntervalMonthDayNanoArray: Unable to downcast to IntervalMonthDayNanoArray".to_string()))?; - - let mut interval_daytime_builder = - IntervalDayTimeBuilder::with_capacity(interval_monthdaynano_array.len()); - - for value in interval_monthdaynano_array { - match value { - None => interval_daytime_builder.append_null(), - Some(interval_monthdaynano_value) => { - if interval_monthdaynano_value.months != 0 { - return Err( - ArrowError::CastError("Failed to cast IntervalMonthDayNanoArray to IntervalDayTimeArray: Non-zero months".to_string()), - ); - } - interval_daytime_builder.append_value(IntervalDayTimeType::make_value( - interval_monthdaynano_value.days, - (interval_monthdaynano_value.nanoseconds / 1_000_000) as i32, - )); - } - } - } - Ok(Arc::new(interval_daytime_builder.finish())) -} - -#[cfg(test)] -mod test { - use datafusion::arrow::{ - array::{IntervalDayTimeArray, IntervalYearMonthArray, RecordBatch}, - datatypes::{ - DataType, Field, IntervalDayTime, IntervalMonthDayNano, IntervalUnit, Schema, SchemaRef, - }, - }; - - use crate::schema_cast::record_convert::try_cast_to; - - use super::*; - - fn input_schema() -> SchemaRef { - Arc::new(Schema::new(vec![ - Field::new( - "interval_daytime", - DataType::Interval(IntervalUnit::MonthDayNano), - false, - ), - Field::new( - "interval_monthday_nano", - DataType::Interval(IntervalUnit::MonthDayNano), - false, - ), - Field::new( - "interval_yearmonth", - DataType::Interval(IntervalUnit::MonthDayNano), - false, - ), - ])) - } - - fn output_schema() -> SchemaRef { - Arc::new(Schema::new(vec![ - Field::new( - "interval_daytime", - DataType::Interval(IntervalUnit::DayTime), - false, - ), - Field::new( - "interval_monthday_nano", - DataType::Interval(IntervalUnit::MonthDayNano), - false, - ), - Field::new( - "interval_yearmonth", - DataType::Interval(IntervalUnit::YearMonth), - false, - ), - ])) - } - - fn batch_input() -> RecordBatch { - let interval_daytime_array = IntervalMonthDayNanoArray::from(vec![ - IntervalMonthDayNano::new(0, 1, 1_000_000_000), - IntervalMonthDayNano::new(0, 33, 0), - IntervalMonthDayNano::new(0, 0, 43_200_000_000_000), - ]); - let interval_monthday_nano_array = IntervalMonthDayNanoArray::from(vec![ - IntervalMonthDayNano::new(1, 2, 1000), - IntervalMonthDayNano::new(12, 1, 0), - IntervalMonthDayNano::new(0, 0, 12 * 1000 * 1000), - ]); - let interval_yearmonth_array = IntervalMonthDayNanoArray::from(vec![ - IntervalMonthDayNano::new(2, 0, 0), - IntervalMonthDayNano::new(25, 0, 0), - IntervalMonthDayNano::new(-1, 0, 0), - ]); - - RecordBatch::try_new( - input_schema(), - vec![ - Arc::new(interval_daytime_array), - Arc::new(interval_monthday_nano_array), - Arc::new(interval_yearmonth_array), - ], - ) - .expect("Failed to created arrow interval record batch") - } - - fn batch_expected() -> RecordBatch { - let interval_daytime_array = IntervalDayTimeArray::from(vec![ - IntervalDayTime::new(1, 1000), - IntervalDayTime::new(33, 0), - IntervalDayTime::new(0, 12 * 60 * 60 * 1000), - ]); - let interval_monthday_nano_array = IntervalMonthDayNanoArray::from(vec![ - IntervalMonthDayNano::new(1, 2, 1000), - IntervalMonthDayNano::new(12, 1, 0), - IntervalMonthDayNano::new(0, 0, 12 * 1000 * 1000), - ]); - let interval_yearmonth_array = IntervalYearMonthArray::from(vec![2, 25, -1]); - - RecordBatch::try_new( - output_schema(), - vec![ - Arc::new(interval_daytime_array), - Arc::new(interval_monthday_nano_array), - Arc::new(interval_yearmonth_array), - ], - ) - .expect("Failed to created arrow interval record batch") - } - - #[test] - fn test_cast_interval_with_schema() { - let input_batch = batch_input(); - let expected = batch_expected(); - let actual = try_cast_to(input_batch, output_schema()).expect("cast should succeed"); - - assert_eq!(actual, expected); - } -} diff --git a/datafusion-federation/src/schema_cast/lists_cast.rs b/datafusion-federation/src/schema_cast/lists_cast.rs deleted file mode 100644 index 8c07d99..0000000 --- a/datafusion-federation/src/schema_cast/lists_cast.rs +++ /dev/null @@ -1,620 +0,0 @@ -use arrow_json::ReaderBuilder; -use datafusion::arrow::array::{GenericStringArray, OffsetSizeTrait}; -use datafusion::arrow::{ - array::{ - Array, ArrayRef, BooleanArray, BooleanBuilder, FixedSizeListBuilder, Float32Array, - Float32Builder, Float64Array, Float64Builder, Int16Array, Int16Builder, Int32Array, - Int32Builder, Int64Array, Int64Builder, Int8Array, Int8Builder, LargeListBuilder, - LargeStringArray, LargeStringBuilder, ListArray, ListBuilder, StringArray, StringBuilder, - }, - datatypes::{DataType, Field, FieldRef}, - error::ArrowError, -}; -use std::sync::Arc; - -pub type Result = std::result::Result; - -macro_rules! cast_string_to_list_array { - ($string_array:expr, $field_name:expr, $data_type:expr, $builder_type:expr, $primitive_type:ty) => {{ - let item_field = Arc::new(Field::new($field_name, $data_type, true)); - let mut list_builder = ListBuilder::with_capacity($builder_type, $string_array.len()) - .with_field(Arc::clone(&item_field)); - - let list_field = Arc::new(Field::new_list("i", item_field, true)); - let mut decoder = ReaderBuilder::new_with_field(Arc::clone(&list_field)) - .build_decoder() - .map_err(|e| ArrowError::CastError(format!("Failed to create decoder: {e}")))?; - - for value in $string_array { - match value { - None => list_builder.append_null(), - Some(string_value) => { - decoder.decode(string_value.as_bytes()).map_err(|e| { - ArrowError::CastError(format!("Failed to decode value: {e}")) - })?; - - if let Some(b) = decoder.flush().map_err(|e| { - ArrowError::CastError(format!("Failed to decode decoder: {e}")) - })? { - let list_array = b - .column(0) - .as_any() - .downcast_ref::() - .ok_or_else(|| { - ArrowError::CastError( - "Failed to decode value: unable to downcast to ListArray" - .to_string(), - ) - })?; - let primitive_array = list_array - .values() - .as_any() - .downcast_ref::<$primitive_type>() - .ok_or_else(|| { - ArrowError::CastError( - "Failed to decode value: unable to downcast to PrimitiveType" - .to_string(), - ) - })?; - primitive_array - .iter() - .for_each(|maybe_value| match maybe_value { - Some(value) => list_builder.values().append_value(value), - None => list_builder.values().append_null(), - }); - list_builder.append(true); - } - } - } - } - - Ok(Arc::new(list_builder.finish())) - }}; -} - -macro_rules! cast_string_to_large_list_array { - ($string_array:expr, $field_name:expr, $data_type:expr, $builder_type:expr, $primitive_type:ty) => {{ - let item_field = Arc::new(Field::new($field_name, $data_type, true)); - let mut list_builder = LargeListBuilder::with_capacity($builder_type, $string_array.len()) - .with_field(Arc::clone(&item_field)); - - let list_field = Arc::new(Field::new_list("i", item_field, true)); - let mut decoder = ReaderBuilder::new_with_field(Arc::clone(&list_field)) - .build_decoder() - .map_err(|e| ArrowError::CastError(format!("Failed to create decoder: {e}")))?; - - for value in $string_array { - match value { - None => list_builder.append_null(), - Some(string_value) => { - decoder.decode(string_value.as_bytes()).map_err(|e| { - ArrowError::CastError(format!("Failed to decode value: {e}")) - })?; - - if let Some(b) = decoder.flush().map_err(|e| { - ArrowError::CastError(format!("Failed to decode decoder: {e}")) - })? { - let list_array = b - .column(0) - .as_any() - .downcast_ref::() - .ok_or_else(|| { - ArrowError::CastError( - "Failed to decode value: unable to downcast to ListArray" - .to_string(), - ) - })?; - let primitive_array = list_array - .values() - .as_any() - .downcast_ref::<$primitive_type>() - .ok_or_else(|| { - ArrowError::CastError( - "Failed to decode value: unable to downcast to PrimitiveType" - .to_string(), - ) - })?; - primitive_array - .iter() - .for_each(|maybe_value| match maybe_value { - Some(value) => list_builder.values().append_value(value), - None => list_builder.values().append_null(), - }); - list_builder.append(true); - } - } - } - } - - Ok(Arc::new(list_builder.finish())) - }}; -} - -macro_rules! cast_string_to_fixed_size_list_array { - ($string_array:expr, $field_name:expr, $data_type:expr, $builder_type:expr, $primitive_type:ty, $value_length:expr) => {{ - let item_field = Arc::new(Field::new($field_name, $data_type, true)); - let mut list_builder = - FixedSizeListBuilder::with_capacity($builder_type, $value_length, $string_array.len()) - .with_field(Arc::clone(&item_field)); - - let list_field = Arc::new(Field::new_list("i", item_field, true)); - let mut decoder = ReaderBuilder::new_with_field(Arc::clone(&list_field)) - .build_decoder() - .map_err(|e| ArrowError::CastError(format!("Failed to create decoder: {e}")))?; - - for value in $string_array { - match value { - None => { - for _ in 0..$value_length { - list_builder.values().append_null() - } - list_builder.append(true) - } - Some(string_value) => { - decoder.decode(string_value.as_bytes()).map_err(|e| { - ArrowError::CastError(format!("Failed to decode value: {e}")) - })?; - - if let Some(b) = decoder.flush().map_err(|e| { - ArrowError::CastError(format!("Failed to decode decoder: {e}")) - })? { - let list_array = b - .column(0) - .as_any() - .downcast_ref::() - .ok_or_else(|| { - ArrowError::CastError( - "Failed to decode value: unable to downcast to ListArray" - .to_string(), - ) - })?; - let primitive_array = list_array - .values() - .as_any() - .downcast_ref::<$primitive_type>() - .ok_or_else(|| { - ArrowError::CastError( - "Failed to decode value: unable to downcast to PrimitiveType" - .to_string(), - ) - })?; - primitive_array - .iter() - .for_each(|maybe_value| match maybe_value { - Some(value) => list_builder.values().append_value(value), - None => list_builder.values().append_null(), - }); - list_builder.append(true); - } - } - } - } - - Ok(Arc::new(list_builder.finish())) - }}; -} - -pub(crate) fn cast_string_to_list( - array: &dyn Array, - list_item_field: &FieldRef, -) -> Result { - let string_array = array - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - ArrowError::CastError( - "Failed to decode value: unable to downcast to StringArray".to_string(), - ) - })?; - - let field_name = list_item_field.name(); - - match list_item_field.data_type() { - DataType::Utf8 => { - cast_string_to_list_array!( - string_array, - field_name, - DataType::Utf8, - StringBuilder::new(), - StringArray - ) - } - DataType::LargeUtf8 => { - cast_string_to_list_array!( - string_array, - field_name, - DataType::LargeUtf8, - LargeStringBuilder::new(), - LargeStringArray - ) - } - DataType::Boolean => { - cast_string_to_list_array!( - string_array, - field_name, - DataType::Boolean, - BooleanBuilder::new(), - BooleanArray - ) - } - DataType::Int8 => { - cast_string_to_list_array!( - string_array, - field_name, - DataType::Int8, - Int8Builder::new(), - Int8Array - ) - } - DataType::Int16 => { - cast_string_to_list_array!( - string_array, - field_name, - DataType::Int16, - Int16Builder::new(), - Int16Array - ) - } - DataType::Int32 => { - cast_string_to_list_array!( - string_array, - field_name, - DataType::Int32, - Int32Builder::new(), - Int32Array - ) - } - DataType::Int64 => { - cast_string_to_list_array!( - string_array, - field_name, - DataType::Int64, - Int64Builder::new(), - Int64Array - ) - } - DataType::Float32 => { - cast_string_to_list_array!( - string_array, - field_name, - DataType::Float32, - Float32Builder::new(), - Float32Array - ) - } - DataType::Float64 => { - cast_string_to_list_array!( - string_array, - field_name, - DataType::Float64, - Float64Builder::new(), - Float64Array - ) - } - _ => Err(ArrowError::CastError(format!( - "Unsupported list item type: {}", - list_item_field.data_type() - ))), - } -} - -pub(crate) fn cast_string_to_large_list( - array: &dyn Array, - list_item_field: &FieldRef, -) -> Result { - let string_array = array - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - ArrowError::CastError( - "Failed to decode value: unable to downcast to StringArray".to_string(), - ) - })?; - - let field_name = list_item_field.name(); - - match list_item_field.data_type() { - DataType::Utf8 => { - cast_string_to_large_list_array!( - string_array, - field_name, - DataType::Utf8, - StringBuilder::new(), - StringArray - ) - } - DataType::LargeUtf8 => { - cast_string_to_large_list_array!( - string_array, - field_name, - DataType::LargeUtf8, - LargeStringBuilder::new(), - LargeStringArray - ) - } - DataType::Boolean => { - cast_string_to_large_list_array!( - string_array, - field_name, - DataType::Boolean, - BooleanBuilder::new(), - BooleanArray - ) - } - DataType::Int8 => { - cast_string_to_large_list_array!( - string_array, - field_name, - DataType::Int8, - Int8Builder::new(), - Int8Array - ) - } - DataType::Int16 => { - cast_string_to_large_list_array!( - string_array, - field_name, - DataType::Int16, - Int16Builder::new(), - Int16Array - ) - } - DataType::Int32 => { - cast_string_to_large_list_array!( - string_array, - field_name, - DataType::Int32, - Int32Builder::new(), - Int32Array - ) - } - DataType::Int64 => { - cast_string_to_large_list_array!( - string_array, - field_name, - DataType::Int64, - Int64Builder::new(), - Int64Array - ) - } - DataType::Float32 => { - cast_string_to_large_list_array!( - string_array, - field_name, - DataType::Float32, - Float32Builder::new(), - Float32Array - ) - } - DataType::Float64 => { - cast_string_to_large_list_array!( - string_array, - field_name, - DataType::Float64, - Float64Builder::new(), - Float64Array - ) - } - _ => Err(ArrowError::CastError(format!( - "Unsupported list item type: {}", - list_item_field.data_type() - ))), - } -} - -pub(crate) fn cast_string_to_fixed_size_list( - array: &dyn Array, - list_item_field: &FieldRef, - value_length: i32, -) -> Result { - let string_array = array - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - ArrowError::CastError( - "Failed to decode value: unable to downcast to StringArray".to_string(), - ) - })?; - - let field_name = list_item_field.name(); - - match list_item_field.data_type() { - DataType::Utf8 => { - cast_string_to_fixed_size_list_array!( - string_array, - field_name, - DataType::Utf8, - StringBuilder::new(), - StringArray, - value_length - ) - } - DataType::LargeUtf8 => { - cast_string_to_fixed_size_list_array!( - string_array, - field_name, - DataType::LargeUtf8, - LargeStringBuilder::new(), - LargeStringArray, - value_length - ) - } - DataType::Boolean => { - cast_string_to_fixed_size_list_array!( - string_array, - field_name, - DataType::Boolean, - BooleanBuilder::new(), - BooleanArray, - value_length - ) - } - DataType::Int8 => { - cast_string_to_fixed_size_list_array!( - string_array, - field_name, - DataType::Int8, - Int8Builder::new(), - Int8Array, - value_length - ) - } - DataType::Int16 => { - cast_string_to_fixed_size_list_array!( - string_array, - field_name, - DataType::Int16, - Int16Builder::new(), - Int16Array, - value_length - ) - } - DataType::Int32 => { - cast_string_to_fixed_size_list_array!( - string_array, - field_name, - DataType::Int32, - Int32Builder::new(), - Int32Array, - value_length - ) - } - DataType::Int64 => { - cast_string_to_fixed_size_list_array!( - string_array, - field_name, - DataType::Int64, - Int64Builder::new(), - Int64Array, - value_length - ) - } - DataType::Float32 => { - cast_string_to_fixed_size_list_array!( - string_array, - field_name, - DataType::Float32, - Float32Builder::new(), - Float32Array, - value_length - ) - } - DataType::Float64 => { - cast_string_to_fixed_size_list_array!( - string_array, - field_name, - DataType::Float64, - Float64Builder::new(), - Float64Array, - value_length - ) - } - _ => Err(ArrowError::CastError(format!( - "Unsupported list item type: {}", - list_item_field.data_type() - ))), - } -} - -#[cfg(test)] -mod test { - use datafusion::arrow::{ - array::{RecordBatch, StringArray}, - datatypes::{DataType, Field, Schema, SchemaRef}, - }; - - use crate::schema_cast::record_convert::try_cast_to; - - use super::*; - - fn input_schema() -> SchemaRef { - Arc::new(Schema::new(vec![ - Field::new("a", DataType::Utf8, false), - Field::new("b", DataType::Utf8, false), - Field::new("c", DataType::Utf8, false), - ])) - } - - fn output_schema() -> SchemaRef { - Arc::new(Schema::new(vec![ - Field::new( - "a", - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - false, - ), - Field::new( - "b", - DataType::LargeList(Arc::new(Field::new("item", DataType::Utf8, true))), - false, - ), - Field::new( - "c", - DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Boolean, true)), 3), - false, - ), - ])) - } - - fn batch_input() -> RecordBatch { - RecordBatch::try_new( - input_schema(), - vec![ - Arc::new(StringArray::from(vec![ - Some("[1, 2, 3]"), - Some("[4, 5, 6]"), - ])), - Arc::new(StringArray::from(vec![ - Some("[\"foo\", \"bar\"]"), - Some("[\"baz\", \"qux\"]"), - ])), - Arc::new(StringArray::from(vec![ - Some("[true, false, true]"), - Some("[false, true, false]"), - ])), - ], - ) - .expect("record batch should not panic") - } - - fn batch_expected() -> RecordBatch { - let mut list_builder = ListBuilder::new(Int32Builder::new()); - list_builder.append_value([Some(1), Some(2), Some(3)]); - list_builder.append_value([Some(4), Some(5), Some(6)]); - let list_array = list_builder.finish(); - - let mut large_list_builder = LargeListBuilder::new(StringBuilder::new()); - large_list_builder.append_value([Some("foo"), Some("bar")]); - large_list_builder.append_value([Some("baz"), Some("qux")]); - let large_list_array = large_list_builder.finish(); - - let mut fixed_size_list_builder = FixedSizeListBuilder::new(BooleanBuilder::new(), 3); - fixed_size_list_builder.values().append_value(true); - fixed_size_list_builder.values().append_value(false); - fixed_size_list_builder.values().append_value(true); - fixed_size_list_builder.append(true); - fixed_size_list_builder.values().append_value(false); - fixed_size_list_builder.values().append_value(true); - fixed_size_list_builder.values().append_value(false); - fixed_size_list_builder.append(true); - let fixed_size_list_array = fixed_size_list_builder.finish(); - - RecordBatch::try_new( - output_schema(), - vec![ - Arc::new(list_array), - Arc::new(large_list_array), - Arc::new(fixed_size_list_array), - ], - ) - .expect("Failed to create expected RecordBatch") - } - - #[test] - fn test_cast_to_list_largelist_fixedsizelist() { - let input_batch = batch_input(); - let expected = batch_expected(); - let actual = try_cast_to(input_batch, output_schema()).expect("cast should succeed"); - - assert_eq!(actual, expected); - } -} diff --git a/datafusion-federation/src/schema_cast/mod.rs b/datafusion-federation/src/schema_cast/mod.rs deleted file mode 100644 index 3ee892c..0000000 --- a/datafusion-federation/src/schema_cast/mod.rs +++ /dev/null @@ -1,116 +0,0 @@ -use async_stream::stream; -use datafusion::arrow::datatypes::SchemaRef; -use datafusion::error::{DataFusionError, Result}; -use datafusion::execution::{SendableRecordBatchStream, TaskContext}; -use datafusion::physical_plan::stream::RecordBatchStreamAdapter; -use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, -}; -use futures::StreamExt; -use std::any::Any; -use std::clone::Clone; -use std::fmt; -use std::sync::Arc; - -mod intervals_cast; -mod lists_cast; -pub mod record_convert; -mod struct_cast; - -#[derive(Debug)] -#[allow(clippy::module_name_repetitions)] -pub struct SchemaCastScanExec { - input: Arc, - schema: SchemaRef, - properties: PlanProperties, -} - -impl SchemaCastScanExec { - pub fn new(input: Arc, schema: SchemaRef) -> Self { - let eq_properties = input.equivalence_properties().clone(); - let emission_type = input.pipeline_behavior(); - let boundedness = input.boundedness(); - let properties = PlanProperties::new( - eq_properties, - input.output_partitioning().clone(), - emission_type, - boundedness, - ); - Self { - input, - schema, - properties, - } - } -} - -impl DisplayAs for SchemaCastScanExec { - fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "SchemaCastScanExec") - } -} - -impl ExecutionPlan for SchemaCastScanExec { - fn name(&self) -> &str { - "SchemaCastScanExec" - } - - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - &self.properties - } - - fn schema(&self) -> SchemaRef { - Arc::clone(&self.schema) - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.input] - } - - /// Prevents the introduction of additional `RepartitionExec` and processing input in parallel. - /// This guarantees that the input is processed as a single stream, preserving the order of the data. - fn benefits_from_input_partitioning(&self) -> Vec { - vec![false] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result> { - if children.len() == 1 { - Ok(Arc::new(Self::new( - Arc::clone(&children[0]), - Arc::clone(&self.schema), - ))) - } else { - Err(DataFusionError::Execution( - "SchemaCastScanExec expects exactly one input".to_string(), - )) - } - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - let mut stream = self.input.execute(partition, context)?; - let schema = Arc::clone(&self.schema); - - Ok(Box::pin(RecordBatchStreamAdapter::new( - Arc::clone(&schema), - { - stream! { - while let Some(batch) = stream.next().await { - let batch = record_convert::try_cast_to(batch?, Arc::clone(&schema)); - yield batch.map_err(|e| { DataFusionError::External(Box::new(e)) }); - } - } - }, - ))) - } -} diff --git a/datafusion-federation/src/schema_cast/record_convert.rs b/datafusion-federation/src/schema_cast/record_convert.rs deleted file mode 100644 index b2b2e0a..0000000 --- a/datafusion-federation/src/schema_cast/record_convert.rs +++ /dev/null @@ -1,237 +0,0 @@ -use datafusion::arrow::{ - array::{Array, RecordBatch}, - compute::cast, - datatypes::{DataType, IntervalUnit, SchemaRef}, -}; -use std::sync::Arc; - -use super::{ - intervals_cast::{ - cast_interval_monthdaynano_to_daytime, cast_interval_monthdaynano_to_yearmonth, - }, - lists_cast::{cast_string_to_fixed_size_list, cast_string_to_large_list, cast_string_to_list}, - struct_cast::cast_string_to_struct, -}; - -pub type Result = std::result::Result; - -#[derive(Debug)] -pub enum Error { - UnableToConvertRecordBatch { - source: datafusion::arrow::error::ArrowError, - }, - - UnexpectedNumberOfColumns { - expected: usize, - found: usize, - }, -} - -impl std::error::Error for Error {} - -impl std::fmt::Display for Error { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - Error::UnableToConvertRecordBatch { source } => { - write!(f, "Unable to convert record batch: {}", source) - } - Error::UnexpectedNumberOfColumns { expected, found } => { - write!( - f, - "Unexpected number of columns. Expected: {}, Found: {}", - expected, found - ) - } - } - } -} - -/// Cast a given record batch into a new record batch with the given schema. -/// It assumes the record batch columns are correctly ordered. -#[allow(clippy::needless_pass_by_value)] -pub fn try_cast_to(record_batch: RecordBatch, expected_schema: SchemaRef) -> Result { - let actual_schema = record_batch.schema(); - - if actual_schema.fields().len() != expected_schema.fields().len() { - return Err(Error::UnexpectedNumberOfColumns { - expected: expected_schema.fields().len(), - found: actual_schema.fields().len(), - }); - } - - let cols = expected_schema - .fields() - .iter() - .enumerate() - .map(|(i, expected_field)| { - let record_batch_col = record_batch.column(i); - - match (record_batch_col.data_type(), expected_field.data_type()) { - (DataType::Utf8, DataType::List(item_type)) => { - cast_string_to_list::(&Arc::clone(record_batch_col), item_type) - .map_err(|e| Error::UnableToConvertRecordBatch { source: e }) - } - (DataType::Utf8, DataType::LargeList(item_type)) => { - cast_string_to_large_list::(&Arc::clone(record_batch_col), item_type) - .map_err(|e| Error::UnableToConvertRecordBatch { source: e }) - } - (DataType::Utf8, DataType::FixedSizeList(item_type, value_length)) => { - cast_string_to_fixed_size_list::( - &Arc::clone(record_batch_col), - item_type, - *value_length, - ) - .map_err(|e| Error::UnableToConvertRecordBatch { source: e }) - } - (DataType::Utf8, DataType::Struct(_)) => cast_string_to_struct::( - &Arc::clone(record_batch_col), - expected_field.clone(), - ) - .map_err(|e| Error::UnableToConvertRecordBatch { source: e }), - (DataType::LargeUtf8, DataType::List(item_type)) => { - cast_string_to_list::(&Arc::clone(record_batch_col), item_type) - .map_err(|e| Error::UnableToConvertRecordBatch { source: e }) - } - (DataType::LargeUtf8, DataType::LargeList(item_type)) => { - cast_string_to_large_list::(&Arc::clone(record_batch_col), item_type) - .map_err(|e| Error::UnableToConvertRecordBatch { source: e }) - } - (DataType::LargeUtf8, DataType::FixedSizeList(item_type, value_length)) => { - cast_string_to_fixed_size_list::( - &Arc::clone(record_batch_col), - item_type, - *value_length, - ) - .map_err(|e| Error::UnableToConvertRecordBatch { source: e }) - } - (DataType::LargeUtf8, DataType::Struct(_)) => cast_string_to_struct::( - &Arc::clone(record_batch_col), - expected_field.clone(), - ) - .map_err(|e| Error::UnableToConvertRecordBatch { source: e }), - ( - DataType::Interval(IntervalUnit::MonthDayNano), - DataType::Interval(IntervalUnit::YearMonth), - ) => cast_interval_monthdaynano_to_yearmonth(&Arc::clone(record_batch_col)) - .map_err(|e| Error::UnableToConvertRecordBatch { source: e }), - ( - DataType::Interval(IntervalUnit::MonthDayNano), - DataType::Interval(IntervalUnit::DayTime), - ) => cast_interval_monthdaynano_to_daytime(&Arc::clone(record_batch_col)) - .map_err(|e| Error::UnableToConvertRecordBatch { source: e }), - _ => cast(&Arc::clone(record_batch_col), expected_field.data_type()) - .map_err(|e| Error::UnableToConvertRecordBatch { source: e }), - } - }) - .collect::>>>()?; - - RecordBatch::try_new(expected_schema, cols) - .map_err(|e| Error::UnableToConvertRecordBatch { source: e }) -} - -#[cfg(test)] -mod test { - use super::*; - use datafusion::arrow::array::LargeStringArray; - use datafusion::arrow::{ - array::{Int32Array, StringArray}, - datatypes::{DataType, Field, Schema, TimeUnit}, - }; - use datafusion::assert_batches_eq; - - fn schema() -> SchemaRef { - Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Utf8, false), - Field::new("c", DataType::Utf8, false), - ])) - } - - fn to_schema() -> SchemaRef { - Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int64, false), - Field::new("b", DataType::LargeUtf8, false), - Field::new("c", DataType::Timestamp(TimeUnit::Microsecond, None), false), - ])) - } - - fn batch_input() -> RecordBatch { - RecordBatch::try_new( - schema(), - vec![ - Arc::new(Int32Array::from(vec![1, 2, 3])), - Arc::new(StringArray::from(vec!["foo", "bar", "baz"])), - Arc::new(StringArray::from(vec![ - "2024-01-13 03:18:09.000000", - "2024-01-13 03:18:09", - "2024-01-13 03:18:09.000", - ])), - ], - ) - .expect("record batch should not panic") - } - - #[test] - fn test_string_to_timestamp_conversion() { - let result = try_cast_to(batch_input(), to_schema()).expect("converted"); - let expected = vec![ - "+---+-----+---------------------+", - "| a | b | c |", - "+---+-----+---------------------+", - "| 1 | foo | 2024-01-13T03:18:09 |", - "| 2 | bar | 2024-01-13T03:18:09 |", - "| 3 | baz | 2024-01-13T03:18:09 |", - "+---+-----+---------------------+", - ]; - - assert_batches_eq!(expected, &[result]); - } - - fn large_string_from_schema() -> SchemaRef { - Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::LargeUtf8, false), - Field::new("c", DataType::LargeUtf8, false), - ])) - } - - fn large_string_to_schema() -> SchemaRef { - Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int64, false), - Field::new("b", DataType::LargeUtf8, false), - Field::new("c", DataType::Timestamp(TimeUnit::Microsecond, None), false), - ])) - } - - fn large_string_batch_input() -> RecordBatch { - RecordBatch::try_new( - large_string_from_schema(), - vec![ - Arc::new(Int32Array::from(vec![1, 2, 3])), - Arc::new(LargeStringArray::from(vec!["foo", "bar", "baz"])), - Arc::new(LargeStringArray::from(vec![ - "2024-01-13 03:18:09.000000", - "2024-01-13 03:18:09", - "2024-01-13 03:18:09.000", - ])), - ], - ) - .expect("record batch should not panic") - } - - #[test] - fn test_large_string_to_timestamp_conversion() { - let result = - try_cast_to(large_string_batch_input(), large_string_to_schema()).expect("converted"); - let expected = vec![ - "+---+-----+---------------------+", - "| a | b | c |", - "+---+-----+---------------------+", - "| 1 | foo | 2024-01-13T03:18:09 |", - "| 2 | bar | 2024-01-13T03:18:09 |", - "| 3 | baz | 2024-01-13T03:18:09 |", - "+---+-----+---------------------+", - ]; - assert_batches_eq!(expected, &[result]); - } -} diff --git a/datafusion-federation/src/schema_cast/struct_cast.rs b/datafusion-federation/src/schema_cast/struct_cast.rs deleted file mode 100644 index a6dad0a..0000000 --- a/datafusion-federation/src/schema_cast/struct_cast.rs +++ /dev/null @@ -1,170 +0,0 @@ -use arrow_json::ReaderBuilder; -use datafusion::arrow::array::{GenericStringArray, OffsetSizeTrait}; -use datafusion::arrow::{ - array::{Array, ArrayRef}, - datatypes::Field, - error::ArrowError, -}; -use std::sync::Arc; - -pub type Result = std::result::Result; - -pub(crate) fn cast_string_to_struct( - array: &dyn Array, - struct_field: Arc, -) -> Result { - let string_array = array - .as_any() - .downcast_ref::>() - .ok_or_else(|| ArrowError::CastError("Failed to downcast to StringArray".to_string()))?; - - let mut decoder = ReaderBuilder::new_with_field(struct_field) - .build_decoder() - .map_err(|e| ArrowError::CastError(format!("Failed to create JSON decoder: {e}")))?; - - for value in string_array { - match value { - Some(v) => { - decoder.decode(v.as_bytes()).map_err(|e| { - ArrowError::CastError(format!("Failed to decode struct array: {e}")) - })?; - } - None => { - decoder.decode("null".as_bytes()).map_err(|e| { - ArrowError::CastError(format!("Failed to decode struct array: {e}")) - })?; - } - } - } - - let record = match decoder.flush() { - Ok(Some(record)) => record, - Ok(None) => { - return Err(ArrowError::CastError( - "Failed to flush decoder: No record".to_string(), - )); - } - Err(e) => { - return Err(ArrowError::CastError(format!( - "Failed to decode struct array: {e}" - ))); - } - }; - // struct_field is single struct column - Ok(Arc::clone(record.column(0))) -} - -#[cfg(test)] -mod test { - use datafusion::arrow::{ - array::{Int32Builder, RecordBatch, StringArray, StringBuilder, StructBuilder}, - datatypes::{DataType, Field, Schema, SchemaRef}, - }; - - use crate::schema_cast::record_convert::try_cast_to; - - use super::*; - - fn input_schema() -> SchemaRef { - Arc::new(Schema::new(vec![Field::new( - "struct_string", - DataType::Utf8, - true, - )])) - } - - fn output_schema() -> SchemaRef { - Arc::new(Schema::new(vec![Field::new( - "struct", - DataType::Struct( - vec![ - Field::new("name", DataType::Utf8, false), - Field::new("age", DataType::Int32, false), - ] - .into(), - ), - true, - )])) - } - - fn batch_input() -> RecordBatch { - RecordBatch::try_new( - input_schema(), - vec![Arc::new(StringArray::from(vec![ - Some(r#"{"name":"John","age":30}"#), - None, - None, - Some(r#"{"name":"Jane","age":25}"#), - ]))], - ) - .expect("record batch should not panic") - } - - fn batch_expected() -> RecordBatch { - let name_field = Field::new("name", DataType::Utf8, false); - let age_field = Field::new("age", DataType::Int32, false); - - let mut struct_builder = StructBuilder::new( - vec![name_field, age_field], - vec![ - Box::new(StringBuilder::new()), - Box::new(Int32Builder::new()), - ], - ); - - struct_builder - .field_builder::(0) - .expect("should return field builder") - .append_value("John"); - struct_builder - .field_builder::(1) - .expect("should return field builder") - .append_value(30); - struct_builder.append(true); - - struct_builder - .field_builder::(0) - .expect("should return field builder") - .append_null(); - struct_builder - .field_builder::(1) - .expect("should return field builder") - .append_null(); - struct_builder.append(false); - - struct_builder - .field_builder::(0) - .expect("should return field builder") - .append_null(); - struct_builder - .field_builder::(1) - .expect("should return field builder") - .append_null(); - struct_builder.append(false); - - struct_builder - .field_builder::(0) - .expect("should return field builder") - .append_value("Jane"); - struct_builder - .field_builder::(1) - .expect("should return field builder") - .append_value(25); - struct_builder.append(true); - - let struct_array = struct_builder.finish(); - - RecordBatch::try_new(output_schema(), vec![Arc::new(struct_array)]) - .expect("Failed to create expected RecordBatch") - } - - #[test] - fn test_cast_to_struct() { - let input_batch = batch_input(); - let expected = batch_expected(); - - let actual = try_cast_to(input_batch, output_schema()).expect("cast should succeed"); - - assert_eq!(actual, expected); - } -} diff --git a/datafusion-federation/src/sql/executor.rs b/datafusion-federation/src/sql/executor.rs deleted file mode 100644 index ca04989..0000000 --- a/datafusion-federation/src/sql/executor.rs +++ /dev/null @@ -1,54 +0,0 @@ -use async_trait::async_trait; -use core::fmt; -use datafusion::{ - arrow::datatypes::SchemaRef, error::Result, physical_plan::SendableRecordBatchStream, - sql::sqlparser::ast, sql::unparser::dialect::Dialect, -}; -use std::sync::Arc; - -pub type SQLExecutorRef = Arc; -pub type AstAnalyzer = Box Result>; - -#[async_trait] -pub trait SQLExecutor: Sync + Send { - /// Executor name - fn name(&self) -> &str; - - /// Executor compute context allows differentiating the remote compute context - /// such as authorization or active database. - /// - /// Note: returning None here may cause incorrect federation with other providers of the - /// same name that also have a compute_context of None. - /// Instead try to return a unique string that will never match any other - /// provider's context. - fn compute_context(&self) -> Option; - - /// The specific SQL dialect (currently supports 'sqlite', 'postgres', 'flight') - fn dialect(&self) -> Arc; - - /// Returns an AST analyzer specific for this engine to modify the AST before execution - fn ast_analyzer(&self) -> Option { - None - } - - /// Execute a SQL query - fn execute(&self, query: &str, schema: SchemaRef) -> Result; - - /// Returns the tables provided by the remote - async fn table_names(&self) -> Result>; - - /// Returns the schema of table_name within this [`SQLExecutor`] - async fn get_table_schema(&self, table_name: &str) -> Result; -} - -impl fmt::Debug for dyn SQLExecutor { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{} {:?}", self.name(), self.compute_context()) - } -} - -impl fmt::Display for dyn SQLExecutor { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{} {:?}", self.name(), self.compute_context()) - } -} diff --git a/datafusion-federation/src/sql/mod.rs b/datafusion-federation/src/sql/mod.rs deleted file mode 100644 index 4b099dc..0000000 --- a/datafusion-federation/src/sql/mod.rs +++ /dev/null @@ -1,1078 +0,0 @@ -mod executor; -mod schema; - -use std::{any::Any, collections::HashMap, fmt, sync::Arc, vec}; - -use async_trait::async_trait; -use datafusion::{ - arrow::datatypes::{Schema, SchemaRef}, - common::{tree_node::Transformed, Column}, - error::Result, - execution::{context::SessionState, TaskContext}, - logical_expr::{ - expr::{ - AggregateFunction, Alias, Exists, InList, InSubquery, PlannedReplaceSelectItem, - ScalarFunction, Sort, Unnest, WildcardOptions, WindowFunction, - }, - Between, BinaryExpr, Case, Cast, Expr, Extension, GroupingSet, Like, Limit, LogicalPlan, - Subquery, TryCast, - }, - optimizer::{optimizer::Optimizer, OptimizerConfig, OptimizerRule}, - physical_expr::EquivalenceProperties, - physical_plan::{ - execution_plan::{Boundedness, EmissionType}, - DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, - SendableRecordBatchStream, - }, - sql::{ - sqlparser::ast::Statement, - unparser::{plan_to_sql, Unparser}, - TableReference, - }, -}; - -pub use executor::{AstAnalyzer, SQLExecutor, SQLExecutorRef}; -pub use schema::{MultiSchemaProvider, SQLSchemaProvider, SQLTableSource}; - -use crate::{ - get_table_source, schema_cast, FederatedPlanNode, FederationPlanner, FederationProvider, -}; - -// #[macro_use] -// extern crate derive_builder; - -// SQLFederationProvider provides federation to SQL DMBSs. -#[derive(Debug)] -pub struct SQLFederationProvider { - optimizer: Arc, - executor: Arc, -} - -impl SQLFederationProvider { - pub fn new(executor: Arc) -> Self { - Self { - optimizer: Arc::new(Optimizer::with_rules(vec![Arc::new( - SQLFederationOptimizerRule::new(executor.clone()), - )])), - executor, - } - } -} - -impl FederationProvider for SQLFederationProvider { - fn name(&self) -> &str { - "sql_federation_provider" - } - - fn compute_context(&self) -> Option { - self.executor.compute_context() - } - - fn optimizer(&self) -> Option> { - Some(self.optimizer.clone()) - } -} - -#[derive(Debug)] -struct SQLFederationOptimizerRule { - planner: Arc, -} - -impl SQLFederationOptimizerRule { - pub fn new(executor: Arc) -> Self { - Self { - planner: Arc::new(SQLFederationPlanner::new(Arc::clone(&executor))), - } - } -} - -impl OptimizerRule for SQLFederationOptimizerRule { - /// Try to rewrite `plan` to an optimized form, returning `Transformed::yes` - /// if the plan was rewritten and `Transformed::no` if it was not. - /// - /// Note: this function is only called if [`Self::supports_rewrite`] returns - /// true. Otherwise the Optimizer calls [`Self::try_optimize`] - fn rewrite( - &self, - plan: LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - if let LogicalPlan::Extension(Extension { ref node }) = plan { - if node.name() == "Federated" { - // Avoid attempting double federation - return Ok(Transformed::no(plan)); - } - } - // Simply accept the entire plan for now - let fed_plan = FederatedPlanNode::new(plan.clone(), self.planner.clone()); - let ext_node = Extension { - node: Arc::new(fed_plan), - }; - Ok(Transformed::yes(LogicalPlan::Extension(ext_node))) - } - - /// A human readable name for this analyzer rule - fn name(&self) -> &str { - "federate_sql" - } - - /// Does this rule support rewriting owned plans (rather than by reference)? - fn supports_rewrite(&self) -> bool { - true - } -} - -/// Rewrite table scans to use the original federated table name. -fn rewrite_table_scans( - plan: &LogicalPlan, - known_rewrites: &mut HashMap, -) -> Result { - if plan.inputs().is_empty() { - if let LogicalPlan::TableScan(table_scan) = plan { - let original_table_name = table_scan.table_name.clone(); - let mut new_table_scan = table_scan.clone(); - - let Some(federated_source) = get_table_source(&table_scan.source)? else { - // Not a federated source - return Ok(plan.clone()); - }; - - match federated_source.as_any().downcast_ref::() { - Some(sql_table_source) => { - let remote_table_name = TableReference::from(sql_table_source.table_name()); - known_rewrites.insert(original_table_name, remote_table_name.clone()); - - // Rewrite the schema of this node to have the remote table as the qualifier. - let new_schema = (*new_table_scan.projected_schema) - .clone() - .replace_qualifier(remote_table_name.clone()); - new_table_scan.projected_schema = Arc::new(new_schema); - new_table_scan.table_name = remote_table_name; - } - None => { - // Not a SQLTableSource (is this possible?) - return Ok(plan.clone()); - } - } - - return Ok(LogicalPlan::TableScan(new_table_scan)); - } else { - return Ok(plan.clone()); - } - } - - let rewritten_inputs = plan - .inputs() - .into_iter() - .map(|i| rewrite_table_scans(i, known_rewrites)) - .collect::>>()?; - - if let LogicalPlan::Limit(limit) = plan { - let rewritten_skip = limit - .skip - .as_ref() - .map(|skip| rewrite_table_scans_in_expr(*skip.clone(), known_rewrites).map(Box::new)) - .transpose()?; - - let rewritten_fetch = limit - .fetch - .as_ref() - .map(|fetch| rewrite_table_scans_in_expr(*fetch.clone(), known_rewrites).map(Box::new)) - .transpose()?; - - // explicitly set fetch and skip - let new_plan = LogicalPlan::Limit(Limit { - skip: rewritten_skip, - fetch: rewritten_fetch, - input: Arc::new(rewritten_inputs[0].clone()), - }); - - return Ok(new_plan); - } - - let mut new_expressions = vec![]; - for expression in plan.expressions() { - let new_expr = rewrite_table_scans_in_expr(expression.clone(), known_rewrites)?; - new_expressions.push(new_expr); - } - - let new_plan = plan.with_new_exprs(new_expressions, rewritten_inputs)?; - - Ok(new_plan) -} - -// The function replaces occurrences of table_ref_str in col_name with the new name defined by rewrite. -// The name to rewrite should NOT be a substring of another name. -// Supports multiple occurrences of table_ref_str in col_name. -fn rewrite_column_name_in_expr( - col_name: &str, - table_ref_str: &str, - rewrite: &str, - start_pos: usize, -) -> Option { - if start_pos >= col_name.len() { - return None; - } - - // Find the first occurrence of table_ref_str starting from start_pos - let idx = col_name[start_pos..].find(table_ref_str)?; - - // Calculate the absolute index of the occurrence in string as the index above is relative to start_pos - let idx = start_pos + idx; - - if idx > 0 { - // Check if the previous character is alphabetic, numeric, underscore or period, in which case we - // should not rewrite as it is a part of another name. - if let Some(prev_char) = col_name.chars().nth(idx - 1) { - if prev_char.is_alphabetic() - || prev_char.is_numeric() - || prev_char == '_' - || prev_char == '.' - { - return rewrite_column_name_in_expr( - col_name, - table_ref_str, - rewrite, - idx + table_ref_str.len(), - ); - } - } - } - - // Check if the next character is alphabetic, numeric or underscore, in which case we - // should not rewrite as it is a part of another name. - if let Some(next_char) = col_name.chars().nth(idx + table_ref_str.len()) { - if next_char.is_alphabetic() || next_char.is_numeric() || next_char == '_' { - return rewrite_column_name_in_expr( - col_name, - table_ref_str, - rewrite, - idx + table_ref_str.len(), - ); - } - } - - // Found full match, replace table_ref_str occurrence with rewrite - let rewritten_name = format!( - "{}{}{}", - &col_name[..idx], - rewrite, - &col_name[idx + table_ref_str.len()..] - ); - // Check if the rewritten name contains more occurrence of table_ref_str, and rewrite them as well - // This is done by providing the updated start_pos for search - match rewrite_column_name_in_expr(&rewritten_name, table_ref_str, rewrite, idx + rewrite.len()) - { - Some(new_name) => Some(new_name), // more occurrences found - None => Some(rewritten_name), // no more occurrences/changes - } -} - -fn rewrite_table_scans_in_expr( - expr: Expr, - known_rewrites: &mut HashMap, -) -> Result { - match expr { - Expr::ScalarSubquery(subquery) => { - let new_subquery = rewrite_table_scans(&subquery.subquery, known_rewrites)?; - let outer_ref_columns = subquery - .outer_ref_columns - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - Ok(Expr::ScalarSubquery(Subquery { - subquery: Arc::new(new_subquery), - outer_ref_columns, - })) - } - Expr::BinaryExpr(binary_expr) => { - let left = rewrite_table_scans_in_expr(*binary_expr.left, known_rewrites)?; - let right = rewrite_table_scans_in_expr(*binary_expr.right, known_rewrites)?; - Ok(Expr::BinaryExpr(BinaryExpr::new( - Box::new(left), - binary_expr.op, - Box::new(right), - ))) - } - Expr::Column(mut col) => { - if let Some(rewrite) = col.relation.as_ref().and_then(|r| known_rewrites.get(r)) { - Ok(Expr::Column(Column::new(Some(rewrite.clone()), &col.name))) - } else { - // This prevent over-eager rewrite and only pass the column into below rewritten - // rule like MAX(...) - if col.relation.is_some() { - return Ok(Expr::Column(col)); - } - - // Check if any of the rewrites match any substring in col.name, and replace that part of the string if so. - // This will handles cases like "MAX(foo.df_table.a)" -> "MAX(remote_table.a)" - let (new_name, was_rewritten) = known_rewrites.iter().fold( - (col.name.to_string(), false), - |(col_name, was_rewritten), (table_ref, rewrite)| { - match rewrite_column_name_in_expr( - &col_name, - &table_ref.to_string(), - &rewrite.to_string(), - 0, - ) { - Some(new_name) => (new_name, true), - None => (col_name, was_rewritten), - } - }, - ); - if was_rewritten { - Ok(Expr::Column(Column::new(col.relation.take(), new_name))) - } else { - Ok(Expr::Column(col)) - } - } - } - Expr::Alias(alias) => { - let expr = rewrite_table_scans_in_expr(*alias.expr, known_rewrites)?; - if let Some(relation) = &alias.relation { - if let Some(rewrite) = known_rewrites.get(relation) { - return Ok(Expr::Alias(Alias::new( - expr, - Some(rewrite.clone()), - alias.name, - ))); - } - } - Ok(Expr::Alias(Alias::new(expr, alias.relation, alias.name))) - } - Expr::Like(like) => { - let expr = rewrite_table_scans_in_expr(*like.expr, known_rewrites)?; - let pattern = rewrite_table_scans_in_expr(*like.pattern, known_rewrites)?; - Ok(Expr::Like(Like::new( - like.negated, - Box::new(expr), - Box::new(pattern), - like.escape_char, - like.case_insensitive, - ))) - } - Expr::SimilarTo(similar_to) => { - let expr = rewrite_table_scans_in_expr(*similar_to.expr, known_rewrites)?; - let pattern = rewrite_table_scans_in_expr(*similar_to.pattern, known_rewrites)?; - Ok(Expr::SimilarTo(Like::new( - similar_to.negated, - Box::new(expr), - Box::new(pattern), - similar_to.escape_char, - similar_to.case_insensitive, - ))) - } - Expr::Not(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::Not(Box::new(expr))) - } - Expr::IsNotNull(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::IsNotNull(Box::new(expr))) - } - Expr::IsNull(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::IsNull(Box::new(expr))) - } - Expr::IsTrue(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::IsTrue(Box::new(expr))) - } - Expr::IsFalse(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::IsFalse(Box::new(expr))) - } - Expr::IsUnknown(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::IsUnknown(Box::new(expr))) - } - Expr::IsNotTrue(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::IsNotTrue(Box::new(expr))) - } - Expr::IsNotFalse(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::IsNotFalse(Box::new(expr))) - } - Expr::IsNotUnknown(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::IsNotUnknown(Box::new(expr))) - } - Expr::Negative(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::Negative(Box::new(expr))) - } - Expr::Between(between) => { - let expr = rewrite_table_scans_in_expr(*between.expr, known_rewrites)?; - let low = rewrite_table_scans_in_expr(*between.low, known_rewrites)?; - let high = rewrite_table_scans_in_expr(*between.high, known_rewrites)?; - Ok(Expr::Between(Between::new( - Box::new(expr), - between.negated, - Box::new(low), - Box::new(high), - ))) - } - Expr::Case(case) => { - let expr = case - .expr - .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites)) - .transpose()? - .map(Box::new); - let else_expr = case - .else_expr - .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites)) - .transpose()? - .map(Box::new); - let when_expr = case - .when_then_expr - .into_iter() - .map(|(when, then)| { - let when = rewrite_table_scans_in_expr(*when, known_rewrites); - let then = rewrite_table_scans_in_expr(*then, known_rewrites); - - match (when, then) { - (Ok(when), Ok(then)) => Ok((Box::new(when), Box::new(then))), - (Err(e), _) | (_, Err(e)) => Err(e), - } - }) - .collect::, Box)>>>()?; - Ok(Expr::Case(Case::new(expr, when_expr, else_expr))) - } - Expr::Cast(cast) => { - let expr = rewrite_table_scans_in_expr(*cast.expr, known_rewrites)?; - Ok(Expr::Cast(Cast::new(Box::new(expr), cast.data_type))) - } - Expr::TryCast(try_cast) => { - let expr = rewrite_table_scans_in_expr(*try_cast.expr, known_rewrites)?; - Ok(Expr::TryCast(TryCast::new( - Box::new(expr), - try_cast.data_type, - ))) - } - Expr::ScalarFunction(sf) => { - let args = sf - .args - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - Ok(Expr::ScalarFunction(ScalarFunction { - func: sf.func, - args, - })) - } - Expr::AggregateFunction(af) => { - let args = af - .args - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - let filter = af - .filter - .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites)) - .transpose()? - .map(Box::new); - let order_by = af - .order_by - .map(|e| { - e.into_iter() - .map(|sort| { - Ok(Sort { - expr: rewrite_table_scans_in_expr(sort.expr, known_rewrites)?, - ..sort - }) - }) - .collect::>>() - }) - .transpose()?; - Ok(Expr::AggregateFunction(AggregateFunction { - func: af.func, - args, - distinct: af.distinct, - filter, - order_by, - null_treatment: af.null_treatment, - })) - } - Expr::WindowFunction(wf) => { - let args = wf - .args - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - let partition_by = wf - .partition_by - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - let order_by = wf - .order_by - .into_iter() - .map(|sort| { - Ok(Sort { - expr: rewrite_table_scans_in_expr(sort.expr, known_rewrites)?, - ..sort - }) - }) - .collect::>>()?; - Ok(Expr::WindowFunction(WindowFunction { - fun: wf.fun, - args, - partition_by, - order_by, - window_frame: wf.window_frame, - null_treatment: wf.null_treatment, - })) - } - Expr::InList(il) => { - let expr = rewrite_table_scans_in_expr(*il.expr, known_rewrites)?; - let list = il - .list - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - Ok(Expr::InList(InList::new(Box::new(expr), list, il.negated))) - } - Expr::Exists(exists) => { - let subquery_plan = rewrite_table_scans(&exists.subquery.subquery, known_rewrites)?; - let outer_ref_columns = exists - .subquery - .outer_ref_columns - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - let subquery = Subquery { - subquery: Arc::new(subquery_plan), - outer_ref_columns, - }; - Ok(Expr::Exists(Exists::new(subquery, exists.negated))) - } - Expr::InSubquery(is) => { - let expr = rewrite_table_scans_in_expr(*is.expr, known_rewrites)?; - let subquery_plan = rewrite_table_scans(&is.subquery.subquery, known_rewrites)?; - let outer_ref_columns = is - .subquery - .outer_ref_columns - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - let subquery = Subquery { - subquery: Arc::new(subquery_plan), - outer_ref_columns, - }; - Ok(Expr::InSubquery(InSubquery::new( - Box::new(expr), - subquery, - is.negated, - ))) - } - Expr::Wildcard { qualifier, options } => { - let options = WildcardOptions { - replace: options - .replace - .map(|replace| -> Result { - Ok(PlannedReplaceSelectItem { - planned_expressions: replace - .planned_expressions - .into_iter() - .map(|expr| rewrite_table_scans_in_expr(expr, known_rewrites)) - .collect::>>()?, - ..replace - }) - }) - .transpose()?, - ..*options - }; - if let Some(rewrite) = qualifier.as_ref().and_then(|q| known_rewrites.get(q)) { - Ok(Expr::Wildcard { - qualifier: Some(rewrite.clone()), - options: Box::new(options), - }) - } else { - Ok(Expr::Wildcard { - qualifier, - options: Box::new(options), - }) - } - } - Expr::GroupingSet(gs) => match gs { - GroupingSet::Rollup(exprs) => { - let exprs = exprs - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - Ok(Expr::GroupingSet(GroupingSet::Rollup(exprs))) - } - GroupingSet::Cube(exprs) => { - let exprs = exprs - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - Ok(Expr::GroupingSet(GroupingSet::Cube(exprs))) - } - GroupingSet::GroupingSets(vec_exprs) => { - let vec_exprs = vec_exprs - .into_iter() - .map(|exprs| { - exprs - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>() - }) - .collect::>>>()?; - Ok(Expr::GroupingSet(GroupingSet::GroupingSets(vec_exprs))) - } - }, - Expr::OuterReferenceColumn(dt, col) => { - if let Some(rewrite) = col.relation.as_ref().and_then(|r| known_rewrites.get(r)) { - Ok(Expr::OuterReferenceColumn( - dt, - Column::new(Some(rewrite.clone()), &col.name), - )) - } else { - Ok(Expr::OuterReferenceColumn(dt, col)) - } - } - Expr::Unnest(unnest) => { - let expr = rewrite_table_scans_in_expr(*unnest.expr, known_rewrites)?; - Ok(Expr::Unnest(Unnest::new(expr))) - } - Expr::ScalarVariable(_, _) | Expr::Literal(_) | Expr::Placeholder(_) => Ok(expr), - } -} - -struct SQLFederationPlanner { - executor: Arc, -} - -impl SQLFederationPlanner { - pub fn new(executor: Arc) -> Self { - Self { executor } - } -} - -#[async_trait] -impl FederationPlanner for SQLFederationPlanner { - async fn plan_federation( - &self, - node: &FederatedPlanNode, - _session_state: &SessionState, - ) -> Result> { - let schema = Arc::new(node.plan().schema().as_arrow().clone()); - let input = Arc::new(VirtualExecutionPlan::new( - node.plan().clone(), - Arc::clone(&self.executor), - )); - let schema_cast_exec = schema_cast::SchemaCastScanExec::new(input, schema); - Ok(Arc::new(schema_cast_exec)) - } -} - -#[derive(Debug, Clone)] -struct VirtualExecutionPlan { - plan: LogicalPlan, - executor: Arc, - props: PlanProperties, -} - -impl VirtualExecutionPlan { - pub fn new(plan: LogicalPlan, executor: Arc) -> Self { - let schema: Schema = plan.schema().as_ref().into(); - let props = PlanProperties::new( - EquivalenceProperties::new(Arc::new(schema)), - Partitioning::UnknownPartitioning(1), - EmissionType::Incremental, - Boundedness::Bounded, - ); - Self { - plan, - executor, - props, - } - } - - fn schema(&self) -> SchemaRef { - let df_schema = self.plan.schema().as_ref(); - Arc::new(Schema::from(df_schema)) - } - - fn sql(&self) -> Result { - // Find all table scans, recover the SQLTableSource, find the remote table name and replace the name of the TableScan table. - let mut known_rewrites = HashMap::new(); - let plan = &rewrite_table_scans(&self.plan, &mut known_rewrites)?; - let mut ast = self.plan_to_sql(plan)?; - - if let Some(analyzer) = self.executor.ast_analyzer() { - ast = analyzer(ast)?; - } - - Ok(format!("{ast}")) - } - - fn plan_to_sql(&self, plan: &LogicalPlan) -> Result { - Unparser::new(self.executor.dialect().as_ref()).plan_to_sql(plan) - } -} - -impl DisplayAs for VirtualExecutionPlan { - fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result { - write!(f, "VirtualExecutionPlan")?; - let Ok(ast) = plan_to_sql(&self.plan) else { - return Ok(()); - }; - write!(f, " name={}", self.executor.name())?; - if let Some(ctx) = self.executor.compute_context() { - write!(f, " compute_context={ctx}")?; - }; - - write!(f, " sql={ast}")?; - if let Ok(query) = self.sql() { - write!(f, " rewritten_sql={query}")?; - }; - - write!(f, " sql={ast}") - } -} - -impl ExecutionPlan for VirtualExecutionPlan { - fn name(&self) -> &str { - "sql_federation_exec" - } - - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> SchemaRef { - self.schema() - } - - fn children(&self) -> Vec<&Arc> { - vec![] - } - - fn with_new_children( - self: Arc, - _: Vec>, - ) -> Result> { - Ok(self) - } - - fn execute( - &self, - _partition: usize, - _context: Arc, - ) -> Result { - let query = self.plan_to_sql(&self.plan)?.to_string(); - self.executor.execute(query.as_str(), self.schema()) - } - - fn properties(&self) -> &PlanProperties { - &self.props - } -} - -#[cfg(test)] -mod tests { - use crate::FederatedTableProviderAdaptor; - use datafusion::{ - arrow::datatypes::{DataType, Field}, - catalog::SchemaProvider, - catalog_common::MemorySchemaProvider, - common::Column, - datasource::{DefaultTableSource, TableProvider}, - error::DataFusionError, - execution::context::SessionContext, - logical_expr::LogicalPlanBuilder, - sql::{unparser::dialect::DefaultDialect, unparser::dialect::Dialect}, - }; - - use super::*; - - struct TestSQLExecutor {} - - #[async_trait] - impl SQLExecutor for TestSQLExecutor { - fn name(&self) -> &str { - "test_sql_table_source" - } - - fn compute_context(&self) -> Option { - None - } - - fn dialect(&self) -> Arc { - Arc::new(DefaultDialect {}) - } - - fn execute(&self, _query: &str, _schema: SchemaRef) -> Result { - Err(DataFusionError::NotImplemented( - "execute not implemented".to_string(), - )) - } - - async fn table_names(&self) -> Result> { - Err(DataFusionError::NotImplemented( - "table inference not implemented".to_string(), - )) - } - - async fn get_table_schema(&self, _table_name: &str) -> Result { - Err(DataFusionError::NotImplemented( - "table inference not implemented".to_string(), - )) - } - } - - fn get_test_table_provider() -> Arc { - let sql_federation_provider = - Arc::new(SQLFederationProvider::new(Arc::new(TestSQLExecutor {}))); - - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int64, false), - Field::new("b", DataType::Utf8, false), - Field::new("c", DataType::Date32, false), - ])); - let table_source = Arc::new( - SQLTableSource::new_with_schema( - sql_federation_provider, - "remote_table".to_string(), - schema, - ) - .expect("to have a valid SQLTableSource"), - ); - Arc::new(FederatedTableProviderAdaptor::new(table_source)) - } - - fn get_test_table_source() -> Arc { - Arc::new(DefaultTableSource::new(get_test_table_provider())) - } - - fn get_test_df_context() -> SessionContext { - let ctx = SessionContext::new(); - let catalog = ctx - .catalog("datafusion") - .expect("default catalog is datafusion"); - let foo_schema = Arc::new(MemorySchemaProvider::new()) as Arc; - catalog - .register_schema("foo", Arc::clone(&foo_schema)) - .expect("to register schema"); - foo_schema - .register_table("df_table".to_string(), get_test_table_provider()) - .expect("to register table"); - - let public_schema = catalog - .schema("public") - .expect("public schema should exist"); - public_schema - .register_table("app_table".to_string(), get_test_table_provider()) - .expect("to register table"); - - ctx - } - - #[test] - fn test_rewrite_table_scans_basic() -> Result<()> { - let default_table_source = get_test_table_source(); - let plan = - LogicalPlanBuilder::scan("foo.df_table", default_table_source, None)?.project(vec![ - Expr::Column(Column::from_qualified_name("foo.df_table.a")), - Expr::Column(Column::from_qualified_name("foo.df_table.b")), - Expr::Column(Column::from_qualified_name("foo.df_table.c")), - ])?; - - let mut known_rewrites = HashMap::new(); - let rewritten_plan = rewrite_table_scans(&plan.build()?, &mut known_rewrites)?; - - println!("rewritten_plan: \n{:#?}", rewritten_plan); - - let unparsed_sql = plan_to_sql(&rewritten_plan)?; - - println!("unparsed_sql: \n{unparsed_sql}"); - - assert_eq!( - format!("{unparsed_sql}"), - r#"SELECT remote_table.a, remote_table.b, remote_table.c FROM remote_table"# - ); - - Ok(()) - } - - fn init_tracing() { - let subscriber = tracing_subscriber::FmtSubscriber::builder() - .with_env_filter("debug") - .with_ansi(true) - .finish(); - let _ = tracing::subscriber::set_global_default(subscriber); - } - - #[tokio::test] - async fn test_rewrite_table_scans_agg() -> Result<()> { - init_tracing(); - let ctx = get_test_df_context(); - - let agg_tests = vec![ - ( - "SELECT MAX(a) FROM foo.df_table", - r#"SELECT max(remote_table.a) FROM remote_table"#, - ), - ( - "SELECT foo.df_table.a FROM foo.df_table", - r#"SELECT remote_table.a FROM remote_table"#, - ), - ( - "SELECT MIN(a) FROM foo.df_table", - r#"SELECT min(remote_table.a) FROM remote_table"#, - ), - ( - "SELECT AVG(a) FROM foo.df_table", - r#"SELECT avg(remote_table.a) FROM remote_table"#, - ), - ( - "SELECT SUM(a) FROM foo.df_table", - r#"SELECT sum(remote_table.a) FROM remote_table"#, - ), - ( - "SELECT COUNT(a) FROM foo.df_table", - r#"SELECT count(remote_table.a) FROM remote_table"#, - ), - ( - "SELECT COUNT(a) as cnt FROM foo.df_table", - r#"SELECT count(remote_table.a) AS cnt FROM remote_table"#, - ), - ( - "SELECT COUNT(a) as cnt FROM foo.df_table", - r#"SELECT count(remote_table.a) AS cnt FROM remote_table"#, - ), - ( - "SELECT app_table from (SELECT a as app_table FROM app_table) b", - r#"SELECT b.app_table FROM (SELECT remote_table.a AS app_table FROM remote_table) AS b"#, - ), - ( - "SELECT MAX(app_table) from (SELECT a as app_table FROM app_table) b", - r#"SELECT max(b.app_table) FROM (SELECT remote_table.a AS app_table FROM remote_table) AS b"#, - ), - // multiple occurrences of the same table in single aggregation expression - ( - "SELECT COUNT(CASE WHEN a > 0 THEN a ELSE 0 END) FROM app_table", - r#"SELECT count(CASE WHEN (remote_table.a > 0) THEN remote_table.a ELSE 0 END) FROM remote_table"#, - ), - // different tables in single aggregation expression - ( - "SELECT COUNT(CASE WHEN appt.a > 0 THEN appt.a ELSE dft.a END) FROM app_table as appt, foo.df_table as dft", - "SELECT count(CASE WHEN (appt.a > 0) THEN appt.a ELSE dft.a END) FROM remote_table AS appt CROSS JOIN remote_table AS dft" - ), - ]; - - for test in agg_tests { - test_sql(&ctx, test.0, test.1).await?; - } - - Ok(()) - } - - #[tokio::test] - async fn test_rewrite_table_scans_alias() -> Result<()> { - init_tracing(); - let ctx = get_test_df_context(); - - let tests = vec![ - ( - "SELECT COUNT(app_table_a) FROM (SELECT a as app_table_a FROM app_table)", - r#"SELECT count(app_table_a) FROM (SELECT remote_table.a AS app_table_a FROM remote_table)"#, - ), - ( - "SELECT app_table_a FROM (SELECT a as app_table_a FROM app_table)", - r#"SELECT app_table_a FROM (SELECT remote_table.a AS app_table_a FROM remote_table)"#, - ), - ( - "SELECT aapp_table FROM (SELECT a as aapp_table FROM app_table)", - r#"SELECT aapp_table FROM (SELECT remote_table.a AS aapp_table FROM remote_table)"#, - ), - ]; - - for test in tests { - test_sql(&ctx, test.0, test.1).await?; - } - - Ok(()) - } - - async fn test_sql( - ctx: &SessionContext, - sql_query: &str, - expected_sql: &str, - ) -> Result<(), datafusion::error::DataFusionError> { - let data_frame = ctx.sql(sql_query).await?; - - println!("before optimization: \n{:#?}", data_frame.logical_plan()); - - let mut known_rewrites = HashMap::new(); - let rewritten_plan = rewrite_table_scans(data_frame.logical_plan(), &mut known_rewrites)?; - - println!("rewritten_plan: \n{:#?}", rewritten_plan); - - let unparsed_sql = plan_to_sql(&rewritten_plan)?; - - println!("unparsed_sql: \n{unparsed_sql}"); - - assert_eq!( - format!("{unparsed_sql}"), - expected_sql, - "SQL under test: {}", - sql_query - ); - - Ok(()) - } - - #[tokio::test] - async fn test_rewrite_table_scans_limit_offset() -> Result<()> { - init_tracing(); - let ctx = get_test_df_context(); - - let tests = vec![ - // Basic LIMIT - ( - "SELECT a FROM foo.df_table LIMIT 5", - r#"SELECT remote_table.a FROM remote_table LIMIT 5"#, - ), - // Basic OFFSET - ( - "SELECT a FROM foo.df_table OFFSET 5", - r#"SELECT remote_table.a FROM remote_table OFFSET 5"#, - ), - // OFFSET after LIMIT - ( - "SELECT a FROM foo.df_table LIMIT 10 OFFSET 5", - r#"SELECT remote_table.a FROM remote_table LIMIT 10 OFFSET 5"#, - ), - // LIMIT after OFFSET - ( - "SELECT a FROM foo.df_table OFFSET 5 LIMIT 10", - r#"SELECT remote_table.a FROM remote_table LIMIT 10 OFFSET 5"#, - ), - // Zero OFFSET - ( - "SELECT a FROM foo.df_table OFFSET 0", - r#"SELECT remote_table.a FROM remote_table OFFSET 0"#, - ), - // Zero LIMIT - ( - "SELECT a FROM foo.df_table LIMIT 0", - r#"SELECT remote_table.a FROM remote_table LIMIT 0"#, - ), - // Zero LIMIT and OFFSET - ( - "SELECT a FROM foo.df_table LIMIT 0 OFFSET 0", - r#"SELECT remote_table.a FROM remote_table LIMIT 0 OFFSET 0"#, - ), - ]; - - for test in tests { - test_sql(&ctx, test.0, test.1).await?; - } - - Ok(()) - } -} diff --git a/datafusion-federation/src/sql/schema.rs b/datafusion-federation/src/sql/schema.rs deleted file mode 100644 index 1961226..0000000 --- a/datafusion-federation/src/sql/schema.rs +++ /dev/null @@ -1,162 +0,0 @@ -use std::{any::Any, sync::Arc}; - -use async_trait::async_trait; -use datafusion::logical_expr::{TableSource, TableType}; -use datafusion::{ - arrow::datatypes::SchemaRef, catalog::SchemaProvider, datasource::TableProvider, error::Result, -}; -use futures::future::join_all; - -use crate::{ - sql::SQLFederationProvider, FederatedTableProviderAdaptor, FederatedTableSource, - FederationProvider, -}; - -#[derive(Debug)] -pub struct SQLSchemaProvider { - // provider: Arc, - tables: Vec>, -} - -impl SQLSchemaProvider { - pub async fn new(provider: Arc) -> Result { - let tables = Arc::clone(&provider).executor.table_names().await?; - - Self::new_with_tables(provider, tables).await - } - - pub async fn new_with_tables( - provider: Arc, - tables: Vec, - ) -> Result { - let futures: Vec<_> = tables - .into_iter() - .map(|t| SQLTableSource::new(Arc::clone(&provider), t)) - .collect(); - let results: Result> = join_all(futures).await.into_iter().collect(); - let sources = results?.into_iter().map(Arc::new).collect(); - Ok(Self::new_with_table_sources(sources)) - } - - pub fn new_with_table_sources(tables: Vec>) -> Self { - Self { tables } - } -} - -#[async_trait] -impl SchemaProvider for SQLSchemaProvider { - fn as_any(&self) -> &dyn Any { - self - } - - fn table_names(&self) -> Vec { - self.tables.iter().map(|s| s.table_name.clone()).collect() - } - - async fn table(&self, name: &str) -> Result>> { - if let Some(source) = self - .tables - .iter() - .find(|s| s.table_name.eq_ignore_ascii_case(name)) - { - let adaptor = FederatedTableProviderAdaptor::new( - Arc::clone(source) as Arc - ); - return Ok(Some(Arc::new(adaptor))); - } - Ok(None) - } - - fn table_exist(&self, name: &str) -> bool { - self.tables - .iter() - .any(|s| s.table_name.eq_ignore_ascii_case(name)) - } -} - -#[derive(Debug)] -pub struct MultiSchemaProvider { - children: Vec>, -} - -impl MultiSchemaProvider { - pub fn new(children: Vec>) -> Self { - Self { children } - } -} - -#[async_trait] -impl SchemaProvider for MultiSchemaProvider { - fn as_any(&self) -> &dyn Any { - self - } - - fn table_names(&self) -> Vec { - self.children.iter().flat_map(|p| p.table_names()).collect() - } - - async fn table(&self, name: &str) -> Result>> { - for child in &self.children { - if let Ok(Some(table)) = child.table(name).await { - return Ok(Some(table)); - } - } - Ok(None) - } - - fn table_exist(&self, name: &str) -> bool { - self.children.iter().any(|p| p.table_exist(name)) - } -} - -#[derive(Debug)] -pub struct SQLTableSource { - provider: Arc, - table_name: String, - schema: SchemaRef, -} - -impl SQLTableSource { - // creates a SQLTableSource and infers the table schema - pub async fn new(provider: Arc, table_name: String) -> Result { - let schema = Arc::clone(&provider) - .executor - .get_table_schema(table_name.as_str()) - .await?; - Self::new_with_schema(provider, table_name, schema) - } - - pub fn new_with_schema( - provider: Arc, - table_name: String, - schema: SchemaRef, - ) -> Result { - Ok(Self { - provider, - table_name, - schema, - }) - } - - pub fn table_name(&self) -> &str { - self.table_name.as_str() - } -} - -impl FederatedTableSource for SQLTableSource { - fn federation_provider(&self) -> Arc { - Arc::clone(&self.provider) as Arc - } -} - -impl TableSource for SQLTableSource { - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> SchemaRef { - Arc::clone(&self.schema) - } - fn table_type(&self) -> TableType { - TableType::Temporary - } -} diff --git a/datafusion-federation/src/table_provider.rs b/datafusion-federation/src/table_provider.rs deleted file mode 100644 index 6da1eb1..0000000 --- a/datafusion-federation/src/table_provider.rs +++ /dev/null @@ -1,158 +0,0 @@ -use std::{any::Any, borrow::Cow, sync::Arc}; - -use async_trait::async_trait; -use datafusion::{ - arrow::datatypes::SchemaRef, - catalog::Session, - common::Constraints, - datasource::TableProvider, - error::{DataFusionError, Result}, - logical_expr::{ - dml::InsertOp, Expr, LogicalPlan, TableProviderFilterPushDown, TableSource, TableType, - }, - physical_plan::ExecutionPlan, -}; - -use crate::FederationProvider; - -// FederatedTableSourceWrapper helps to recover the FederatedTableSource -// from a TableScan. This wrapper may be avoidable. -#[derive(Debug)] -pub struct FederatedTableProviderAdaptor { - pub source: Arc, - pub table_provider: Option>, -} - -impl FederatedTableProviderAdaptor { - pub fn new(source: Arc) -> Self { - Self { - source, - table_provider: None, - } - } - - /// Creates a new FederatedTableProviderAdaptor that falls back to the - /// provided TableProvider. This is useful if used within a DataFusion - /// context without the federation optimizer. - pub fn new_with_provider( - source: Arc, - table_provider: Arc, - ) -> Self { - Self { - source, - table_provider: Some(table_provider), - } - } -} - -#[async_trait] -impl TableProvider for FederatedTableProviderAdaptor { - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> SchemaRef { - if let Some(table_provider) = &self.table_provider { - return table_provider.schema(); - } - - self.source.schema() - } - fn constraints(&self) -> Option<&Constraints> { - if let Some(table_provider) = &self.table_provider { - return table_provider - .constraints() - .or_else(|| self.source.constraints()); - } - - self.source.constraints() - } - fn table_type(&self) -> TableType { - if let Some(table_provider) = &self.table_provider { - return table_provider.table_type(); - } - - self.source.table_type() - } - fn get_logical_plan(&self) -> Option> { - if let Some(table_provider) = &self.table_provider { - return table_provider - .get_logical_plan() - .or_else(|| self.source.get_logical_plan()); - } - - self.source.get_logical_plan() - } - fn get_column_default(&self, column: &str) -> Option<&Expr> { - if let Some(table_provider) = &self.table_provider { - return table_provider - .get_column_default(column) - .or_else(|| self.source.get_column_default(column)); - } - - self.source.get_column_default(column) - } - fn supports_filters_pushdown( - &self, - filters: &[&Expr], - ) -> Result> { - if let Some(table_provider) = &self.table_provider { - return table_provider.supports_filters_pushdown(filters); - } - - Ok(vec![ - TableProviderFilterPushDown::Unsupported; - filters.len() - ]) - } - - // Scan is not supported; the adaptor should be replaced - // with a virtual TableProvider that provides federation for a sub-plan. - async fn scan( - &self, - state: &dyn Session, - projection: Option<&Vec>, - filters: &[Expr], - limit: Option, - ) -> Result> { - if let Some(table_provider) = &self.table_provider { - return table_provider.scan(state, projection, filters, limit).await; - } - - Err(DataFusionError::NotImplemented( - "FederatedTableProviderAdaptor cannot scan".to_string(), - )) - } - - async fn insert_into( - &self, - _state: &dyn Session, - input: Arc, - insert_op: InsertOp, - ) -> Result> { - if let Some(table_provider) = &self.table_provider { - return table_provider.insert_into(_state, input, insert_op).await; - } - - Err(DataFusionError::NotImplemented( - "FederatedTableProviderAdaptor cannot insert_into".to_string(), - )) - } -} - -// FederatedTableProvider extends DataFusion's TableProvider trait -// to allow grouping of TableScans of the same FederationProvider. -#[async_trait] -pub trait FederatedTableSource: TableSource { - // Return the FederationProvider associated with this Table - fn federation_provider(&self) -> Arc; -} - -impl std::fmt::Debug for dyn FederatedTableSource { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!( - f, - "FederatedTableSource: {:?}", - self.federation_provider().name() - ) - } -} diff --git a/datafusion-flight-sql-server/README.md b/datafusion-flight-sql-server/README.md deleted file mode 100644 index f72903e..0000000 --- a/datafusion-flight-sql-server/README.md +++ /dev/null @@ -1,41 +0,0 @@ -# DataFusion Flight SQL Server - -The `datafusion-flight-sql-server` is a Flight SQL server that implements the -necessary endpoints to use DataFusion as the query engine. - -## Getting Started - -To use `datafusion-flight-sql-server` in your Rust project, run: - -```sh -$ cargo add datafusion-flight-sql-server -``` - -## Example - -Here's a basic example of setting up a Flight SQL server: - -```rust -use datafusion_flight_sql_server::service::FlightSqlService; -use datafusion::{ - execution::{ - context::SessionContext, - options::CsvReadOptions, - }, -}; - -async { - let dsn: String = "0.0.0.0:50051".to_string(); - let remote_ctx = SessionContext::new(); - remote_ctx - .register_csv("test", "./examples/test.csv", CsvReadOptions::new()) - .await.expect("Register csv"); - - FlightSqlService::new(remote_ctx.state()).serve(dsn.clone()) - .await - .expect("Run flight sql service"); - -}; -``` - -This example sets up a Flight SQL server listening on `127.0.0.1:50051`. diff --git a/datafusion-flight-sql-server/src/lib.rs b/datafusion-flight-sql-server/src/lib.rs index 101d335..4a7f2b1 100644 --- a/datafusion-flight-sql-server/src/lib.rs +++ b/datafusion-flight-sql-server/src/lib.rs @@ -1,4 +1,4 @@ -#![doc = include_str!("../README.md")] +#![doc = include_str!("../../README.md")] pub mod service; pub mod session;