diff --git a/.github/workflows/dev.yml b/.github/workflows/dev.yml index cc23e99e8cbad..19af21ec910be 100644 --- a/.github/workflows/dev.yml +++ b/.github/workflows/dev.yml @@ -30,7 +30,7 @@ jobs: - name: Checkout uses: actions/checkout@v4 - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.10" - name: Audit licenses diff --git a/.github/workflows/dev_pr.yml b/.github/workflows/dev_pr.yml index 85aabc188934b..77b257743331e 100644 --- a/.github/workflows/dev_pr.yml +++ b/.github/workflows/dev_pr.yml @@ -46,7 +46,7 @@ jobs: github.event_name == 'pull_request_target' && (github.event.action == 'opened' || github.event.action == 'synchronize') - uses: actions/labeler@v4.3.0 + uses: actions/labeler@v5.0.0 with: repo-token: ${{ secrets.GITHUB_TOKEN }} configuration-path: .github/workflows/dev_pr/labeler.yml diff --git a/.github/workflows/dev_pr/labeler.yml b/.github/workflows/dev_pr/labeler.yml index e84cf5efb1d8a..34a37948785b5 100644 --- a/.github/workflows/dev_pr/labeler.yml +++ b/.github/workflows/dev_pr/labeler.yml @@ -16,35 +16,37 @@ # under the License. development-process: - - dev/**.* - - .github/**.* - - ci/**.* - - .asf.yaml +- changed-files: + - any-glob-to-any-file: ['dev/**.*', '.github/**.*', 'ci/**.*', '.asf.yaml'] documentation: - - docs/**.* - - README.md - - ./**/README.md - - DEVELOPERS.md - - datafusion/docs/**.* +- changed-files: + - any-glob-to-any-file: ['docs/**.*', 'README.md', './**/README.md', 'DEVELOPERS.md', 'datafusion/docs/**.*'] sql: - - datafusion/sql/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/sql/**/*'] logical-expr: - - datafusion/expr/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/expr/**/*'] physical-expr: - - datafusion/physical-expr/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/physical-expr/**/*'] optimizer: - - datafusion/optimizer/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/optimizer/**/*'] core: - - datafusion/core/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/core/**/*'] substrait: - - datafusion/substrait/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/substrait/**/*'] sqllogictest: - - datafusion/sqllogictest/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/sqllogictest/**/*'] diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 14b2038e87941..ab6a615ab60be 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -24,7 +24,7 @@ jobs: path: asf-site - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.10" diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 485d179571e30..622521a6fbc77 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -68,6 +68,9 @@ jobs: - name: Check workspace without default features run: cargo check --no-default-features -p datafusion + - name: Check datafusion-common without default features + run: cargo check --tests --no-default-features -p datafusion-common + - name: Check workspace in debug mode run: cargo check @@ -96,6 +99,14 @@ jobs: rust-version: stable - name: Run tests (excluding doctests) run: cargo test --lib --tests --bins --features avro,json,backtrace + env: + # do not produce debug symbols to keep memory usage down + # hardcoding other profile params to avoid profile override values + # More on Cargo profiles https://doc.rust-lang.org/cargo/reference/profiles.html?profile-settings#profile-settings + RUSTFLAGS: "-C debuginfo=0 -C opt-level=0 -C incremental=false -C codegen-units=256" + RUST_BACKTRACE: "1" + # avoid rust stack overflows on tpc-ds tests + RUST_MINSTACK: "3000000" - name: Verify Working Directory Clean run: git diff --exit-code @@ -287,6 +298,7 @@ jobs: # with a OS-dependent path. - name: Setup Rust toolchain run: | + rustup update stable rustup toolchain install stable rustup default stable rustup component add rustfmt @@ -299,9 +311,13 @@ jobs: cargo test --lib --tests --bins --all-features env: # do not produce debug symbols to keep memory usage down - RUSTFLAGS: "-C debuginfo=0" + # use higher optimization level to overcome Windows rust slowness for tpc-ds + # and speed builds: https://github.com/apache/arrow-datafusion/issues/8696 + # Cargo profile docs https://doc.rust-lang.org/cargo/reference/profiles.html?profile-settings#profile-settings + RUSTFLAGS: "-C debuginfo=0 -C opt-level=1 -C target-feature=+crt-static -C incremental=false -C codegen-units=256" RUST_BACKTRACE: "1" - + # avoid rust stack overflows on tpc-ds tests + RUST_MINSTACK: "3000000" macos: name: cargo test (mac) runs-on: macos-latest @@ -324,6 +340,7 @@ jobs: # with a OS-dependent path. - name: Setup Rust toolchain run: | + rustup update stable rustup toolchain install stable rustup default stable rustup component add rustfmt @@ -335,8 +352,12 @@ jobs: cargo test --lib --tests --bins --all-features env: # do not produce debug symbols to keep memory usage down - RUSTFLAGS: "-C debuginfo=0" + # hardcoding other profile params to avoid profile override values + # More on Cargo profiles https://doc.rust-lang.org/cargo/reference/profiles.html?profile-settings#profile-settings + RUSTFLAGS: "-C debuginfo=0 -C opt-level=0 -C incremental=false -C codegen-units=256" RUST_BACKTRACE: "1" + # avoid rust stack overflows on tpc-ds tests + RUST_MINSTACK: "3000000" test-datafusion-pyarrow: name: cargo test pyarrow (amd64) @@ -348,7 +369,7 @@ jobs: - uses: actions/checkout@v4 with: submodules: true - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 with: python-version: "3.8" - name: Install PyArrow diff --git a/Cargo.toml b/Cargo.toml index 39ebd1fa59b5b..a87923b6a1a00 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,24 +17,7 @@ [workspace] exclude = ["datafusion-cli"] -members = [ - "datafusion/common", - "datafusion/core", - "datafusion/expr", - "datafusion/execution", - "datafusion/optimizer", - "datafusion/physical-expr", - "datafusion/physical-plan", - "datafusion/proto", - "datafusion/proto/gen", - "datafusion/sql", - "datafusion/sqllogictest", - "datafusion/substrait", - "datafusion/wasmtest", - "datafusion-examples", - "docs", - "test-utils", - "benchmarks", +members = ["datafusion/common", "datafusion/core", "datafusion/expr", "datafusion/execution", "datafusion/optimizer", "datafusion/physical-expr", "datafusion/physical-plan", "datafusion/proto", "datafusion/proto/gen", "datafusion/sql", "datafusion/sqllogictest", "datafusion/substrait", "datafusion/wasmtest", "datafusion-examples", "docs", "test-utils", "benchmarks", ] resolver = "2" @@ -46,49 +29,50 @@ license = "Apache-2.0" readme = "README.md" repository = "https://github.com/apache/arrow-datafusion" rust-version = "1.70" -version = "33.0.0" +version = "34.0.0" [workspace.dependencies] -arrow = { version = "48.0.0", features = ["prettyprint"] } -arrow-array = { version = "48.0.0", default-features = false, features = ["chrono-tz"] } -arrow-buffer = { version = "48.0.0", default-features = false } -arrow-flight = { version = "48.0.0", features = ["flight-sql-experimental"] } -arrow-ord = { version = "48.0.0", default-features = false } -arrow-schema = { version = "48.0.0", default-features = false } +arrow = { version = "49.0.0", features = ["prettyprint"] } +arrow-array = { version = "49.0.0", default-features = false, features = ["chrono-tz"] } +arrow-buffer = { version = "49.0.0", default-features = false } +arrow-flight = { version = "49.0.0", features = ["flight-sql-experimental"] } +arrow-ipc = { version = "49.0.0", default-features = false, features = ["lz4"] } +arrow-ord = { version = "49.0.0", default-features = false } +arrow-schema = { version = "49.0.0", default-features = false } async-trait = "0.1.73" bigdecimal = "0.4.1" bytes = "1.4" +chrono = { version = "0.4.31", default-features = false } ctor = "0.2.0" -datafusion = { path = "datafusion/core" } -datafusion-common = { path = "datafusion/common" } -datafusion-expr = { path = "datafusion/expr" } -datafusion-sql = { path = "datafusion/sql" } -datafusion-optimizer = { path = "datafusion/optimizer" } -datafusion-physical-expr = { path = "datafusion/physical-expr" } -datafusion-physical-plan = { path = "datafusion/physical-plan" } -datafusion-execution = { path = "datafusion/execution" } -datafusion-proto = { path = "datafusion/proto" } -datafusion-sqllogictest = { path = "datafusion/sqllogictest" } -datafusion-substrait = { path = "datafusion/substrait" } dashmap = "5.4.0" +datafusion = { path = "datafusion/core", version = "34.0.0" } +datafusion-common = { path = "datafusion/common", version = "34.0.0" } +datafusion-execution = { path = "datafusion/execution", version = "34.0.0" } +datafusion-expr = { path = "datafusion/expr", version = "34.0.0" } +datafusion-optimizer = { path = "datafusion/optimizer", version = "34.0.0" } +datafusion-physical-expr = { path = "datafusion/physical-expr", version = "34.0.0" } +datafusion-physical-plan = { path = "datafusion/physical-plan", version = "34.0.0" } +datafusion-proto = { path = "datafusion/proto", version = "34.0.0" } +datafusion-sql = { path = "datafusion/sql", version = "34.0.0" } +datafusion-sqllogictest = { path = "datafusion/sqllogictest", version = "34.0.0" } +datafusion-substrait = { path = "datafusion/substrait", version = "34.0.0" } doc-comment = "0.3" env_logger = "0.10" futures = "0.3" half = "2.2.1" indexmap = "2.0.0" -itertools = "0.11" +itertools = "0.12" log = "^0.4" num_cpus = "1.13.0" -object_store = "0.7.0" +object_store = { version = "0.8.0", default-features = false } parking_lot = "0.12" -parquet = { version = "48.0.0", features = ["arrow", "async", "object_store"] } +parquet = { version = "49.0.0", default-features = false, features = ["arrow", "async", "object_store"] } rand = "0.8" rstest = "0.18.0" serde_json = "1" -sqlparser = { version = "0.39.0", features = ["visitor"] } +sqlparser = { version = "0.41.0", features = ["visitor"] } tempfile = "3" thiserror = "1.0.44" -chrono = { version = "0.4.31", default-features = false } url = "2.2" [profile.release] @@ -108,4 +92,3 @@ opt-level = 3 overflow-checks = false panic = 'unwind' rpath = false - diff --git a/README.md b/README.md index 1997a6f73dd5f..883700a39355a 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ in-memory format. [Python Bindings](https://github.com/apache/arrow-datafusion-p Here are links to some important information - [Project Site](https://arrow.apache.org/datafusion) +- [Installation](https://arrow.apache.org/datafusion/user-guide/cli.html#installation) - [Rust Getting Started](https://arrow.apache.org/datafusion/user-guide/example-usage.html) - [Rust DataFrame API](https://arrow.apache.org/datafusion/user-guide/dataframe.html) - [Rust API docs](https://docs.rs/datafusion/latest/datafusion) @@ -40,8 +41,19 @@ Here are links to some important information DataFusion is great for building projects such as domain specific query engines, new database platforms and data pipelines, query languages and more. It lets you start quickly from a fully working engine, and then customize those features specific to your use. [Click Here](https://arrow.apache.org/datafusion/user-guide/introduction.html#known-users) to see a list known users. +## Contributing to DataFusion + +Please see the [developer’s guide] for contributing and [communication] for getting in touch with us. + +[developer’s guide]: https://arrow.apache.org/datafusion/contributor-guide/index.html#developer-s-guide +[communication]: https://arrow.apache.org/datafusion/contributor-guide/communication.html + ## Crate features +This crate has several [features] which can be specified in your `Cargo.toml`. + +[features]: https://doc.rust-lang.org/cargo/reference/features.html + Default features: - `compression`: reading files compressed with `xz2`, `bzip2`, `flate2`, and `zstd` @@ -65,9 +77,3 @@ Optional features: ## Rust Version Compatibility This crate is tested with the latest stable version of Rust. We do not currently test against other, older versions of the Rust compiler. - -## Contributing to DataFusion - -The [developer’s guide] contains information on how to contribute. - -[developer’s guide]: https://arrow.apache.org/datafusion/contributor-guide/index.html#developer-s-guide diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index 35f94f677d86d..4ce46968e1f49 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -18,7 +18,7 @@ [package] name = "datafusion-benchmarks" description = "DataFusion Benchmarks" -version = "33.0.0" +version = "34.0.0" edition = { workspace = true } authors = ["Apache Arrow "] homepage = "https://github.com/apache/arrow-datafusion" @@ -34,14 +34,14 @@ snmalloc = ["snmalloc-rs"] [dependencies] arrow = { workspace = true } -datafusion = { path = "../datafusion/core", version = "33.0.0" } -datafusion-common = { path = "../datafusion/common", version = "33.0.0" } +datafusion = { path = "../datafusion/core", version = "34.0.0" } +datafusion-common = { path = "../datafusion/common", version = "34.0.0" } env_logger = { workspace = true } futures = { workspace = true } log = { workspace = true } mimalloc = { version = "0.1", optional = true, default-features = false } num_cpus = { workspace = true } -parquet = { workspace = true } +parquet = { workspace = true, default-features = true } serde = { version = "1.0.136", features = ["derive"] } serde_json = { workspace = true } snmalloc-rs = { version = "0.3", optional = true } @@ -50,4 +50,4 @@ test-utils = { path = "../test-utils/", version = "0.1.0" } tokio = { version = "^1.0", features = ["macros", "rt", "rt-multi-thread", "parking_lot"] } [dev-dependencies] -datafusion-proto = { path = "../datafusion/proto", version = "33.0.0" } +datafusion-proto = { path = "../datafusion/proto", version = "34.0.0" } diff --git a/benchmarks/compare.py b/benchmarks/compare.py index 80aa3c76b754c..ec2b28fa0556c 100755 --- a/benchmarks/compare.py +++ b/benchmarks/compare.py @@ -109,7 +109,6 @@ def compare( noise_threshold: float, ) -> None: baseline = BenchmarkRun.load_from_file(baseline_path) - comparison = BenchmarkRun.load_from_file(comparison_path) console = Console() @@ -124,27 +123,57 @@ def compare( table.add_column(comparison_header, justify="right", style="dim") table.add_column("Change", justify="right", style="dim") + faster_count = 0 + slower_count = 0 + no_change_count = 0 + total_baseline_time = 0 + total_comparison_time = 0 + for baseline_result, comparison_result in zip(baseline.queries, comparison.queries): assert baseline_result.query == comparison_result.query + total_baseline_time += baseline_result.execution_time + total_comparison_time += comparison_result.execution_time + change = comparison_result.execution_time / baseline_result.execution_time if (1.0 - noise_threshold) <= change <= (1.0 + noise_threshold): - change = "no change" + change_text = "no change" + no_change_count += 1 elif change < 1.0: - change = f"+{(1 / change):.2f}x faster" + change_text = f"+{(1 / change):.2f}x faster" + faster_count += 1 else: - change = f"{change:.2f}x slower" + change_text = f"{change:.2f}x slower" + slower_count += 1 table.add_row( f"Q{baseline_result.query}", f"{baseline_result.execution_time:.2f}ms", f"{comparison_result.execution_time:.2f}ms", - change, + change_text, ) console.print(table) + # Calculate averages + avg_baseline_time = total_baseline_time / len(baseline.queries) + avg_comparison_time = total_comparison_time / len(comparison.queries) + + # Summary table + summary_table = Table(show_header=True, header_style="bold magenta") + summary_table.add_column("Benchmark Summary", justify="left", style="dim") + summary_table.add_column("", justify="right", style="dim") + + summary_table.add_row(f"Total Time ({baseline_header})", f"{total_baseline_time:.2f}ms") + summary_table.add_row(f"Total Time ({comparison_header})", f"{total_comparison_time:.2f}ms") + summary_table.add_row(f"Average Time ({baseline_header})", f"{avg_baseline_time:.2f}ms") + summary_table.add_row(f"Average Time ({comparison_header})", f"{avg_comparison_time:.2f}ms") + summary_table.add_row("Queries Faster", str(faster_count)) + summary_table.add_row("Queries Slower", str(slower_count)) + summary_table.add_row("Queries with No Change", str(no_change_count)) + + console.print(summary_table) def main() -> None: parser = ArgumentParser() diff --git a/benchmarks/src/parquet_filter.rs b/benchmarks/src/parquet_filter.rs index e19596b80f54e..1d816908e2b04 100644 --- a/benchmarks/src/parquet_filter.rs +++ b/benchmarks/src/parquet_filter.rs @@ -19,8 +19,8 @@ use crate::AccessLogOpt; use crate::{BenchmarkRun, CommonOpt}; use arrow::util::pretty; use datafusion::common::Result; +use datafusion::logical_expr::utils::disjunction; use datafusion::logical_expr::{lit, or, Expr}; -use datafusion::optimizer::utils::disjunction; use datafusion::physical_plan::collect; use datafusion::prelude::{col, SessionContext}; use datafusion::test_util::parquet::{ParquetScanOptions, TestParquetFile}; diff --git a/benchmarks/src/sort.rs b/benchmarks/src/sort.rs index 5643c85619443..224f2b19c72e5 100644 --- a/benchmarks/src/sort.rs +++ b/benchmarks/src/sort.rs @@ -148,8 +148,9 @@ impl RunOpt { println!("Executing '{title}' (sorting by: {expr:?})"); rundata.start_new_case(title); for i in 0..self.common.iterations { - let config = - SessionConfig::new().with_target_partitions(self.common.partitions); + let config = SessionConfig::new().with_target_partitions( + self.common.partitions.unwrap_or(num_cpus::get()), + ); let ctx = SessionContext::new_with_config(config); let (rows, elapsed) = exec_sort(&ctx, &expr, &test_file, self.common.debug).await?; diff --git a/benchmarks/src/tpch/run.rs b/benchmarks/src/tpch/run.rs index 171b074d2a1b4..5193d578fb486 100644 --- a/benchmarks/src/tpch/run.rs +++ b/benchmarks/src/tpch/run.rs @@ -285,7 +285,7 @@ impl RunOpt { } fn partitions(&self) -> usize { - self.common.partitions + self.common.partitions.unwrap_or(num_cpus::get()) } } @@ -325,7 +325,7 @@ mod tests { let path = get_tpch_data_path()?; let common = CommonOpt { iterations: 1, - partitions: 2, + partitions: Some(2), batch_size: 8192, debug: false, }; @@ -357,7 +357,7 @@ mod tests { let path = get_tpch_data_path()?; let common = CommonOpt { iterations: 1, - partitions: 2, + partitions: Some(2), batch_size: 8192, debug: false, }; diff --git a/benchmarks/src/util/options.rs b/benchmarks/src/util/options.rs index 1d86d10fb88c6..b9398e5b522f2 100644 --- a/benchmarks/src/util/options.rs +++ b/benchmarks/src/util/options.rs @@ -26,9 +26,9 @@ pub struct CommonOpt { #[structopt(short = "i", long = "iterations", default_value = "3")] pub iterations: usize, - /// Number of partitions to process in parallel - #[structopt(short = "n", long = "partitions", default_value = "2")] - pub partitions: usize, + /// Number of partitions to process in parallel. Defaults to number of available cores. + #[structopt(short = "n", long = "partitions")] + pub partitions: Option, /// Batch size when reading CSV or Parquet files #[structopt(short = "s", long = "batch-size", default_value = "8192")] @@ -48,7 +48,7 @@ impl CommonOpt { /// Modify the existing config appropriately pub fn update_config(&self, config: SessionConfig) -> SessionConfig { config - .with_target_partitions(self.partitions) + .with_target_partitions(self.partitions.unwrap_or(num_cpus::get())) .with_batch_size(self.batch_size) } } diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 74df8aab01754..252b00ca0adc4 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -130,9 +130,9 @@ checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" [[package]] name = "arrow" -version = "48.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edb738d83750ec705808f6d44046d165e6bb8623f64e29a4d53fcb136ab22dfb" +checksum = "5bc25126d18a012146a888a0298f2c22e1150327bd2765fc76d710a556b2d614" dependencies = [ "ahash", "arrow-arith", @@ -152,9 +152,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "48.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5c3d17fc5b006e7beeaebfb1d2edfc92398b981f82d9744130437909b72a468" +checksum = "34ccd45e217ffa6e53bbb0080990e77113bdd4e91ddb84e97b77649810bcf1a7" dependencies = [ "arrow-array", "arrow-buffer", @@ -167,9 +167,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "48.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55705ada5cdde4cb0f202ffa6aa756637e33fea30e13d8d0d0fd6a24ffcee1e3" +checksum = "6bda9acea48b25123c08340f3a8ac361aa0f74469bb36f5ee9acf923fce23e9d" dependencies = [ "ahash", "arrow-buffer", @@ -178,15 +178,15 @@ dependencies = [ "chrono", "chrono-tz", "half", - "hashbrown 0.14.2", + "hashbrown 0.14.3", "num", ] [[package]] name = "arrow-buffer" -version = "48.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a722f90a09b94f295ab7102542e97199d3500128843446ef63e410ad546c5333" +checksum = "01a0fc21915b00fc6c2667b069c1b64bdd920982f426079bc4a7cab86822886c" dependencies = [ "bytes", "half", @@ -195,15 +195,16 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "48.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af01fc1a06f6f2baf31a04776156d47f9f31ca5939fe6d00cd7a059f95a46ff1" +checksum = "5dc0368ed618d509636c1e3cc20db1281148190a78f43519487b2daf07b63b4a" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", "arrow-select", + "base64", "chrono", "comfy-table", "half", @@ -213,9 +214,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "48.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83cbbfde86f9ecd3f875c42a73d8aeab3d95149cd80129b18d09e039ecf5391b" +checksum = "2e09aa6246a1d6459b3f14baeaa49606cfdbca34435c46320e14054d244987ca" dependencies = [ "arrow-array", "arrow-buffer", @@ -232,9 +233,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "48.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0a547195e607e625e7fafa1a7269b8df1a4a612c919efd9b26bd86e74538f3a" +checksum = "907fafe280a3874474678c1858b9ca4cb7fd83fb8034ff5b6d6376205a08c634" dependencies = [ "arrow-buffer", "arrow-schema", @@ -244,9 +245,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "48.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e36bf091502ab7e37775ff448413ef1ffff28ff93789acb669fffdd51b394d51" +checksum = "79a43d6808411886b8c7d4f6f7dd477029c1e77ffffffb7923555cc6579639cd" dependencies = [ "arrow-array", "arrow-buffer", @@ -254,13 +255,14 @@ dependencies = [ "arrow-data", "arrow-schema", "flatbuffers", + "lz4_flex", ] [[package]] name = "arrow-json" -version = "48.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ac346bc84846ab425ab3c8c7b6721db90643bc218939677ed7e071ccbfb919d" +checksum = "d82565c91fd627922ebfe2810ee4e8346841b6f9361b87505a9acea38b614fee" dependencies = [ "arrow-array", "arrow-buffer", @@ -278,9 +280,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "48.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4502123d2397319f3a13688432bc678c61cb1582f2daa01253186da650bf5841" +checksum = "9b23b0e53c0db57c6749997fd343d4c0354c994be7eca67152dd2bdb9a3e1bb4" dependencies = [ "arrow-array", "arrow-buffer", @@ -293,9 +295,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "48.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "249fc5a07906ab3f3536a6e9f118ec2883fbcde398a97a5ba70053f0276abda4" +checksum = "361249898d2d6d4a6eeb7484be6ac74977e48da12a4dd81a708d620cc558117a" dependencies = [ "ahash", "arrow-array", @@ -303,20 +305,20 @@ dependencies = [ "arrow-data", "arrow-schema", "half", - "hashbrown 0.14.2", + "hashbrown 0.14.3", ] [[package]] name = "arrow-schema" -version = "48.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d7a8c3f97f5ef6abd862155a6f39aaba36b029322462d72bbcfa69782a50614" +checksum = "09e28a5e781bf1b0f981333684ad13f5901f4cd2f20589eab7cf1797da8fc167" [[package]] name = "arrow-select" -version = "48.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f868f4a5001429e20f7c1994b5cd1aa68b82e3db8cf96c559cdb56dc8be21410" +checksum = "4f6208466590960efc1d2a7172bc4ff18a67d6e25c529381d7f96ddaf0dc4036" dependencies = [ "ahash", "arrow-array", @@ -328,9 +330,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "48.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a27fdf8fc70040a2dee78af2e217479cb5b263bd7ab8711c7999e74056eb688a" +checksum = "a4a48149c63c11c9ff571e50ab8f017d2a7cb71037a882b42f6354ed2da9acc7" dependencies = [ "arrow-array", "arrow-buffer", @@ -359,9 +361,9 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.4" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f658e2baef915ba0f26f1f7c42bfb8e12f532a01f449a090ded75ae7a07e9ba2" +checksum = "bc2d0cfb2a7388d34f590e76686704c494ed7aaceed62ee1ba35cbf363abc2a5" dependencies = [ "bzip2", "flate2", @@ -377,13 +379,13 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.74" +version = "0.1.75" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" +checksum = "fdf6721fb0140e4f897002dd086c06f6c27775df19cfe1fccb21181a48fd2c98" dependencies = [ "proc-macro2", "quote", - "syn 2.0.38", + "syn 2.0.43", ] [[package]] @@ -790,9 +792,9 @@ dependencies = [ [[package]] name = "bstr" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c79ad7fb2dd38f3dabd76b09c6a5a20c038fc0213ef1e9afd30eb777f120f019" +checksum = "542f33a8835a0884b006a0c3df3dadd99c0c3f296ed26c2fdc8028e01ad6230c" dependencies = [ "memchr", "regex-automata", @@ -819,9 +821,9 @@ checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" [[package]] name = "bytes-utils" -version = "0.1.3" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e47d3a8076e283f3acd27400535992edb3ba4b5bb72f8891ad8fbe7932a7d4b9" +checksum = "7dafe3a8757b027e2be6e4e5601ed563c55989fcf1546e933c66c8eb3a058d35" dependencies = [ "bytes", "either", @@ -874,7 +876,7 @@ dependencies = [ "iana-time-zone", "num-traits", "serde", - "windows-targets", + "windows-targets 0.48.5", ] [[package]] @@ -988,9 +990,9 @@ checksum = "f7144d30dcf0fafbce74250a3963025d8d52177934239851c917d29f1df280c2" [[package]] name = "core-foundation" -version = "0.9.3" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "194a7a9e6de53fa55116934067c844d9d749312f75c6f6d0980e8c252f8c2146" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" dependencies = [ "core-foundation-sys", "libc", @@ -998,9 +1000,9 @@ dependencies = [ [[package]] name = "core-foundation-sys" -version = "0.8.4" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" [[package]] name = "core2" @@ -1068,12 +1070,12 @@ dependencies = [ [[package]] name = "ctor" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37e366bff8cd32dd8754b0991fb66b279dc48f598c3a18914852a6673deef583" +checksum = "30d2b3721e861707777e3195b0158f950ae6dc4a27e4d02ff9f67e3eb3de199e" dependencies = [ "quote", - "syn 2.0.38", + "syn 2.0.43", ] [[package]] @@ -1089,7 +1091,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" dependencies = [ "cfg-if", - "hashbrown 0.14.2", + "hashbrown 0.14.3", "lock_api", "once_cell", "parking_lot_core", @@ -1097,12 +1099,13 @@ dependencies = [ [[package]] name = "datafusion" -version = "33.0.0" +version = "34.0.0" dependencies = [ "ahash", "apache-avro", "arrow", "arrow-array", + "arrow-ipc", "arrow-schema", "async-compression", "async-trait", @@ -1121,9 +1124,9 @@ dependencies = [ "futures", "glob", "half", - "hashbrown 0.14.2", + "hashbrown 0.14.3", "indexmap 2.1.0", - "itertools", + "itertools 0.12.0", "log", "num-traits", "num_cpus", @@ -1144,7 +1147,7 @@ dependencies = [ [[package]] name = "datafusion-cli" -version = "33.0.0" +version = "34.0.0" dependencies = [ "arrow", "assert_cmd", @@ -1154,11 +1157,14 @@ dependencies = [ "clap", "ctor", "datafusion", + "datafusion-common", "dirs", "env_logger", + "futures", "mimalloc", "object_store", "parking_lot", + "parquet", "predicates", "regex", "rstest", @@ -1169,7 +1175,7 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "33.0.0" +version = "34.0.0" dependencies = [ "ahash", "apache-avro", @@ -1179,6 +1185,7 @@ dependencies = [ "arrow-schema", "chrono", "half", + "libc", "num_cpus", "object_store", "parquet", @@ -1187,7 +1194,7 @@ dependencies = [ [[package]] name = "datafusion-execution" -version = "33.0.0" +version = "34.0.0" dependencies = [ "arrow", "chrono", @@ -1195,7 +1202,7 @@ dependencies = [ "datafusion-common", "datafusion-expr", "futures", - "hashbrown 0.14.2", + "hashbrown 0.14.3", "log", "object_store", "parking_lot", @@ -1206,12 +1213,13 @@ dependencies = [ [[package]] name = "datafusion-expr" -version = "33.0.0" +version = "34.0.0" dependencies = [ "ahash", "arrow", "arrow-array", "datafusion-common", + "paste", "sqlparser", "strum", "strum_macros", @@ -1219,7 +1227,7 @@ dependencies = [ [[package]] name = "datafusion-optimizer" -version = "33.0.0" +version = "34.0.0" dependencies = [ "arrow", "async-trait", @@ -1227,15 +1235,15 @@ dependencies = [ "datafusion-common", "datafusion-expr", "datafusion-physical-expr", - "hashbrown 0.14.2", - "itertools", + "hashbrown 0.14.3", + "itertools 0.12.0", "log", "regex-syntax", ] [[package]] name = "datafusion-physical-expr" -version = "33.0.0" +version = "34.0.0" dependencies = [ "ahash", "arrow", @@ -1250,11 +1258,10 @@ dependencies = [ "datafusion-common", "datafusion-expr", "half", - "hashbrown 0.14.2", + "hashbrown 0.14.3", "hex", "indexmap 2.1.0", - "itertools", - "libc", + "itertools 0.12.0", "log", "md-5", "paste", @@ -1268,7 +1275,7 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" -version = "33.0.0" +version = "34.0.0" dependencies = [ "ahash", "arrow", @@ -1283,9 +1290,9 @@ dependencies = [ "datafusion-physical-expr", "futures", "half", - "hashbrown 0.14.2", + "hashbrown 0.14.3", "indexmap 2.1.0", - "itertools", + "itertools 0.12.0", "log", "once_cell", "parking_lot", @@ -1297,7 +1304,7 @@ dependencies = [ [[package]] name = "datafusion-sql" -version = "33.0.0" +version = "34.0.0" dependencies = [ "arrow", "arrow-schema", @@ -1309,9 +1316,9 @@ dependencies = [ [[package]] name = "deranged" -version = "0.3.9" +version = "0.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f32d04922c60427da6f9fef14d042d9edddef64cb9d4ce0d64d0685fbeb1fd3" +checksum = "8eb30d70a07a3b04884d2677f06bec33509dc67ca60d92949e5535352d3191dc" dependencies = [ "powerfmt", ] @@ -1422,12 +1429,12 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.5" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3e13f66a2f95e32a39eaa81f6b95d42878ca0e1db0c7543723dfe12557e860" +checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -1463,7 +1470,7 @@ checksum = "ef033ed5e9bad94e55838ca0ca906db0e043f517adda0c8b79c7a8c66c93c1b5" dependencies = [ "cfg-if", "rustix", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -1509,18 +1516,18 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "form_urlencoded" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a62bc1cf6f830c2ec14a513a9fb124d0a213a629668a4186f329db21fe045652" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" dependencies = [ "percent-encoding", ] [[package]] name = "futures" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da0290714b38af9b4a7b094b8a37086d1b4e61f2df9122c3cad2577669145335" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" dependencies = [ "futures-channel", "futures-core", @@ -1533,9 +1540,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff4dd66668b557604244583e3e1e1eada8c5c2e96a6d0d6653ede395b78bbacb" +checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" dependencies = [ "futures-core", "futures-sink", @@ -1543,15 +1550,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" [[package]] name = "futures-executor" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f4fb8693db0cf099eadcca0efe2a5a22e4550f98ed16aba6c48700da29597bc" +checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" dependencies = [ "futures-core", "futures-task", @@ -1560,32 +1567,32 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bf34a163b5c4c52d0478a4d757da8fb65cabef42ba90515efee0f6f9fa45aaa" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" [[package]] name = "futures-macro" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.38", + "syn 2.0.43", ] [[package]] name = "futures-sink" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e36d3378ee38c2a36ad710c5d30c2911d752cb941c00c72dbabfb786a7970817" +checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" [[package]] name = "futures-task" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2" +checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" [[package]] name = "futures-timer" @@ -1595,9 +1602,9 @@ checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c" [[package]] name = "futures-util" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a19526d624e703a3179b3d322efec918b6246ea0fa51d41124525f00f1cc8104" +checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ "futures-channel", "futures-core", @@ -1623,9 +1630,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.10" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" +checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f" dependencies = [ "cfg-if", "libc", @@ -1634,9 +1641,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.28.0" +version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0" +checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" [[package]] name = "glob" @@ -1646,9 +1653,9 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "h2" -version = "0.3.21" +version = "0.3.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91fc23aa11be92976ef4729127f1a74adf36d8436f7816b185d18df956790833" +checksum = "4d6250322ef6e60f93f9a2162799302cd6f68f79f6e5d85c8c16f14d1d958178" dependencies = [ "bytes", "fnv", @@ -1656,7 +1663,7 @@ dependencies = [ "futures-sink", "futures-util", "http", - "indexmap 1.9.3", + "indexmap 2.1.0", "slab", "tokio", "tokio-util", @@ -1691,9 +1698,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.14.2" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f93e7192158dbcda357bdec5fb5788eebf8bbac027f3f33e719d29135ae84156" +checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" dependencies = [ "ahash", "allocator-api2", @@ -1737,9 +1744,9 @@ dependencies = [ [[package]] name = "http" -version = "0.2.9" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd6effc99afb63425aff9b05836f029929e345a6148a14b7ecd5ab67af944482" +checksum = "8947b1a6fad4393052c7ba1f4cd97bed3e953a95c79c92ad9b051a04611d9fbb" dependencies = [ "bytes", "fnv", @@ -1748,9 +1755,9 @@ dependencies = [ [[package]] name = "http-body" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" +checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" dependencies = [ "bytes", "http", @@ -1777,9 +1784,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "0.14.27" +version = "0.14.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffb1cfd654a8219eaef89881fdb3bb3b1cdc5fa75ded05d6933b2b382e395468" +checksum = "bf96e135eb83a2a8ddf766e426a841d8ddd7449d5f00d34ea02b41d2f19eef80" dependencies = [ "bytes", "futures-channel", @@ -1792,7 +1799,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2 0.4.10", + "socket2", "tokio", "tower-service", "tracing", @@ -1823,7 +1830,7 @@ dependencies = [ "futures-util", "http", "hyper", - "rustls 0.21.8", + "rustls 0.21.10", "tokio", "tokio-rustls 0.24.1", ] @@ -1853,9 +1860,9 @@ dependencies = [ [[package]] name = "idna" -version = "0.4.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" dependencies = [ "unicode-bidi", "unicode-normalization", @@ -1878,7 +1885,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d530e1a18b1cb4c484e6e34556a0d948706958449fca0cab753d649f2bce3d1f" dependencies = [ "equivalent", - "hashbrown 0.14.2", + "hashbrown 0.14.3", ] [[package]] @@ -1911,11 +1918,20 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25db6b064527c5d482d0423354fcd07a89a2dfe07b67892e62411946db7f07b0" +dependencies = [ + "either", +] + [[package]] name = "itoa" -version = "1.0.9" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" +checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" [[package]] name = "jobserver" @@ -1928,9 +1944,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.65" +version = "0.3.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54c0c35952f67de54bb584e9fd912b3023117cbafc0a77d8f3dee1fb5f572fe8" +checksum = "cee9c64da59eae3b50095c18d3e74f8b73c0b86d2792824ff01bbce68ba229ca" dependencies = [ "wasm-bindgen", ] @@ -2007,9 +2023,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.150" +version = "0.2.151" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c" +checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4" [[package]] name = "libflate" @@ -2064,9 +2080,9 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.4.10" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da2479e8c062e40bf0066ffa0bc823de0a9368974af99c9f6df941d2c231e03f" +checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" [[package]] name = "lock_api" @@ -2146,13 +2162,13 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.9" +version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3dce281c5e46beae905d4de1870d8b1509a9142b62eedf18b443b011ca8343d0" +checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" dependencies = [ "libc", "wasi", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -2270,18 +2286,18 @@ dependencies = [ [[package]] name = "object" -version = "0.32.1" +version = "0.32.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0" +checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" dependencies = [ "memchr", ] [[package]] name = "object_store" -version = "0.7.1" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f930c88a43b1c3f6e776dfe495b4afab89882dbc81530c632db2ed65451ebcb4" +checksum = "2524735495ea1268be33d200e1ee97455096a0846295a21548cd2f3541de7050" dependencies = [ "async-trait", "base64", @@ -2290,13 +2306,13 @@ dependencies = [ "futures", "humantime", "hyper", - "itertools", + "itertools 0.11.0", "parking_lot", "percent-encoding", "quick-xml", "rand", "reqwest", - "ring 0.16.20", + "ring 0.17.7", "rustls-pemfile", "serde", "serde_json", @@ -2309,9 +2325,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.18.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "openssl-probe" @@ -2360,14 +2376,14 @@ dependencies = [ "libc", "redox_syscall", "smallvec", - "windows-targets", + "windows-targets 0.48.5", ] [[package]] name = "parquet" -version = "48.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "239229e6a668ab50c61de3dce61cf0fa1069345f7aa0f4c934491f92205a4945" +checksum = "af88740a842787da39b3d69ce5fbf6fce97d20211d3b299fee0a0da6430c74d4" dependencies = [ "ahash", "arrow-array", @@ -2383,7 +2399,7 @@ dependencies = [ "chrono", "flate2", "futures", - "hashbrown 0.14.2", + "hashbrown 0.14.3", "lz4_flex", "num", "num-bigint", @@ -2414,9 +2430,9 @@ checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" [[package]] name = "percent-encoding" -version = "2.3.0" +version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "petgraph" @@ -2483,7 +2499,7 @@ checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" dependencies = [ "proc-macro2", "quote", - "syn 2.0.38", + "syn 2.0.43", ] [[package]] @@ -2500,9 +2516,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkg-config" -version = "0.3.27" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" +checksum = "69d3587f8a9e599cc7ec2c00e331f71c4e69a5f9a4b8a6efd5b07466b9736f9a" [[package]] name = "powerfmt" @@ -2525,7 +2541,7 @@ dependencies = [ "anstyle", "difflib", "float-cmp", - "itertools", + "itertools 0.11.0", "normalize-line-endings", "predicates-core", "regex", @@ -2573,9 +2589,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.69" +version = "1.0.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" +checksum = "75cb1540fadbd5b8fbccc4dddad2734eba435053f725621c070711a14bb5f4b8" dependencies = [ "unicode-ident", ] @@ -2588,9 +2604,9 @@ checksum = "658fa1faf7a4cc5f057c9ee5ef560f717ad9d8dc66d975267f709624d6e1ab88" [[package]] name = "quick-xml" -version = "0.30.0" +version = "0.31.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eff6510e86862b57b210fd8cbe8ed3f0d7d600b9c2863cd4549a2e033c66e956" +checksum = "1004a344b30a54e2ee58d66a71b32d2db2feb0a31f9a2d302bf0536f15de2a33" dependencies = [ "memchr", "serde", @@ -2702,9 +2718,9 @@ checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "reqwest" -version = "0.11.22" +version = "0.11.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "046cd98826c46c2ac8ddecae268eb5c2e58628688a5fc7a2643704a73faba95b" +checksum = "37b1ae8d9ac08420c66222fb9096fc5de435c3c48542bc5336c51892cffafb41" dependencies = [ "base64", "bytes", @@ -2723,7 +2739,7 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "rustls 0.21.8", + "rustls 0.21.10", "rustls-pemfile", "serde", "serde_json", @@ -2759,16 +2775,16 @@ dependencies = [ [[package]] name = "ring" -version = "0.17.5" +version = "0.17.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb0205304757e5d899b9c2e448b867ffd03ae7f988002e47cd24954391394d0b" +checksum = "688c63d65483050968b2a8937f7995f443e27041a0f7700aa59b0822aedebb74" dependencies = [ "cc", "getrandom", "libc", "spin 0.9.8", "untrusted 0.9.0", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -2820,15 +2836,15 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.21" +version = "0.38.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b426b0506e5d50a7d8dafcf2e81471400deb602392c7dd110815afb4eaf02a3" +checksum = "72e572a5e8ca657d7366229cdde4bd14c4eb5499a9573d4d366fe1b599daa316" dependencies = [ "bitflags 2.4.1", "errno", "libc", "linux-raw-sys", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -2845,12 +2861,12 @@ dependencies = [ [[package]] name = "rustls" -version = "0.21.8" +version = "0.21.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "446e14c5cda4f3f30fe71863c34ec70f5ac79d6087097ad0bb433e1be5edf04c" +checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba" dependencies = [ "log", - "ring 0.17.5", + "ring 0.17.7", "rustls-webpki", "sct", ] @@ -2869,9 +2885,9 @@ dependencies = [ [[package]] name = "rustls-pemfile" -version = "1.0.3" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d3987094b1d07b653b7dfdc3f70ce9a1da9c51ac18c1b06b662e4f9a0e9f4b2" +checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" dependencies = [ "base64", ] @@ -2882,7 +2898,7 @@ version = "0.101.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" dependencies = [ - "ring 0.17.5", + "ring 0.17.7", "untrusted 0.9.0", ] @@ -2917,9 +2933,9 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.15" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" +checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" [[package]] name = "same-file" @@ -2936,7 +2952,7 @@ version = "0.1.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c3733bf4cf7ea0880754e19cb5a462007c4a8c1914bff372ccc95b464f1df88" dependencies = [ - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -2951,7 +2967,7 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" dependencies = [ - "ring 0.17.5", + "ring 0.17.7", "untrusted 0.9.0", ] @@ -2992,22 +3008,22 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.190" +version = "1.0.193" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91d3c334ca1ee894a2c6f6ad698fe8c435b76d504b13d436f0685d648d6d96f7" +checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.190" +version = "1.0.193" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67c5609f394e5c2bd7fc51efda478004ea80ef42fee983d5c67a65e34f32c0e3" +checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.38", + "syn 2.0.43", ] [[package]] @@ -3044,6 +3060,15 @@ dependencies = [ "digest", ] +[[package]] +name = "signal-hook-registry" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" +dependencies = [ + "libc", +] + [[package]] name = "siphasher" version = "0.3.11" @@ -3061,9 +3086,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.11.1" +version = "1.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "942b4a808e05215192e39f4ab80813e599068285906cc91aa64f923db842bd5a" +checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" [[package]] name = "snafu" @@ -3089,19 +3114,9 @@ dependencies = [ [[package]] name = "snap" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e9f0ab6ef7eb7353d9119c170a436d1bf248eea575ac42d19d12f4e34130831" - -[[package]] -name = "socket2" -version = "0.4.10" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f7916fc008ca5542385b89a3d3ce689953c143e9304a9bf8beec1de48994c0d" -dependencies = [ - "libc", - "winapi", -] +checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" [[package]] name = "socket2" @@ -3110,7 +3125,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -3127,9 +3142,9 @@ checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" [[package]] name = "sqlparser" -version = "0.39.0" +version = "0.41.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "743b4dc2cbde11890ccb254a8fc9d537fa41b36da00de2a1c5e9848c9bc42bd7" +checksum = "5cc2c25a6c66789625ef164b4c7d2e548d627902280c13710d33da8222169964" dependencies = [ "log", "sqlparser_derive", @@ -3137,9 +3152,9 @@ dependencies = [ [[package]] name = "sqlparser_derive" -version = "0.1.1" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55fe75cb4a364c7f7ae06c7dbbc8d84bddd85d6cdf9975963c3935bc1991761e" +checksum = "3e9c2e1dde0efa87003e7923d94a90f46e3274ad1649f51de96812be561f041f" dependencies = [ "proc-macro2", "quote", @@ -3183,7 +3198,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.38", + "syn 2.0.43", ] [[package]] @@ -3205,9 +3220,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.38" +version = "2.0.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e96b79aaa137db8f61e26363a0c9b47d8b4ec75da28b7d1d614c2303e232408b" +checksum = "ee659fb5f3d355364e1f3e5bc10fb82068efbf824a1e9d1c9504244a6469ad53" dependencies = [ "proc-macro2", "quote", @@ -3245,14 +3260,14 @@ dependencies = [ "fastrand 2.0.1", "redox_syscall", "rustix", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] name = "termcolor" -version = "1.3.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6093bad37da69aab9d123a8091e4be0aa4a03e4d601ec641c327398315f62b64" +checksum = "ff1bc3d3f05aff0403e8ac0d92ced918ec05b666a43f83297ccef5bea8a3d449" dependencies = [ "winapi-util", ] @@ -3271,22 +3286,22 @@ checksum = "222a222a5bfe1bba4a77b45ec488a741b3cb8872e5e499451fd7d0129c9c7c3d" [[package]] name = "thiserror" -version = "1.0.50" +version = "1.0.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9a7210f5c9a7156bb50aa36aed4c95afb51df0df00713949448cf9e97d382d2" +checksum = "83a48fd946b02c0a526b2e9481c8e2a17755e47039164a86c4070446e3a4614d" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.50" +version = "1.0.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" +checksum = "e7fbe9b594d6568a6a1443250a7e67d80b74e1e96f6d1715e1e21cc1888291d3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.38", + "syn 2.0.43", ] [[package]] @@ -3302,9 +3317,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4a34ab300f2dee6e562c10a046fc05e358b29f9bf92277f30c3c8d82275f6f5" +checksum = "f657ba42c3f86e7680e53c8cd3af8abbe56b5491790b46e22e19c0d57463583e" dependencies = [ "deranged", "powerfmt", @@ -3321,9 +3336,9 @@ checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" [[package]] name = "time-macros" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ad70d68dba9e1f8aceda7aa6711965dfec1cac869f311a51bd08b3a2ccbce20" +checksum = "26197e33420244aeb70c3e8c78376ca46571bc4e701e4791c2cd9f57dcb3a43f" dependencies = [ "time-core", ] @@ -3354,9 +3369,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.33.0" +version = "1.35.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f38200e3ef7995e5ef13baec2f432a6da0aa9ac495b2c0e8f3b7eec2c92d653" +checksum = "c89b4efa943be685f629b149f53829423f8f5531ea21249408e8e2f8671ec104" dependencies = [ "backtrace", "bytes", @@ -3365,20 +3380,21 @@ dependencies = [ "num_cpus", "parking_lot", "pin-project-lite", - "socket2 0.5.5", + "signal-hook-registry", + "socket2", "tokio-macros", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] name = "tokio-macros" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" +checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.38", + "syn 2.0.43", ] [[package]] @@ -3398,7 +3414,7 @@ version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" dependencies = [ - "rustls 0.21.8", + "rustls 0.21.10", "tokio", ] @@ -3475,7 +3491,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.38", + "syn 2.0.43", ] [[package]] @@ -3489,9 +3505,9 @@ dependencies = [ [[package]] name = "try-lock" -version = "0.2.4" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" [[package]] name = "twox-hash" @@ -3520,7 +3536,7 @@ checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.38", + "syn 2.0.43", ] [[package]] @@ -3531,9 +3547,9 @@ checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "unicode-bidi" -version = "0.3.13" +version = "0.3.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460" +checksum = "6f2528f27a9eb2b21e69c95319b30bd0efd85d09c379741b0f78ea1d86be2416" [[package]] name = "unicode-ident" @@ -3576,9 +3592,9 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" -version = "2.4.1" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "143b538f18257fac9cad154828a57c6bf5157e1aa604d4816b5995bf6de87ae5" +checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" dependencies = [ "form_urlencoded", "idna", @@ -3599,9 +3615,9 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" [[package]] name = "uuid" -version = "1.5.0" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ad59a7560b41a70d191093a945f0b87bc1deeda46fb237479708a1d6b6cdfc" +checksum = "5e395fcf16a7a3d8127ec99782007af141946b4795001f876d54fb0d55978560" dependencies = [ "getrandom", "serde", @@ -3655,9 +3671,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7daec296f25a1bae309c0cd5c29c4b260e510e6d813c286b19eaadf409d40fce" +checksum = "0ed0d4f68a3015cc185aff4db9506a015f4b96f95303897bfa23f846db54064e" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -3665,24 +3681,24 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e397f4664c0e4e428e8313a469aaa58310d302159845980fd23b0f22a847f217" +checksum = "1b56f625e64f3a1084ded111c4d5f477df9f8c92df113852fa5a374dbda78826" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.38", + "syn 2.0.43", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.38" +version = "0.4.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9afec9963e3d0994cac82455b2b3502b81a7f40f9a0d32181f7528d9f4b43e02" +checksum = "ac36a15a220124ac510204aec1c3e5db8a22ab06fd6706d881dc6149f8ed9a12" dependencies = [ "cfg-if", "js-sys", @@ -3692,9 +3708,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5961017b3b08ad5f3fe39f1e79877f8ee7c23c5e5fd5eb80de95abc41f1f16b2" +checksum = "0162dbf37223cd2afce98f3d0785506dcb8d266223983e4b5b525859e6e182b2" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -3702,22 +3718,22 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5353b8dab669f5e10f5bd76df26a9360c748f054f862ff5f3f8aae0c7fb3907" +checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" dependencies = [ "proc-macro2", "quote", - "syn 2.0.38", + "syn 2.0.43", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d046c5d029ba91a1ed14da14dca44b68bf2f124cfbaf741c54151fdb3e0750b" +checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f" [[package]] name = "wasm-streams" @@ -3734,9 +3750,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.65" +version = "0.3.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5db499c5f66323272151db0e666cd34f78617522fb0c1604d31a27c50c206a85" +checksum = "50c24a44ec86bb68fbecd1b3efed7e85ea5621b39b35ef2766b66cd984f8010f" dependencies = [ "js-sys", "wasm-bindgen", @@ -3748,15 +3764,15 @@ version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed63aea5ce73d0ff405984102c42de94fc55a6b75765d621c65262469b3c9b53" dependencies = [ - "ring 0.17.5", + "ring 0.17.7", "untrusted 0.9.0", ] [[package]] name = "webpki-roots" -version = "0.25.2" +version = "0.25.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14247bb57be4f377dfb94c72830b8ce8fc6beac03cf4bf7b9732eadd414123fc" +checksum = "1778a42e8b3b90bff8d0f5032bf22250792889a5cdc752aa0020c84abe3aaf10" [[package]] name = "winapi" @@ -3795,7 +3811,7 @@ version = "0.51.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f1f8cf84f35d2db49a46868f947758c7a1138116f7fac3bc844f43ade1292e64" dependencies = [ - "windows-targets", + "windows-targets 0.48.5", ] [[package]] @@ -3804,7 +3820,16 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ - "windows-targets", + "windows-targets 0.48.5", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.0", ] [[package]] @@ -3813,13 +3838,28 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] + +[[package]] +name = "windows-targets" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" +dependencies = [ + "windows_aarch64_gnullvm 0.52.0", + "windows_aarch64_msvc 0.52.0", + "windows_i686_gnu 0.52.0", + "windows_i686_msvc 0.52.0", + "windows_x86_64_gnu 0.52.0", + "windows_x86_64_gnullvm 0.52.0", + "windows_x86_64_msvc 0.52.0", ] [[package]] @@ -3828,42 +3868,84 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" + [[package]] name = "windows_i686_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" +[[package]] +name = "windows_i686_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" + [[package]] name = "windows_i686_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" +[[package]] +name = "windows_i686_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" + [[package]] name = "winreg" version = "0.50.0" @@ -3871,7 +3953,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" dependencies = [ "cfg-if", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -3891,29 +3973,29 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.7.25" +version = "0.7.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cd369a67c0edfef15010f980c3cbe45d7f651deac2cd67ce097cd801de16557" +checksum = "74d4d3961e53fa4c9a25a8637fc2bfaf2595b3d3ae34875568a5cf64787716be" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.25" +version = "0.7.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2f140bda219a26ccc0cdb03dba58af72590c53b22642577d88a927bc5c87d6b" +checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.38", + "syn 2.0.43", ] [[package]] name = "zeroize" -version = "1.6.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a0956f1ba7c7909bfb66c2e9e4124ab6f6482560f6628b5aaeba39207c9aad9" +checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" [[package]] name = "zstd" diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 73c4431f43529..eab7c8e0d1f8b 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -18,7 +18,7 @@ [package] name = "datafusion-cli" description = "Command Line Client for DataFusion query engine." -version = "33.0.0" +version = "34.0.0" authors = ["Apache Arrow "] edition = "2021" keywords = ["arrow", "datafusion", "query", "sql"] @@ -29,20 +29,23 @@ rust-version = "1.70" readme = "README.md" [dependencies] -arrow = "48.0.0" +arrow = "49.0.0" async-trait = "0.1.41" aws-config = "0.55" aws-credential-types = "0.55" clap = { version = "3", features = ["derive", "cargo"] } -datafusion = { path = "../datafusion/core", version = "33.0.0", features = ["avro", "crypto_expressions", "encoding_expressions", "parquet", "regex_expressions", "unicode_expressions", "compression"] } +datafusion = { path = "../datafusion/core", version = "34.0.0", features = ["avro", "crypto_expressions", "encoding_expressions", "parquet", "regex_expressions", "unicode_expressions", "compression"] } +datafusion-common = { path = "../datafusion/common" } dirs = "4.0.0" env_logger = "0.9" +futures = "0.3" mimalloc = { version = "0.1", default-features = false } -object_store = { version = "0.7.0", features = ["aws", "gcp"] } +object_store = { version = "0.8.0", features = ["aws", "gcp"] } parking_lot = { version = "0.12" } +parquet = { version = "49.0.0", default-features = false } regex = "1.8" rustyline = "11.0" -tokio = { version = "1.24", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] } +tokio = { version = "1.24", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot", "signal"] } url = "2.2" [dev-dependencies] diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index b62ad12dbfbbd..2320a8c314cfe 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -17,6 +17,12 @@ //! Execution functions +use std::io::prelude::*; +use std::io::BufReader; +use std::time::Instant; +use std::{fs::File, sync::Arc}; + +use crate::print_format::PrintFormat; use crate::{ command::{Command, OutputFormat}, helper::{unescape_input, CliHelper}, @@ -26,21 +32,20 @@ use crate::{ }, print_options::{MaxRows, PrintOptions}, }; -use datafusion::common::plan_datafusion_err; + +use datafusion::common::{exec_datafusion_err, plan_datafusion_err}; +use datafusion::datasource::listing::ListingTableUrl; +use datafusion::datasource::physical_plan::is_plan_streaming; +use datafusion::error::{DataFusionError, Result}; +use datafusion::logical_expr::{CreateExternalTable, DdlStatement, LogicalPlan}; +use datafusion::physical_plan::{collect, execute_stream}; +use datafusion::prelude::SessionContext; use datafusion::sql::{parser::DFParser, sqlparser::dialect::dialect_from_str}; -use datafusion::{ - datasource::listing::ListingTableUrl, - error::{DataFusionError, Result}, - logical_expr::{CreateExternalTable, DdlStatement}, -}; -use datafusion::{logical_expr::LogicalPlan, prelude::SessionContext}; + use object_store::ObjectStore; use rustyline::error::ReadlineError; use rustyline::Editor; -use std::io::prelude::*; -use std::io::BufReader; -use std::time::Instant; -use std::{fs::File, sync::Arc}; +use tokio::signal; use url::Url; /// run and execute SQL statements and commands, against a context with the given print options @@ -125,8 +130,6 @@ pub async fn exec_from_repl( ))); rl.load_history(".history").ok(); - let mut print_options = print_options.clone(); - loop { match rl.readline("❯ ") { Ok(line) if line.starts_with('\\') => { @@ -138,9 +141,7 @@ pub async fn exec_from_repl( Command::OutputFormat(subcommand) => { if let Some(subcommand) = subcommand { if let Ok(command) = subcommand.parse::() { - if let Err(e) = - command.execute(&mut print_options).await - { + if let Err(e) = command.execute(print_options).await { eprintln!("{e}") } } else { @@ -154,7 +155,7 @@ pub async fn exec_from_repl( } } _ => { - if let Err(e) = cmd.execute(ctx, &mut print_options).await { + if let Err(e) = cmd.execute(ctx, print_options).await { eprintln!("{e}") } } @@ -165,9 +166,15 @@ pub async fn exec_from_repl( } Ok(line) => { rl.add_history_entry(line.trim_end())?; - match exec_and_print(ctx, &print_options, line).await { - Ok(_) => {} - Err(err) => eprintln!("{err}"), + tokio::select! { + res = exec_and_print(ctx, print_options, line) => match res { + Ok(_) => {} + Err(err) => eprintln!("{err}"), + }, + _ = signal::ctrl_c() => { + println!("^C"); + continue + }, } // dialect might have changed rl.helper_mut().unwrap().set_dialect( @@ -198,7 +205,6 @@ async fn exec_and_print( sql: String, ) -> Result<()> { let now = Instant::now(); - let sql = unescape_input(&sql)?; let task_ctx = ctx.task_ctx(); let dialect = &task_ctx.session_config().options().sql_parser.dialect; @@ -211,7 +217,7 @@ async fn exec_and_print( })?; let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?; for statement in statements { - let plan = ctx.state().statement_to_plan(statement).await?; + let mut plan = ctx.state().statement_to_plan(statement).await?; // For plans like `Explain` ignore `MaxRows` option and always display all rows let should_ignore_maxrows = matches!( @@ -221,25 +227,30 @@ async fn exec_and_print( | LogicalPlan::Analyze(_) ); - let df = match &plan { - LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) => { - create_external_table(ctx, cmd).await?; - ctx.execute_logical_plan(plan).await? - } - _ => ctx.execute_logical_plan(plan).await?, - }; + // Note that cmd is a mutable reference so that create_external_table function can remove all + // datafusion-cli specific options before passing through to datafusion. Otherwise, datafusion + // will raise Configuration errors. + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { + create_external_table(ctx, cmd).await?; + } - let results = df.collect().await?; + let df = ctx.execute_logical_plan(plan).await?; + let physical_plan = df.create_physical_plan().await?; - let print_options = if should_ignore_maxrows { - PrintOptions { - maxrows: MaxRows::Unlimited, - ..print_options.clone() - } + if is_plan_streaming(&physical_plan)? { + let stream = execute_stream(physical_plan, task_ctx.clone())?; + print_options.print_stream(stream, now).await?; } else { - print_options.clone() - }; - print_options.print_batches(&results, now)?; + let mut print_options = print_options.clone(); + if should_ignore_maxrows { + print_options.maxrows = MaxRows::Unlimited; + } + if print_options.format == PrintFormat::Automatic { + print_options.format = PrintFormat::Table; + } + let results = collect(physical_plan, task_ctx.clone()).await?; + print_options.print_batches(&results, now)?; + } } Ok(()) @@ -247,7 +258,7 @@ async fn exec_and_print( async fn create_external_table( ctx: &SessionContext, - cmd: &CreateExternalTable, + cmd: &mut CreateExternalTable, ) -> Result<()> { let table_path = ListingTableUrl::parse(&cmd.location)?; let scheme = table_path.scheme(); @@ -273,10 +284,7 @@ async fn create_external_table( .object_store_registry .get_store(url) .map_err(|_| { - DataFusionError::Execution(format!( - "Unsupported object store scheme: {}", - scheme - )) + exec_datafusion_err!("Unsupported object store scheme: {}", scheme) })? } }; @@ -288,15 +296,32 @@ async fn create_external_table( #[cfg(test)] mod tests { + use std::str::FromStr; + use super::*; use datafusion::common::plan_err; + use datafusion_common::{file_options::StatementOptions, FileTypeWriterOptions}; async fn create_external_table_test(location: &str, sql: &str) -> Result<()> { let ctx = SessionContext::new(); - let plan = ctx.state().create_logical_plan(sql).await?; + let mut plan = ctx.state().create_logical_plan(sql).await?; - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { create_external_table(&ctx, cmd).await?; + let options: Vec<_> = cmd + .options + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); + let statement_options = StatementOptions::new(options); + let file_type = + datafusion_common::FileType::from_str(cmd.file_type.as_str())?; + + let _file_type_writer_options = FileTypeWriterOptions::build( + &file_type, + ctx.state().config_options(), + &statement_options, + )?; } else { return plan_err!("LogicalPlan is not a CreateExternalTable"); } @@ -350,7 +375,7 @@ mod tests { async fn create_object_store_table_gcs() -> Result<()> { let service_account_path = "fake_service_account_path"; let service_account_key = - "{\"private_key\": \"fake_private_key.pem\",\"client_email\":\"fake_client_email\"}"; + "{\"private_key\": \"fake_private_key.pem\",\"client_email\":\"fake_client_email\", \"private_key_id\":\"id\"}"; let application_credentials_path = "fake_application_credentials_path"; let location = "gcs://bucket/path/file.parquet"; @@ -366,8 +391,9 @@ mod tests { let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('service_account_key' '{service_account_key}') LOCATION '{location}'"); let err = create_external_table_test(location, &sql) .await - .unwrap_err(); - assert!(err.to_string().contains("No RSA key found in pem file")); + .unwrap_err() + .to_string(); + assert!(err.contains("No RSA key found in pem file"), "{err}"); // for application_credentials_path let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET @@ -387,15 +413,7 @@ mod tests { // Ensure that local files are also registered let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET LOCATION '{location}'"); - let err = create_external_table_test(location, &sql) - .await - .unwrap_err(); - - if let DataFusionError::IoError(e) = err { - assert_eq!(e.kind(), std::io::ErrorKind::NotFound); - } else { - return Err(err); - } + create_external_table_test(location, &sql).await.unwrap(); Ok(()) } diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index eeebe713d716e..5390fa9f2271a 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -16,12 +16,27 @@ // under the License. //! Functions that are query-able and searchable via the `\h` command -use arrow::array::StringArray; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::array::{Int64Array, StringArray}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; +use async_trait::async_trait; +use datafusion::common::DataFusionError; +use datafusion::common::{plan_err, Column}; +use datafusion::datasource::function::TableFunctionImpl; +use datafusion::datasource::TableProvider; use datafusion::error::Result; +use datafusion::execution::context::SessionState; +use datafusion::logical_expr::Expr; +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::scalar::ScalarValue; +use parquet::basic::ConvertedType; +use parquet::file::reader::FileReader; +use parquet::file::serialized_reader::SerializedFileReader; +use parquet::file::statistics::Statistics; use std::fmt; +use std::fs::File; use std::str::FromStr; use std::sync::Arc; @@ -196,3 +211,232 @@ pub fn display_all_functions() -> Result<()> { println!("{}", pretty_format_batches(&[batch]).unwrap()); Ok(()) } + +/// PARQUET_META table function +struct ParquetMetadataTable { + schema: SchemaRef, + batch: RecordBatch, +} + +#[async_trait] +impl TableProvider for ParquetMetadataTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> arrow::datatypes::SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> datafusion::logical_expr::TableType { + datafusion::logical_expr::TableType::Base + } + + async fn scan( + &self, + _state: &SessionState, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + Ok(Arc::new(MemoryExec::try_new( + &[vec![self.batch.clone()]], + TableProvider::schema(self), + projection.cloned(), + )?)) + } +} + +fn convert_parquet_statistics( + value: &Statistics, + converted_type: ConvertedType, +) -> (String, String) { + match (value, converted_type) { + (Statistics::Boolean(val), _) => (val.min().to_string(), val.max().to_string()), + (Statistics::Int32(val), _) => (val.min().to_string(), val.max().to_string()), + (Statistics::Int64(val), _) => (val.min().to_string(), val.max().to_string()), + (Statistics::Int96(val), _) => (val.min().to_string(), val.max().to_string()), + (Statistics::Float(val), _) => (val.min().to_string(), val.max().to_string()), + (Statistics::Double(val), _) => (val.min().to_string(), val.max().to_string()), + (Statistics::ByteArray(val), ConvertedType::UTF8) => { + let min_bytes = val.min(); + let max_bytes = val.max(); + let min = min_bytes + .as_utf8() + .map(|v| v.to_string()) + .unwrap_or_else(|_| min_bytes.to_string()); + + let max = max_bytes + .as_utf8() + .map(|v| v.to_string()) + .unwrap_or_else(|_| max_bytes.to_string()); + (min, max) + } + (Statistics::ByteArray(val), _) => (val.min().to_string(), val.max().to_string()), + (Statistics::FixedLenByteArray(val), ConvertedType::UTF8) => { + let min_bytes = val.min(); + let max_bytes = val.max(); + let min = min_bytes + .as_utf8() + .map(|v| v.to_string()) + .unwrap_or_else(|_| min_bytes.to_string()); + + let max = max_bytes + .as_utf8() + .map(|v| v.to_string()) + .unwrap_or_else(|_| max_bytes.to_string()); + (min, max) + } + (Statistics::FixedLenByteArray(val), _) => { + (val.min().to_string(), val.max().to_string()) + } + } +} + +pub struct ParquetMetadataFunc {} + +impl TableFunctionImpl for ParquetMetadataFunc { + fn call(&self, exprs: &[Expr]) -> Result> { + let filename = match exprs.first() { + Some(Expr::Literal(ScalarValue::Utf8(Some(s)))) => s, // single quote: parquet_metadata('x.parquet') + Some(Expr::Column(Column { name, .. })) => name, // double quote: parquet_metadata("x.parquet") + _ => { + return plan_err!( + "parquet_metadata requires string argument as its input" + ); + } + }; + + let file = File::open(filename.clone())?; + let reader = SerializedFileReader::new(file)?; + let metadata = reader.metadata(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("filename", DataType::Utf8, true), + Field::new("row_group_id", DataType::Int64, true), + Field::new("row_group_num_rows", DataType::Int64, true), + Field::new("row_group_num_columns", DataType::Int64, true), + Field::new("row_group_bytes", DataType::Int64, true), + Field::new("column_id", DataType::Int64, true), + Field::new("file_offset", DataType::Int64, true), + Field::new("num_values", DataType::Int64, true), + Field::new("path_in_schema", DataType::Utf8, true), + Field::new("type", DataType::Utf8, true), + Field::new("stats_min", DataType::Utf8, true), + Field::new("stats_max", DataType::Utf8, true), + Field::new("stats_null_count", DataType::Int64, true), + Field::new("stats_distinct_count", DataType::Int64, true), + Field::new("stats_min_value", DataType::Utf8, true), + Field::new("stats_max_value", DataType::Utf8, true), + Field::new("compression", DataType::Utf8, true), + Field::new("encodings", DataType::Utf8, true), + Field::new("index_page_offset", DataType::Int64, true), + Field::new("dictionary_page_offset", DataType::Int64, true), + Field::new("data_page_offset", DataType::Int64, true), + Field::new("total_compressed_size", DataType::Int64, true), + Field::new("total_uncompressed_size", DataType::Int64, true), + ])); + + // construct recordbatch from metadata + let mut filename_arr = vec![]; + let mut row_group_id_arr = vec![]; + let mut row_group_num_rows_arr = vec![]; + let mut row_group_num_columns_arr = vec![]; + let mut row_group_bytes_arr = vec![]; + let mut column_id_arr = vec![]; + let mut file_offset_arr = vec![]; + let mut num_values_arr = vec![]; + let mut path_in_schema_arr = vec![]; + let mut type_arr = vec![]; + let mut stats_min_arr = vec![]; + let mut stats_max_arr = vec![]; + let mut stats_null_count_arr = vec![]; + let mut stats_distinct_count_arr = vec![]; + let mut stats_min_value_arr = vec![]; + let mut stats_max_value_arr = vec![]; + let mut compression_arr = vec![]; + let mut encodings_arr = vec![]; + let mut index_page_offset_arr = vec![]; + let mut dictionary_page_offset_arr = vec![]; + let mut data_page_offset_arr = vec![]; + let mut total_compressed_size_arr = vec![]; + let mut total_uncompressed_size_arr = vec![]; + for (rg_idx, row_group) in metadata.row_groups().iter().enumerate() { + for (col_idx, column) in row_group.columns().iter().enumerate() { + filename_arr.push(filename.clone()); + row_group_id_arr.push(rg_idx as i64); + row_group_num_rows_arr.push(row_group.num_rows()); + row_group_num_columns_arr.push(row_group.num_columns() as i64); + row_group_bytes_arr.push(row_group.total_byte_size()); + column_id_arr.push(col_idx as i64); + file_offset_arr.push(column.file_offset()); + num_values_arr.push(column.num_values()); + path_in_schema_arr.push(column.column_path().to_string()); + type_arr.push(column.column_type().to_string()); + let converted_type = column.column_descr().converted_type(); + + if let Some(s) = column.statistics() { + let (min_val, max_val) = if s.has_min_max_set() { + let (min_val, max_val) = + convert_parquet_statistics(s, converted_type); + (Some(min_val), Some(max_val)) + } else { + (None, None) + }; + stats_min_arr.push(min_val.clone()); + stats_max_arr.push(max_val.clone()); + stats_null_count_arr.push(Some(s.null_count() as i64)); + stats_distinct_count_arr.push(s.distinct_count().map(|c| c as i64)); + stats_min_value_arr.push(min_val); + stats_max_value_arr.push(max_val); + } else { + stats_min_arr.push(None); + stats_max_arr.push(None); + stats_null_count_arr.push(None); + stats_distinct_count_arr.push(None); + stats_min_value_arr.push(None); + stats_max_value_arr.push(None); + }; + compression_arr.push(format!("{:?}", column.compression())); + encodings_arr.push(format!("{:?}", column.encodings())); + index_page_offset_arr.push(column.index_page_offset()); + dictionary_page_offset_arr.push(column.dictionary_page_offset()); + data_page_offset_arr.push(column.data_page_offset()); + total_compressed_size_arr.push(column.compressed_size()); + total_uncompressed_size_arr.push(column.uncompressed_size()); + } + } + + let rb = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(filename_arr)), + Arc::new(Int64Array::from(row_group_id_arr)), + Arc::new(Int64Array::from(row_group_num_rows_arr)), + Arc::new(Int64Array::from(row_group_num_columns_arr)), + Arc::new(Int64Array::from(row_group_bytes_arr)), + Arc::new(Int64Array::from(column_id_arr)), + Arc::new(Int64Array::from(file_offset_arr)), + Arc::new(Int64Array::from(num_values_arr)), + Arc::new(StringArray::from(path_in_schema_arr)), + Arc::new(StringArray::from(type_arr)), + Arc::new(StringArray::from(stats_min_arr)), + Arc::new(StringArray::from(stats_max_arr)), + Arc::new(Int64Array::from(stats_null_count_arr)), + Arc::new(Int64Array::from(stats_distinct_count_arr)), + Arc::new(StringArray::from(stats_min_value_arr)), + Arc::new(StringArray::from(stats_max_value_arr)), + Arc::new(StringArray::from(compression_arr)), + Arc::new(StringArray::from(encodings_arr)), + Arc::new(Int64Array::from(index_page_offset_arr)), + Arc::new(Int64Array::from(dictionary_page_offset_arr)), + Arc::new(Int64Array::from(data_page_offset_arr)), + Arc::new(Int64Array::from(total_compressed_size_arr)), + Arc::new(Int64Array::from(total_uncompressed_size_arr)), + ], + )?; + + let parquet_metadata = ParquetMetadataTable { schema, batch: rb }; + Ok(Arc::new(parquet_metadata)) + } +} diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index c069f458f196c..563d172f2c95e 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -15,25 +15,28 @@ // specific language governing permissions and limitations // under the License. -use clap::Parser; +use std::collections::HashMap; +use std::env; +use std::path::Path; +use std::str::FromStr; +use std::sync::{Arc, OnceLock}; + use datafusion::error::{DataFusionError, Result}; use datafusion::execution::context::SessionConfig; use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool}; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::prelude::SessionContext; use datafusion_cli::catalog::DynamicFileCatalog; +use datafusion_cli::functions::ParquetMetadataFunc; use datafusion_cli::{ exec, print_format::PrintFormat, print_options::{MaxRows, PrintOptions}, DATAFUSION_CLI_VERSION, }; + +use clap::Parser; use mimalloc::MiMalloc; -use std::collections::HashMap; -use std::env; -use std::path::Path; -use std::str::FromStr; -use std::sync::{Arc, OnceLock}; #[global_allocator] static GLOBAL: MiMalloc = MiMalloc; @@ -110,7 +113,7 @@ struct Args { )] rc: Option>, - #[clap(long, arg_enum, default_value_t = PrintFormat::Table)] + #[clap(long, arg_enum, default_value_t = PrintFormat::Automatic)] format: PrintFormat, #[clap( @@ -185,6 +188,8 @@ pub async fn main() -> Result<()> { ctx.state().catalog_list(), ctx.state_weak_ref(), ))); + // register `parquet_metadata` table function to get metadata from parquet files + ctx.register_udtf("parquet_metadata", Arc::new(ParquetMetadataFunc {})); let mut print_options = PrintOptions { format: args.format, @@ -329,6 +334,7 @@ fn extract_memory_pool_size(size: &str) -> Result { #[cfg(test)] mod tests { use super::*; + use datafusion::assert_batches_eq; fn assert_conversion(input: &str, expected: Result) { let result = extract_memory_pool_size(input); @@ -385,4 +391,58 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_parquet_metadata_works() -> Result<(), DataFusionError> { + let ctx = SessionContext::new(); + ctx.register_udtf("parquet_metadata", Arc::new(ParquetMetadataFunc {})); + + // input with single quote + let sql = + "SELECT * FROM parquet_metadata('../datafusion/core/tests/data/fixed_size_list_array.parquet')"; + let df = ctx.sql(sql).await?; + let rbs = df.collect().await?; + + let excepted = [ + "+-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+", + "| filename | row_group_id | row_group_num_rows | row_group_num_columns | row_group_bytes | column_id | file_offset | num_values | path_in_schema | type | stats_min | stats_max | stats_null_count | stats_distinct_count | stats_min_value | stats_max_value | compression | encodings | index_page_offset | dictionary_page_offset | data_page_offset | total_compressed_size | total_uncompressed_size |", + "+-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+", + "| ../datafusion/core/tests/data/fixed_size_list_array.parquet | 0 | 2 | 1 | 123 | 0 | 125 | 4 | \"f0.list.item\" | INT64 | 1 | 4 | 0 | | 1 | 4 | SNAPPY | [RLE_DICTIONARY, PLAIN, RLE] | | 4 | 46 | 121 | 123 |", + "+-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+", + ]; + assert_batches_eq!(excepted, &rbs); + + // input with double quote + let sql = + "SELECT * FROM parquet_metadata(\"../datafusion/core/tests/data/fixed_size_list_array.parquet\")"; + let df = ctx.sql(sql).await?; + let rbs = df.collect().await?; + assert_batches_eq!(excepted, &rbs); + + Ok(()) + } + + #[tokio::test] + async fn test_parquet_metadata_works_with_strings() -> Result<(), DataFusionError> { + let ctx = SessionContext::new(); + ctx.register_udtf("parquet_metadata", Arc::new(ParquetMetadataFunc {})); + + // input with string columns + let sql = + "SELECT * FROM parquet_metadata('../parquet-testing/data/data_index_bloom_encoding_stats.parquet')"; + let df = ctx.sql(sql).await?; + let rbs = df.collect().await?; + + let excepted = [ + +"+-----------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+------------+-----------+-----------+------------------+----------------------+-----------------+-----------------+--------------------+--------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+", +"| filename | row_group_id | row_group_num_rows | row_group_num_columns | row_group_bytes | column_id | file_offset | num_values | path_in_schema | type | stats_min | stats_max | stats_null_count | stats_distinct_count | stats_min_value | stats_max_value | compression | encodings | index_page_offset | dictionary_page_offset | data_page_offset | total_compressed_size | total_uncompressed_size |", +"+-----------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+------------+-----------+-----------+------------------+----------------------+-----------------+-----------------+--------------------+--------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+", +"| ../parquet-testing/data/data_index_bloom_encoding_stats.parquet | 0 | 14 | 1 | 163 | 0 | 4 | 14 | \"String\" | BYTE_ARRAY | Hello | today | 0 | | Hello | today | GZIP(GzipLevel(6)) | [BIT_PACKED, RLE, PLAIN] | | | 4 | 152 | 163 |", +"+-----------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+------------+-----------+-----------+------------------+----------------------+-----------------+-----------------+--------------------+--------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+" + ]; + assert_batches_eq!(excepted, &rbs); + + Ok(()) + } } diff --git a/datafusion-cli/src/object_storage.rs b/datafusion-cli/src/object_storage.rs index c39d1915eb435..9d79c7e0ec78e 100644 --- a/datafusion-cli/src/object_storage.rs +++ b/datafusion-cli/src/object_storage.rs @@ -30,20 +30,23 @@ use url::Url; pub async fn get_s3_object_store_builder( url: &Url, - cmd: &CreateExternalTable, + cmd: &mut CreateExternalTable, ) -> Result { let bucket_name = get_bucket_name(url)?; let mut builder = AmazonS3Builder::from_env().with_bucket_name(bucket_name); if let (Some(access_key_id), Some(secret_access_key)) = ( - cmd.options.get("access_key_id"), - cmd.options.get("secret_access_key"), + // These options are datafusion-cli specific and must be removed before passing through to datafusion. + // Otherwise, a Configuration error will be raised. + cmd.options.remove("access_key_id"), + cmd.options.remove("secret_access_key"), ) { + println!("removing secret access key!"); builder = builder .with_access_key_id(access_key_id) .with_secret_access_key(secret_access_key); - if let Some(session_token) = cmd.options.get("session_token") { + if let Some(session_token) = cmd.options.remove("session_token") { builder = builder.with_token(session_token); } } else { @@ -66,7 +69,7 @@ pub async fn get_s3_object_store_builder( builder = builder.with_credentials(credentials); } - if let Some(region) = cmd.options.get("region") { + if let Some(region) = cmd.options.remove("region") { builder = builder.with_region(region); } @@ -99,7 +102,7 @@ impl CredentialProvider for S3CredentialProvider { pub fn get_oss_object_store_builder( url: &Url, - cmd: &CreateExternalTable, + cmd: &mut CreateExternalTable, ) -> Result { let bucket_name = get_bucket_name(url)?; let mut builder = AmazonS3Builder::from_env() @@ -109,15 +112,15 @@ pub fn get_oss_object_store_builder( .with_region("do_not_care"); if let (Some(access_key_id), Some(secret_access_key)) = ( - cmd.options.get("access_key_id"), - cmd.options.get("secret_access_key"), + cmd.options.remove("access_key_id"), + cmd.options.remove("secret_access_key"), ) { builder = builder .with_access_key_id(access_key_id) .with_secret_access_key(secret_access_key); } - if let Some(endpoint) = cmd.options.get("endpoint") { + if let Some(endpoint) = cmd.options.remove("endpoint") { builder = builder.with_endpoint(endpoint); } @@ -126,21 +129,21 @@ pub fn get_oss_object_store_builder( pub fn get_gcs_object_store_builder( url: &Url, - cmd: &CreateExternalTable, + cmd: &mut CreateExternalTable, ) -> Result { let bucket_name = get_bucket_name(url)?; let mut builder = GoogleCloudStorageBuilder::from_env().with_bucket_name(bucket_name); - if let Some(service_account_path) = cmd.options.get("service_account_path") { + if let Some(service_account_path) = cmd.options.remove("service_account_path") { builder = builder.with_service_account_path(service_account_path); } - if let Some(service_account_key) = cmd.options.get("service_account_key") { + if let Some(service_account_key) = cmd.options.remove("service_account_key") { builder = builder.with_service_account_key(service_account_key); } if let Some(application_credentials_path) = - cmd.options.get("application_credentials_path") + cmd.options.remove("application_credentials_path") { builder = builder.with_application_credentials(application_credentials_path); } @@ -180,9 +183,9 @@ mod tests { let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('access_key_id' '{access_key_id}', 'secret_access_key' '{secret_access_key}', 'region' '{region}', 'session_token' {session_token}) LOCATION '{location}'"); let ctx = SessionContext::new(); - let plan = ctx.state().create_logical_plan(&sql).await?; + let mut plan = ctx.state().create_logical_plan(&sql).await?; - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { let builder = get_s3_object_store_builder(table_url.as_ref(), cmd).await?; // get the actual configuration information, then assert_eq! let config = [ @@ -212,9 +215,9 @@ mod tests { let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('access_key_id' '{access_key_id}', 'secret_access_key' '{secret_access_key}', 'endpoint' '{endpoint}') LOCATION '{location}'"); let ctx = SessionContext::new(); - let plan = ctx.state().create_logical_plan(&sql).await?; + let mut plan = ctx.state().create_logical_plan(&sql).await?; - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { let builder = get_oss_object_store_builder(table_url.as_ref(), cmd)?; // get the actual configuration information, then assert_eq! let config = [ @@ -244,9 +247,9 @@ mod tests { let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('service_account_path' '{service_account_path}', 'service_account_key' '{service_account_key}', 'application_credentials_path' '{application_credentials_path}') LOCATION '{location}'"); let ctx = SessionContext::new(); - let plan = ctx.state().create_logical_plan(&sql).await?; + let mut plan = ctx.state().create_logical_plan(&sql).await?; - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { let builder = get_gcs_object_store_builder(table_url.as_ref(), cmd)?; // get the actual configuration information, then assert_eq! let config = [ diff --git a/datafusion-cli/src/print_format.rs b/datafusion-cli/src/print_format.rs index 0738bf6f9b47c..ea418562495d1 100644 --- a/datafusion-cli/src/print_format.rs +++ b/datafusion-cli/src/print_format.rs @@ -16,23 +16,27 @@ // under the License. //! Print format variants + +use std::str::FromStr; + use crate::print_options::MaxRows; + use arrow::csv::writer::WriterBuilder; use arrow::json::{ArrayWriter, LineDelimitedWriter}; +use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches_with_options; -use datafusion::arrow::record_batch::RecordBatch; use datafusion::common::format::DEFAULT_FORMAT_OPTIONS; -use datafusion::error::{DataFusionError, Result}; -use std::str::FromStr; +use datafusion::error::Result; /// Allow records to be printed in different formats -#[derive(Debug, PartialEq, Eq, clap::ArgEnum, Clone)] +#[derive(Debug, PartialEq, Eq, clap::ArgEnum, Clone, Copy)] pub enum PrintFormat { Csv, Tsv, Table, Json, NdJson, + Automatic, } impl FromStr for PrintFormat { @@ -44,31 +48,44 @@ impl FromStr for PrintFormat { } macro_rules! batches_to_json { - ($WRITER: ident, $batches: expr) => {{ - let mut bytes = vec![]; + ($WRITER: ident, $writer: expr, $batches: expr) => {{ { - let mut writer = $WRITER::new(&mut bytes); - $batches.iter().try_for_each(|batch| writer.write(batch))?; - writer.finish()?; + if !$batches.is_empty() { + let mut json_writer = $WRITER::new(&mut *$writer); + for batch in $batches { + json_writer.write(batch)?; + } + json_writer.finish()?; + json_finish!($WRITER, $writer); + } } - String::from_utf8(bytes).map_err(|e| DataFusionError::External(Box::new(e)))? + Ok(()) as Result<()> }}; } -fn print_batches_with_sep(batches: &[RecordBatch], delimiter: u8) -> Result { - let mut bytes = vec![]; - { - let builder = WriterBuilder::new() - .with_header(true) - .with_delimiter(delimiter); - let mut writer = builder.build(&mut bytes); - for batch in batches { - writer.write(batch)?; - } +macro_rules! json_finish { + (ArrayWriter, $writer: expr) => {{ + writeln!($writer)?; + }}; + (LineDelimitedWriter, $writer: expr) => {{}}; +} + +fn print_batches_with_sep( + writer: &mut W, + batches: &[RecordBatch], + delimiter: u8, + with_header: bool, +) -> Result<()> { + let builder = WriterBuilder::new() + .with_header(with_header) + .with_delimiter(delimiter); + let mut csv_writer = builder.build(writer); + + for batch in batches { + csv_writer.write(batch)?; } - let formatted = - String::from_utf8(bytes).map_err(|e| DataFusionError::External(Box::new(e)))?; - Ok(formatted) + + Ok(()) } fn keep_only_maxrows(s: &str, maxrows: usize) -> String { @@ -88,97 +105,118 @@ fn keep_only_maxrows(s: &str, maxrows: usize) -> String { result.join("\n") } -fn format_batches_with_maxrows( +fn format_batches_with_maxrows( + writer: &mut W, batches: &[RecordBatch], maxrows: MaxRows, -) -> Result { +) -> Result<()> { match maxrows { MaxRows::Limited(maxrows) => { - // Only format enough batches for maxrows + // Filter batches to meet the maxrows condition let mut filtered_batches = Vec::new(); - let mut batches = batches; - let row_count: usize = batches.iter().map(|b| b.num_rows()).sum(); - if row_count > maxrows { - let mut accumulated_rows = 0; - - for batch in batches { + let mut row_count: usize = 0; + let mut over_limit = false; + for batch in batches { + if row_count + batch.num_rows() > maxrows { + // If adding this batch exceeds maxrows, slice the batch + let limit = maxrows - row_count; + let sliced_batch = batch.slice(0, limit); + filtered_batches.push(sliced_batch); + over_limit = true; + break; + } else { filtered_batches.push(batch.clone()); - if accumulated_rows + batch.num_rows() > maxrows { - break; - } - accumulated_rows += batch.num_rows(); + row_count += batch.num_rows(); } - - batches = &filtered_batches; } - let mut formatted = format!( - "{}", - pretty_format_batches_with_options(batches, &DEFAULT_FORMAT_OPTIONS)?, - ); - - if row_count > maxrows { - formatted = keep_only_maxrows(&formatted, maxrows); + let formatted = pretty_format_batches_with_options( + &filtered_batches, + &DEFAULT_FORMAT_OPTIONS, + )?; + if over_limit { + let mut formatted_str = format!("{}", formatted); + formatted_str = keep_only_maxrows(&formatted_str, maxrows); + writeln!(writer, "{}", formatted_str)?; + } else { + writeln!(writer, "{}", formatted)?; } - - Ok(formatted) } MaxRows::Unlimited => { - // maxrows not specified, print all rows - Ok(format!( - "{}", - pretty_format_batches_with_options(batches, &DEFAULT_FORMAT_OPTIONS)?, - )) + let formatted = + pretty_format_batches_with_options(batches, &DEFAULT_FORMAT_OPTIONS)?; + writeln!(writer, "{}", formatted)?; } } + + Ok(()) } impl PrintFormat { - /// print the batches to stdout using the specified format - /// `maxrows` option is only used for `Table` format: - /// If `maxrows` is Some(n), then at most n rows will be displayed - /// If `maxrows` is None, then every row will be displayed - pub fn print_batches(&self, batches: &[RecordBatch], maxrows: MaxRows) -> Result<()> { - if batches.is_empty() { + /// Print the batches to a writer using the specified format + pub fn print_batches( + &self, + writer: &mut W, + batches: &[RecordBatch], + maxrows: MaxRows, + with_header: bool, + ) -> Result<()> { + if batches.is_empty() || batches[0].num_rows() == 0 { return Ok(()); } match self { - Self::Csv => println!("{}", print_batches_with_sep(batches, b',')?), - Self::Tsv => println!("{}", print_batches_with_sep(batches, b'\t')?), + Self::Csv | Self::Automatic => { + print_batches_with_sep(writer, batches, b',', with_header) + } + Self::Tsv => print_batches_with_sep(writer, batches, b'\t', with_header), Self::Table => { if maxrows == MaxRows::Limited(0) { return Ok(()); } - println!("{}", format_batches_with_maxrows(batches, maxrows)?,) - } - Self::Json => println!("{}", batches_to_json!(ArrayWriter, batches)), - Self::NdJson => { - println!("{}", batches_to_json!(LineDelimitedWriter, batches)) + format_batches_with_maxrows(writer, batches, maxrows) } + Self::Json => batches_to_json!(ArrayWriter, writer, batches), + Self::NdJson => batches_to_json!(LineDelimitedWriter, writer, batches), } - Ok(()) } } #[cfg(test)] mod tests { + use std::io::{Cursor, Read, Write}; + use std::sync::Arc; + use super::*; + use arrow::array::Int32Array; use arrow::datatypes::{DataType, Field, Schema}; - use std::sync::Arc; + use datafusion::error::Result; + + fn run_test(batches: &[RecordBatch], test_fn: F) -> Result + where + F: Fn(&mut Cursor>, &[RecordBatch]) -> Result<()>, + { + let mut buffer = Cursor::new(Vec::new()); + test_fn(&mut buffer, batches)?; + buffer.set_position(0); + let mut contents = String::new(); + buffer.read_to_string(&mut contents)?; + Ok(contents) + } #[test] - fn test_print_batches_with_sep() { - let batches = vec![]; - assert_eq!("", print_batches_with_sep(&batches, b',').unwrap()); + fn test_print_batches_with_sep() -> Result<()> { + let contents = run_test(&[], |buffer, batches| { + print_batches_with_sep(buffer, batches, b',', true) + })?; + assert_eq!(contents, ""); let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), Field::new("c", DataType::Int32, false), ])); - let batch = RecordBatch::try_new( schema, vec![ @@ -186,29 +224,33 @@ mod tests { Arc::new(Int32Array::from(vec![4, 5, 6])), Arc::new(Int32Array::from(vec![7, 8, 9])), ], - ) - .unwrap(); + )?; - let batches = vec![batch]; - let r = print_batches_with_sep(&batches, b',').unwrap(); - assert_eq!("a,b,c\n1,4,7\n2,5,8\n3,6,9\n", r); + let contents = run_test(&[batch], |buffer, batches| { + print_batches_with_sep(buffer, batches, b',', true) + })?; + assert_eq!(contents, "a,b,c\n1,4,7\n2,5,8\n3,6,9\n"); + + Ok(()) } #[test] fn test_print_batches_to_json_empty() -> Result<()> { - let batches = vec![]; - let r = batches_to_json!(ArrayWriter, &batches); - assert_eq!("", r); + let contents = run_test(&[], |buffer, batches| { + batches_to_json!(ArrayWriter, buffer, batches) + })?; + assert_eq!(contents, ""); - let r = batches_to_json!(LineDelimitedWriter, &batches); - assert_eq!("", r); + let contents = run_test(&[], |buffer, batches| { + batches_to_json!(LineDelimitedWriter, buffer, batches) + })?; + assert_eq!(contents, ""); let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), Field::new("c", DataType::Int32, false), ])); - let batch = RecordBatch::try_new( schema, vec![ @@ -216,25 +258,29 @@ mod tests { Arc::new(Int32Array::from(vec![4, 5, 6])), Arc::new(Int32Array::from(vec![7, 8, 9])), ], - ) - .unwrap(); - + )?; let batches = vec![batch]; - let r = batches_to_json!(ArrayWriter, &batches); - assert_eq!("[{\"a\":1,\"b\":4,\"c\":7},{\"a\":2,\"b\":5,\"c\":8},{\"a\":3,\"b\":6,\"c\":9}]", r); - let r = batches_to_json!(LineDelimitedWriter, &batches); - assert_eq!("{\"a\":1,\"b\":4,\"c\":7}\n{\"a\":2,\"b\":5,\"c\":8}\n{\"a\":3,\"b\":6,\"c\":9}\n", r); + let contents = run_test(&batches, |buffer, batches| { + batches_to_json!(ArrayWriter, buffer, batches) + })?; + assert_eq!(contents, "[{\"a\":1,\"b\":4,\"c\":7},{\"a\":2,\"b\":5,\"c\":8},{\"a\":3,\"b\":6,\"c\":9}]\n"); + + let contents = run_test(&batches, |buffer, batches| { + batches_to_json!(LineDelimitedWriter, buffer, batches) + })?; + assert_eq!(contents, "{\"a\":1,\"b\":4,\"c\":7}\n{\"a\":2,\"b\":5,\"c\":8}\n{\"a\":3,\"b\":6,\"c\":9}\n"); + Ok(()) } #[test] fn test_format_batches_with_maxrows() -> Result<()> { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); - - let batch = - RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1, 2, 3]))]) - .unwrap(); + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + )?; #[rustfmt::skip] let all_rows_expected = [ @@ -244,7 +290,7 @@ mod tests { "| 1 |", "| 2 |", "| 3 |", - "+---+", + "+---+\n", ].join("\n"); #[rustfmt::skip] @@ -256,7 +302,7 @@ mod tests { "| . |", "| . |", "| . |", - "+---+", + "+---+\n", ].join("\n"); #[rustfmt::skip] @@ -272,26 +318,36 @@ mod tests { "| . |", "| . |", "| . |", - "+---+", + "+---+\n", ].join("\n"); - let no_limit = format_batches_with_maxrows(&[batch.clone()], MaxRows::Unlimited)?; - assert_eq!(all_rows_expected, no_limit); - - let maxrows_less_than_actual = - format_batches_with_maxrows(&[batch.clone()], MaxRows::Limited(1))?; - assert_eq!(one_row_expected, maxrows_less_than_actual); - let maxrows_more_than_actual = - format_batches_with_maxrows(&[batch.clone()], MaxRows::Limited(5))?; - assert_eq!(all_rows_expected, maxrows_more_than_actual); - let maxrows_equals_actual = - format_batches_with_maxrows(&[batch.clone()], MaxRows::Limited(3))?; - assert_eq!(all_rows_expected, maxrows_equals_actual); - let multi_batches = format_batches_with_maxrows( + let no_limit = run_test(&[batch.clone()], |buffer, batches| { + format_batches_with_maxrows(buffer, batches, MaxRows::Unlimited) + })?; + assert_eq!(no_limit, all_rows_expected); + + let maxrows_less_than_actual = run_test(&[batch.clone()], |buffer, batches| { + format_batches_with_maxrows(buffer, batches, MaxRows::Limited(1)) + })?; + assert_eq!(maxrows_less_than_actual, one_row_expected); + + let maxrows_more_than_actual = run_test(&[batch.clone()], |buffer, batches| { + format_batches_with_maxrows(buffer, batches, MaxRows::Limited(5)) + })?; + assert_eq!(maxrows_more_than_actual, all_rows_expected); + + let maxrows_equals_actual = run_test(&[batch.clone()], |buffer, batches| { + format_batches_with_maxrows(buffer, batches, MaxRows::Limited(3)) + })?; + assert_eq!(maxrows_equals_actual, all_rows_expected); + + let multi_batches = run_test( &[batch.clone(), batch.clone(), batch.clone()], - MaxRows::Limited(5), + |buffer, batches| { + format_batches_with_maxrows(buffer, batches, MaxRows::Limited(5)) + }, )?; - assert_eq!(multi_batches_expected, multi_batches); + assert_eq!(multi_batches, multi_batches_expected); Ok(()) } diff --git a/datafusion-cli/src/print_options.rs b/datafusion-cli/src/print_options.rs index 0a6c8d4c36fce..b8594352b585b 100644 --- a/datafusion-cli/src/print_options.rs +++ b/datafusion-cli/src/print_options.rs @@ -15,13 +15,21 @@ // specific language governing permissions and limitations // under the License. -use crate::print_format::PrintFormat; -use datafusion::arrow::record_batch::RecordBatch; -use datafusion::error::Result; use std::fmt::{Display, Formatter}; +use std::io::Write; +use std::pin::Pin; use std::str::FromStr; use std::time::Instant; +use crate::print_format::PrintFormat; + +use arrow::record_batch::RecordBatch; +use datafusion::common::DataFusionError; +use datafusion::error::Result; +use datafusion::physical_plan::RecordBatchStream; + +use futures::StreamExt; + #[derive(Debug, Clone, PartialEq, Copy)] pub enum MaxRows { /// show all rows in the output @@ -85,20 +93,70 @@ fn get_timing_info_str( } impl PrintOptions { - /// print the batches to stdout using the specified format + /// Print the batches to stdout using the specified format pub fn print_batches( &self, batches: &[RecordBatch], query_start_time: Instant, ) -> Result<()> { + let stdout = std::io::stdout(); + let mut writer = stdout.lock(); + + self.format + .print_batches(&mut writer, batches, self.maxrows, true)?; + let row_count: usize = batches.iter().map(|b| b.num_rows()).sum(); - // Elapsed time should not count time for printing batches - let timing_info = get_timing_info_str(row_count, self.maxrows, query_start_time); + let timing_info = get_timing_info_str( + row_count, + if self.format == PrintFormat::Table { + self.maxrows + } else { + MaxRows::Unlimited + }, + query_start_time, + ); + + if !self.quiet { + writeln!(writer, "{timing_info}")?; + } + + Ok(()) + } + + /// Print the stream to stdout using the specified format + pub async fn print_stream( + &self, + mut stream: Pin>, + query_start_time: Instant, + ) -> Result<()> { + if self.format == PrintFormat::Table { + return Err(DataFusionError::External( + "PrintFormat::Table is not implemented".to_string().into(), + )); + }; + + let stdout = std::io::stdout(); + let mut writer = stdout.lock(); + + let mut row_count = 0_usize; + let mut with_header = true; + + while let Some(Ok(batch)) = stream.next().await { + row_count += batch.num_rows(); + self.format.print_batches( + &mut writer, + &[batch], + MaxRows::Unlimited, + with_header, + )?; + with_header = false; + } - self.format.print_batches(batches, self.maxrows)?; + let timing_info = + get_timing_info_str(row_count, MaxRows::Unlimited, query_start_time); if !self.quiet { - println!("{timing_info}"); + writeln!(writer, "{timing_info}")?; } Ok(()) diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index 57691520a401b..59580bcb6a05a 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -46,9 +46,9 @@ futures = { workspace = true } log = { workspace = true } mimalloc = { version = "0.1", default-features = false } num_cpus = { workspace = true } -object_store = { version = "0.7.0", features = ["aws", "http"] } +object_store = { workspace = true, features = ["aws", "http"] } prost = { version = "0.12", default-features = false } -prost-derive = { version = "0.11", default-features = false } +prost-derive = { version = "0.12", default-features = false } serde = { version = "1.0.136", features = ["derive"] } serde_json = { workspace = true } tempfile = { workspace = true } diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 9f7c9f99d14e1..aae451add9e75 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -47,10 +47,11 @@ cargo run --example csv_sql - [`catalog.rs`](examples/external_dependency/catalog.rs): Register the table into a custom catalog - [`custom_datasource.rs`](examples/custom_datasource.rs): Run queries against a custom datasource (TableProvider) - [`dataframe.rs`](examples/dataframe.rs): Run a query using a DataFrame against a local parquet file -- [`dataframe-to-s3.rs`](examples/external_dependency/dataframe-to-s3.rs): Run a query using a DataFrame against a parquet file from s3 +- [`dataframe-to-s3.rs`](examples/external_dependency/dataframe-to-s3.rs): Run a query using a DataFrame against a parquet file from s3 and writing back to s3 +- [`dataframe_output.rs`](examples/dataframe_output.rs): Examples of methods which write data out from a DataFrame - [`dataframe_in_memory.rs`](examples/dataframe_in_memory.rs): Run a query using a DataFrame against data in memory - [`deserialize_to_struct.rs`](examples/deserialize_to_struct.rs): Convert query results into rust structs using serde -- [`expr_api.rs`](examples/expr_api.rs): Use the `Expr` construction and simplification API +- [`expr_api.rs`](examples/expr_api.rs): Create, execute, simplify and anaylze `Expr`s - [`flight_sql_server.rs`](examples/flight/flight_sql_server.rs): Run DataFusion as a standalone process and execute SQL queries from JDBC clients - [`memtable.rs`](examples/memtable.rs): Create an query data in memory using SQL and `RecordBatch`es - [`parquet_sql.rs`](examples/parquet_sql.rs): Build and run a query plan from a SQL statement against a local Parquet file @@ -58,9 +59,11 @@ cargo run --example csv_sql - [`query-aws-s3.rs`](examples/external_dependency/query-aws-s3.rs): Configure `object_store` and run a query against files stored in AWS S3 - [`query-http-csv.rs`](examples/query-http-csv.rs): Configure `object_store` and run a query against files vi HTTP - [`rewrite_expr.rs`](examples/rewrite_expr.rs): Define and invoke a custom Query Optimizer pass +- [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined Scalar Function (UDF) +- [`advanced_udf.rs`](examples/advanced_udf.rs): Define and invoke a more complicated User Defined Scalar Function (UDF) - [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF) -- [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined (scalar) Function (UDF) - [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF) +- [`advanced_udwf.rs`](examples/advanced_udwf.rs): Define and invoke a more complicated User Defined Window Function (UDWF) ## Distributed diff --git a/datafusion-examples/examples/advanced_udf.rs b/datafusion-examples/examples/advanced_udf.rs new file mode 100644 index 0000000000000..6ebf88a0b671b --- /dev/null +++ b/datafusion-examples/examples/advanced_udf.rs @@ -0,0 +1,243 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::{ + arrow::{ + array::{ArrayRef, Float32Array, Float64Array}, + datatypes::DataType, + record_batch::RecordBatch, + }, + logical_expr::Volatility, +}; +use std::any::Any; + +use arrow::array::{new_null_array, Array, AsArray}; +use arrow::compute; +use arrow::datatypes::Float64Type; +use datafusion::error::Result; +use datafusion::prelude::*; +use datafusion_common::{internal_err, ScalarValue}; +use datafusion_expr::{ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature}; +use std::sync::Arc; + +/// This example shows how to use the full ScalarUDFImpl API to implement a user +/// defined function. As in the `simple_udf.rs` example, this struct implements +/// a function that takes two arguments and returns the first argument raised to +/// the power of the second argument `a^b`. +/// +/// To do so, we must implement the `ScalarUDFImpl` trait. +struct PowUdf { + signature: Signature, + aliases: Vec, +} + +impl PowUdf { + /// Create a new instance of the `PowUdf` struct + fn new() -> Self { + Self { + signature: Signature::exact( + // this function will always take two arguments of type f64 + vec![DataType::Float64, DataType::Float64], + // this function is deterministic and will always return the same + // result for the same input + Volatility::Immutable, + ), + // we will also add an alias of "my_pow" + aliases: vec!["my_pow".to_string()], + } + } +} + +impl ScalarUDFImpl for PowUdf { + /// We implement as_any so that we can downcast the ScalarUDFImpl trait object + fn as_any(&self) -> &dyn Any { + self + } + + /// Return the name of this function + fn name(&self) -> &str { + "pow" + } + + /// Return the "signature" of this function -- namely what types of arguments it will take + fn signature(&self) -> &Signature { + &self.signature + } + + /// What is the type of value that will be returned by this function? In + /// this case it will always be a constant value, but it could also be a + /// function of the input types. + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + /// This is the function that actually calculates the results. + /// + /// This is the same way that functions built into DataFusion are invoked, + /// which permits important special cases when one or both of the arguments + /// are single values (constants). For example `pow(a, 2)` + /// + /// However, it also means the implementation is more complex than when + /// using `create_udf`. + fn invoke(&self, args: &[ColumnarValue]) -> Result { + // DataFusion has arranged for the correct inputs to be passed to this + // function, but we check again to make sure + assert_eq!(args.len(), 2); + let (base, exp) = (&args[0], &args[1]); + assert_eq!(base.data_type(), DataType::Float64); + assert_eq!(exp.data_type(), DataType::Float64); + + match (base, exp) { + // For demonstration purposes we also implement the scalar / scalar + // case here, but it is not typically required for high performance. + // + // For performance it is most important to optimize cases where at + // least one argument is an array. If all arguments are constants, + // the DataFusion expression simplification logic will often invoke + // this path once during planning, and simply use the result during + // execution. + ( + ColumnarValue::Scalar(ScalarValue::Float64(base)), + ColumnarValue::Scalar(ScalarValue::Float64(exp)), + ) => { + // compute the output. Note DataFusion treats `None` as NULL. + let res = match (base, exp) { + (Some(base), Some(exp)) => Some(base.powf(*exp)), + // one or both arguments were NULL + _ => None, + }; + Ok(ColumnarValue::Scalar(ScalarValue::from(res))) + } + // special case if the exponent is a constant + ( + ColumnarValue::Array(base_array), + ColumnarValue::Scalar(ScalarValue::Float64(exp)), + ) => { + let result_array = match exp { + // a ^ null = null + None => new_null_array(base_array.data_type(), base_array.len()), + // a ^ exp + Some(exp) => { + // DataFusion has ensured both arguments are Float64: + let base_array = base_array.as_primitive::(); + // calculate the result for every row. The `unary` + // kernel creates very fast "vectorized" code and + // handles things like null values for us. + let res: Float64Array = + compute::unary(base_array, |base| base.powf(*exp)); + Arc::new(res) + } + }; + Ok(ColumnarValue::Array(result_array)) + } + + // special case if the base is a constant (note this code is quite + // similar to the previous case, so we omit comments) + ( + ColumnarValue::Scalar(ScalarValue::Float64(base)), + ColumnarValue::Array(exp_array), + ) => { + let res = match base { + None => new_null_array(exp_array.data_type(), exp_array.len()), + Some(base) => { + let exp_array = exp_array.as_primitive::(); + let res: Float64Array = + compute::unary(exp_array, |exp| base.powf(exp)); + Arc::new(res) + } + }; + Ok(ColumnarValue::Array(res)) + } + // Both arguments are arrays so we have to perform the calculation for every row + (ColumnarValue::Array(base_array), ColumnarValue::Array(exp_array)) => { + let res: Float64Array = compute::binary( + base_array.as_primitive::(), + exp_array.as_primitive::(), + |base, exp| base.powf(exp), + )?; + Ok(ColumnarValue::Array(Arc::new(res))) + } + // if the types were not float, it is a bug in DataFusion + _ => { + use datafusion_common::DataFusionError; + internal_err!("Invalid argument types to pow function") + } + } + } + + /// We will also add an alias of "my_pow" + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// In this example we register `PowUdf` as a user defined function +/// and invoke it via the DataFrame API and SQL +#[tokio::main] +async fn main() -> Result<()> { + let ctx = create_context()?; + + // create the UDF + let pow = ScalarUDF::from(PowUdf::new()); + + // register the UDF with the context so it can be invoked by name and from SQL + ctx.register_udf(pow.clone()); + + // get a DataFrame from the context for scanning the "t" table + let df = ctx.table("t").await?; + + // Call pow(a, 10) using the DataFrame API + let df = df.select(vec![pow.call(vec![col("a"), lit(10i32)])])?; + + // note that the second argument is passed as an i32, not f64. DataFusion + // automatically coerces the types to match the UDF's defined signature. + + // print the results + df.show().await?; + + // You can also invoke both pow(2, 10) and its alias my_pow(a, b) using SQL + let sql_df = ctx.sql("SELECT pow(2, 10), my_pow(a, b) FROM t").await?; + sql_df.show().await?; + + Ok(()) +} + +/// create local execution context with an in-memory table: +/// +/// ```text +/// +-----+-----+ +/// | a | b | +/// +-----+-----+ +/// | 2.1 | 1.0 | +/// | 3.1 | 2.0 | +/// | 4.1 | 3.0 | +/// | 5.1 | 4.0 | +/// +-----+-----+ +/// ``` +fn create_context() -> Result { + // define data. + let a: ArrayRef = Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1])); + let b: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])); + let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)])?; + + // declare a new context. In Spark API, this corresponds to a new SparkSession + let ctx = SessionContext::new(); + + // declare a table in memory. In Spark API, this corresponds to createDataFrame(...). + ctx.register_batch("t", batch)?; + Ok(ctx) +} diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs new file mode 100644 index 0000000000000..91869d80a41ac --- /dev/null +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -0,0 +1,230 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; +use std::any::Any; + +use arrow::{ + array::{ArrayRef, AsArray, Float64Array}, + datatypes::Float64Type, +}; +use datafusion::error::Result; +use datafusion::prelude::*; +use datafusion_common::ScalarValue; +use datafusion_expr::{ + PartitionEvaluator, Signature, WindowFrame, WindowUDF, WindowUDFImpl, +}; + +/// This example shows how to use the full WindowUDFImpl API to implement a user +/// defined window function. As in the `simple_udwf.rs` example, this struct implements +/// a function `partition_evaluator` that returns the `MyPartitionEvaluator` instance. +/// +/// To do so, we must implement the `WindowUDFImpl` trait. +struct SmoothItUdf { + signature: Signature, +} + +impl SmoothItUdf { + /// Create a new instance of the SmoothItUdf struct + fn new() -> Self { + Self { + signature: Signature::exact( + // this function will always take one arguments of type f64 + vec![DataType::Float64], + // this function is deterministic and will always return the same + // result for the same input + Volatility::Immutable, + ), + } + } +} + +impl WindowUDFImpl for SmoothItUdf { + /// We implement as_any so that we can downcast the WindowUDFImpl trait object + fn as_any(&self) -> &dyn Any { + self + } + + /// Return the name of this function + fn name(&self) -> &str { + "smooth_it" + } + + /// Return the "signature" of this function -- namely that types of arguments it will take + fn signature(&self) -> &Signature { + &self.signature + } + + /// What is the type of value that will be returned by this function. + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + /// Create a `PartitionEvalutor` to evaluate this function on a new + /// partition. + fn partition_evaluator(&self) -> Result> { + Ok(Box::new(MyPartitionEvaluator::new())) + } +} + +/// This implements the lowest level evaluation for a window function +/// +/// It handles calculating the value of the window function for each +/// distinct values of `PARTITION BY` (each car type in our example) +#[derive(Clone, Debug)] +struct MyPartitionEvaluator {} + +impl MyPartitionEvaluator { + fn new() -> Self { + Self {} + } +} + +/// Different evaluation methods are called depending on the various +/// settings of WindowUDF. This example uses the simplest and most +/// general, `evaluate`. See `PartitionEvaluator` for the other more +/// advanced uses. +impl PartitionEvaluator for MyPartitionEvaluator { + /// Tell DataFusion the window function varies based on the value + /// of the window frame. + fn uses_window_frame(&self) -> bool { + true + } + + /// This function is called once per input row. + /// + /// `range`specifies which indexes of `values` should be + /// considered for the calculation. + /// + /// Note this is the SLOWEST, but simplest, way to evaluate a + /// window function. It is much faster to implement + /// evaluate_all or evaluate_all_with_rank, if possible + fn evaluate( + &mut self, + values: &[ArrayRef], + range: &std::ops::Range, + ) -> Result { + // Again, the input argument is an array of floating + // point numbers to calculate a moving average + let arr: &Float64Array = values[0].as_ref().as_primitive::(); + + let range_len = range.end - range.start; + + // our smoothing function will average all the values in the + let output = if range_len > 0 { + let sum: f64 = arr.values().iter().skip(range.start).take(range_len).sum(); + Some(sum / range_len as f64) + } else { + None + }; + + Ok(ScalarValue::Float64(output)) + } +} + +// create local execution context with `cars.csv` registered as a table named `cars` +async fn create_context() -> Result { + // declare a new context. In spark API, this corresponds to a new spark SQL session + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + println!("pwd: {}", std::env::current_dir().unwrap().display()); + let csv_path = "../../datafusion/core/tests/data/cars.csv".to_string(); + let read_options = CsvReadOptions::default().has_header(true); + + ctx.register_csv("cars", &csv_path, read_options).await?; + Ok(ctx) +} + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = create_context().await?; + let smooth_it = WindowUDF::from(SmoothItUdf::new()); + ctx.register_udwf(smooth_it.clone()); + + // Use SQL to run the new window function + let df = ctx.sql("SELECT * from cars").await?; + // print the results + df.show().await?; + + // Use SQL to run the new window function: + // + // `PARTITION BY car`:each distinct value of car (red, and green) + // should be treated as a separate partition (and will result in + // creating a new `PartitionEvaluator`) + // + // `ORDER BY time`: within each partition ('green' or 'red') the + // rows will be be ordered by the value in the `time` column + // + // `evaluate_inside_range` is invoked with a window defined by the + // SQL. In this case: + // + // The first invocation will be passed row 0, the first row in the + // partition. + // + // The second invocation will be passed rows 0 and 1, the first + // two rows in the partition. + // + // etc. + let df = ctx + .sql( + "SELECT \ + car, \ + speed, \ + smooth_it(speed) OVER (PARTITION BY car ORDER BY time) AS smooth_speed,\ + time \ + from cars \ + ORDER BY \ + car", + ) + .await?; + // print the results + df.show().await?; + + // this time, call the new widow function with an explicit + // window so evaluate will be invoked with each window. + // + // `ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING`: each invocation + // sees at most 3 rows: the row before, the current row, and the 1 + // row afterward. + let df = ctx.sql( + "SELECT \ + car, \ + speed, \ + smooth_it(speed) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS smooth_speed,\ + time \ + from cars \ + ORDER BY \ + car", + ).await?; + // print the results + df.show().await?; + + // Now, run the function using the DataFrame API: + let window_expr = smooth_it.call( + vec![col("speed")], // smooth_it(speed) + vec![col("car")], // PARTITION BY car + vec![col("time").sort(true, true)], // ORDER BY time ASC + WindowFrame::new(false), + ); + let df = ctx.table("cars").await?.window(vec![window_expr])?; + + // print the results + df.show().await?; + + Ok(()) +} diff --git a/datafusion-examples/examples/csv_opener.rs b/datafusion-examples/examples/csv_opener.rs index 15fb07ded4811..96753c8c52608 100644 --- a/datafusion-examples/examples/csv_opener.rs +++ b/datafusion-examples/examples/csv_opener.rs @@ -67,7 +67,6 @@ async fn main() -> Result<()> { limit: Some(5), table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }; let result = diff --git a/datafusion-examples/examples/custom_datasource.rs b/datafusion-examples/examples/custom_datasource.rs index 9f25a0b2fa477..69f9c9530e871 100644 --- a/datafusion-examples/examples/custom_datasource.rs +++ b/datafusion-examples/examples/custom_datasource.rs @@ -80,7 +80,7 @@ async fn search_accounts( timeout(Duration::from_secs(10), async move { let result = dataframe.collect().await.unwrap(); - let record_batch = result.get(0).unwrap(); + let record_batch = result.first().unwrap(); assert_eq!(expected_result_length, record_batch.column(1).len()); dbg!(record_batch.columns()); diff --git a/datafusion-examples/examples/dataframe_output.rs b/datafusion-examples/examples/dataframe_output.rs new file mode 100644 index 0000000000000..c773384dfcd50 --- /dev/null +++ b/datafusion-examples/examples/dataframe_output.rs @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::{dataframe::DataFrameWriteOptions, prelude::*}; +use datafusion_common::{parsers::CompressionTypeVariant, DataFusionError}; + +/// This example demonstrates the various methods to write out a DataFrame to local storage. +/// See datafusion-examples/examples/external_dependency/dataframe-to-s3.rs for an example +/// using a remote object store. +#[tokio::main] +async fn main() -> Result<(), DataFusionError> { + let ctx = SessionContext::new(); + + let mut df = ctx.sql("values ('a'), ('b'), ('c')").await.unwrap(); + + // Ensure the column names and types match the target table + df = df.with_column_renamed("column1", "tablecol1").unwrap(); + + ctx.sql( + "create external table + test(tablecol1 varchar) + stored as parquet + location './datafusion-examples/test_table/'", + ) + .await? + .collect() + .await?; + + // This is equivalent to INSERT INTO test VALUES ('a'), ('b'), ('c'). + // The behavior of write_table depends on the TableProvider's implementation + // of the insert_into method. + df.clone() + .write_table("test", DataFrameWriteOptions::new()) + .await?; + + df.clone() + .write_parquet( + "./datafusion-examples/test_parquet/", + DataFrameWriteOptions::new(), + None, + ) + .await?; + + df.clone() + .write_csv( + "./datafusion-examples/test_csv/", + // DataFrameWriteOptions contains options which control how data is written + // such as compression codec + DataFrameWriteOptions::new().with_compression(CompressionTypeVariant::GZIP), + None, + ) + .await?; + + df.clone() + .write_json( + "./datafusion-examples/test_json/", + DataFrameWriteOptions::new(), + ) + .await?; + + Ok(()) +} diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index 97abf4d552a9d..715e1ff2dce60 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -15,28 +15,43 @@ // specific language governing permissions and limitations // under the License. +use arrow::array::{BooleanArray, Int32Array}; +use arrow::record_batch::RecordBatch; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::error::Result; use datafusion::optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext}; use datafusion::physical_expr::execution_props::ExecutionProps; +use datafusion::physical_expr::{ + analyze, create_physical_expr, AnalysisContext, ExprBoundaries, PhysicalExpr, +}; use datafusion::prelude::*; use datafusion_common::{ScalarValue, ToDFSchema}; use datafusion_expr::expr::BinaryExpr; -use datafusion_expr::Operator; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::{ColumnarValue, ExprSchemable, Operator}; +use std::sync::Arc; /// This example demonstrates the DataFusion [`Expr`] API. /// /// DataFusion comes with a powerful and extensive system for /// representing and manipulating expressions such as `A + 5` and `X -/// IN ('foo', 'bar', 'baz')` and many other constructs. +/// IN ('foo', 'bar', 'baz')`. +/// +/// In addition to building and manipulating [`Expr`]s, DataFusion +/// also comes with APIs for evaluation, simplification, and analysis. +/// +/// The code in this example shows how to: +/// 1. Create [`Exprs`] using different APIs: [`main`]` +/// 2. Evaluate [`Exprs`] against data: [`evaluate_demo`] +/// 3. Simplify expressions: [`simplify_demo`] +/// 4. Analyze predicates for boundary ranges: [`range_analysis_demo`] #[tokio::main] async fn main() -> Result<()> { // The easiest way to do create expressions is to use the - // "fluent"-style API, like this: + // "fluent"-style API: let expr = col("a") + lit(5); - // this creates the same expression as the following though with - // much less code, + // The same same expression can be created directly, with much more code: let expr2 = Expr::BinaryExpr(BinaryExpr::new( Box::new(col("a")), Operator::Plus, @@ -44,15 +59,51 @@ async fn main() -> Result<()> { )); assert_eq!(expr, expr2); + // See how to evaluate expressions + evaluate_demo()?; + + // See how to simplify expressions simplify_demo()?; + // See how to analyze ranges in expressions + range_analysis_demo()?; + + Ok(()) +} + +/// DataFusion can also evaluate arbitrary expressions on Arrow arrays. +fn evaluate_demo() -> Result<()> { + // For example, let's say you have some integers in an array + let batch = RecordBatch::try_from_iter([( + "a", + Arc::new(Int32Array::from(vec![4, 5, 6, 7, 8, 7, 4])) as _, + )])?; + + // If you want to find all rows where the expression `a < 5 OR a = 8` is true + let expr = col("a").lt(lit(5)).or(col("a").eq(lit(8))); + + // First, you make a "physical expression" from the logical `Expr` + let physical_expr = physical_expr(&batch.schema(), expr)?; + + // Now, you can evaluate the expression against the RecordBatch + let result = physical_expr.evaluate(&batch)?; + + // The result contain an array that is true only for where `a < 5 OR a = 8` + let expected_result = Arc::new(BooleanArray::from(vec![ + true, false, false, false, true, false, true, + ])) as _; + assert!( + matches!(&result, ColumnarValue::Array(r) if r == &expected_result), + "result: {:?}", + result + ); + Ok(()) } -/// In addition to easy construction, DataFusion exposes APIs for -/// working with and simplifying such expressions that call into the -/// same powerful and extensive implementation used for the query -/// engine. +/// In addition to easy construction, DataFusion exposes APIs for simplifying +/// such expression so they are more efficient to evaluate. This code is also +/// used by the query engine to optimize queries. fn simplify_demo() -> Result<()> { // For example, lets say you have has created an expression such // ts = to_timestamp("2020-09-08T12:00:00+00:00") @@ -94,7 +145,7 @@ fn simplify_demo() -> Result<()> { make_field("b", DataType::Boolean), ]) .to_dfschema_ref()?; - let context = SimplifyContext::new(&props).with_schema(schema); + let context = SimplifyContext::new(&props).with_schema(schema.clone()); let simplifier = ExprSimplifier::new(context); // basic arithmetic simplification @@ -120,6 +171,64 @@ fn simplify_demo() -> Result<()> { col("i").lt(lit(10)) ); + // String --> Date simplification + // `cast('2020-09-01' as date)` --> 18500 + assert_eq!( + simplifier.simplify(lit("2020-09-01").cast_to(&DataType::Date32, &schema)?)?, + lit(ScalarValue::Date32(Some(18506))) + ); + + Ok(()) +} + +/// DataFusion also has APIs for analyzing predicates (boolean expressions) to +/// determine any ranges restrictions on the inputs required for the predicate +/// evaluate to true. +fn range_analysis_demo() -> Result<()> { + // For example, let's say you are interested in finding data for all days + // in the month of September, 2020 + let september_1 = ScalarValue::Date32(Some(18506)); // 2020-09-01 + let october_1 = ScalarValue::Date32(Some(18536)); // 2020-10-01 + + // The predicate to find all such days could be + // `date > '2020-09-01' AND date < '2020-10-01'` + let expr = col("date") + .gt(lit(september_1.clone())) + .and(col("date").lt(lit(october_1.clone()))); + + // Using the analysis API, DataFusion can determine that the value of `date` + // must be in the range `['2020-09-01', '2020-10-01']`. If your data is + // organized in files according to day, this information permits skipping + // entire files without reading them. + // + // While this simple example could be handled with a special case, the + // DataFusion API handles arbitrary expressions (so for example, you don't + // have to handle the case where the predicate clauses are reversed such as + // `date < '2020-10-01' AND date > '2020-09-01'` + + // As always, we need to tell DataFusion the type of column "date" + let schema = Schema::new(vec![make_field("date", DataType::Date32)]); + + // You can provide DataFusion any known boundaries on the values of `date` + // (for example, maybe you know you only have data up to `2020-09-15`), but + // in this case, let's say we don't know any boundaries beforehand so we use + // `try_new_unknown` + let boundaries = ExprBoundaries::try_new_unbounded(&schema)?; + + // Now, we invoke the analysis code to perform the range analysis + let physical_expr = physical_expr(&schema, expr)?; + let analysis_result = + analyze(&physical_expr, AnalysisContext::new(boundaries), &schema)?; + + // The results of the analysis is an range, encoded as an `Interval`, for + // each column in the schema, that must be true in order for the predicate + // to be true. + // + // In this case, we can see that, as expected, `analyze` has figured out + // that in this case, `date` must be in the range `['2020-09-01', '2020-10-01']` + let expected_range = Interval::try_new(september_1, october_1)?; + assert_eq!(analysis_result.boundaries[0].interval, expected_range); + Ok(()) } @@ -132,3 +241,18 @@ fn make_ts_field(name: &str) -> Field { let tz = None; make_field(name, DataType::Timestamp(TimeUnit::Nanosecond, tz)) } + +/// Build a physical expression from a logical one, after applying simplification and type coercion +pub fn physical_expr(schema: &Schema, expr: Expr) -> Result> { + let df_schema = schema.clone().to_dfschema_ref()?; + + // Simplify + let props = ExecutionProps::new(); + let simplifier = + ExprSimplifier::new(SimplifyContext::new(&props).with_schema(df_schema.clone())); + + // apply type coercion here to ensure types match + let expr = simplifier.coerce(expr, df_schema.clone())?; + + create_physical_expr(&expr, df_schema.as_ref(), schema, &props) +} diff --git a/datafusion-examples/examples/json_opener.rs b/datafusion-examples/examples/json_opener.rs index 1a3dbe57be75e..ee33f969caa9f 100644 --- a/datafusion-examples/examples/json_opener.rs +++ b/datafusion-examples/examples/json_opener.rs @@ -70,7 +70,6 @@ async fn main() -> Result<()> { limit: Some(5), table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }; let result = diff --git a/datafusion-examples/examples/memtable.rs b/datafusion-examples/examples/memtable.rs index bef8f3e5bb8f5..5cce578039e74 100644 --- a/datafusion-examples/examples/memtable.rs +++ b/datafusion-examples/examples/memtable.rs @@ -40,7 +40,7 @@ async fn main() -> Result<()> { timeout(Duration::from_secs(10), async move { let result = dataframe.collect().await.unwrap(); - let record_batch = result.get(0).unwrap(); + let record_batch = result.first().unwrap(); assert_eq!(1, record_batch.column(0).len()); dbg!(record_batch.columns()); diff --git a/datafusion-examples/examples/simple_udaf.rs b/datafusion-examples/examples/simple_udaf.rs index 7aec9698d92f3..2c797f221b2cc 100644 --- a/datafusion-examples/examples/simple_udaf.rs +++ b/datafusion-examples/examples/simple_udaf.rs @@ -154,6 +154,10 @@ async fn main() -> Result<()> { // This is the description of the state. `state()` must match the types here. Arc::new(vec![DataType::Float64, DataType::UInt32]), ); + ctx.register_udaf(geometric_mean.clone()); + + let sql_df = ctx.sql("SELECT geo_mean(a) FROM t").await?; + sql_df.show().await?; // get a DataFrame from the context // this table has 1 column `a` f32 with values {2,4,8,64}, whose geometric mean is 8.0. diff --git a/datafusion-examples/examples/simple_udf.rs b/datafusion-examples/examples/simple_udf.rs index dba4385b8eeaa..39e1e13ce39aa 100644 --- a/datafusion-examples/examples/simple_udf.rs +++ b/datafusion-examples/examples/simple_udf.rs @@ -29,23 +29,23 @@ use datafusion::{error::Result, physical_plan::functions::make_scalar_function}; use datafusion_common::cast::as_float64_array; use std::sync::Arc; -// create local execution context with an in-memory table +/// create local execution context with an in-memory table: +/// +/// ```text +/// +-----+-----+ +/// | a | b | +/// +-----+-----+ +/// | 2.1 | 1.0 | +/// | 3.1 | 2.0 | +/// | 4.1 | 3.0 | +/// | 5.1 | 4.0 | +/// +-----+-----+ +/// ``` fn create_context() -> Result { - use datafusion::arrow::datatypes::{Field, Schema}; - // define a schema. - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Float32, false), - Field::new("b", DataType::Float64, false), - ])); - // define data. - let batch = RecordBatch::try_new( - schema, - vec![ - Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1])), - Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])), - ], - )?; + let a: ArrayRef = Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1])); + let b: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])); + let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)])?; // declare a new context. In spark API, this corresponds to a new spark SQLsession let ctx = SessionContext::new(); @@ -140,5 +140,11 @@ async fn main() -> Result<()> { // print the results df.show().await?; + // Given that `pow` is registered in the context, we can also use it in SQL: + let sql_df = ctx.sql("SELECT pow(a, b) FROM t").await?; + + // print the results + sql_df.show().await?; + Ok(()) } diff --git a/datafusion-examples/examples/simple_udtf.rs b/datafusion-examples/examples/simple_udtf.rs new file mode 100644 index 0000000000000..f1d763ba6e413 --- /dev/null +++ b/datafusion-examples/examples/simple_udtf.rs @@ -0,0 +1,178 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::csv::reader::Format; +use arrow::csv::ReaderBuilder; +use async_trait::async_trait; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::datasource::function::TableFunctionImpl; +use datafusion::datasource::TableProvider; +use datafusion::error::Result; +use datafusion::execution::context::{ExecutionProps, SessionState}; +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::SessionContext; +use datafusion_common::{plan_err, DataFusionError, ScalarValue}; +use datafusion_expr::{Expr, TableType}; +use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext}; +use std::fs::File; +use std::io::Seek; +use std::path::Path; +use std::sync::Arc; + +// To define your own table function, you only need to do the following 3 things: +// 1. Implement your own [`TableProvider`] +// 2. Implement your own [`TableFunctionImpl`] and return your [`TableProvider`] +// 3. Register the function using [`SessionContext::register_udtf`] + +/// This example demonstrates how to register a TableFunction +#[tokio::main] +async fn main() -> Result<()> { + // create local execution context + let ctx = SessionContext::new(); + + // register the table function that will be called in SQL statements by `read_csv` + ctx.register_udtf("read_csv", Arc::new(LocalCsvTableFunc {})); + + let testdata = datafusion::test_util::arrow_test_data(); + let csv_file = format!("{testdata}/csv/aggregate_test_100.csv"); + + // Pass 2 arguments, read csv with at most 2 rows (simplify logic makes 1+1 --> 2) + let df = ctx + .sql(format!("SELECT * FROM read_csv('{csv_file}', 1 + 1);").as_str()) + .await?; + df.show().await?; + + // just run, return all rows + let df = ctx + .sql(format!("SELECT * FROM read_csv('{csv_file}');").as_str()) + .await?; + df.show().await?; + + Ok(()) +} + +/// Table Function that mimics the [`read_csv`] function in DuckDB. +/// +/// Usage: `read_csv(filename, [limit])` +/// +/// [`read_csv`]: https://duckdb.org/docs/data/csv/overview.html +struct LocalCsvTable { + schema: SchemaRef, + limit: Option, + batches: Vec, +} + +#[async_trait] +impl TableProvider for LocalCsvTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &SessionState, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + let batches = if let Some(max_return_lines) = self.limit { + // get max return rows from self.batches + let mut batches = vec![]; + let mut lines = 0; + for batch in &self.batches { + let batch_lines = batch.num_rows(); + if lines + batch_lines > max_return_lines { + let batch_lines = max_return_lines - lines; + batches.push(batch.slice(0, batch_lines)); + break; + } else { + batches.push(batch.clone()); + lines += batch_lines; + } + } + batches + } else { + self.batches.clone() + }; + Ok(Arc::new(MemoryExec::try_new( + &[batches], + TableProvider::schema(self), + projection.cloned(), + )?)) + } +} + +struct LocalCsvTableFunc {} + +impl TableFunctionImpl for LocalCsvTableFunc { + fn call(&self, exprs: &[Expr]) -> Result> { + let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)))) = exprs.first() else { + return plan_err!("read_csv requires at least one string argument"); + }; + + let limit = exprs + .get(1) + .map(|expr| { + // try to simpify the expression, so 1+2 becomes 3, for example + let execution_props = ExecutionProps::new(); + let info = SimplifyContext::new(&execution_props); + let expr = ExprSimplifier::new(info).simplify(expr.clone())?; + + if let Expr::Literal(ScalarValue::Int64(Some(limit))) = expr { + Ok(limit as usize) + } else { + plan_err!("Limit must be an integer") + } + }) + .transpose()?; + + let (schema, batches) = read_csv_batches(path)?; + + let table = LocalCsvTable { + schema, + limit, + batches, + }; + Ok(Arc::new(table)) + } +} + +fn read_csv_batches(csv_path: impl AsRef) -> Result<(SchemaRef, Vec)> { + let mut file = File::open(csv_path)?; + let (schema, _) = Format::default().infer_schema(&mut file, None)?; + file.rewind()?; + + let reader = ReaderBuilder::new(Arc::new(schema.clone())) + .with_header(true) + .build(file)?; + let mut batches = vec![]; + for bacth in reader { + batches.push(bacth?); + } + let schema = Arc::new(schema); + Ok((schema, batches)) +} diff --git a/datafusion-examples/examples/simple_udwf.rs b/datafusion-examples/examples/simple_udwf.rs index d1cbcc7c43896..0d04c093e1478 100644 --- a/datafusion-examples/examples/simple_udwf.rs +++ b/datafusion-examples/examples/simple_udwf.rs @@ -89,7 +89,7 @@ async fn main() -> Result<()> { "SELECT \ car, \ speed, \ - smooth_it(speed) OVER (PARTITION BY car ORDER BY time),\ + smooth_it(speed) OVER (PARTITION BY car ORDER BY time) AS smooth_speed,\ time \ from cars \ ORDER BY \ @@ -109,7 +109,7 @@ async fn main() -> Result<()> { "SELECT \ car, \ speed, \ - smooth_it(speed) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\ + smooth_it(speed) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS smooth_speed,\ time \ from cars \ ORDER BY \ diff --git a/datafusion/CHANGELOG.md b/datafusion/CHANGELOG.md index e224b93876551..d64bbeda877df 100644 --- a/datafusion/CHANGELOG.md +++ b/datafusion/CHANGELOG.md @@ -19,6 +19,7 @@ # Changelog +- [34.0.0](../dev/changelog/34.0.0.md) - [33.0.0](../dev/changelog/33.0.0.md) - [32.0.0](../dev/changelog/32.0.0.md) - [31.0.0](../dev/changelog/31.0.0.md) diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index d04db86b78301..b69e1f7f3d108 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -38,17 +38,25 @@ backtrace = [] pyarrow = ["pyo3", "arrow/pyarrow", "parquet"] [dependencies] -ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } -apache-avro = { version = "0.16", default-features = false, features = ["bzip", "snappy", "xz", "zstandard"], optional = true } +ahash = { version = "0.8", default-features = false, features = [ + "runtime-rng", +] } +apache-avro = { version = "0.16", default-features = false, features = [ + "bzip", + "snappy", + "xz", + "zstandard", +], optional = true } arrow = { workspace = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } arrow-schema = { workspace = true } chrono = { workspace = true } half = { version = "2.1", default-features = false } +libc = "0.2.140" num_cpus = { workspace = true } -object_store = { version = "0.7.0", default-features = false, optional = true } -parquet = { workspace = true, optional = true } +object_store = { workspace = true, optional = true } +parquet = { workspace = true, optional = true, default-features = true } pyo3 = { version = "0.20.0", optional = true } sqlparser = { workspace = true } diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs index 4356f36b18d86..088f03e002ed3 100644 --- a/datafusion/common/src/cast.rs +++ b/datafusion/common/src/cast.rs @@ -181,23 +181,17 @@ pub fn as_timestamp_second_array(array: &dyn Array) -> Result<&TimestampSecondAr } // Downcast ArrayRef to IntervalYearMonthArray -pub fn as_interval_ym_array( - array: &dyn Array, -) -> Result<&IntervalYearMonthArray, DataFusionError> { +pub fn as_interval_ym_array(array: &dyn Array) -> Result<&IntervalYearMonthArray> { Ok(downcast_value!(array, IntervalYearMonthArray)) } // Downcast ArrayRef to IntervalDayTimeArray -pub fn as_interval_dt_array( - array: &dyn Array, -) -> Result<&IntervalDayTimeArray, DataFusionError> { +pub fn as_interval_dt_array(array: &dyn Array) -> Result<&IntervalDayTimeArray> { Ok(downcast_value!(array, IntervalDayTimeArray)) } // Downcast ArrayRef to IntervalMonthDayNanoArray -pub fn as_interval_mdn_array( - array: &dyn Array, -) -> Result<&IntervalMonthDayNanoArray, DataFusionError> { +pub fn as_interval_mdn_array(array: &dyn Array) -> Result<&IntervalMonthDayNanoArray> { Ok(downcast_value!(array, IntervalMonthDayNanoArray)) } diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 403241fcce581..5b1325ec06eef 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -273,6 +273,11 @@ config_namespace! { /// memory consumption pub max_buffered_batches_per_output_file: usize, default = 2 + /// Should sub directories be ignored when scanning directories for data + /// files. Defaults to true (ignores subdirectories), consistent with + /// Hive. Note that this setting does not affect reading partitioned + /// tables (e.g. `/table/year=2021/month=01/data.parquet`). + pub listing_table_ignore_subdirectory: bool, default = true } } @@ -427,6 +432,11 @@ config_namespace! { config_namespace! { /// Options related to query optimization pub struct OptimizerOptions { + /// When set to true, the optimizer will push a limit operation into + /// grouped aggregations which have no aggregate expressions, as a soft limit, + /// emitting groups once the limit is reached, before all rows in the group are read. + pub enable_distinct_aggregation_soft_limit: bool, default = true + /// When set to true, the physical plan optimizer will try to add round robin /// repartitioning to increase parallelism to leverage more CPU cores pub enable_round_robin_repartition: bool, default = true @@ -519,6 +529,11 @@ config_namespace! { /// The maximum estimated size in bytes for one input side of a HashJoin /// will be collected into a single partition pub hash_join_single_partition_threshold: usize, default = 1024 * 1024 + + /// The default filter selectivity used by Filter Statistics + /// when an exact selectivity cannot be determined. Valid values are + /// between 0 (no selectivity) and 100 (all rows are selected). + pub default_filter_selectivity: u8, default = 20 } } @@ -872,6 +887,7 @@ config_field!(String); config_field!(bool); config_field!(usize); config_field!(f64); +config_field!(u8); config_field!(u64); /// An implementation trait used to recursively walk configuration diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index d8cd103a47778..d6e4490cec4c1 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -34,10 +34,75 @@ use crate::{ use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field, FieldRef, Fields, Schema, SchemaRef}; -/// A reference-counted reference to a `DFSchema`. +/// A reference-counted reference to a [DFSchema]. pub type DFSchemaRef = Arc; -/// DFSchema wraps an Arrow schema and adds relation names +/// DFSchema wraps an Arrow schema and adds relation names. +/// +/// The schema may hold the fields across multiple tables. Some fields may be +/// qualified and some unqualified. A qualified field is a field that has a +/// relation name associated with it. +/// +/// Unqualified fields must be unique not only amongst themselves, but also must +/// have a distinct name from any qualified field names. This allows finding a +/// qualified field by name to be possible, so long as there aren't multiple +/// qualified fields with the same name. +/// +/// There is an alias to `Arc` named [DFSchemaRef]. +/// +/// # Creating qualified schemas +/// +/// Use [DFSchema::try_from_qualified_schema] to create a qualified schema from +/// an Arrow schema. +/// +/// ```rust +/// use datafusion_common::{DFSchema, Column}; +/// use arrow_schema::{DataType, Field, Schema}; +/// +/// let arrow_schema = Schema::new(vec![ +/// Field::new("c1", DataType::Int32, false), +/// ]); +/// +/// let df_schema = DFSchema::try_from_qualified_schema("t1", &arrow_schema).unwrap(); +/// let column = Column::from_qualified_name("t1.c1"); +/// assert!(df_schema.has_column(&column)); +/// +/// // Can also access qualified fields with unqualified name, if it's unambiguous +/// let column = Column::from_qualified_name("c1"); +/// assert!(df_schema.has_column(&column)); +/// ``` +/// +/// # Creating unqualified schemas +/// +/// Create an unqualified schema using TryFrom: +/// +/// ```rust +/// use datafusion_common::{DFSchema, Column}; +/// use arrow_schema::{DataType, Field, Schema}; +/// +/// let arrow_schema = Schema::new(vec![ +/// Field::new("c1", DataType::Int32, false), +/// ]); +/// +/// let df_schema = DFSchema::try_from(arrow_schema).unwrap(); +/// let column = Column::new_unqualified("c1"); +/// assert!(df_schema.has_column(&column)); +/// ``` +/// +/// # Converting back to Arrow schema +/// +/// Use the `Into` trait to convert `DFSchema` into an Arrow schema: +/// +/// ```rust +/// use datafusion_common::{DFSchema, DFField}; +/// use arrow_schema::Schema; +/// +/// let df_schema = DFSchema::new(vec![ +/// DFField::new_unqualified("c1", arrow::datatypes::DataType::Int32, false), +/// ]).unwrap(); +/// let schema = Schema::from(df_schema); +/// assert_eq!(schema.fields().len(), 1); +/// ``` #[derive(Debug, Clone, PartialEq, Eq)] pub struct DFSchema { /// Fields @@ -112,6 +177,9 @@ impl DFSchema { } /// Create a `DFSchema` from an Arrow schema and a given qualifier + /// + /// To create a schema from an Arrow schema without a qualifier, use + /// `DFSchema::try_from`. pub fn try_from_qualified_schema<'a>( qualifier: impl Into>, schema: &Schema, @@ -131,9 +199,16 @@ impl DFSchema { pub fn with_functional_dependencies( mut self, functional_dependencies: FunctionalDependencies, - ) -> Self { - self.functional_dependencies = functional_dependencies; - self + ) -> Result { + if functional_dependencies.is_valid(self.fields.len()) { + self.functional_dependencies = functional_dependencies; + Ok(self) + } else { + _plan_err!( + "Invalid functional dependency: {:?}", + functional_dependencies + ) + } } /// Create a new schema that contains the fields from this schema followed by the fields @@ -272,6 +347,22 @@ impl DFSchema { .collect() } + /// Find all fields indices having the given qualifier + pub fn fields_indices_with_qualified( + &self, + qualifier: &TableReference, + ) -> Vec { + self.fields + .iter() + .enumerate() + .filter_map(|(idx, field)| { + field + .qualifier() + .and_then(|q| q.eq(qualifier).then_some(idx)) + }) + .collect() + } + /// Find all fields match the given name pub fn fields_with_unqualified_name(&self, name: &str) -> Vec<&DFField> { self.fields @@ -1408,8 +1499,8 @@ mod tests { DFSchema::new_with_metadata([a, b].to_vec(), HashMap::new()).unwrap(), ); let schema: Schema = df_schema.as_ref().clone().into(); - let a_df = df_schema.fields.get(0).unwrap().field(); - let a_arrow = schema.fields.get(0).unwrap(); + let a_df = df_schema.fields.first().unwrap().field(); + let a_arrow = schema.fields.first().unwrap(); assert_eq!(a_df.metadata(), a_arrow.metadata()) } diff --git a/datafusion/common/src/display/mod.rs b/datafusion/common/src/display/mod.rs index 766b37ce2891b..4d1d48bf9fcc7 100644 --- a/datafusion/common/src/display/mod.rs +++ b/datafusion/common/src/display/mod.rs @@ -47,6 +47,8 @@ pub enum PlanType { FinalLogicalPlan, /// The initial physical plan, prepared for execution InitialPhysicalPlan, + /// The initial physical plan with stats, prepared for execution + InitialPhysicalPlanWithStats, /// The ExecutionPlan which results from applying an optimizer pass OptimizedPhysicalPlan { /// The name of the optimizer which produced this plan @@ -54,6 +56,8 @@ pub enum PlanType { }, /// The final, fully optimized physical which would be executed FinalPhysicalPlan, + /// The final with stats, fully optimized physical which would be executed + FinalPhysicalPlanWithStats, } impl Display for PlanType { @@ -69,10 +73,14 @@ impl Display for PlanType { } PlanType::FinalLogicalPlan => write!(f, "logical_plan"), PlanType::InitialPhysicalPlan => write!(f, "initial_physical_plan"), + PlanType::InitialPhysicalPlanWithStats => { + write!(f, "initial_physical_plan_with_stats") + } PlanType::OptimizedPhysicalPlan { optimizer_name } => { write!(f, "physical_plan after {optimizer_name}") } PlanType::FinalPhysicalPlan => write!(f, "physical_plan"), + PlanType::FinalPhysicalPlanWithStats => write!(f, "physical_plan_with_stats"), } } } diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 9114c669ab8bc..e58faaa15096d 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -47,7 +47,8 @@ pub type GenericError = Box; #[derive(Debug)] pub enum DataFusionError { /// Error returned by arrow. - ArrowError(ArrowError), + /// 2nd argument is for optional backtrace + ArrowError(ArrowError, Option), /// Wraps an error from the Parquet crate #[cfg(feature = "parquet")] ParquetError(ParquetError), @@ -60,7 +61,8 @@ pub enum DataFusionError { /// Error associated to I/O operations and associated traits. IoError(io::Error), /// Error returned when SQL is syntactically incorrect. - SQL(ParserError), + /// 2nd argument is for optional backtrace + SQL(ParserError, Option), /// Error returned on a branch that we know it is possible /// but to which we still have no implementation for. /// Often, these errors are tracked in our issue tracker. @@ -223,14 +225,14 @@ impl From for DataFusionError { impl From for DataFusionError { fn from(e: ArrowError) -> Self { - DataFusionError::ArrowError(e) + DataFusionError::ArrowError(e, None) } } impl From for ArrowError { fn from(e: DataFusionError) -> Self { match e { - DataFusionError::ArrowError(e) => e, + DataFusionError::ArrowError(e, _) => e, DataFusionError::External(e) => ArrowError::ExternalError(e), other => ArrowError::ExternalError(Box::new(other)), } @@ -267,7 +269,7 @@ impl From for DataFusionError { impl From for DataFusionError { fn from(e: ParserError) -> Self { - DataFusionError::SQL(e) + DataFusionError::SQL(e, None) } } @@ -280,8 +282,9 @@ impl From for DataFusionError { impl Display for DataFusionError { fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { match *self { - DataFusionError::ArrowError(ref desc) => { - write!(f, "Arrow error: {desc}") + DataFusionError::ArrowError(ref desc, ref backtrace) => { + let backtrace = backtrace.clone().unwrap_or("".to_owned()); + write!(f, "Arrow error: {desc}{backtrace}") } #[cfg(feature = "parquet")] DataFusionError::ParquetError(ref desc) => { @@ -294,8 +297,9 @@ impl Display for DataFusionError { DataFusionError::IoError(ref desc) => { write!(f, "IO error: {desc}") } - DataFusionError::SQL(ref desc) => { - write!(f, "SQL error: {desc:?}") + DataFusionError::SQL(ref desc, ref backtrace) => { + let backtrace = backtrace.clone().unwrap_or("".to_owned()); + write!(f, "SQL error: {desc:?}{backtrace}") } DataFusionError::Configuration(ref desc) => { write!(f, "Invalid or Unsupported Configuration: {desc}") @@ -339,7 +343,7 @@ impl Display for DataFusionError { impl Error for DataFusionError { fn source(&self) -> Option<&(dyn Error + 'static)> { match self { - DataFusionError::ArrowError(e) => Some(e), + DataFusionError::ArrowError(e, _) => Some(e), #[cfg(feature = "parquet")] DataFusionError::ParquetError(e) => Some(e), #[cfg(feature = "avro")] @@ -347,7 +351,7 @@ impl Error for DataFusionError { #[cfg(feature = "object_store")] DataFusionError::ObjectStore(e) => Some(e), DataFusionError::IoError(e) => Some(e), - DataFusionError::SQL(e) => Some(e), + DataFusionError::SQL(e, _) => Some(e), DataFusionError::NotImplemented(_) => None, DataFusionError::Internal(_) => None, DataFusionError::Configuration(_) => None, @@ -505,29 +509,56 @@ macro_rules! make_error { }; } -// Exposes a macro to create `DataFusionError::Plan` +// Exposes a macro to create `DataFusionError::Plan` with optional backtrace make_error!(plan_err, plan_datafusion_err, Plan); -// Exposes a macro to create `DataFusionError::Internal` +// Exposes a macro to create `DataFusionError::Internal` with optional backtrace make_error!(internal_err, internal_datafusion_err, Internal); -// Exposes a macro to create `DataFusionError::NotImplemented` +// Exposes a macro to create `DataFusionError::NotImplemented` with optional backtrace make_error!(not_impl_err, not_impl_datafusion_err, NotImplemented); -// Exposes a macro to create `DataFusionError::Execution` +// Exposes a macro to create `DataFusionError::Execution` with optional backtrace make_error!(exec_err, exec_datafusion_err, Execution); -// Exposes a macro to create `DataFusionError::SQL` +// Exposes a macro to create `DataFusionError::Substrait` with optional backtrace +make_error!(substrait_err, substrait_datafusion_err, Substrait); + +// Exposes a macro to create `DataFusionError::SQL` with optional backtrace +#[macro_export] +macro_rules! sql_datafusion_err { + ($ERR:expr) => { + DataFusionError::SQL($ERR, Some(DataFusionError::get_back_trace())) + }; +} + +// Exposes a macro to create `Err(DataFusionError::SQL)` with optional backtrace #[macro_export] macro_rules! sql_err { ($ERR:expr) => { - Err(DataFusionError::SQL($ERR)) + Err(datafusion_common::sql_datafusion_err!($ERR)) + }; +} + +// Exposes a macro to create `DataFusionError::ArrowError` with optional backtrace +#[macro_export] +macro_rules! arrow_datafusion_err { + ($ERR:expr) => { + DataFusionError::ArrowError($ERR, Some(DataFusionError::get_back_trace())) + }; +} + +// Exposes a macro to create `Err(DataFusionError::ArrowError)` with optional backtrace +#[macro_export] +macro_rules! arrow_err { + ($ERR:expr) => { + Err(datafusion_common::arrow_datafusion_err!($ERR)) }; } // To avoid compiler error when using macro in the same crate: // macros from the current crate cannot be referred to by absolute paths -pub use exec_err as _exec_err; +pub use internal_datafusion_err as _internal_datafusion_err; pub use internal_err as _internal_err; pub use not_impl_err as _not_impl_err; pub use plan_err as _plan_err; @@ -564,18 +595,16 @@ mod test { assert_eq!( err.split(DataFusionError::BACK_TRACE_SEP) .collect::>() - .get(0) + .first() .unwrap(), &"Error during planning: Err" ); - assert!( - err.split(DataFusionError::BACK_TRACE_SEP) - .collect::>() - .get(1) - .unwrap() - .len() - > 0 - ); + assert!(!err + .split(DataFusionError::BACK_TRACE_SEP) + .collect::>() + .get(1) + .unwrap() + .is_empty()); } #[cfg(not(feature = "backtrace"))] @@ -599,9 +628,12 @@ mod test { ); do_root_test( - DataFusionError::ArrowError(ArrowError::ExternalError(Box::new( - DataFusionError::ResourcesExhausted("foo".to_string()), - ))), + DataFusionError::ArrowError( + ArrowError::ExternalError(Box::new(DataFusionError::ResourcesExhausted( + "foo".to_string(), + ))), + None, + ), DataFusionError::ResourcesExhausted("foo".to_string()), ); @@ -620,11 +652,12 @@ mod test { ); do_root_test( - DataFusionError::ArrowError(ArrowError::ExternalError(Box::new( - ArrowError::ExternalError(Box::new(DataFusionError::ResourcesExhausted( - "foo".to_string(), - ))), - ))), + DataFusionError::ArrowError( + ArrowError::ExternalError(Box::new(ArrowError::ExternalError(Box::new( + DataFusionError::ResourcesExhausted("foo".to_string()), + )))), + None, + ), DataFusionError::ResourcesExhausted("foo".to_string()), ); diff --git a/datafusion/common/src/file_options/csv_writer.rs b/datafusion/common/src/file_options/csv_writer.rs index fef4a1d21b4bc..d6046f0219dd3 100644 --- a/datafusion/common/src/file_options/csv_writer.rs +++ b/datafusion/common/src/file_options/csv_writer.rs @@ -91,6 +91,12 @@ impl TryFrom<(&ConfigOptions, &StatementOptions)> for CsvWriterOptions { ) })?) }, + "quote" | "escape" => { + // https://github.com/apache/arrow-rs/issues/5146 + // These two attributes are only available when reading csv files. + // To avoid error + builder + }, _ => return Err(DataFusionError::Configuration(format!("Found unsupported option {option} with value {value} for CSV format!"))) } } diff --git a/datafusion/common/src/file_options/file_type.rs b/datafusion/common/src/file_options/file_type.rs index a07f2e0cb847b..97362bdad3ccc 100644 --- a/datafusion/common/src/file_options/file_type.rs +++ b/datafusion/common/src/file_options/file_type.rs @@ -103,6 +103,7 @@ impl FromStr for FileType { } #[cfg(test)] +#[cfg(feature = "parquet")] mod tests { use crate::error::DataFusionError; use crate::file_options::FileType; diff --git a/datafusion/common/src/file_options/mod.rs b/datafusion/common/src/file_options/mod.rs index b7c1341e30460..1d661b17eb1c0 100644 --- a/datafusion/common/src/file_options/mod.rs +++ b/datafusion/common/src/file_options/mod.rs @@ -296,6 +296,7 @@ impl Display for FileTypeWriterOptions { } #[cfg(test)] +#[cfg(feature = "parquet")] mod tests { use std::collections::HashMap; @@ -506,6 +507,7 @@ mod tests { } #[test] + // for StatementOptions fn test_writeroptions_csv_from_statement_options() -> Result<()> { let mut option_map: HashMap = HashMap::new(); option_map.insert("header".to_owned(), "true".to_owned()); @@ -533,6 +535,7 @@ mod tests { } #[test] + // for StatementOptions fn test_writeroptions_json_from_statement_options() -> Result<()> { let mut option_map: HashMap = HashMap::new(); option_map.insert("compression".to_owned(), "gzip".to_owned()); diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs index fbddcddab4bcb..1cb1751d713ef 100644 --- a/datafusion/common/src/functional_dependencies.rs +++ b/datafusion/common/src/functional_dependencies.rs @@ -24,6 +24,7 @@ use std::ops::Deref; use std::vec::IntoIter; use crate::error::_plan_err; +use crate::utils::{merge_and_order_indices, set_difference}; use crate::{DFSchema, DFSchemaRef, DataFusionError, JoinType, Result}; use sqlparser::ast::TableConstraint; @@ -271,6 +272,29 @@ impl FunctionalDependencies { self.deps.extend(other.deps); } + /// Sanity checks if functional dependencies are valid. For example, if + /// there are 10 fields, we cannot receive any index further than 9. + pub fn is_valid(&self, n_field: usize) -> bool { + self.deps.iter().all( + |FunctionalDependence { + source_indices, + target_indices, + .. + }| { + source_indices + .iter() + .max() + .map(|&max_index| max_index < n_field) + .unwrap_or(true) + && target_indices + .iter() + .max() + .map(|&max_index| max_index < n_field) + .unwrap_or(true) + }, + ) + } + /// Adds the `offset` value to `source_indices` and `target_indices` for /// each functional dependency. pub fn add_offset(&mut self, offset: usize) { @@ -413,6 +437,14 @@ impl FunctionalDependencies { } } +impl Deref for FunctionalDependencies { + type Target = [FunctionalDependence]; + + fn deref(&self) -> &Self::Target { + self.deps.as_slice() + } +} + /// Calculates functional dependencies for aggregate output, when there is a GROUP BY expression. pub fn aggregate_functional_dependencies( aggr_input_schema: &DFSchema, @@ -434,44 +466,56 @@ pub fn aggregate_functional_dependencies( } in &func_dependencies.deps { // Keep source indices in a `HashSet` to prevent duplicate entries: - let mut new_source_indices = HashSet::new(); + let mut new_source_indices = vec![]; + let mut new_source_field_names = vec![]; let source_field_names = source_indices .iter() .map(|&idx| aggr_input_fields[idx].qualified_name()) .collect::>(); + for (idx, group_by_expr_name) in group_by_expr_names.iter().enumerate() { // When one of the input determinant expressions matches with // the GROUP BY expression, add the index of the GROUP BY // expression as a new determinant key: if source_field_names.contains(group_by_expr_name) { - new_source_indices.insert(idx); + new_source_indices.push(idx); + new_source_field_names.push(group_by_expr_name.clone()); } } + let existing_target_indices = + get_target_functional_dependencies(aggr_input_schema, group_by_expr_names); + let new_target_indices = get_target_functional_dependencies( + aggr_input_schema, + &new_source_field_names, + ); + let mode = if existing_target_indices == new_target_indices + && new_target_indices.is_some() + { + // If dependency covers all GROUP BY expressions, mode will be `Single`: + Dependency::Single + } else { + // Otherwise, existing mode is preserved: + *mode + }; // All of the composite indices occur in the GROUP BY expression: if new_source_indices.len() == source_indices.len() { aggregate_func_dependencies.push( FunctionalDependence::new( - new_source_indices.into_iter().collect(), + new_source_indices, target_indices.clone(), *nullable, ) - // input uniqueness stays the same when GROUP BY matches with input functional dependence determinants - .with_mode(*mode), + .with_mode(mode), ); } } + // If we have a single GROUP BY key, we can guarantee uniqueness after // aggregation: if group_by_expr_names.len() == 1 { // If `source_indices` contain 0, delete this functional dependency // as it will be added anyway with mode `Dependency::Single`: - if let Some(idx) = aggregate_func_dependencies - .iter() - .position(|item| item.source_indices.contains(&0)) - { - // Delete the functional dependency that contains zeroth idx: - aggregate_func_dependencies.remove(idx); - } + aggregate_func_dependencies.retain(|item| !item.source_indices.contains(&0)); // Add a new functional dependency associated with the whole table: aggregate_func_dependencies.push( // Use nullable property of the group by expression @@ -519,8 +563,61 @@ pub fn get_target_functional_dependencies( combined_target_indices.extend(target_indices.iter()); } } - (!combined_target_indices.is_empty()) - .then_some(combined_target_indices.iter().cloned().collect::>()) + (!combined_target_indices.is_empty()).then_some({ + let mut result = combined_target_indices.into_iter().collect::>(); + result.sort(); + result + }) +} + +/// Returns indices for the minimal subset of GROUP BY expressions that are +/// functionally equivalent to the original set of GROUP BY expressions. +pub fn get_required_group_by_exprs_indices( + schema: &DFSchema, + group_by_expr_names: &[String], +) -> Option> { + let dependencies = schema.functional_dependencies(); + let field_names = schema + .fields() + .iter() + .map(|item| item.qualified_name()) + .collect::>(); + let mut groupby_expr_indices = group_by_expr_names + .iter() + .map(|group_by_expr_name| { + field_names + .iter() + .position(|field_name| field_name == group_by_expr_name) + }) + .collect::>>()?; + + groupby_expr_indices.sort(); + for FunctionalDependence { + source_indices, + target_indices, + .. + } in &dependencies.deps + { + if source_indices + .iter() + .all(|source_idx| groupby_expr_indices.contains(source_idx)) + { + // If all source indices are among GROUP BY expression indices, we + // can remove target indices from GROUP BY expression indices and + // use source indices instead. + groupby_expr_indices = set_difference(&groupby_expr_indices, target_indices); + groupby_expr_indices = + merge_and_order_indices(groupby_expr_indices, source_indices); + } + } + groupby_expr_indices + .iter() + .map(|idx| { + group_by_expr_names + .iter() + .position(|name| &field_names[*idx] == name) + }) + .collect() } /// Updates entries inside the `entries` vector with their corresponding diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index 9198461e00bf9..5c36f41a6e424 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -27,7 +27,8 @@ use arrow::{downcast_dictionary_array, downcast_primitive_array}; use arrow_buffer::i256; use crate::cast::{ - as_boolean_array, as_generic_binary_array, as_primitive_array, as_string_array, + as_boolean_array, as_generic_binary_array, as_large_list_array, as_list_array, + as_primitive_array, as_string_array, as_struct_array, }; use crate::error::{DataFusionError, Result, _internal_err}; @@ -207,6 +208,35 @@ fn hash_dictionary( Ok(()) } +fn hash_struct_array( + array: &StructArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) -> Result<()> { + let nulls = array.nulls(); + let num_columns = array.num_columns(); + + // Skip null columns + let valid_indices: Vec = if let Some(nulls) = nulls { + nulls.valid_indices().collect() + } else { + (0..num_columns).collect() + }; + + // Create hashes for each row that combines the hashes over all the column at that row. + // array.len() is the number of rows. + let mut values_hashes = vec![0u64; array.len()]; + create_hashes(array.columns(), random_state, &mut values_hashes)?; + + // Skip the null columns, nulls should get hash value 0. + for i in valid_indices { + let hash = &mut hashes_buffer[i]; + *hash = combine_hashes(*hash, values_hashes[i]); + } + + Ok(()) +} + fn hash_list_array( array: &GenericListArray, random_state: &RandomState, @@ -327,12 +357,16 @@ pub fn create_hashes<'a>( array => hash_dictionary(array, random_state, hashes_buffer, rehash)?, _ => unreachable!() } + DataType::Struct(_) => { + let array = as_struct_array(array)?; + hash_struct_array(array, random_state, hashes_buffer)?; + } DataType::List(_) => { - let array = as_list_array(array); + let array = as_list_array(array)?; hash_list_array(array, random_state, hashes_buffer)?; } DataType::LargeList(_) => { - let array = as_large_list_array(array); + let array = as_large_list_array(array)?; hash_list_array(array, random_state, hashes_buffer)?; } _ => { @@ -515,6 +549,58 @@ mod tests { assert_eq!(hashes[2], hashes[3]); } + #[test] + // Tests actual values of hashes, which are different if forcing collisions + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_struct_arrays() { + use arrow_buffer::Buffer; + + let boolarr = Arc::new(BooleanArray::from(vec![ + false, false, true, true, true, true, + ])); + let i32arr = Arc::new(Int32Array::from(vec![10, 10, 20, 20, 30, 31])); + + let struct_array = StructArray::from(( + vec![ + ( + Arc::new(Field::new("bool", DataType::Boolean, false)), + boolarr.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("i32", DataType::Int32, false)), + i32arr.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("i32", DataType::Int32, false)), + i32arr.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("bool", DataType::Boolean, false)), + boolarr.clone() as ArrayRef, + ), + ], + Buffer::from(&[0b001011]), + )); + + assert!(struct_array.is_valid(0)); + assert!(struct_array.is_valid(1)); + assert!(struct_array.is_null(2)); + assert!(struct_array.is_valid(3)); + assert!(struct_array.is_null(4)); + assert!(struct_array.is_null(5)); + + let array = Arc::new(struct_array) as ArrayRef; + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut hashes = vec![0; array.len()]; + create_hashes(&[array], &random_state, &mut hashes).unwrap(); + assert_eq!(hashes[0], hashes[1]); + // same value but the third row ( hashes[2] ) is null + assert_ne!(hashes[2], hashes[3]); + // different values but both are null + assert_eq!(hashes[4], hashes[5]); + } + #[test] // Tests actual values of hashes, which are different if forcing collisions #[cfg(not(feature = "force_hash_collisions"))] diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 53c3cfddff8d3..ed547782e4a5e 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -20,6 +20,7 @@ mod dfschema; mod error; mod functional_dependencies; mod join_type; +mod param_value; #[cfg(feature = "pyarrow")] mod pyarrow; mod schema_reference; @@ -34,6 +35,7 @@ pub mod file_options; pub mod format; pub mod hash_utils; pub mod parsers; +pub mod rounding; pub mod scalar; pub mod stats; pub mod test_util; @@ -54,10 +56,12 @@ pub use file_options::file_type::{ }; pub use file_options::FileTypeWriterOptions; pub use functional_dependencies::{ - aggregate_functional_dependencies, get_target_functional_dependencies, Constraint, - Constraints, Dependency, FunctionalDependence, FunctionalDependencies, + aggregate_functional_dependencies, get_required_group_by_exprs_indices, + get_target_functional_dependencies, Constraint, Constraints, Dependency, + FunctionalDependence, FunctionalDependencies, }; pub use join_type::{JoinConstraint, JoinSide, JoinType}; +pub use param_value::ParamValues; pub use scalar::{ScalarType, ScalarValue}; pub use schema_reference::{OwnedSchemaReference, SchemaReference}; pub use stats::{ColumnStatistics, Statistics}; diff --git a/datafusion/common/src/param_value.rs b/datafusion/common/src/param_value.rs new file mode 100644 index 0000000000000..3fe2ba99ab836 --- /dev/null +++ b/datafusion/common/src/param_value.rs @@ -0,0 +1,152 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::error::{_internal_err, _plan_err}; +use crate::{DataFusionError, Result, ScalarValue}; +use arrow_schema::DataType; +use std::collections::HashMap; + +/// The parameter value corresponding to the placeholder +#[derive(Debug, Clone)] +pub enum ParamValues { + /// For positional query parameters, like `SELECT * FROM test WHERE a > $1 AND b = $2` + List(Vec), + /// For named query parameters, like `SELECT * FROM test WHERE a > $foo AND b = $goo` + Map(HashMap), +} + +impl ParamValues { + /// Verify parameter list length and type + pub fn verify(&self, expect: &[DataType]) -> Result<()> { + match self { + ParamValues::List(list) => { + // Verify if the number of params matches the number of values + if expect.len() != list.len() { + return _plan_err!( + "Expected {} parameters, got {}", + expect.len(), + list.len() + ); + } + + // Verify if the types of the params matches the types of the values + let iter = expect.iter().zip(list.iter()); + for (i, (param_type, value)) in iter.enumerate() { + if *param_type != value.data_type() { + return _plan_err!( + "Expected parameter of type {:?}, got {:?} at index {}", + param_type, + value.data_type(), + i + ); + } + } + Ok(()) + } + ParamValues::Map(_) => { + // If it is a named query, variables can be reused, + // but the lengths are not necessarily equal + Ok(()) + } + } + } + + pub fn get_placeholders_with_values( + &self, + id: &str, + data_type: Option<&DataType>, + ) -> Result { + match self { + ParamValues::List(list) => { + if id.is_empty() { + return _plan_err!("Empty placeholder id"); + } + // convert id (in format $1, $2, ..) to idx (0, 1, ..) + let idx = id[1..] + .parse::() + .map_err(|e| { + DataFusionError::Internal(format!( + "Failed to parse placeholder id: {e}" + )) + })? + .checked_sub(1); + // value at the idx-th position in param_values should be the value for the placeholder + let value = idx.and_then(|idx| list.get(idx)).ok_or_else(|| { + DataFusionError::Internal(format!( + "No value found for placeholder with id {id}" + )) + })?; + // check if the data type of the value matches the data type of the placeholder + if Some(&value.data_type()) != data_type { + return _internal_err!( + "Placeholder value type mismatch: expected {:?}, got {:?}", + data_type, + value.data_type() + ); + } + Ok(value.clone()) + } + ParamValues::Map(map) => { + // convert name (in format $a, $b, ..) to mapped values (a, b, ..) + let name = &id[1..]; + // value at the name position in param_values should be the value for the placeholder + let value = map.get(name).ok_or_else(|| { + DataFusionError::Internal(format!( + "No value found for placeholder with name {id}" + )) + })?; + // check if the data type of the value matches the data type of the placeholder + if Some(&value.data_type()) != data_type { + return _internal_err!( + "Placeholder value type mismatch: expected {:?}, got {:?}", + data_type, + value.data_type() + ); + } + Ok(value.clone()) + } + } + } +} + +impl From> for ParamValues { + fn from(value: Vec) -> Self { + Self::List(value) + } +} + +impl From> for ParamValues +where + K: Into, +{ + fn from(value: Vec<(K, ScalarValue)>) -> Self { + let value: HashMap = + value.into_iter().map(|(k, v)| (k.into(), v)).collect(); + Self::Map(value) + } +} + +impl From> for ParamValues +where + K: Into, +{ + fn from(value: HashMap) -> Self { + let value: HashMap = + value.into_iter().map(|(k, v)| (k.into(), v)).collect(); + Self::Map(value) + } +} diff --git a/datafusion/common/src/pyarrow.rs b/datafusion/common/src/pyarrow.rs index 59a8b811e3c8e..f4356477532f4 100644 --- a/datafusion/common/src/pyarrow.rs +++ b/datafusion/common/src/pyarrow.rs @@ -54,7 +54,7 @@ impl FromPyArrow for ScalarValue { impl ToPyArrow for ScalarValue { fn to_pyarrow(&self, py: Python) -> PyResult { - let array = self.to_array(); + let array = self.to_array()?; // convert to pyarrow array using C data interface let pyarray = array.to_data().to_pyarrow(py)?; let pyscalar = pyarray.call_method1(py, "__getitem__", (0,))?; @@ -119,7 +119,7 @@ mod tests { ScalarValue::Boolean(Some(true)), ScalarValue::Int32(Some(23)), ScalarValue::Float64(Some(12.34)), - ScalarValue::Utf8(Some("Hello!".to_string())), + ScalarValue::from("Hello!"), ScalarValue::Date32(Some(1234)), ]; diff --git a/datafusion/physical-expr/src/intervals/rounding.rs b/datafusion/common/src/rounding.rs similarity index 98% rename from datafusion/physical-expr/src/intervals/rounding.rs rename to datafusion/common/src/rounding.rs index c1172fba91526..413067ecd61ed 100644 --- a/datafusion/physical-expr/src/intervals/rounding.rs +++ b/datafusion/common/src/rounding.rs @@ -22,8 +22,8 @@ use std::ops::{Add, BitAnd, Sub}; -use datafusion_common::Result; -use datafusion_common::ScalarValue; +use crate::Result; +use crate::ScalarValue; // Define constants for ARM #[cfg(all(target_arch = "aarch64", not(target_os = "windows")))] @@ -162,7 +162,7 @@ impl FloatBits for f64 { /// # Examples /// /// ``` -/// use datafusion_physical_expr::intervals::rounding::next_up; +/// use datafusion_common::rounding::next_up; /// /// let f: f32 = 1.0; /// let next_f = next_up(f); @@ -195,7 +195,7 @@ pub fn next_up(float: F) -> F { /// # Examples /// /// ``` -/// use datafusion_physical_expr::intervals::rounding::next_down; +/// use datafusion_common::rounding::next_down; /// /// let f: f32 = 1.0; /// let next_f = next_down(f); diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 0d701eaad2836..48878aa9bd99a 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -24,34 +24,78 @@ use std::convert::{Infallible, TryInto}; use std::str::FromStr; use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; +use crate::arrow_datafusion_err; use crate::cast::{ as_decimal128_array, as_decimal256_array, as_dictionary_array, as_fixed_size_binary_array, as_fixed_size_list_array, as_struct_array, }; use crate::error::{DataFusionError, Result, _internal_err, _not_impl_err}; use crate::hash_utils::create_hashes; -use crate::utils::array_into_list_array; -use arrow::buffer::{NullBuffer, OffsetBuffer}; +use crate::utils::{array_into_large_list_array, array_into_list_array}; use arrow::compute::kernels::numeric::*; -use arrow::datatypes::{i256, FieldRef, Fields, SchemaBuilder}; +use arrow::datatypes::{i256, Fields, SchemaBuilder}; +use arrow::util::display::{ArrayFormatter, FormatOptions}; use arrow::{ array::*, compute::kernels::cast::{cast_with_options, CastOptions}, datatypes::{ - ArrowDictionaryKeyType, ArrowNativeType, DataType, Field, Float32Type, - Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType, - IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, TimeUnit, - TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, - TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, - DECIMAL128_MAX_PRECISION, + ArrowDictionaryKeyType, ArrowNativeType, DataType, Field, Float32Type, Int16Type, + Int32Type, Int64Type, Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, + IntervalUnit, IntervalYearMonthType, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, + UInt16Type, UInt32Type, UInt64Type, UInt8Type, DECIMAL128_MAX_PRECISION, }, }; use arrow_array::cast::as_list_array; +use arrow_array::types::ArrowTimestampType; use arrow_array::{ArrowNativeTypeOp, Scalar}; -/// Represents a dynamically typed, nullable single value. -/// This is the single-valued counter-part to arrow's [`Array`]. +/// A dynamically typed, nullable single value, (the single-valued counter-part +/// to arrow's [`Array`]) /// +/// # Performance +/// +/// In general, please use arrow [`Array`]s rather than [`ScalarValue`] whenever +/// possible, as it is far more efficient for multiple values. +/// +/// # Example +/// ``` +/// # use datafusion_common::ScalarValue; +/// // Create single scalar value for an Int32 value +/// let s1 = ScalarValue::Int32(Some(10)); +/// +/// // You can also create values using the From impl: +/// let s2 = ScalarValue::from(10i32); +/// assert_eq!(s1, s2); +/// ``` +/// +/// # Null Handling +/// +/// `ScalarValue` represents null values in the same way as Arrow. Nulls are +/// "typed" in the sense that a null value in an [`Int32Array`] is different +/// than a null value in a [`Float64Array`], and is different than the values in +/// a [`NullArray`]. +/// +/// ``` +/// # fn main() -> datafusion_common::Result<()> { +/// # use std::collections::hash_set::Difference; +/// # use datafusion_common::ScalarValue; +/// # use arrow::datatypes::DataType; +/// // You can create a 'null' Int32 value directly: +/// let s1 = ScalarValue::Int32(None); +/// +/// // You can also create a null value for a given datatype: +/// let s2 = ScalarValue::try_from(&DataType::Int32)?; +/// assert_eq!(s1, s2); +/// +/// // Note that this is DIFFERENT than a `ScalarValue::Null` +/// let s3 = ScalarValue::Null; +/// assert_ne!(s1, s3); +/// # Ok(()) +/// # } +/// ``` +/// +/// # Further Reading /// See [datatypes](https://arrow.apache.org/docs/python/api/datatypes.html) for /// details on datatypes and the [format](https://github.com/apache/arrow/blob/master/format/Schema.fbs#L354-L375) /// for the definitive reference. @@ -95,10 +139,16 @@ pub enum ScalarValue { FixedSizeBinary(i32, Option>), /// large binary LargeBinary(Option>), - /// Fixed size list of nested ScalarValue - Fixedsizelist(Option>, FieldRef, i32), + /// Fixed size list scalar. + /// + /// The array must be a FixedSizeListArray with length 1. + FixedSizeList(ArrayRef), /// Represents a single element of a [`ListArray`] as an [`ArrayRef`] + /// + /// The array must be a ListArray with length 1. List(ArrayRef), + /// The array must be a LargeListArray with length 1. + LargeList(ArrayRef), /// Date stored as a signed 32bit int days since UNIX epoch 1970-01-01 Date32(Option), /// Date stored as a signed 64bit int milliseconds since UNIX epoch 1970-01-01 @@ -196,12 +246,12 @@ impl PartialEq for ScalarValue { (FixedSizeBinary(_, _), _) => false, (LargeBinary(v1), LargeBinary(v2)) => v1.eq(v2), (LargeBinary(_), _) => false, - (Fixedsizelist(v1, t1, l1), Fixedsizelist(v2, t2, l2)) => { - v1.eq(v2) && t1.eq(t2) && l1.eq(l2) - } - (Fixedsizelist(_, _, _), _) => false, + (FixedSizeList(v1), FixedSizeList(v2)) => v1.eq(v2), + (FixedSizeList(_), _) => false, (List(v1), List(v2)) => v1.eq(v2), (List(_), _) => false, + (LargeList(v1), LargeList(v2)) => v1.eq(v2), + (LargeList(_), _) => false, (Date32(v1), Date32(v2)) => v1.eq(v2), (Date32(_), _) => false, (Date64(v1), Date64(v2)) => v1.eq(v2), @@ -310,45 +360,47 @@ impl PartialOrd for ScalarValue { (FixedSizeBinary(_, _), _) => None, (LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2), (LargeBinary(_), _) => None, - (Fixedsizelist(v1, t1, l1), Fixedsizelist(v2, t2, l2)) => { - if t1.eq(t2) && l1.eq(l2) { - v1.partial_cmp(v2) - } else { - None + (List(arr1), List(arr2)) + | (FixedSizeList(arr1), FixedSizeList(arr2)) + | (LargeList(arr1), LargeList(arr2)) => { + // ScalarValue::List / ScalarValue::FixedSizeList / ScalarValue::LargeList are ensure to have length 1 + assert_eq!(arr1.len(), 1); + assert_eq!(arr2.len(), 1); + + if arr1.data_type() != arr2.data_type() { + return None; } - } - (Fixedsizelist(_, _, _), _) => None, - (List(arr1), List(arr2)) => { - if arr1.data_type() == arr2.data_type() { - let list_arr1 = as_list_array(arr1); - let list_arr2 = as_list_array(arr2); - if list_arr1.len() != list_arr2.len() { - return None; + + fn first_array_for_list(arr: &ArrayRef) -> ArrayRef { + if let Some(arr) = arr.as_list_opt::() { + arr.value(0) + } else if let Some(arr) = arr.as_list_opt::() { + arr.value(0) + } else if let Some(arr) = arr.as_fixed_size_list_opt() { + arr.value(0) + } else { + unreachable!("Since only List / LargeList / FixedSizeList are supported, this should never happen") } - for i in 0..list_arr1.len() { - let arr1 = list_arr1.value(i); - let arr2 = list_arr2.value(i); - - let lt_res = - arrow::compute::kernels::cmp::lt(&arr1, &arr2).unwrap(); - let eq_res = - arrow::compute::kernels::cmp::eq(&arr1, &arr2).unwrap(); - - for j in 0..lt_res.len() { - if lt_res.is_valid(j) && lt_res.value(j) { - return Some(Ordering::Less); - } - if eq_res.is_valid(j) && !eq_res.value(j) { - return Some(Ordering::Greater); - } - } + } + + let arr1 = first_array_for_list(arr1); + let arr2 = first_array_for_list(arr2); + + let lt_res = arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?; + let eq_res = arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?; + + for j in 0..lt_res.len() { + if lt_res.is_valid(j) && lt_res.value(j) { + return Some(Ordering::Less); + } + if eq_res.is_valid(j) && !eq_res.value(j) { + return Some(Ordering::Greater); } - Some(Ordering::Equal) - } else { - None } + + Some(Ordering::Equal) } - (List(_), _) => None, + (List(_), _) | (LargeList(_), _) | (FixedSizeList(_), _) => None, (Date32(v1), Date32(v2)) => v1.partial_cmp(v2), (Date32(_), _) => None, (Date64(v1), Date64(v2)) => v1.partial_cmp(v2), @@ -431,6 +483,10 @@ macro_rules! hash_float_value { hash_float_value!((f64, u64), (f32, u32)); // manual implementation of `Hash` +// +// # Panics +// +// Panics if there is an error when creating hash values for rows impl std::hash::Hash for ScalarValue { fn hash(&self, state: &mut H) { use ScalarValue::*; @@ -461,12 +517,7 @@ impl std::hash::Hash for ScalarValue { Binary(v) => v.hash(state), FixedSizeBinary(_, v) => v.hash(state), LargeBinary(v) => v.hash(state), - Fixedsizelist(v, t, l) => { - v.hash(state); - t.hash(state); - l.hash(state); - } - List(arr) => { + List(arr) | LargeList(arr) | FixedSizeList(arr) => { let arrays = vec![arr.to_owned()]; let hashes_buffer = &mut vec![0; arr.len()]; let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0); @@ -506,15 +557,19 @@ impl std::hash::Hash for ScalarValue { } } -/// return a reference to the values array and the index into it for a +/// Return a reference to the values array and the index into it for a /// dictionary array +/// +/// # Errors +/// +/// Errors if the array cannot be downcasted to DictionaryArray #[inline] pub fn get_dict_value( array: &dyn Array, index: usize, -) -> (&ArrayRef, Option) { - let dict_array = as_dictionary_array::(array).unwrap(); - (dict_array.values(), dict_array.key(index)) +) -> Result<(&ArrayRef, Option)> { + let dict_array = as_dictionary_array::(array)?; + Ok((dict_array.values(), dict_array.key(index))) } /// Create a dictionary array representing `value` repeated `size` @@ -522,9 +577,9 @@ pub fn get_dict_value( fn dict_from_scalar( value: &ScalarValue, size: usize, -) -> ArrayRef { +) -> Result { // values array is one element long (the value) - let values_array = value.to_array_of_size(1); + let values_array = value.to_array_of_size(1)?; // Create a key array with `size` elements, each of 0 let key_array: PrimitiveArray = std::iter::repeat(Some(K::default_value())) @@ -536,11 +591,9 @@ fn dict_from_scalar( // Note: this path could be made faster by using the ArrayData // APIs and skipping validation, if it every comes up in // performance traces. - Arc::new( - DictionaryArray::::try_new(key_array, values_array) - // should always be valid by construction above - .expect("Can not construct dictionary array"), - ) + Ok(Arc::new( + DictionaryArray::::try_new(key_array, values_array)?, // should always be valid by construction above + )) } /// Create a dictionary array representing all the values in values @@ -579,24 +632,44 @@ fn dict_from_values( macro_rules! typed_cast_tz { ($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident, $TZ:expr) => {{ - let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); - ScalarValue::$SCALAR( + use std::any::type_name; + let array = $array + .as_any() + .downcast_ref::<$ARRAYTYPE>() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast value to {}", + type_name::<$ARRAYTYPE>() + )) + })?; + Ok::(ScalarValue::$SCALAR( match array.is_null($index) { true => None, false => Some(array.value($index).into()), }, $TZ.clone(), - ) + )) }}; } macro_rules! typed_cast { ($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{ - let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); - ScalarValue::$SCALAR(match array.is_null($index) { - true => None, - false => Some(array.value($index).into()), - }) + use std::any::type_name; + let array = $array + .as_any() + .downcast_ref::<$ARRAYTYPE>() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast value to {}", + type_name::<$ARRAYTYPE>() + )) + })?; + Ok::(ScalarValue::$SCALAR( + match array.is_null($index) { + true => None, + false => Some(array.value($index).into()), + }, + )) }}; } @@ -628,12 +701,21 @@ macro_rules! build_timestamp_array_from_option { macro_rules! eq_array_primitive { ($array:expr, $index:expr, $ARRAYTYPE:ident, $VALUE:expr) => {{ - let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); + use std::any::type_name; + let array = $array + .as_any() + .downcast_ref::<$ARRAYTYPE>() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast value to {}", + type_name::<$ARRAYTYPE>() + )) + })?; let is_valid = array.is_valid($index); - match $VALUE { + Ok::(match $VALUE { Some(val) => is_valid && &array.value($index) == val, None => !is_valid, - } + }) }}; } @@ -670,7 +752,7 @@ impl ScalarValue { /// Returns a [`ScalarValue::Utf8`] representing `val` pub fn new_utf8(val: impl Into) -> Self { - ScalarValue::Utf8(Some(val.into())) + ScalarValue::from(val.into()) } /// Returns a [`ScalarValue::IntervalYearMonth`] representing @@ -694,6 +776,20 @@ impl ScalarValue { ScalarValue::IntervalMonthDayNano(Some(val)) } + /// Returns a [`ScalarValue`] representing + /// `value` and `tz_opt` timezone + pub fn new_timestamp( + value: Option, + tz_opt: Option>, + ) -> Self { + match T::UNIT { + TimeUnit::Second => ScalarValue::TimestampSecond(value, tz_opt), + TimeUnit::Millisecond => ScalarValue::TimestampMillisecond(value, tz_opt), + TimeUnit::Microsecond => ScalarValue::TimestampMicrosecond(value, tz_opt), + TimeUnit::Nanosecond => ScalarValue::TimestampNanosecond(value, tz_opt), + } + } + /// Create a zero value in the given type. pub fn new_zero(datatype: &DataType) -> Result { assert!(datatype.is_primitive()); @@ -846,11 +942,9 @@ impl ScalarValue { ScalarValue::Binary(_) => DataType::Binary, ScalarValue::FixedSizeBinary(sz, _) => DataType::FixedSizeBinary(*sz), ScalarValue::LargeBinary(_) => DataType::LargeBinary, - ScalarValue::Fixedsizelist(_, field, length) => DataType::FixedSizeList( - Arc::new(Field::new("item", field.data_type().clone(), true)), - *length, - ), - ScalarValue::List(arr) => arr.data_type().to_owned(), + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => arr.data_type().to_owned(), ScalarValue::Date32(_) => DataType::Date32, ScalarValue::Date64(_) => DataType::Date64, ScalarValue::Time32Second(_) => DataType::Time32(TimeUnit::Second), @@ -924,6 +1018,18 @@ impl ScalarValue { ScalarValue::Decimal256(Some(v), precision, scale) => Ok( ScalarValue::Decimal256(Some(v.neg_wrapping()), *precision, *scale), ), + ScalarValue::TimestampSecond(Some(v), tz) => { + Ok(ScalarValue::TimestampSecond(Some(-v), tz.clone())) + } + ScalarValue::TimestampNanosecond(Some(v), tz) => { + Ok(ScalarValue::TimestampNanosecond(Some(-v), tz.clone())) + } + ScalarValue::TimestampMicrosecond(Some(v), tz) => { + Ok(ScalarValue::TimestampMicrosecond(Some(-v), tz.clone())) + } + ScalarValue::TimestampMillisecond(Some(v), tz) => { + Ok(ScalarValue::TimestampMillisecond(Some(-v), tz.clone())) + } value => _internal_err!( "Can not run arithmetic negative on scalar value {value:?}" ), @@ -935,7 +1041,7 @@ impl ScalarValue { /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code /// should operate on Arrays directly, using vectorized array kernels pub fn add>(&self, other: T) -> Result { - let r = add_wrapping(&self.to_scalar(), &other.borrow().to_scalar())?; + let r = add_wrapping(&self.to_scalar()?, &other.borrow().to_scalar()?)?; Self::try_from_array(r.as_ref(), 0) } /// Checked addition of `ScalarValue` @@ -943,7 +1049,7 @@ impl ScalarValue { /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code /// should operate on Arrays directly, using vectorized array kernels pub fn add_checked>(&self, other: T) -> Result { - let r = add(&self.to_scalar(), &other.borrow().to_scalar())?; + let r = add(&self.to_scalar()?, &other.borrow().to_scalar()?)?; Self::try_from_array(r.as_ref(), 0) } @@ -952,7 +1058,7 @@ impl ScalarValue { /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code /// should operate on Arrays directly, using vectorized array kernels pub fn sub>(&self, other: T) -> Result { - let r = sub_wrapping(&self.to_scalar(), &other.borrow().to_scalar())?; + let r = sub_wrapping(&self.to_scalar()?, &other.borrow().to_scalar()?)?; Self::try_from_array(r.as_ref(), 0) } @@ -961,7 +1067,49 @@ impl ScalarValue { /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code /// should operate on Arrays directly, using vectorized array kernels pub fn sub_checked>(&self, other: T) -> Result { - let r = sub(&self.to_scalar(), &other.borrow().to_scalar())?; + let r = sub(&self.to_scalar()?, &other.borrow().to_scalar()?)?; + Self::try_from_array(r.as_ref(), 0) + } + + /// Wrapping multiplication of `ScalarValue` + /// + /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code + /// should operate on Arrays directly, using vectorized array kernels. + pub fn mul>(&self, other: T) -> Result { + let r = mul_wrapping(&self.to_scalar()?, &other.borrow().to_scalar()?)?; + Self::try_from_array(r.as_ref(), 0) + } + + /// Checked multiplication of `ScalarValue` + /// + /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code + /// should operate on Arrays directly, using vectorized array kernels. + pub fn mul_checked>(&self, other: T) -> Result { + let r = mul(&self.to_scalar()?, &other.borrow().to_scalar()?)?; + Self::try_from_array(r.as_ref(), 0) + } + + /// Performs `lhs / rhs` + /// + /// Overflow or division by zero will result in an error, with exception to + /// floating point numbers, which instead follow the IEEE 754 rules. + /// + /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code + /// should operate on Arrays directly, using vectorized array kernels. + pub fn div>(&self, other: T) -> Result { + let r = div(&self.to_scalar()?, &other.borrow().to_scalar()?)?; + Self::try_from_array(r.as_ref(), 0) + } + + /// Performs `lhs % rhs` + /// + /// Overflow or division by zero will result in an error, with exception to + /// floating point numbers, which instead follow the IEEE 754 rules. + /// + /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code + /// should operate on Arrays directly, using vectorized array kernels. + pub fn rem>(&self, other: T) -> Result { + let r = rem(&self.to_scalar()?, &other.borrow().to_scalar()?)?; Self::try_from_array(r.as_ref(), 0) } @@ -997,8 +1145,11 @@ impl ScalarValue { ScalarValue::Binary(v) => v.is_none(), ScalarValue::FixedSizeBinary(_, v) => v.is_none(), ScalarValue::LargeBinary(v) => v.is_none(), - ScalarValue::Fixedsizelist(v, ..) => v.is_none(), - ScalarValue::List(arr) => arr.len() == arr.null_count(), + // arr.len() should be 1 for a list scalar, but we don't seem to + // enforce that anywhere, so we still check against array length. + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => arr.len() == arr.null_count(), ScalarValue::Date32(v) => v.is_none(), ScalarValue::Date64(v) => v.is_none(), ScalarValue::Time32Second(v) => v.is_none(), @@ -1050,7 +1201,11 @@ impl ScalarValue { } /// Converts a scalar value into an 1-row array. - pub fn to_array(&self) -> ArrayRef { + /// + /// # Errors + /// + /// Errors if the ScalarValue cannot be converted into a 1-row array + pub fn to_array(&self) -> Result { self.to_array_of_size(1) } @@ -1059,6 +1214,10 @@ impl ScalarValue { /// /// This can be used to call arrow compute kernels such as `lt` /// + /// # Errors + /// + /// Errors if the ScalarValue cannot be converted into a 1-row array + /// /// # Example /// ``` /// use datafusion_common::ScalarValue; @@ -1069,7 +1228,7 @@ impl ScalarValue { /// /// let result = arrow::compute::kernels::cmp::lt( /// &arr, - /// &five.to_scalar(), + /// &five.to_scalar().unwrap(), /// ).unwrap(); /// /// let expected = BooleanArray::from(vec![ @@ -1082,8 +1241,8 @@ impl ScalarValue { /// assert_eq!(&result, &expected); /// ``` /// [`Datum`]: arrow_array::Datum - pub fn to_scalar(&self) -> Scalar { - Scalar::new(self.to_array_of_size(1)) + pub fn to_scalar(&self) -> Result> { + Ok(Scalar::new(self.to_array_of_size(1)?)) } /// Converts an iterator of references [`ScalarValue`] into an [`ArrayRef`] @@ -1093,6 +1252,10 @@ impl ScalarValue { /// Returns an error if the iterator is empty or if the /// [`ScalarValue`]s are not all the same type /// + /// # Panics + /// + /// Panics if `self` is a dictionary with invalid key type + /// /// # Example /// ``` /// use datafusion_common::ScalarValue; @@ -1197,69 +1360,36 @@ impl ScalarValue { }}; } - macro_rules! build_array_list_primitive { - ($ARRAY_TY:ident, $SCALAR_TY:ident, $NATIVE_TYPE:ident) => {{ - Arc::new(ListArray::from_iter_primitive::<$ARRAY_TY, _, _>( - scalars.into_iter().map(|x| match x { - ScalarValue::List(arr) => { - // `ScalarValue::List` contains a single element `ListArray`. - let list_arr = as_list_array(&arr); - if list_arr.is_null(0) { - None - } else { - let primitive_arr = - list_arr.values().as_primitive::<$ARRAY_TY>(); - Some( - primitive_arr.into_iter().collect::>>(), - ) - } - } - sv => panic!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected {:?}, got {:?}", - data_type, sv - ), - }), - )) - }}; - } - - macro_rules! build_array_list_string { - ($BUILDER:ident, $STRING_ARRAY:ident) => {{ - let mut builder = ListBuilder::new($BUILDER::new()); - for scalar in scalars.into_iter() { - match scalar { - ScalarValue::List(arr) => { - // `ScalarValue::List` contains a single element `ListArray`. - let list_arr = as_list_array(&arr); - - if list_arr.is_null(0) { - builder.append(false); - continue; - } - - let string_arr = $STRING_ARRAY(list_arr.values()); + fn build_list_array( + scalars: impl IntoIterator, + ) -> Result { + let arrays = scalars + .into_iter() + .map(|s| s.to_array()) + .collect::>>()?; - for v in string_arr.iter() { - if let Some(v) = v { - builder.values().append_value(v); - } else { - builder.values().append_null(); - } - } - builder.append(true); - } - sv => { - return _internal_err!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected List, got {:?}", - sv - ) - } - } + let capacity = Capacities::Array(arrays.iter().map(|arr| arr.len()).sum()); + // ScalarValue::List contains a single element ListArray. + let nulls = arrays + .iter() + .map(|arr| arr.is_null(0)) + .collect::>(); + let arrays_data = arrays.iter().map(|arr| arr.to_data()).collect::>(); + + let arrays_ref = arrays_data.iter().collect::>(); + let mut mutable = + MutableArrayData::with_capacities(arrays_ref, true, capacity); + + // ScalarValue::List contains a single element ListArray. + for (index, is_null) in (0..arrays.len()).zip(nulls.into_iter()) { + if is_null { + mutable.extend_nulls(1) + } else { + mutable.extend(index, 0, 1); } - Arc::new(builder.finish()) - }}; + } + let data = mutable.freeze(); + Ok(arrow_array::make_array(data)) } let array: ArrayRef = match &data_type { @@ -1273,7 +1403,7 @@ impl ScalarValue { ScalarValue::iter_to_decimal256_array(scalars, *precision, *scale)?; Arc::new(decimal_array) } - DataType::Null => ScalarValue::iter_to_null_array(scalars), + DataType::Null => ScalarValue::iter_to_null_array(scalars)?, DataType::Boolean => build_array_primitive!(BooleanArray, Boolean), DataType::Float32 => build_array_primitive!(Float32Array, Float32), DataType::Float64 => build_array_primitive!(Float64Array, Float64), @@ -1336,47 +1466,7 @@ impl ScalarValue { DataType::Interval(IntervalUnit::MonthDayNano) => { build_array_primitive!(IntervalMonthDayNanoArray, IntervalMonthDayNano) } - DataType::List(fields) if fields.data_type() == &DataType::Int8 => { - build_array_list_primitive!(Int8Type, Int8, i8) - } - DataType::List(fields) if fields.data_type() == &DataType::Int16 => { - build_array_list_primitive!(Int16Type, Int16, i16) - } - DataType::List(fields) if fields.data_type() == &DataType::Int32 => { - build_array_list_primitive!(Int32Type, Int32, i32) - } - DataType::List(fields) if fields.data_type() == &DataType::Int64 => { - build_array_list_primitive!(Int64Type, Int64, i64) - } - DataType::List(fields) if fields.data_type() == &DataType::UInt8 => { - build_array_list_primitive!(UInt8Type, UInt8, u8) - } - DataType::List(fields) if fields.data_type() == &DataType::UInt16 => { - build_array_list_primitive!(UInt16Type, UInt16, u16) - } - DataType::List(fields) if fields.data_type() == &DataType::UInt32 => { - build_array_list_primitive!(UInt32Type, UInt32, u32) - } - DataType::List(fields) if fields.data_type() == &DataType::UInt64 => { - build_array_list_primitive!(UInt64Type, UInt64, u64) - } - DataType::List(fields) if fields.data_type() == &DataType::Float32 => { - build_array_list_primitive!(Float32Type, Float32, f32) - } - DataType::List(fields) if fields.data_type() == &DataType::Float64 => { - build_array_list_primitive!(Float64Type, Float64, f64) - } - DataType::List(fields) if fields.data_type() == &DataType::Utf8 => { - build_array_list_string!(StringBuilder, as_string_array) - } - DataType::List(fields) if fields.data_type() == &DataType::LargeUtf8 => { - build_array_list_string!(LargeStringBuilder, as_largestring_array) - } - DataType::List(_) => { - // Fallback case handling homogeneous lists with any ScalarValue element type - let list_array = ScalarValue::iter_to_array_list(scalars)?; - Arc::new(list_array) - } + DataType::List(_) | DataType::LargeList(_) => build_list_array(scalars)?, DataType::Struct(fields) => { // Initialize a Vector to store the ScalarValues for each column let mut columns: Vec> = @@ -1432,7 +1522,7 @@ impl ScalarValue { if &inner_key_type == key_type { Ok(*scalar) } else { - panic!("Expected inner key type of {key_type} but found: {inner_key_type}, value was ({scalar:?})"); + _internal_err!("Expected inner key type of {key_type} but found: {inner_key_type}, value was ({scalar:?})") } } _ => { @@ -1489,7 +1579,6 @@ impl ScalarValue { | DataType::Time64(TimeUnit::Millisecond) | DataType::Duration(_) | DataType::FixedSizeList(_, _) - | DataType::LargeList(_) | DataType::Union(_, _) | DataType::Map(_, _) | DataType::RunEndEncoded(_, _) => { @@ -1504,15 +1593,19 @@ impl ScalarValue { Ok(array) } - fn iter_to_null_array(scalars: impl IntoIterator) -> ArrayRef { - let length = - scalars - .into_iter() - .fold(0usize, |r, element: ScalarValue| match element { - ScalarValue::Null => r + 1, - _ => unreachable!(), - }); - new_null_array(&DataType::Null, length) + fn iter_to_null_array( + scalars: impl IntoIterator, + ) -> Result { + let length = scalars.into_iter().try_fold( + 0usize, + |r, element: ScalarValue| match element { + ScalarValue::Null => Ok::(r + 1), + s => { + _internal_err!("Expected ScalarValue::Null element. Received {s:?}") + } + }, + )?; + Ok(new_null_array(&DataType::Null, length)) } fn iter_to_decimal_array( @@ -1523,10 +1616,12 @@ impl ScalarValue { let array = scalars .into_iter() .map(|element: ScalarValue| match element { - ScalarValue::Decimal128(v1, _, _) => v1, - _ => unreachable!(), + ScalarValue::Decimal128(v1, _, _) => Ok(v1), + s => { + _internal_err!("Expected ScalarValue::Null element. Received {s:?}") + } }) - .collect::() + .collect::>()? .with_precision_and_scale(precision, scale)?; Ok(array) } @@ -1539,85 +1634,34 @@ impl ScalarValue { let array = scalars .into_iter() .map(|element: ScalarValue| match element { - ScalarValue::Decimal256(v1, _, _) => v1, - _ => unreachable!(), + ScalarValue::Decimal256(v1, _, _) => Ok(v1), + s => { + _internal_err!( + "Expected ScalarValue::Decimal256 element. Received {s:?}" + ) + } }) - .collect::() + .collect::>()? .with_precision_and_scale(precision, scale)?; Ok(array) } - /// This function build with nulls with nulls buffer. - fn iter_to_array_list( - scalars: impl IntoIterator, - ) -> Result> { - let mut elements: Vec = vec![]; - let mut valid = BooleanBufferBuilder::new(0); - let mut offsets = vec![]; - - for scalar in scalars { - if let ScalarValue::List(arr) = scalar { - // `ScalarValue::List` contains a single element `ListArray`. - let list_arr = as_list_array(&arr); - - if list_arr.is_null(0) { - // Repeat previous offset index - offsets.push(0); - - // Element is null - valid.append(false); - } else { - let arr = list_arr.values().to_owned(); - offsets.push(arr.len()); - elements.push(arr); - - // Element is valid - valid.append(true); - } - } else { - return _internal_err!( - "Expected ScalarValue::List element. Received {scalar:?}" - ); - } - } - - // Concatenate element arrays to create single flat array - let element_arrays: Vec<&dyn Array> = - elements.iter().map(|a| a.as_ref()).collect(); - - let flat_array = match arrow::compute::concat(&element_arrays) { - Ok(flat_array) => flat_array, - Err(err) => return Err(DataFusionError::ArrowError(err)), - }; - - let buffer = valid.finish(); - - let list_array = ListArray::new( - Arc::new(Field::new("item", flat_array.data_type().clone(), true)), - OffsetBuffer::::from_lengths(offsets), - flat_array, - Some(NullBuffer::new(buffer)), - ); - - Ok(list_array) - } - fn build_decimal_array( value: Option, precision: u8, scale: i8, size: usize, - ) -> Decimal128Array { + ) -> Result { match value { Some(val) => Decimal128Array::from(vec![val; size]) .with_precision_and_scale(precision, scale) - .unwrap(), + .map_err(|e| arrow_datafusion_err!(e)), None => { let mut builder = Decimal128Array::builder(size) .with_precision_and_scale(precision, scale) - .unwrap(); + .map_err(|e| arrow_datafusion_err!(e))?; builder.append_nulls(size); - builder.finish() + Ok(builder.finish()) } } } @@ -1627,12 +1671,12 @@ impl ScalarValue { precision: u8, scale: i8, size: usize, - ) -> Decimal256Array { + ) -> Result { std::iter::repeat(value) .take(size) .collect::() .with_precision_and_scale(precision, scale) - .unwrap() + .map_err(|e| arrow_datafusion_err!(e)) } /// Converts `Vec` where each element has type corresponding to @@ -1670,14 +1714,57 @@ impl ScalarValue { Arc::new(array_into_list_array(values)) } + /// Converts `Vec` where each element has type corresponding to + /// `data_type`, to a [`LargeListArray`]. + /// + /// Example + /// ``` + /// use datafusion_common::ScalarValue; + /// use arrow::array::{LargeListArray, Int32Array}; + /// use arrow::datatypes::{DataType, Int32Type}; + /// use datafusion_common::cast::as_large_list_array; + /// + /// let scalars = vec![ + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(None), + /// ScalarValue::Int32(Some(2)) + /// ]; + /// + /// let array = ScalarValue::new_large_list(&scalars, &DataType::Int32); + /// let result = as_large_list_array(&array).unwrap(); + /// + /// let expected = LargeListArray::from_iter_primitive::( + /// vec![ + /// Some(vec![Some(1), None, Some(2)]) + /// ]); + /// + /// assert_eq!(result, &expected); + /// ``` + pub fn new_large_list(values: &[ScalarValue], data_type: &DataType) -> ArrayRef { + let values = if values.is_empty() { + new_empty_array(data_type) + } else { + Self::iter_to_array(values.iter().cloned()).unwrap() + }; + Arc::new(array_into_large_list_array(values)) + } + /// Converts a scalar value into an array of `size` rows. - pub fn to_array_of_size(&self, size: usize) -> ArrayRef { - match self { + /// + /// # Errors + /// + /// Errors if `self` is + /// - a decimal that fails be converted to a decimal array of size + /// - a `Fixedsizelist` that fails to be concatenated into an array of size + /// - a `List` that fails to be concatenated into an array of size + /// - a `Dictionary` that fails be converted to a dictionary array of size + pub fn to_array_of_size(&self, size: usize) -> Result { + Ok(match self { ScalarValue::Decimal128(e, precision, scale) => Arc::new( - ScalarValue::build_decimal_array(*e, *precision, *scale, size), + ScalarValue::build_decimal_array(*e, *precision, *scale, size)?, ), ScalarValue::Decimal256(e, precision, scale) => Arc::new( - ScalarValue::build_decimal256_array(*e, *precision, *scale, size), + ScalarValue::build_decimal256_array(*e, *precision, *scale, size)?, ), ScalarValue::Boolean(e) => { Arc::new(BooleanArray::from(vec![*e; size])) as ArrayRef @@ -1789,14 +1876,14 @@ impl ScalarValue { .collect::(), ), }, - ScalarValue::Fixedsizelist(..) => { - unimplemented!("FixedSizeList is not supported yet") - } - ScalarValue::List(arr) => { + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => { let arrays = std::iter::repeat(arr.as_ref()) .take(size) .collect::>(); - arrow::compute::concat(arrays.as_slice()).unwrap() + arrow::compute::concat(arrays.as_slice()) + .map_err(|e| arrow_datafusion_err!(e))? } ScalarValue::Date32(e) => { build_array_from_option!(Date32, Date32Array, e, size) @@ -1891,13 +1978,13 @@ impl ScalarValue { ), ScalarValue::Struct(values, fields) => match values { Some(values) => { - let field_values: Vec<_> = fields + let field_values = fields .iter() .zip(values.iter()) .map(|(field, value)| { - (field.clone(), value.to_array_of_size(size)) + Ok((field.clone(), value.to_array_of_size(size)?)) }) - .collect(); + .collect::>>()?; Arc::new(StructArray::from(field_values)) } @@ -1909,19 +1996,19 @@ impl ScalarValue { ScalarValue::Dictionary(key_type, v) => { // values array is one element long (the value) match key_type.as_ref() { - DataType::Int8 => dict_from_scalar::(v, size), - DataType::Int16 => dict_from_scalar::(v, size), - DataType::Int32 => dict_from_scalar::(v, size), - DataType::Int64 => dict_from_scalar::(v, size), - DataType::UInt8 => dict_from_scalar::(v, size), - DataType::UInt16 => dict_from_scalar::(v, size), - DataType::UInt32 => dict_from_scalar::(v, size), - DataType::UInt64 => dict_from_scalar::(v, size), + DataType::Int8 => dict_from_scalar::(v, size)?, + DataType::Int16 => dict_from_scalar::(v, size)?, + DataType::Int32 => dict_from_scalar::(v, size)?, + DataType::Int64 => dict_from_scalar::(v, size)?, + DataType::UInt8 => dict_from_scalar::(v, size)?, + DataType::UInt16 => dict_from_scalar::(v, size)?, + DataType::UInt32 => dict_from_scalar::(v, size)?, + DataType::UInt64 => dict_from_scalar::(v, size)?, _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), } } ScalarValue::Null => new_null_array(&DataType::Null, size), - } + }) } fn get_decimal_value_from_array( @@ -2037,23 +2124,25 @@ impl ScalarValue { array, index, *precision, *scale, )? } - DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean), - DataType::Float64 => typed_cast!(array, index, Float64Array, Float64), - DataType::Float32 => typed_cast!(array, index, Float32Array, Float32), - DataType::UInt64 => typed_cast!(array, index, UInt64Array, UInt64), - DataType::UInt32 => typed_cast!(array, index, UInt32Array, UInt32), - DataType::UInt16 => typed_cast!(array, index, UInt16Array, UInt16), - DataType::UInt8 => typed_cast!(array, index, UInt8Array, UInt8), - DataType::Int64 => typed_cast!(array, index, Int64Array, Int64), - DataType::Int32 => typed_cast!(array, index, Int32Array, Int32), - DataType::Int16 => typed_cast!(array, index, Int16Array, Int16), - DataType::Int8 => typed_cast!(array, index, Int8Array, Int8), - DataType::Binary => typed_cast!(array, index, BinaryArray, Binary), + DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean)?, + DataType::Float64 => typed_cast!(array, index, Float64Array, Float64)?, + DataType::Float32 => typed_cast!(array, index, Float32Array, Float32)?, + DataType::UInt64 => typed_cast!(array, index, UInt64Array, UInt64)?, + DataType::UInt32 => typed_cast!(array, index, UInt32Array, UInt32)?, + DataType::UInt16 => typed_cast!(array, index, UInt16Array, UInt16)?, + DataType::UInt8 => typed_cast!(array, index, UInt8Array, UInt8)?, + DataType::Int64 => typed_cast!(array, index, Int64Array, Int64)?, + DataType::Int32 => typed_cast!(array, index, Int32Array, Int32)?, + DataType::Int16 => typed_cast!(array, index, Int16Array, Int16)?, + DataType::Int8 => typed_cast!(array, index, Int8Array, Int8)?, + DataType::Binary => typed_cast!(array, index, BinaryArray, Binary)?, DataType::LargeBinary => { - typed_cast!(array, index, LargeBinaryArray, LargeBinary) + typed_cast!(array, index, LargeBinaryArray, LargeBinary)? + } + DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8)?, + DataType::LargeUtf8 => { + typed_cast!(array, index, LargeStringArray, LargeUtf8)? } - DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8), - DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, LargeUtf8), DataType::List(_) => { let list_array = as_list_array(array); let nested_array = list_array.value(index); @@ -2062,6 +2151,14 @@ impl ScalarValue { ScalarValue::List(arr) } + DataType::LargeList(_) => { + let list_array = as_large_list_array(array); + let nested_array = list_array.value(index); + // Produces a single element `LargeListArray` with the value at `index`. + let arr = Arc::new(array_into_large_list_array(nested_array)); + + ScalarValue::LargeList(arr) + } // TODO: There is no test for FixedSizeList now, add it later DataType::FixedSizeList(_, _) => { let list_array = as_fixed_size_list_array(array)?; @@ -2071,70 +2168,58 @@ impl ScalarValue { ScalarValue::List(arr) } - DataType::Date32 => { - typed_cast!(array, index, Date32Array, Date32) - } - DataType::Date64 => { - typed_cast!(array, index, Date64Array, Date64) - } + DataType::Date32 => typed_cast!(array, index, Date32Array, Date32)?, + DataType::Date64 => typed_cast!(array, index, Date64Array, Date64)?, DataType::Time32(TimeUnit::Second) => { - typed_cast!(array, index, Time32SecondArray, Time32Second) + typed_cast!(array, index, Time32SecondArray, Time32Second)? } DataType::Time32(TimeUnit::Millisecond) => { - typed_cast!(array, index, Time32MillisecondArray, Time32Millisecond) + typed_cast!(array, index, Time32MillisecondArray, Time32Millisecond)? } DataType::Time64(TimeUnit::Microsecond) => { - typed_cast!(array, index, Time64MicrosecondArray, Time64Microsecond) + typed_cast!(array, index, Time64MicrosecondArray, Time64Microsecond)? } DataType::Time64(TimeUnit::Nanosecond) => { - typed_cast!(array, index, Time64NanosecondArray, Time64Nanosecond) - } - DataType::Timestamp(TimeUnit::Second, tz_opt) => { - typed_cast_tz!( - array, - index, - TimestampSecondArray, - TimestampSecond, - tz_opt - ) - } - DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { - typed_cast_tz!( - array, - index, - TimestampMillisecondArray, - TimestampMillisecond, - tz_opt - ) - } - DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { - typed_cast_tz!( - array, - index, - TimestampMicrosecondArray, - TimestampMicrosecond, - tz_opt - ) - } - DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { - typed_cast_tz!( - array, - index, - TimestampNanosecondArray, - TimestampNanosecond, - tz_opt - ) + typed_cast!(array, index, Time64NanosecondArray, Time64Nanosecond)? } + DataType::Timestamp(TimeUnit::Second, tz_opt) => typed_cast_tz!( + array, + index, + TimestampSecondArray, + TimestampSecond, + tz_opt + )?, + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => typed_cast_tz!( + array, + index, + TimestampMillisecondArray, + TimestampMillisecond, + tz_opt + )?, + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => typed_cast_tz!( + array, + index, + TimestampMicrosecondArray, + TimestampMicrosecond, + tz_opt + )?, + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => typed_cast_tz!( + array, + index, + TimestampNanosecondArray, + TimestampNanosecond, + tz_opt + )?, DataType::Dictionary(key_type, _) => { let (values_array, values_index) = match key_type.as_ref() { - DataType::Int8 => get_dict_value::(array, index), - DataType::Int16 => get_dict_value::(array, index), - DataType::Int32 => get_dict_value::(array, index), - DataType::Int64 => get_dict_value::(array, index), - DataType::UInt8 => get_dict_value::(array, index), - DataType::UInt16 => get_dict_value::(array, index), - DataType::UInt32 => get_dict_value::(array, index), - DataType::UInt64 => get_dict_value::(array, index), + DataType::Int8 => get_dict_value::(array, index)?, + DataType::Int16 => get_dict_value::(array, index)?, + DataType::Int32 => get_dict_value::(array, index)?, + DataType::Int64 => get_dict_value::(array, index)?, + DataType::UInt8 => get_dict_value::(array, index)?, + DataType::UInt16 => get_dict_value::(array, index)?, + DataType::UInt32 => get_dict_value::(array, index)?, + DataType::UInt64 => get_dict_value::(array, index)?, _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), }; // look up the index in the values dictionary @@ -2173,31 +2258,29 @@ impl ScalarValue { ) } DataType::Interval(IntervalUnit::DayTime) => { - typed_cast!(array, index, IntervalDayTimeArray, IntervalDayTime) + typed_cast!(array, index, IntervalDayTimeArray, IntervalDayTime)? } DataType::Interval(IntervalUnit::YearMonth) => { - typed_cast!(array, index, IntervalYearMonthArray, IntervalYearMonth) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - typed_cast!( - array, - index, - IntervalMonthDayNanoArray, - IntervalMonthDayNano - ) + typed_cast!(array, index, IntervalYearMonthArray, IntervalYearMonth)? } + DataType::Interval(IntervalUnit::MonthDayNano) => typed_cast!( + array, + index, + IntervalMonthDayNanoArray, + IntervalMonthDayNano + )?, DataType::Duration(TimeUnit::Second) => { - typed_cast!(array, index, DurationSecondArray, DurationSecond) + typed_cast!(array, index, DurationSecondArray, DurationSecond)? } DataType::Duration(TimeUnit::Millisecond) => { - typed_cast!(array, index, DurationMillisecondArray, DurationMillisecond) + typed_cast!(array, index, DurationMillisecondArray, DurationMillisecond)? } DataType::Duration(TimeUnit::Microsecond) => { - typed_cast!(array, index, DurationMicrosecondArray, DurationMicrosecond) + typed_cast!(array, index, DurationMicrosecondArray, DurationMicrosecond)? } DataType::Duration(TimeUnit::Nanosecond) => { - typed_cast!(array, index, DurationNanosecondArray, DurationNanosecond) + typed_cast!(array, index, DurationNanosecondArray, DurationNanosecond)? } other => { @@ -2210,12 +2293,12 @@ impl ScalarValue { /// Try to parse `value` into a ScalarValue of type `target_type` pub fn try_from_string(value: String, target_type: &DataType) -> Result { - let value = ScalarValue::Utf8(Some(value)); + let value = ScalarValue::from(value); let cast_options = CastOptions { safe: false, format_options: Default::default(), }; - let cast_arr = cast_with_options(&value.to_array(), target_type, &cast_options)?; + let cast_arr = cast_with_options(&value.to_array()?, target_type, &cast_options)?; ScalarValue::try_from_array(&cast_arr, 0) } @@ -2273,9 +2356,19 @@ impl ScalarValue { /// /// This function has a few narrow usescases such as hash table key /// comparisons where comparing a single row at a time is necessary. + /// + /// # Errors + /// + /// Errors if + /// - it fails to downcast `array` to the data type of `self` + /// - `self` is a `Struct` + /// + /// # Panics + /// + /// Panics if `self` is a dictionary with invalid key type #[inline] - pub fn eq_array(&self, array: &ArrayRef, index: usize) -> bool { - match self { + pub fn eq_array(&self, array: &ArrayRef, index: usize) -> Result { + Ok(match self { ScalarValue::Decimal128(v, precision, scale) => { ScalarValue::eq_array_decimal( array, @@ -2283,8 +2376,7 @@ impl ScalarValue { v.as_ref(), *precision, *scale, - ) - .unwrap() + )? } ScalarValue::Decimal256(v, precision, scale) => { ScalarValue::eq_array_decimal256( @@ -2293,119 +2385,134 @@ impl ScalarValue { v.as_ref(), *precision, *scale, - ) - .unwrap() + )? } ScalarValue::Boolean(val) => { - eq_array_primitive!(array, index, BooleanArray, val) + eq_array_primitive!(array, index, BooleanArray, val)? } ScalarValue::Float32(val) => { - eq_array_primitive!(array, index, Float32Array, val) + eq_array_primitive!(array, index, Float32Array, val)? } ScalarValue::Float64(val) => { - eq_array_primitive!(array, index, Float64Array, val) + eq_array_primitive!(array, index, Float64Array, val)? + } + ScalarValue::Int8(val) => eq_array_primitive!(array, index, Int8Array, val)?, + ScalarValue::Int16(val) => { + eq_array_primitive!(array, index, Int16Array, val)? + } + ScalarValue::Int32(val) => { + eq_array_primitive!(array, index, Int32Array, val)? + } + ScalarValue::Int64(val) => { + eq_array_primitive!(array, index, Int64Array, val)? + } + ScalarValue::UInt8(val) => { + eq_array_primitive!(array, index, UInt8Array, val)? } - ScalarValue::Int8(val) => eq_array_primitive!(array, index, Int8Array, val), - ScalarValue::Int16(val) => eq_array_primitive!(array, index, Int16Array, val), - ScalarValue::Int32(val) => eq_array_primitive!(array, index, Int32Array, val), - ScalarValue::Int64(val) => eq_array_primitive!(array, index, Int64Array, val), - ScalarValue::UInt8(val) => eq_array_primitive!(array, index, UInt8Array, val), ScalarValue::UInt16(val) => { - eq_array_primitive!(array, index, UInt16Array, val) + eq_array_primitive!(array, index, UInt16Array, val)? } ScalarValue::UInt32(val) => { - eq_array_primitive!(array, index, UInt32Array, val) + eq_array_primitive!(array, index, UInt32Array, val)? } ScalarValue::UInt64(val) => { - eq_array_primitive!(array, index, UInt64Array, val) + eq_array_primitive!(array, index, UInt64Array, val)? + } + ScalarValue::Utf8(val) => { + eq_array_primitive!(array, index, StringArray, val)? } - ScalarValue::Utf8(val) => eq_array_primitive!(array, index, StringArray, val), ScalarValue::LargeUtf8(val) => { - eq_array_primitive!(array, index, LargeStringArray, val) + eq_array_primitive!(array, index, LargeStringArray, val)? } ScalarValue::Binary(val) => { - eq_array_primitive!(array, index, BinaryArray, val) + eq_array_primitive!(array, index, BinaryArray, val)? } ScalarValue::FixedSizeBinary(_, val) => { - eq_array_primitive!(array, index, FixedSizeBinaryArray, val) + eq_array_primitive!(array, index, FixedSizeBinaryArray, val)? } ScalarValue::LargeBinary(val) => { - eq_array_primitive!(array, index, LargeBinaryArray, val) + eq_array_primitive!(array, index, LargeBinaryArray, val)? + } + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => { + let right = array.slice(index, 1); + arr == &right } - ScalarValue::Fixedsizelist(..) => unimplemented!(), - ScalarValue::List(_) => unimplemented!("ListArr"), ScalarValue::Date32(val) => { - eq_array_primitive!(array, index, Date32Array, val) + eq_array_primitive!(array, index, Date32Array, val)? } ScalarValue::Date64(val) => { - eq_array_primitive!(array, index, Date64Array, val) + eq_array_primitive!(array, index, Date64Array, val)? } ScalarValue::Time32Second(val) => { - eq_array_primitive!(array, index, Time32SecondArray, val) + eq_array_primitive!(array, index, Time32SecondArray, val)? } ScalarValue::Time32Millisecond(val) => { - eq_array_primitive!(array, index, Time32MillisecondArray, val) + eq_array_primitive!(array, index, Time32MillisecondArray, val)? } ScalarValue::Time64Microsecond(val) => { - eq_array_primitive!(array, index, Time64MicrosecondArray, val) + eq_array_primitive!(array, index, Time64MicrosecondArray, val)? } ScalarValue::Time64Nanosecond(val) => { - eq_array_primitive!(array, index, Time64NanosecondArray, val) + eq_array_primitive!(array, index, Time64NanosecondArray, val)? } ScalarValue::TimestampSecond(val, _) => { - eq_array_primitive!(array, index, TimestampSecondArray, val) + eq_array_primitive!(array, index, TimestampSecondArray, val)? } ScalarValue::TimestampMillisecond(val, _) => { - eq_array_primitive!(array, index, TimestampMillisecondArray, val) + eq_array_primitive!(array, index, TimestampMillisecondArray, val)? } ScalarValue::TimestampMicrosecond(val, _) => { - eq_array_primitive!(array, index, TimestampMicrosecondArray, val) + eq_array_primitive!(array, index, TimestampMicrosecondArray, val)? } ScalarValue::TimestampNanosecond(val, _) => { - eq_array_primitive!(array, index, TimestampNanosecondArray, val) + eq_array_primitive!(array, index, TimestampNanosecondArray, val)? } ScalarValue::IntervalYearMonth(val) => { - eq_array_primitive!(array, index, IntervalYearMonthArray, val) + eq_array_primitive!(array, index, IntervalYearMonthArray, val)? } ScalarValue::IntervalDayTime(val) => { - eq_array_primitive!(array, index, IntervalDayTimeArray, val) + eq_array_primitive!(array, index, IntervalDayTimeArray, val)? } ScalarValue::IntervalMonthDayNano(val) => { - eq_array_primitive!(array, index, IntervalMonthDayNanoArray, val) + eq_array_primitive!(array, index, IntervalMonthDayNanoArray, val)? } ScalarValue::DurationSecond(val) => { - eq_array_primitive!(array, index, DurationSecondArray, val) + eq_array_primitive!(array, index, DurationSecondArray, val)? } ScalarValue::DurationMillisecond(val) => { - eq_array_primitive!(array, index, DurationMillisecondArray, val) + eq_array_primitive!(array, index, DurationMillisecondArray, val)? } ScalarValue::DurationMicrosecond(val) => { - eq_array_primitive!(array, index, DurationMicrosecondArray, val) + eq_array_primitive!(array, index, DurationMicrosecondArray, val)? } ScalarValue::DurationNanosecond(val) => { - eq_array_primitive!(array, index, DurationNanosecondArray, val) + eq_array_primitive!(array, index, DurationNanosecondArray, val)? + } + ScalarValue::Struct(_, _) => { + return _not_impl_err!("Struct is not supported yet") } - ScalarValue::Struct(_, _) => unimplemented!(), ScalarValue::Dictionary(key_type, v) => { let (values_array, values_index) = match key_type.as_ref() { - DataType::Int8 => get_dict_value::(array, index), - DataType::Int16 => get_dict_value::(array, index), - DataType::Int32 => get_dict_value::(array, index), - DataType::Int64 => get_dict_value::(array, index), - DataType::UInt8 => get_dict_value::(array, index), - DataType::UInt16 => get_dict_value::(array, index), - DataType::UInt32 => get_dict_value::(array, index), - DataType::UInt64 => get_dict_value::(array, index), + DataType::Int8 => get_dict_value::(array, index)?, + DataType::Int16 => get_dict_value::(array, index)?, + DataType::Int32 => get_dict_value::(array, index)?, + DataType::Int64 => get_dict_value::(array, index)?, + DataType::UInt8 => get_dict_value::(array, index)?, + DataType::UInt16 => get_dict_value::(array, index)?, + DataType::UInt32 => get_dict_value::(array, index)?, + DataType::UInt64 => get_dict_value::(array, index)?, _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), }; // was the value in the array non null? match values_index { - Some(values_index) => v.eq_array(values_array, values_index), + Some(values_index) => v.eq_array(values_array, values_index)?, None => v.is_null(), } } ScalarValue::Null => array.is_null(index), - } + }) } /// Estimate size if bytes including `Self`. For values with internal containers such as `String` @@ -2454,14 +2561,9 @@ impl ScalarValue { | ScalarValue::LargeBinary(b) => { b.as_ref().map(|b| b.capacity()).unwrap_or_default() } - ScalarValue::Fixedsizelist(vals, field, _) => { - vals.as_ref() - .map(|vals| Self::size_of_vec(vals) - std::mem::size_of_val(vals)) - .unwrap_or_default() - // `field` is boxed, so it is NOT already included in `self` - + field.size() - } - ScalarValue::List(arr) => arr.get_array_memory_size(), + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(), ScalarValue::Struct(vals, fields) => { vals.as_ref() .map(|vals| { @@ -2557,6 +2659,12 @@ impl FromStr for ScalarValue { } } +impl From for ScalarValue { + fn from(value: String) -> Self { + ScalarValue::Utf8(Some(value)) + } +} + impl From> for ScalarValue { fn from(value: Vec<(&str, ScalarValue)>) -> Self { let (fields, scalars): (SchemaBuilder, Vec<_>) = value @@ -2785,6 +2893,11 @@ macro_rules! format_option { }}; } +// Implement Display trait for ScalarValue +// +// # Panics +// +// Panics if there is an error when creating a visual representation of columns via `arrow::util::pretty` impl fmt::Display for ScalarValue { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -2824,23 +2937,16 @@ impl fmt::Display for ScalarValue { )?, None => write!(f, "NULL")?, }, - ScalarValue::Fixedsizelist(e, ..) => match e { - Some(l) => write!( - f, - "{}", - l.iter() - .map(|v| format!("{v}")) - .collect::>() - .join(",") - )?, - None => write!(f, "NULL")?, - }, - ScalarValue::List(arr) => write!( - f, - "{}", - arrow::util::pretty::pretty_format_columns("col", &[arr.to_owned()]) - .unwrap() - )?, + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => { + // ScalarValue List should always have a single element + assert_eq!(arr.len(), 1); + let options = FormatOptions::default().with_display_error(true); + let formatter = ArrayFormatter::try_new(arr, &options).unwrap(); + let value_formatter = formatter.value(0); + write!(f, "{value_formatter}")? + } ScalarValue::Date32(e) => format_option!(f, e)?, ScalarValue::Date64(e) => format_option!(f, e)?, ScalarValue::Time32Second(e) => format_option!(f, e)?, @@ -2915,8 +3021,9 @@ impl fmt::Debug for ScalarValue { } ScalarValue::LargeBinary(None) => write!(f, "LargeBinary({self})"), ScalarValue::LargeBinary(Some(_)) => write!(f, "LargeBinary(\"{self}\")"), - ScalarValue::Fixedsizelist(..) => write!(f, "FixedSizeList([{self}])"), - ScalarValue::List(arr) => write!(f, "List([{arr:?}])"), + ScalarValue::FixedSizeList(_) => write!(f, "FixedSizeList({self})"), + ScalarValue::List(_) => write!(f, "List({self})"), + ScalarValue::LargeList(_) => write!(f, "LargeList({self})"), ScalarValue::Date32(_) => write!(f, "Date32(\"{self}\")"), ScalarValue::Date64(_) => write!(f, "Date64(\"{self}\")"), ScalarValue::Time32Second(_) => write!(f, "Time32Second(\"{self}\")"), @@ -3007,21 +3114,23 @@ impl ScalarType for TimestampNanosecondType { #[cfg(test)] mod tests { + use super::*; + use std::cmp::Ordering; use std::sync::Arc; + use chrono::NaiveDate; + use rand::Rng; + + use arrow::buffer::OffsetBuffer; use arrow::compute::kernels; use arrow::compute::{concat, is_null}; use arrow::datatypes::ArrowPrimitiveType; use arrow::util::pretty::pretty_format_columns; use arrow_array::ArrowNumericType; - use chrono::NaiveDate; - use rand::Rng; use crate::cast::{as_string_array, as_uint32_array, as_uint64_array}; - use super::*; - #[test] fn test_to_array_of_size_for_list() { let arr = ListArray::from_iter_primitive::(vec![Some(vec![ @@ -3031,7 +3140,9 @@ mod tests { ])]); let sv = ScalarValue::List(Arc::new(arr)); - let actual_arr = sv.to_array_of_size(2); + let actual_arr = sv + .to_array_of_size(2) + .expect("Failed to convert to array of size"); let actual_list_arr = as_list_array(&actual_arr); let arr = ListArray::from_iter_primitive::(vec![ @@ -3042,12 +3153,33 @@ mod tests { assert_eq!(&arr, actual_list_arr); } + #[test] + fn test_to_array_of_size_for_fsl() { + let values = Int32Array::from_iter([Some(1), None, Some(2)]); + let field = Arc::new(Field::new("item", DataType::Int32, true)); + let arr = FixedSizeListArray::new(field.clone(), 3, Arc::new(values), None); + let sv = ScalarValue::FixedSizeList(Arc::new(arr)); + let actual_arr = sv + .to_array_of_size(2) + .expect("Failed to convert to array of size"); + + let expected_values = + Int32Array::from_iter([Some(1), None, Some(2), Some(1), None, Some(2)]); + let expected_arr = + FixedSizeListArray::new(field, 3, Arc::new(expected_values), None); + + assert_eq!( + &expected_arr, + as_fixed_size_list_array(actual_arr.as_ref()).unwrap() + ); + } + #[test] fn test_list_to_array_string() { let scalars = vec![ - ScalarValue::Utf8(Some(String::from("rust"))), - ScalarValue::Utf8(Some(String::from("arrow"))), - ScalarValue::Utf8(Some(String::from("data-fusion"))), + ScalarValue::from("rust"), + ScalarValue::from("arrow"), + ScalarValue::from("data-fusion"), ]; let array = ScalarValue::new_list(scalars.as_slice(), &DataType::Utf8); @@ -3061,28 +3193,77 @@ mod tests { assert_eq!(result, &expected); } + fn build_list( + values: Vec>>>, + ) -> Vec { + values + .into_iter() + .map(|v| { + let arr = if v.is_some() { + Arc::new( + GenericListArray::::from_iter_primitive::( + vec![v], + ), + ) + } else if O::IS_LARGE { + new_null_array( + &DataType::LargeList(Arc::new(Field::new( + "item", + DataType::Int64, + true, + ))), + 1, + ) + } else { + new_null_array( + &DataType::List(Arc::new(Field::new( + "item", + DataType::Int64, + true, + ))), + 1, + ) + }; + + if O::IS_LARGE { + ScalarValue::LargeList(arr) + } else { + ScalarValue::List(arr) + } + }) + .collect() + } + #[test] fn iter_to_array_primitive_test() { - let scalars = vec![ - ScalarValue::List(Arc::new( - ListArray::from_iter_primitive::(vec![Some(vec![ - Some(1), - Some(2), - Some(3), - ])]), - )), - ScalarValue::List(Arc::new( - ListArray::from_iter_primitive::(vec![Some(vec![ - Some(4), - Some(5), - ])]), - )), - ]; + // List[[1,2,3]], List[null], List[[4,5]] + let scalars = build_list::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5)]), + ]); let array = ScalarValue::iter_to_array(scalars).unwrap(); let list_array = as_list_array(&array); + // List[[1,2,3], null, [4,5]] let expected = ListArray::from_iter_primitive::(vec![ Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5)]), + ]); + assert_eq!(list_array, &expected); + + let scalars = build_list::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5)]), + ]); + + let array = ScalarValue::iter_to_array(scalars).unwrap(); + let list_array = as_large_list_array(&array); + let expected = LargeListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, Some(vec![Some(4), Some(5)]), ]); assert_eq!(list_array, &expected); @@ -3120,6 +3301,33 @@ mod tests { assert_eq!(result, &expected); } + #[test] + fn test_list_scalar_eq_to_array() { + let list_array: ArrayRef = + Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![None, Some(5)]), + ])); + + let fsl_array: ArrayRef = + Arc::new(FixedSizeListArray::from_iter_primitive::( + vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + ], + 3, + )); + + for arr in [list_array, fsl_array] { + for i in 0..arr.len() { + let scalar = ScalarValue::List(arr.slice(i, 1)); + assert!(scalar.eq_array(&arr, i).unwrap()); + } + } + } + #[test] fn scalar_add_trait_test() -> Result<()> { let float_value = ScalarValue::Float64(Some(123.)); @@ -3238,8 +3446,8 @@ mod tests { { let scalar_result = left.add_checked(&right); - let left_array = left.to_array(); - let right_array = right.to_array(); + let left_array = left.to_array().expect("Failed to convert to array"); + let right_array = right.to_array().expect("Failed to convert to array"); let arrow_left_array = left_array.as_primitive::(); let arrow_right_array = right_array.as_primitive::(); let arrow_result = kernels::numeric::add(arrow_left_array, arrow_right_array); @@ -3287,22 +3495,30 @@ mod tests { } // decimal scalar to array - let array = decimal_value.to_array(); + let array = decimal_value + .to_array() + .expect("Failed to convert to array"); let array = as_decimal128_array(&array)?; assert_eq!(1, array.len()); assert_eq!(DataType::Decimal128(10, 1), array.data_type().clone()); assert_eq!(123i128, array.value(0)); // decimal scalar to array with size - let array = decimal_value.to_array_of_size(10); + let array = decimal_value + .to_array_of_size(10) + .expect("Failed to convert to array of size"); let array_decimal = as_decimal128_array(&array)?; assert_eq!(10, array.len()); assert_eq!(DataType::Decimal128(10, 1), array.data_type().clone()); assert_eq!(123i128, array_decimal.value(0)); assert_eq!(123i128, array_decimal.value(9)); // test eq array - assert!(decimal_value.eq_array(&array, 1)); - assert!(decimal_value.eq_array(&array, 5)); + assert!(decimal_value + .eq_array(&array, 1) + .expect("Failed to compare arrays")); + assert!(decimal_value + .eq_array(&array, 5) + .expect("Failed to compare arrays")); // test try from array assert_eq!( decimal_value, @@ -3349,13 +3565,16 @@ mod tests { assert!(ScalarValue::try_new_decimal128(1, 10, 2) .unwrap() - .eq_array(&array, 0)); + .eq_array(&array, 0) + .expect("Failed to compare arrays")); assert!(ScalarValue::try_new_decimal128(2, 10, 2) .unwrap() - .eq_array(&array, 1)); + .eq_array(&array, 1) + .expect("Failed to compare arrays")); assert!(ScalarValue::try_new_decimal128(3, 10, 2) .unwrap() - .eq_array(&array, 2)); + .eq_array(&array, 2) + .expect("Failed to compare arrays")); assert_eq!( ScalarValue::Decimal128(None, 10, 2), ScalarValue::try_from_array(&array, 3).unwrap() @@ -3419,37 +3638,19 @@ mod tests { ])]), )); assert_eq!(a.partial_cmp(&b), Some(Ordering::Less)); - - let a = - ScalarValue::List(Arc::new( - ListArray::from_iter_primitive::(vec![ - Some(vec![Some(10), Some(2), Some(3)]), - None, - Some(vec![Some(10), Some(2), Some(3)]), - ]), - )); - let b = - ScalarValue::List(Arc::new( - ListArray::from_iter_primitive::(vec![ - Some(vec![Some(10), Some(2), Some(3)]), - None, - Some(vec![Some(10), Some(2), Some(3)]), - ]), - )); - assert_eq!(a.partial_cmp(&b), Some(Ordering::Equal)); } #[test] fn scalar_value_to_array_u64() -> Result<()> { let value = ScalarValue::UInt64(Some(13u64)); - let array = value.to_array(); + let array = value.to_array().expect("Failed to convert to array"); let array = as_uint64_array(&array)?; assert_eq!(array.len(), 1); assert!(!array.is_null(0)); assert_eq!(array.value(0), 13); let value = ScalarValue::UInt64(None); - let array = value.to_array(); + let array = value.to_array().expect("Failed to convert to array"); let array = as_uint64_array(&array)?; assert_eq!(array.len(), 1); assert!(array.is_null(0)); @@ -3459,14 +3660,14 @@ mod tests { #[test] fn scalar_value_to_array_u32() -> Result<()> { let value = ScalarValue::UInt32(Some(13u32)); - let array = value.to_array(); + let array = value.to_array().expect("Failed to convert to array"); let array = as_uint32_array(&array)?; assert_eq!(array.len(), 1); assert!(!array.is_null(0)); assert_eq!(array.value(0), 13); let value = ScalarValue::UInt32(None); - let array = value.to_array(); + let array = value.to_array().expect("Failed to convert to array"); let array = as_uint32_array(&array)?; assert_eq!(array.len(), 1); assert!(array.is_null(0)); @@ -3482,6 +3683,15 @@ mod tests { assert_eq!(list_array.values().len(), 0); } + #[test] + fn scalar_large_list_null_to_array() { + let list_array_ref = ScalarValue::new_large_list(&[], &DataType::UInt64); + let list_array = as_large_list_array(&list_array_ref); + + assert_eq!(list_array.len(), 1); + assert_eq!(list_array.values().len(), 0); + } + #[test] fn scalar_list_to_array() -> Result<()> { let values = vec![ @@ -3503,6 +3713,27 @@ mod tests { Ok(()) } + #[test] + fn scalar_large_list_to_array() -> Result<()> { + let values = vec![ + ScalarValue::UInt64(Some(100)), + ScalarValue::UInt64(None), + ScalarValue::UInt64(Some(101)), + ]; + let list_array_ref = ScalarValue::new_large_list(&values, &DataType::UInt64); + let list_array = as_large_list_array(&list_array_ref); + assert_eq!(list_array.len(), 1); + assert_eq!(list_array.values().len(), 3); + + let prim_array_ref = list_array.value(0); + let prim_array = as_uint64_array(&prim_array_ref)?; + assert_eq!(prim_array.len(), 3); + assert_eq!(prim_array.value(0), 100); + assert!(prim_array.is_null(1)); + assert_eq!(prim_array.value(2), 101); + Ok(()) + } + /// Creates array directly and via ScalarValue and ensures they are the same macro_rules! check_scalar_iter { ($SCALAR_T:ident, $ARRAYTYPE:ident, $INPUT:expr) => {{ @@ -4025,7 +4256,9 @@ mod tests { for (index, scalar) in scalars.into_iter().enumerate() { assert!( - scalar.eq_array(&array, index), + scalar + .eq_array(&array, index) + .expect("Failed to compare arrays"), "Expected {scalar:?} to be equal to {array:?} at index {index}" ); @@ -4033,7 +4266,7 @@ mod tests { for other_index in 0..array.len() { if index != other_index { assert!( - !scalar.eq_array(&array, other_index), + !scalar.eq_array(&array, other_index).expect("Failed to compare arrays"), "Expected {scalar:?} to be NOT equal to {array:?} at index {other_index}" ); } @@ -4088,6 +4321,16 @@ mod tests { ); } + #[test] + fn test_scalar_value_from_string() { + let scalar = ScalarValue::from("foo"); + assert_eq!(scalar, ScalarValue::Utf8(Some("foo".to_string()))); + let scalar = ScalarValue::from("foo".to_string()); + assert_eq!(scalar, ScalarValue::Utf8(Some("foo".to_string()))); + let scalar = ScalarValue::from_str("foo").unwrap(); + assert_eq!(scalar, ScalarValue::Utf8(Some("foo".to_string()))); + } + #[test] fn test_scalar_struct() { let field_a = Arc::new(Field::new("A", DataType::Int32, false)); @@ -4106,7 +4349,7 @@ mod tests { Some(vec![ ScalarValue::Int32(Some(23)), ScalarValue::Boolean(Some(false)), - ScalarValue::Utf8(Some("Hello".to_string())), + ScalarValue::from("Hello"), ScalarValue::from(vec![ ("e", ScalarValue::from(2i16)), ("f", ScalarValue::from(3i64)), @@ -4136,7 +4379,9 @@ mod tests { ); // Convert to length-2 array - let array = scalar.to_array_of_size(2); + let array = scalar + .to_array_of_size(2) + .expect("Failed to convert to array of size"); let expected = Arc::new(StructArray::from(vec![ ( @@ -4297,17 +4542,17 @@ mod tests { // Define struct scalars let s0 = ScalarValue::from(vec![ - ("A", ScalarValue::Utf8(Some(String::from("First")))), + ("A", ScalarValue::from("First")), ("primitive_list", l0), ]); let s1 = ScalarValue::from(vec![ - ("A", ScalarValue::Utf8(Some(String::from("Second")))), + ("A", ScalarValue::from("Second")), ("primitive_list", l1), ]); let s2 = ScalarValue::from(vec![ - ("A", ScalarValue::Utf8(Some(String::from("Third")))), + ("A", ScalarValue::from("Third")), ("primitive_list", l2), ]); @@ -4465,69 +4710,37 @@ mod tests { assert_eq!(array, &expected); } - #[test] - fn test_nested_lists() { - // Define inner list scalars - let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(1), - Some(2), - Some(3), - ])]); - let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(4), - Some(5), - ])]); - let l1 = ListArray::new( - Arc::new(Field::new( - "item", - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - true, - )), - OffsetBuffer::::from_lengths([1, 1]), - arrow::compute::concat(&[&a1, &a2]).unwrap(), - None, - ); - - let a1 = - ListArray::from_iter_primitive::(vec![Some(vec![Some(6)])]); - let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(7), - Some(8), - ])]); - let l2 = ListArray::new( - Arc::new(Field::new( - "item", - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - true, - )), - OffsetBuffer::::from_lengths([1, 1]), - arrow::compute::concat(&[&a1, &a2]).unwrap(), - None, - ); - - let a1 = - ListArray::from_iter_primitive::(vec![Some(vec![Some(9)])]); - let l3 = ListArray::new( + fn build_2d_list(data: Vec>) -> ListArray { + let a1 = ListArray::from_iter_primitive::(vec![Some(data)]); + ListArray::new( Arc::new(Field::new( "item", DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), true, )), OffsetBuffer::::from_lengths([1]), - arrow::compute::concat(&[&a1]).unwrap(), + Arc::new(a1), None, - ); + ) + } + + #[test] + fn test_nested_lists() { + // Define inner list scalars + let arr1 = build_2d_list(vec![Some(1), Some(2), Some(3)]); + let arr2 = build_2d_list(vec![Some(4), Some(5)]); + let arr3 = build_2d_list(vec![Some(6)]); let array = ScalarValue::iter_to_array(vec![ - ScalarValue::List(Arc::new(l1)), - ScalarValue::List(Arc::new(l2)), - ScalarValue::List(Arc::new(l3)), + ScalarValue::List(Arc::new(arr1)), + ScalarValue::List(Arc::new(arr2)), + ScalarValue::List(Arc::new(arr3)), ]) .unwrap(); let array = as_list_array(&array); // Construct expected array with array builders - let inner_builder = Int32Array::builder(8); + let inner_builder = Int32Array::builder(6); let middle_builder = ListBuilder::new(inner_builder); let mut outer_builder = ListBuilder::new(middle_builder); @@ -4535,6 +4748,7 @@ mod tests { outer_builder.values().values().append_value(2); outer_builder.values().values().append_value(3); outer_builder.values().append(true); + outer_builder.append(true); outer_builder.values().values().append_value(4); outer_builder.values().values().append_value(5); @@ -4543,14 +4757,6 @@ mod tests { outer_builder.values().values().append_value(6); outer_builder.values().append(true); - - outer_builder.values().values().append_value(7); - outer_builder.values().values().append_value(8); - outer_builder.values().append(true); - outer_builder.append(true); - - outer_builder.values().values().append_value(9); - outer_builder.values().append(true); outer_builder.append(true); let expected = outer_builder.finish(); @@ -4570,7 +4776,7 @@ mod tests { DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())) ); - let array = scalar.to_array(); + let array = scalar.to_array().expect("Failed to convert to array"); assert_eq!(array.len(), 1); assert_eq!( array.data_type(), @@ -4594,7 +4800,7 @@ mod tests { check_scalar_cast(ScalarValue::Float64(None), DataType::Int16); check_scalar_cast( - ScalarValue::Utf8(Some("foo".to_string())), + ScalarValue::from("foo"), DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), ); @@ -4607,7 +4813,7 @@ mod tests { // mimics how casting work on scalar values by `casting` `scalar` to `desired_type` fn check_scalar_cast(scalar: ScalarValue, desired_type: DataType) { // convert from scalar --> Array to call cast - let scalar_array = scalar.to_array(); + let scalar_array = scalar.to_array().expect("Failed to convert to array"); // cast the actual value let cast_array = kernels::cast::cast(&scalar_array, &desired_type).unwrap(); @@ -4616,7 +4822,9 @@ mod tests { assert_eq!(cast_scalar.data_type(), desired_type); // Some time later the "cast" scalar is turned back into an array: - let array = cast_scalar.to_array_of_size(10); + let array = cast_scalar + .to_array_of_size(10) + .expect("Failed to convert to array of size"); // The datatype should be "Dictionary" but is actually Utf8!!! assert_eq!(array.data_type(), &desired_type) @@ -4873,10 +5081,7 @@ mod tests { (ScalarValue::Int8(None), ScalarValue::Int16(Some(1))), (ScalarValue::Int8(Some(1)), ScalarValue::Int16(None)), // Unsupported types - ( - ScalarValue::Utf8(Some("foo".to_string())), - ScalarValue::Utf8(Some("bar".to_string())), - ), + (ScalarValue::from("foo"), ScalarValue::from("bar")), ( ScalarValue::Boolean(Some(true)), ScalarValue::Boolean(Some(false)), @@ -5065,7 +5270,8 @@ mod tests { let arrays = scalars .iter() .map(ScalarValue::to_array) - .collect::>(); + .collect::>>() + .expect("Failed to convert to array"); let arrays = arrays.iter().map(|a| a.as_ref()).collect::>(); let array = concat(&arrays).unwrap(); check_array(array); diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index fbf639a321827..7ad8992ca9aec 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -151,6 +151,15 @@ impl Precision { (_, _) => Precision::Absent, } } + + /// Return the estimate of applying a filter with estimated selectivity + /// `selectivity` to this Precision. A selectivity of `1.0` means that all + /// rows are selected. A selectivity of `0.5` means half the rows are + /// selected. Will always return inexact statistics. + pub fn with_estimated_selectivity(self, selectivity: f64) -> Self { + self.map(|v| ((v as f64 * selectivity).ceil()) as usize) + .to_inexact() + } } impl Precision { @@ -257,7 +266,44 @@ impl Statistics { impl Display for Statistics { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Rows={}, Bytes={}", self.num_rows, self.total_byte_size)?; + // string of column statistics + let column_stats = self + .column_statistics + .iter() + .enumerate() + .map(|(i, cs)| { + let s = format!("(Col[{}]:", i); + let s = if cs.min_value != Precision::Absent { + format!("{} Min={}", s, cs.min_value) + } else { + s + }; + let s = if cs.max_value != Precision::Absent { + format!("{} Max={}", s, cs.max_value) + } else { + s + }; + let s = if cs.null_count != Precision::Absent { + format!("{} Null={}", s, cs.null_count) + } else { + s + }; + let s = if cs.distinct_count != Precision::Absent { + format!("{} Distinct={}", s, cs.distinct_count) + } else { + s + }; + + s + ")" + }) + .collect::>() + .join(","); + + write!( + f, + "Rows={}, Bytes={}, [{}]", + self.num_rows, self.total_byte_size, column_stats + )?; Ok(()) } @@ -279,9 +325,11 @@ pub struct ColumnStatistics { impl ColumnStatistics { /// Column contains a single non null value (e.g constant). pub fn is_singleton(&self) -> bool { - match (self.min_value.get_value(), self.max_value.get_value()) { + match (&self.min_value, &self.max_value) { // Min and max values are the same and not infinity. - (Some(min), Some(max)) => !min.is_null() && !max.is_null() && (min == max), + (Precision::Exact(min), Precision::Exact(max)) => { + !min.is_null() && !max.is_null() && (min == max) + } (_, _) => false, } } diff --git a/datafusion/common/src/test_util.rs b/datafusion/common/src/test_util.rs index 9a44337821570..eeace97eebfa3 100644 --- a/datafusion/common/src/test_util.rs +++ b/datafusion/common/src/test_util.rs @@ -285,6 +285,7 @@ mod tests { } #[test] + #[cfg(feature = "parquet")] fn test_happy() { let res = arrow_test_data(); assert!(PathBuf::from(res).is_dir()); diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 2919d9a39c9c8..5f11c8cc1d11b 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -18,6 +18,7 @@ //! This module provides common traits for visiting or rewriting tree //! data structures easily. +use std::borrow::Cow; use std::sync::Arc; use crate::Result; @@ -32,7 +33,10 @@ use crate::Result; /// [`PhysicalExpr`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.PhysicalExpr.html /// [`LogicalPlan`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/logical_plan/enum.LogicalPlan.html /// [`Expr`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/expr/enum.Expr.html -pub trait TreeNode: Sized { +pub trait TreeNode: Sized + Clone { + /// Returns all children of the TreeNode + fn children_nodes(&self) -> Vec>; + /// Use preorder to iterate the node on the tree so that we can /// stop fast for some cases. /// @@ -125,6 +129,17 @@ pub trait TreeNode: Sized { after_op.map_children(|node| node.transform_down(op)) } + /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its + /// children(Preorder Traversal) using a mutable function, `F`. + /// When the `op` does not apply to a given node, it is left unchanged. + fn transform_down_mut(self, op: &mut F) -> Result + where + F: FnMut(Self) -> Result>, + { + let after_op = op(self)?.into(); + after_op.map_children(|node| node.transform_down_mut(op)) + } + /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its /// children and then itself(Postorder Traversal). /// When the `op` does not apply to a given node, it is left unchanged. @@ -138,6 +153,19 @@ pub trait TreeNode: Sized { Ok(new_node) } + /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its + /// children and then itself(Postorder Traversal) using a mutable function, `F`. + /// When the `op` does not apply to a given node, it is left unchanged. + fn transform_up_mut(self, op: &mut F) -> Result + where + F: FnMut(Self) -> Result>, + { + let after_op_children = self.map_children(|node| node.transform_up_mut(op))?; + + let new_node = op(after_op_children)?.into(); + Ok(new_node) + } + /// Transform the tree node using the given [TreeNodeRewriter] /// It performs a depth first walk of an node and its children. /// @@ -187,7 +215,17 @@ pub trait TreeNode: Sized { /// Apply the closure `F` to the node's children fn apply_children(&self, op: &mut F) -> Result where - F: FnMut(&Self) -> Result; + F: FnMut(&Self) -> Result, + { + for child in self.children_nodes() { + match op(&child)? { + VisitRecursion::Continue => {} + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + } + } + Ok(VisitRecursion::Continue) + } /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) fn map_children(self, transform: F) -> Result @@ -318,19 +356,8 @@ pub trait DynTreeNode { /// Blanket implementation for Arc for any tye that implements /// [`DynTreeNode`] (such as [`Arc`]) impl TreeNode for Arc { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in self.arc_children() { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.arc_children().into_iter().map(Cow::Owned).collect() } fn map_children(self, transform: F) -> Result diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index f031f7880436b..0a61fce15482d 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -17,15 +17,16 @@ //! This module provides the bisect function, which implements binary search. -use crate::error::_internal_err; -use crate::{DataFusionError, Result, ScalarValue}; +use crate::error::{_internal_datafusion_err, _internal_err}; +use crate::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; use arrow::array::{ArrayRef, PrimitiveArray}; use arrow::buffer::OffsetBuffer; use arrow::compute; use arrow::compute::{partition, SortColumn, SortOptions}; use arrow::datatypes::{Field, SchemaRef, UInt32Type}; use arrow::record_batch::RecordBatch; -use arrow_array::{Array, ListArray}; +use arrow_array::{Array, LargeListArray, ListArray, RecordBatchOptions}; +use arrow_schema::DataType; use sqlparser::ast::Ident; use sqlparser::dialect::GenericDialect; use sqlparser::parser::Parser; @@ -89,8 +90,12 @@ pub fn get_record_batch_at_indices( indices: &PrimitiveArray, ) -> Result { let new_columns = get_arrayref_at_indices(record_batch.columns(), indices)?; - RecordBatch::try_new(record_batch.schema(), new_columns) - .map_err(DataFusionError::ArrowError) + RecordBatch::try_new_with_options( + record_batch.schema(), + new_columns, + &RecordBatchOptions::new().with_row_count(Some(indices.len())), + ) + .map_err(|e| arrow_datafusion_err!(e)) } /// This function compares two tuples depending on the given sort options. @@ -112,7 +117,7 @@ pub fn compare_rows( lhs.partial_cmp(rhs) } .ok_or_else(|| { - DataFusionError::Internal("Column array shouldn't be empty".to_string()) + _internal_datafusion_err!("Column array shouldn't be empty") })?, (true, true, _) => continue, }; @@ -134,7 +139,7 @@ pub fn bisect( ) -> Result { let low: usize = 0; let high: usize = item_columns - .get(0) + .first() .ok_or_else(|| { DataFusionError::Internal("Column array shouldn't be empty".to_string()) })? @@ -185,7 +190,7 @@ pub fn linear_search( ) -> Result { let low: usize = 0; let high: usize = item_columns - .get(0) + .first() .ok_or_else(|| { DataFusionError::Internal("Column array shouldn't be empty".to_string()) })? @@ -286,7 +291,7 @@ pub fn get_arrayref_at_indices( indices, None, // None: no index check ) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) }) .collect() } @@ -337,6 +342,8 @@ pub fn longest_consecutive_prefix>( count } +/// Array Utils + /// Wrap an array into a single element `ListArray`. /// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` pub fn array_into_list_array(arr: ArrayRef) -> ListArray { @@ -349,6 +356,18 @@ pub fn array_into_list_array(arr: ArrayRef) -> ListArray { ) } +/// Wrap an array into a single element `LargeListArray`. +/// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` +pub fn array_into_large_list_array(arr: ArrayRef) -> LargeListArray { + let offsets = OffsetBuffer::from_lengths([arr.len()]); + LargeListArray::new( + Arc::new(Field::new("item", arr.data_type().to_owned(), true)), + offsets, + arr, + None, + ) +} + /// Wrap arrays into a single element `ListArray`. /// /// Example: @@ -390,6 +409,89 @@ pub fn arrays_into_list_array( )) } +/// Get the base type of a data type. +/// +/// Example +/// ``` +/// use arrow::datatypes::{DataType, Field}; +/// use datafusion_common::utils::base_type; +/// use std::sync::Arc; +/// +/// let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); +/// assert_eq!(base_type(&data_type), DataType::Int32); +/// +/// let data_type = DataType::Int32; +/// assert_eq!(base_type(&data_type), DataType::Int32); +/// ``` +pub fn base_type(data_type: &DataType) -> DataType { + match data_type { + DataType::List(field) | DataType::LargeList(field) => { + base_type(field.data_type()) + } + _ => data_type.to_owned(), + } +} + +/// A helper function to coerce base type in List. +/// +/// Example +/// ``` +/// use arrow::datatypes::{DataType, Field}; +/// use datafusion_common::utils::coerced_type_with_base_type_only; +/// use std::sync::Arc; +/// +/// let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); +/// let base_type = DataType::Float64; +/// let coerced_type = coerced_type_with_base_type_only(&data_type, &base_type); +/// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new("item", DataType::Float64, true)))); +pub fn coerced_type_with_base_type_only( + data_type: &DataType, + base_type: &DataType, +) -> DataType { + match data_type { + DataType::List(field) => { + let data_type = match field.data_type() { + DataType::List(_) => { + coerced_type_with_base_type_only(field.data_type(), base_type) + } + _ => base_type.to_owned(), + }; + + DataType::List(Arc::new(Field::new( + field.name(), + data_type, + field.is_nullable(), + ))) + } + DataType::LargeList(field) => { + let data_type = match field.data_type() { + DataType::LargeList(_) => { + coerced_type_with_base_type_only(field.data_type(), base_type) + } + _ => base_type.to_owned(), + }; + + DataType::LargeList(Arc::new(Field::new( + field.name(), + data_type, + field.is_nullable(), + ))) + } + + _ => base_type.clone(), + } +} + +/// Compute the number of dimensions in a list data type. +pub fn list_ndims(data_type: &DataType) -> u64 { + match data_type { + DataType::List(field) | DataType::LargeList(field) => { + 1 + list_ndims(field.data_type()) + } + _ => 0, + } +} + /// An extension trait for smart pointers. Provides an interface to get a /// raw pointer to the data (with metadata stripped away). /// diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index b44914ec719fb..9de6a7f7d6a0b 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -55,6 +55,7 @@ ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] apache-avro = { version = "0.16", optional = true } arrow = { workspace = true } arrow-array = { workspace = true } +arrow-ipc = { workspace = true } arrow-schema = { workspace = true } async-compression = { version = "0.4.0", features = ["bzip2", "gzip", "xz", "zstd", "futures-io", "tokio"], optional = true } async-trait = { workspace = true } @@ -62,11 +63,11 @@ bytes = { workspace = true } bzip2 = { version = "0.4.3", optional = true } chrono = { workspace = true } dashmap = { workspace = true } -datafusion-common = { path = "../common", version = "33.0.0", features = ["object_store"], default-features = false } +datafusion-common = { path = "../common", version = "34.0.0", features = ["object_store"], default-features = false } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } -datafusion-optimizer = { path = "../optimizer", version = "33.0.0", default-features = false } -datafusion-physical-expr = { path = "../physical-expr", version = "33.0.0", default-features = false } +datafusion-optimizer = { path = "../optimizer", version = "34.0.0", default-features = false } +datafusion-physical-expr = { path = "../physical-expr", version = "34.0.0", default-features = false } datafusion-physical-plan = { workspace = true } datafusion-sql = { workspace = true } flate2 = { version = "1.0.24", optional = true } @@ -81,7 +82,7 @@ num-traits = { version = "0.2", optional = true } num_cpus = { workspace = true } object_store = { workspace = true } parking_lot = { workspace = true } -parquet = { workspace = true, optional = true } +parquet = { workspace = true, optional = true, default-features = true } pin-project-lite = "^0.2.7" rand = { workspace = true } sqlparser = { workspace = true } @@ -120,6 +121,10 @@ nix = { version = "0.27.1", features = ["fs"] } harness = false name = "aggregate_query_sql" +[[bench]] +harness = false +name = "distinct_query_sql" + [[bench]] harness = false name = "sort_limit_query_sql" @@ -163,3 +168,7 @@ name = "sort" [[bench]] harness = false name = "topk_aggregate" + +[[bench]] +harness = false +name = "array_expression" diff --git a/datafusion/core/benches/array_expression.rs b/datafusion/core/benches/array_expression.rs new file mode 100644 index 0000000000000..95bc93e0e353a --- /dev/null +++ b/datafusion/core/benches/array_expression.rs @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate criterion; +extern crate arrow; +extern crate datafusion; + +mod data_utils; +use crate::criterion::Criterion; +use arrow_array::cast::AsArray; +use arrow_array::types::Int64Type; +use arrow_array::{ArrayRef, Int64Array, ListArray}; +use datafusion_physical_expr::array_expressions; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + // Construct large arrays for benchmarking + + let array_len = 100000000; + + let array = (0..array_len).map(|_| Some(2_i64)).collect::>(); + let list_array = ListArray::from_iter_primitive::(vec![ + Some(array.clone()), + Some(array.clone()), + Some(array), + ]); + let from_array = Int64Array::from_value(2, 3); + let to_array = Int64Array::from_value(-2, 3); + + let args = vec![ + Arc::new(list_array) as ArrayRef, + Arc::new(from_array) as ArrayRef, + Arc::new(to_array) as ArrayRef, + ]; + + let array = (0..array_len).map(|_| Some(-2_i64)).collect::>(); + let expected_array = ListArray::from_iter_primitive::(vec![ + Some(array.clone()), + Some(array.clone()), + Some(array), + ]); + + // Benchmark array functions + + c.bench_function("array_replace", |b| { + b.iter(|| { + assert_eq!( + array_expressions::array_replace_all(args.as_slice()) + .unwrap() + .as_list::(), + criterion::black_box(&expected_array) + ) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/core/benches/data_utils/mod.rs b/datafusion/core/benches/data_utils/mod.rs index 64c0e4b100a14..9d2864919225a 100644 --- a/datafusion/core/benches/data_utils/mod.rs +++ b/datafusion/core/benches/data_utils/mod.rs @@ -25,11 +25,16 @@ use arrow::{ datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::RecordBatch, }; +use arrow_array::builder::{Int64Builder, StringBuilder}; use datafusion::datasource::MemTable; use datafusion::error::Result; +use datafusion_common::DataFusionError; use rand::rngs::StdRng; use rand::seq::SliceRandom; use rand::{Rng, SeedableRng}; +use rand_distr::Distribution; +use rand_distr::{Normal, Pareto}; +use std::fmt::Write; use std::sync::Arc; /// create an in-memory table given the partition len, array len, and batch size, @@ -156,3 +161,83 @@ pub fn create_record_batches( }) .collect::>() } + +/// Create time series data with `partition_cnt` partitions and `sample_cnt` rows per partition +/// in ascending order, if `asc` is true, otherwise randomly sampled using a Pareto distribution +#[allow(dead_code)] +pub(crate) fn make_data( + partition_cnt: i32, + sample_cnt: i32, + asc: bool, +) -> Result<(Arc, Vec>), DataFusionError> { + // constants observed from trace data + let simultaneous_group_cnt = 2000; + let fitted_shape = 12f64; + let fitted_scale = 5f64; + let mean = 0.1; + let stddev = 1.1; + let pareto = Pareto::new(fitted_scale, fitted_shape).unwrap(); + let normal = Normal::new(mean, stddev).unwrap(); + let mut rng = rand::rngs::SmallRng::from_seed([0; 32]); + + // populate data + let schema = test_schema(); + let mut partitions = vec![]; + let mut cur_time = 16909000000000i64; + for _ in 0..partition_cnt { + let mut id_builder = StringBuilder::new(); + let mut ts_builder = Int64Builder::new(); + let gen_id = |rng: &mut rand::rngs::SmallRng| { + rng.gen::<[u8; 16]>() + .iter() + .fold(String::new(), |mut output, b| { + let _ = write!(output, "{b:02X}"); + output + }) + }; + let gen_sample_cnt = + |mut rng: &mut rand::rngs::SmallRng| pareto.sample(&mut rng).ceil() as u32; + let mut group_ids = (0..simultaneous_group_cnt) + .map(|_| gen_id(&mut rng)) + .collect::>(); + let mut group_sample_cnts = (0..simultaneous_group_cnt) + .map(|_| gen_sample_cnt(&mut rng)) + .collect::>(); + for _ in 0..sample_cnt { + let random_index = rng.gen_range(0..simultaneous_group_cnt); + let trace_id = &mut group_ids[random_index]; + let sample_cnt = &mut group_sample_cnts[random_index]; + *sample_cnt -= 1; + if *sample_cnt == 0 { + *trace_id = gen_id(&mut rng); + *sample_cnt = gen_sample_cnt(&mut rng); + } + + id_builder.append_value(trace_id); + ts_builder.append_value(cur_time); + + if asc { + cur_time += 1; + } else { + let samp: f64 = normal.sample(&mut rng); + let samp = samp.round(); + cur_time += samp as i64; + } + } + + // convert to MemTable + let id_col = Arc::new(id_builder.finish()); + let ts_col = Arc::new(ts_builder.finish()); + let batch = RecordBatch::try_new(schema.clone(), vec![id_col, ts_col])?; + partitions.push(vec![batch]); + } + Ok((schema, partitions)) +} + +/// The Schema used by make_data +fn test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("trace_id", DataType::Utf8, false), + Field::new("timestamp_ms", DataType::Int64, false), + ])) +} diff --git a/datafusion/core/benches/distinct_query_sql.rs b/datafusion/core/benches/distinct_query_sql.rs new file mode 100644 index 0000000000000..c242798a56f00 --- /dev/null +++ b/datafusion/core/benches/distinct_query_sql.rs @@ -0,0 +1,208 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate criterion; +extern crate arrow; +extern crate datafusion; + +mod data_utils; +use crate::criterion::Criterion; +use data_utils::{create_table_provider, make_data}; +use datafusion::execution::context::SessionContext; +use datafusion::physical_plan::{collect, ExecutionPlan}; +use datafusion::{datasource::MemTable, error::Result}; +use datafusion_execution::config::SessionConfig; +use datafusion_execution::TaskContext; + +use parking_lot::Mutex; +use std::{sync::Arc, time::Duration}; +use tokio::runtime::Runtime; + +fn query(ctx: Arc>, sql: &str) { + let rt = Runtime::new().unwrap(); + let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); + criterion::black_box(rt.block_on(df.collect()).unwrap()); +} + +fn create_context( + partitions_len: usize, + array_len: usize, + batch_size: usize, +) -> Result>> { + let ctx = SessionContext::new(); + let provider = create_table_provider(partitions_len, array_len, batch_size)?; + ctx.register_table("t", provider)?; + Ok(Arc::new(Mutex::new(ctx))) +} + +fn criterion_benchmark_limited_distinct(c: &mut Criterion) { + let partitions_len = 10; + let array_len = 1 << 26; // 64 M + let batch_size = 8192; + let ctx = create_context(partitions_len, array_len, batch_size).unwrap(); + + let mut group = c.benchmark_group("custom-measurement-time"); + group.measurement_time(Duration::from_secs(40)); + + group.bench_function("distinct_group_by_u64_narrow_limit_10", |b| { + b.iter(|| { + query( + ctx.clone(), + "SELECT DISTINCT u64_narrow FROM t GROUP BY u64_narrow LIMIT 10", + ) + }) + }); + + group.bench_function("distinct_group_by_u64_narrow_limit_100", |b| { + b.iter(|| { + query( + ctx.clone(), + "SELECT DISTINCT u64_narrow FROM t GROUP BY u64_narrow LIMIT 100", + ) + }) + }); + + group.bench_function("distinct_group_by_u64_narrow_limit_1000", |b| { + b.iter(|| { + query( + ctx.clone(), + "SELECT DISTINCT u64_narrow FROM t GROUP BY u64_narrow LIMIT 1000", + ) + }) + }); + + group.bench_function("distinct_group_by_u64_narrow_limit_10000", |b| { + b.iter(|| { + query( + ctx.clone(), + "SELECT DISTINCT u64_narrow FROM t GROUP BY u64_narrow LIMIT 10000", + ) + }) + }); + + group.bench_function("group_by_multiple_columns_limit_10", |b| { + b.iter(|| { + query( + ctx.clone(), + "SELECT u64_narrow, u64_wide, utf8, f64 FROM t GROUP BY 1, 2, 3, 4 LIMIT 10", + ) + }) + }); + group.finish(); +} + +async fn distinct_with_limit( + plan: Arc, + ctx: Arc, +) -> Result<()> { + let batches = collect(plan, ctx).await?; + assert_eq!(batches.len(), 1); + let batch = batches.first().unwrap(); + assert_eq!(batch.num_rows(), 10); + + Ok(()) +} + +fn run(plan: Arc, ctx: Arc) { + let rt = Runtime::new().unwrap(); + criterion::black_box( + rt.block_on(async { distinct_with_limit(plan.clone(), ctx.clone()).await }), + ) + .unwrap(); +} + +pub async fn create_context_sampled_data( + sql: &str, + partition_cnt: i32, + sample_cnt: i32, +) -> Result<(Arc, Arc)> { + let (schema, parts) = make_data(partition_cnt, sample_cnt, false /* asc */).unwrap(); + let mem_table = Arc::new(MemTable::try_new(schema, parts).unwrap()); + + // Create the DataFrame + let cfg = SessionConfig::new(); + let ctx = SessionContext::new_with_config(cfg); + let _ = ctx.register_table("traces", mem_table)?; + let df = ctx.sql(sql).await?; + let physical_plan = df.create_physical_plan().await?; + Ok((physical_plan, ctx.task_ctx())) +} + +fn criterion_benchmark_limited_distinct_sampled(c: &mut Criterion) { + let rt = Runtime::new().unwrap(); + + let limit = 10; + let partitions = 100; + let samples = 100_000; + let sql = + format!("select DISTINCT trace_id from traces group by trace_id limit {limit};"); + + let distinct_trace_id_100_partitions_100_000_samples_limit_100 = rt.block_on(async { + create_context_sampled_data(sql.as_str(), partitions, samples) + .await + .unwrap() + }); + + c.bench_function( + format!("distinct query with {} partitions and {} samples per partition with limit {}", partitions, samples, limit).as_str(), + |b| b.iter(|| run(distinct_trace_id_100_partitions_100_000_samples_limit_100.0.clone(), + distinct_trace_id_100_partitions_100_000_samples_limit_100.1.clone())), + ); + + let partitions = 10; + let samples = 1_000_000; + let sql = + format!("select DISTINCT trace_id from traces group by trace_id limit {limit};"); + + let distinct_trace_id_10_partitions_1_000_000_samples_limit_10 = rt.block_on(async { + create_context_sampled_data(sql.as_str(), partitions, samples) + .await + .unwrap() + }); + + c.bench_function( + format!("distinct query with {} partitions and {} samples per partition with limit {}", partitions, samples, limit).as_str(), + |b| b.iter(|| run(distinct_trace_id_10_partitions_1_000_000_samples_limit_10.0.clone(), + distinct_trace_id_10_partitions_1_000_000_samples_limit_10.1.clone())), + ); + + let partitions = 1; + let samples = 10_000_000; + let sql = + format!("select DISTINCT trace_id from traces group by trace_id limit {limit};"); + + let rt = Runtime::new().unwrap(); + let distinct_trace_id_1_partition_10_000_000_samples_limit_10 = rt.block_on(async { + create_context_sampled_data(sql.as_str(), partitions, samples) + .await + .unwrap() + }); + + c.bench_function( + format!("distinct query with {} partitions and {} samples per partition with limit {}", partitions, samples, limit).as_str(), + |b| b.iter(|| run(distinct_trace_id_1_partition_10_000_000_samples_limit_10.0.clone(), + distinct_trace_id_1_partition_10_000_000_samples_limit_10.1.clone())), + ); +} + +criterion_group!( + benches, + criterion_benchmark_limited_distinct, + criterion_benchmark_limited_distinct_sampled +); +criterion_main!(benches); diff --git a/datafusion/core/benches/scalar.rs b/datafusion/core/benches/scalar.rs index 30f21a964d5f7..540f7212e96e9 100644 --- a/datafusion/core/benches/scalar.rs +++ b/datafusion/core/benches/scalar.rs @@ -22,7 +22,15 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("to_array_of_size 100000", |b| { let scalar = ScalarValue::Int32(Some(100)); - b.iter(|| assert_eq!(scalar.to_array_of_size(100000).null_count(), 0)) + b.iter(|| { + assert_eq!( + scalar + .to_array_of_size(100000) + .expect("Failed to convert to array of size") + .null_count(), + 0 + ) + }) }); } diff --git a/datafusion/core/benches/sort_limit_query_sql.rs b/datafusion/core/benches/sort_limit_query_sql.rs index efed5a04e7a5e..cfd4b8bc4bba8 100644 --- a/datafusion/core/benches/sort_limit_query_sql.rs +++ b/datafusion/core/benches/sort_limit_query_sql.rs @@ -99,7 +99,7 @@ fn create_context() -> Arc> { ctx_holder.lock().push(Arc::new(Mutex::new(ctx))) }); - let ctx = ctx_holder.lock().get(0).unwrap().clone(); + let ctx = ctx_holder.lock().first().unwrap().clone(); ctx } diff --git a/datafusion/core/benches/sql_planner.rs b/datafusion/core/benches/sql_planner.rs index 7a41b6bec6f50..1754129a768fa 100644 --- a/datafusion/core/benches/sql_planner.rs +++ b/datafusion/core/benches/sql_planner.rs @@ -60,6 +60,104 @@ pub fn create_table_provider(column_prefix: &str, num_columns: usize) -> Arc [(String, Schema); 8] { + let lineitem_schema = Schema::new(vec![ + Field::new("l_orderkey", DataType::Int64, false), + Field::new("l_partkey", DataType::Int64, false), + Field::new("l_suppkey", DataType::Int64, false), + Field::new("l_linenumber", DataType::Int32, false), + Field::new("l_quantity", DataType::Decimal128(15, 2), false), + Field::new("l_extendedprice", DataType::Decimal128(15, 2), false), + Field::new("l_discount", DataType::Decimal128(15, 2), false), + Field::new("l_tax", DataType::Decimal128(15, 2), false), + Field::new("l_returnflag", DataType::Utf8, false), + Field::new("l_linestatus", DataType::Utf8, false), + Field::new("l_shipdate", DataType::Date32, false), + Field::new("l_commitdate", DataType::Date32, false), + Field::new("l_receiptdate", DataType::Date32, false), + Field::new("l_shipinstruct", DataType::Utf8, false), + Field::new("l_shipmode", DataType::Utf8, false), + Field::new("l_comment", DataType::Utf8, false), + ]); + + let orders_schema = Schema::new(vec![ + Field::new("o_orderkey", DataType::Int64, false), + Field::new("o_custkey", DataType::Int64, false), + Field::new("o_orderstatus", DataType::Utf8, false), + Field::new("o_totalprice", DataType::Decimal128(15, 2), false), + Field::new("o_orderdate", DataType::Date32, false), + Field::new("o_orderpriority", DataType::Utf8, false), + Field::new("o_clerk", DataType::Utf8, false), + Field::new("o_shippriority", DataType::Int32, false), + Field::new("o_comment", DataType::Utf8, false), + ]); + + let part_schema = Schema::new(vec![ + Field::new("p_partkey", DataType::Int64, false), + Field::new("p_name", DataType::Utf8, false), + Field::new("p_mfgr", DataType::Utf8, false), + Field::new("p_brand", DataType::Utf8, false), + Field::new("p_type", DataType::Utf8, false), + Field::new("p_size", DataType::Int32, false), + Field::new("p_container", DataType::Utf8, false), + Field::new("p_retailprice", DataType::Decimal128(15, 2), false), + Field::new("p_comment", DataType::Utf8, false), + ]); + + let supplier_schema = Schema::new(vec![ + Field::new("s_suppkey", DataType::Int64, false), + Field::new("s_name", DataType::Utf8, false), + Field::new("s_address", DataType::Utf8, false), + Field::new("s_nationkey", DataType::Int64, false), + Field::new("s_phone", DataType::Utf8, false), + Field::new("s_acctbal", DataType::Decimal128(15, 2), false), + Field::new("s_comment", DataType::Utf8, false), + ]); + + let partsupp_schema = Schema::new(vec![ + Field::new("ps_partkey", DataType::Int64, false), + Field::new("ps_suppkey", DataType::Int64, false), + Field::new("ps_availqty", DataType::Int32, false), + Field::new("ps_supplycost", DataType::Decimal128(15, 2), false), + Field::new("ps_comment", DataType::Utf8, false), + ]); + + let customer_schema = Schema::new(vec![ + Field::new("c_custkey", DataType::Int64, false), + Field::new("c_name", DataType::Utf8, false), + Field::new("c_address", DataType::Utf8, false), + Field::new("c_nationkey", DataType::Int64, false), + Field::new("c_phone", DataType::Utf8, false), + Field::new("c_acctbal", DataType::Decimal128(15, 2), false), + Field::new("c_mktsegment", DataType::Utf8, false), + Field::new("c_comment", DataType::Utf8, false), + ]); + + let nation_schema = Schema::new(vec![ + Field::new("n_nationkey", DataType::Int64, false), + Field::new("n_name", DataType::Utf8, false), + Field::new("n_regionkey", DataType::Int64, false), + Field::new("n_comment", DataType::Utf8, false), + ]); + + let region_schema = Schema::new(vec![ + Field::new("r_regionkey", DataType::Int64, false), + Field::new("r_name", DataType::Utf8, false), + Field::new("r_comment", DataType::Utf8, false), + ]); + + [ + ("lineitem".to_string(), lineitem_schema), + ("orders".to_string(), orders_schema), + ("part".to_string(), part_schema), + ("supplier".to_string(), supplier_schema), + ("partsupp".to_string(), partsupp_schema), + ("customer".to_string(), customer_schema), + ("nation".to_string(), nation_schema), + ("region".to_string(), region_schema), + ] +} + fn create_context() -> SessionContext { let ctx = SessionContext::new(); ctx.register_table("t1", create_table_provider("a", 200)) @@ -68,6 +166,16 @@ fn create_context() -> SessionContext { .unwrap(); ctx.register_table("t700", create_table_provider("c", 700)) .unwrap(); + + let tpch_schemas = create_tpch_schemas(); + tpch_schemas.iter().for_each(|(name, schema)| { + ctx.register_table( + name, + Arc::new(MemTable::try_new(Arc::new(schema.clone()), vec![]).unwrap()), + ) + .unwrap(); + }); + ctx } @@ -115,6 +223,54 @@ fn criterion_benchmark(c: &mut Criterion) { ) }) }); + + let q1_sql = std::fs::read_to_string("../../benchmarks/queries/q1.sql").unwrap(); + let q2_sql = std::fs::read_to_string("../../benchmarks/queries/q2.sql").unwrap(); + let q3_sql = std::fs::read_to_string("../../benchmarks/queries/q3.sql").unwrap(); + let q4_sql = std::fs::read_to_string("../../benchmarks/queries/q4.sql").unwrap(); + let q5_sql = std::fs::read_to_string("../../benchmarks/queries/q5.sql").unwrap(); + let q6_sql = std::fs::read_to_string("../../benchmarks/queries/q6.sql").unwrap(); + let q7_sql = std::fs::read_to_string("../../benchmarks/queries/q7.sql").unwrap(); + let q8_sql = std::fs::read_to_string("../../benchmarks/queries/q8.sql").unwrap(); + let q9_sql = std::fs::read_to_string("../../benchmarks/queries/q9.sql").unwrap(); + let q10_sql = std::fs::read_to_string("../../benchmarks/queries/q10.sql").unwrap(); + let q11_sql = std::fs::read_to_string("../../benchmarks/queries/q11.sql").unwrap(); + let q12_sql = std::fs::read_to_string("../../benchmarks/queries/q12.sql").unwrap(); + let q13_sql = std::fs::read_to_string("../../benchmarks/queries/q13.sql").unwrap(); + let q14_sql = std::fs::read_to_string("../../benchmarks/queries/q14.sql").unwrap(); + // let q15_sql = std::fs::read_to_string("../../benchmarks/queries/q15.sql").unwrap(); + let q16_sql = std::fs::read_to_string("../../benchmarks/queries/q16.sql").unwrap(); + let q17_sql = std::fs::read_to_string("../../benchmarks/queries/q17.sql").unwrap(); + let q18_sql = std::fs::read_to_string("../../benchmarks/queries/q18.sql").unwrap(); + let q19_sql = std::fs::read_to_string("../../benchmarks/queries/q19.sql").unwrap(); + let q20_sql = std::fs::read_to_string("../../benchmarks/queries/q20.sql").unwrap(); + let q21_sql = std::fs::read_to_string("../../benchmarks/queries/q21.sql").unwrap(); + let q22_sql = std::fs::read_to_string("../../benchmarks/queries/q22.sql").unwrap(); + + c.bench_function("physical_plan_tpch", |b| { + b.iter(|| physical_plan(&ctx, &q1_sql)); + b.iter(|| physical_plan(&ctx, &q2_sql)); + b.iter(|| physical_plan(&ctx, &q3_sql)); + b.iter(|| physical_plan(&ctx, &q4_sql)); + b.iter(|| physical_plan(&ctx, &q5_sql)); + b.iter(|| physical_plan(&ctx, &q6_sql)); + b.iter(|| physical_plan(&ctx, &q7_sql)); + b.iter(|| physical_plan(&ctx, &q8_sql)); + b.iter(|| physical_plan(&ctx, &q9_sql)); + b.iter(|| physical_plan(&ctx, &q10_sql)); + b.iter(|| physical_plan(&ctx, &q11_sql)); + b.iter(|| physical_plan(&ctx, &q12_sql)); + b.iter(|| physical_plan(&ctx, &q13_sql)); + b.iter(|| physical_plan(&ctx, &q14_sql)); + // b.iter(|| physical_plan(&ctx, &q15_sql)); + b.iter(|| physical_plan(&ctx, &q16_sql)); + b.iter(|| physical_plan(&ctx, &q17_sql)); + b.iter(|| physical_plan(&ctx, &q18_sql)); + b.iter(|| physical_plan(&ctx, &q19_sql)); + b.iter(|| physical_plan(&ctx, &q20_sql)); + b.iter(|| physical_plan(&ctx, &q21_sql)); + b.iter(|| physical_plan(&ctx, &q22_sql)); + }); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/core/benches/sql_query_with_io.rs b/datafusion/core/benches/sql_query_with_io.rs index 1f9b4dc6ccf70..c7a838385bd68 100644 --- a/datafusion/core/benches/sql_query_with_io.rs +++ b/datafusion/core/benches/sql_query_with_io.rs @@ -93,10 +93,9 @@ async fn setup_files(store: Arc) { for partition in 0..TABLE_PARTITIONS { for file in 0..PARTITION_FILES { let data = create_parquet_file(&mut rng, file * FILE_ROWS); - let location = Path::try_from(format!( + let location = Path::from(format!( "{table_name}/partition={partition}/{file}.parquet" - )) - .unwrap(); + )); store.put(&location, data).await.unwrap(); } } diff --git a/datafusion/core/benches/topk_aggregate.rs b/datafusion/core/benches/topk_aggregate.rs index ef84d6e3cac8c..922cbd2b42292 100644 --- a/datafusion/core/benches/topk_aggregate.rs +++ b/datafusion/core/benches/topk_aggregate.rs @@ -15,20 +15,15 @@ // specific language governing permissions and limitations // under the License. +mod data_utils; use arrow::util::pretty::pretty_format_batches; -use arrow::{datatypes::Schema, record_batch::RecordBatch}; -use arrow_array::builder::{Int64Builder, StringBuilder}; -use arrow_schema::{DataType, Field, SchemaRef}; use criterion::{criterion_group, criterion_main, Criterion}; +use data_utils::make_data; use datafusion::physical_plan::{collect, displayable, ExecutionPlan}; use datafusion::prelude::SessionContext; use datafusion::{datasource::MemTable, error::Result}; -use datafusion_common::DataFusionError; use datafusion_execution::config::SessionConfig; use datafusion_execution::TaskContext; -use rand_distr::Distribution; -use rand_distr::{Normal, Pareto}; -use std::fmt::Write; use std::sync::Arc; use tokio::runtime::Runtime; @@ -78,10 +73,10 @@ async fn aggregate( let batch = batches.first().unwrap(); assert_eq!(batch.num_rows(), 10); - let actual = format!("{}", pretty_format_batches(&batches)?); + let actual = format!("{}", pretty_format_batches(&batches)?).to_lowercase(); let expected_asc = r#" +----------------------------------+--------------------------+ -| trace_id | MAX(traces.timestamp_ms) | +| trace_id | max(traces.timestamp_ms) | +----------------------------------+--------------------------+ | 5868861a23ed31355efc5200eb80fe74 | 16909009999999 | | 4040e64656804c3d77320d7a0e7eb1f0 | 16909009999998 | @@ -103,85 +98,6 @@ async fn aggregate( Ok(()) } -fn make_data( - partition_cnt: i32, - sample_cnt: i32, - asc: bool, -) -> Result<(Arc, Vec>), DataFusionError> { - use rand::Rng; - use rand::SeedableRng; - - // constants observed from trace data - let simultaneous_group_cnt = 2000; - let fitted_shape = 12f64; - let fitted_scale = 5f64; - let mean = 0.1; - let stddev = 1.1; - let pareto = Pareto::new(fitted_scale, fitted_shape).unwrap(); - let normal = Normal::new(mean, stddev).unwrap(); - let mut rng = rand::rngs::SmallRng::from_seed([0; 32]); - - // populate data - let schema = test_schema(); - let mut partitions = vec![]; - let mut cur_time = 16909000000000i64; - for _ in 0..partition_cnt { - let mut id_builder = StringBuilder::new(); - let mut ts_builder = Int64Builder::new(); - let gen_id = |rng: &mut rand::rngs::SmallRng| { - rng.gen::<[u8; 16]>() - .iter() - .fold(String::new(), |mut output, b| { - let _ = write!(output, "{b:02X}"); - output - }) - }; - let gen_sample_cnt = - |mut rng: &mut rand::rngs::SmallRng| pareto.sample(&mut rng).ceil() as u32; - let mut group_ids = (0..simultaneous_group_cnt) - .map(|_| gen_id(&mut rng)) - .collect::>(); - let mut group_sample_cnts = (0..simultaneous_group_cnt) - .map(|_| gen_sample_cnt(&mut rng)) - .collect::>(); - for _ in 0..sample_cnt { - let random_index = rng.gen_range(0..simultaneous_group_cnt); - let trace_id = &mut group_ids[random_index]; - let sample_cnt = &mut group_sample_cnts[random_index]; - *sample_cnt -= 1; - if *sample_cnt == 0 { - *trace_id = gen_id(&mut rng); - *sample_cnt = gen_sample_cnt(&mut rng); - } - - id_builder.append_value(trace_id); - ts_builder.append_value(cur_time); - - if asc { - cur_time += 1; - } else { - let samp: f64 = normal.sample(&mut rng); - let samp = samp.round(); - cur_time += samp as i64; - } - } - - // convert to MemTable - let id_col = Arc::new(id_builder.finish()); - let ts_col = Arc::new(ts_builder.finish()); - let batch = RecordBatch::try_new(schema.clone(), vec![id_col, ts_col])?; - partitions.push(vec![batch]); - } - Ok((schema, partitions)) -} - -fn test_schema() -> SchemaRef { - Arc::new(Schema::new(vec![ - Field::new("trace_id", DataType::Utf8, false), - Field::new("timestamp_ms", DataType::Int64, false), - ])) -} - fn criterion_benchmark(c: &mut Criterion) { let limit = 10; let partitions = 10; diff --git a/datafusion/core/src/catalog/listing_schema.rs b/datafusion/core/src/catalog/listing_schema.rs index 7e527642be164..c3c6826895421 100644 --- a/datafusion/core/src/catalog/listing_schema.rs +++ b/datafusion/core/src/catalog/listing_schema.rs @@ -92,12 +92,7 @@ impl ListingSchemaProvider { /// Reload table information from ObjectStore pub async fn refresh(&self, state: &SessionState) -> datafusion_common::Result<()> { - let entries: Vec<_> = self - .store - .list(Some(&self.path)) - .await? - .try_collect() - .await?; + let entries: Vec<_> = self.store.list(Some(&self.path)).try_collect().await?; let base = Path::new(self.path.as_ref()); let mut tables = HashSet::new(); for file in entries.iter() { @@ -154,6 +149,7 @@ impl ListingSchemaProvider { unbounded: false, options: Default::default(), constraints: Constraints::empty(), + column_defaults: Default::default(), }, ) .await?; diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 0a99c331826c1..5a8c706e32cd4 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -23,43 +23,43 @@ mod parquet; use std::any::Any; use std::sync::Arc; +use crate::arrow::datatypes::{Schema, SchemaRef}; +use crate::arrow::record_batch::RecordBatch; +use crate::arrow::util::pretty; +use crate::datasource::{provider_as_source, MemTable, TableProvider}; +use crate::error::Result; +use crate::execution::{ + context::{SessionState, TaskContext}, + FunctionRegistry, +}; +use crate::logical_expr::utils::find_window_exprs; +use crate::logical_expr::{ + col, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Partitioning, TableType, +}; +use crate::physical_plan::{ + collect, collect_partitioned, execute_stream, execute_stream_partitioned, + ExecutionPlan, SendableRecordBatchStream, +}; +use crate::prelude::SessionContext; + use arrow::array::{Array, ArrayRef, Int64Array, StringArray}; use arrow::compute::{cast, concat}; use arrow::csv::WriterBuilder; use arrow::datatypes::{DataType, Field}; -use async_trait::async_trait; use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ - DataFusionError, FileType, FileTypeWriterOptions, SchemaError, UnnestOptions, + Column, DFSchema, DataFusionError, FileType, FileTypeWriterOptions, ParamValues, + SchemaError, UnnestOptions, }; use datafusion_expr::dml::CopyOptions; - -use datafusion_common::{Column, DFSchema, ScalarValue}; use datafusion_expr::{ avg, count, is_null, max, median, min, stddev, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, }; -use crate::arrow::datatypes::Schema; -use crate::arrow::datatypes::SchemaRef; -use crate::arrow::record_batch::RecordBatch; -use crate::arrow::util::pretty; -use crate::datasource::{provider_as_source, MemTable, TableProvider}; -use crate::error::Result; -use crate::execution::{ - context::{SessionState, TaskContext}, - FunctionRegistry, -}; -use crate::logical_expr::{ - col, utils::find_window_exprs, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, - Partitioning, TableType, -}; -use crate::physical_plan::SendableRecordBatchStream; -use crate::physical_plan::{collect, collect_partitioned}; -use crate::physical_plan::{execute_stream, execute_stream_partitioned, ExecutionPlan}; -use crate::prelude::SessionContext; +use async_trait::async_trait; /// Contains options that control how data is /// written out from a DataFrame @@ -1013,11 +1013,16 @@ impl DataFrame { )) } - /// Write this DataFrame to the referenced table + /// Write this DataFrame to the referenced table by name. /// This method uses on the same underlying implementation - /// as the SQL Insert Into statement. - /// Unlike most other DataFrame methods, this method executes - /// eagerly, writing data, and returning the count of rows written. + /// as the SQL Insert Into statement. Unlike most other DataFrame methods, + /// this method executes eagerly. Data is written to the table using an + /// execution plan returned by the [TableProvider]'s insert_into method. + /// Refer to the documentation of the specific [TableProvider] to determine + /// the expected data returned by the insert_into plan via this method. + /// For the built in ListingTable provider, a single [RecordBatch] containing + /// a single column and row representing the count of total rows written + /// is returned. pub async fn write_table( self, table_name: &str, @@ -1227,11 +1232,32 @@ impl DataFrame { /// ], /// &results /// ); + /// // Note you can also provide named parameters + /// let results = ctx + /// .sql("SELECT a FROM example WHERE b = $my_param") + /// .await? + /// // replace $my_param with value 2 + /// // Note you can also use a HashMap as well + /// .with_param_values(vec![ + /// ("my_param", ScalarValue::from(2i64)) + /// ])? + /// .collect() + /// .await?; + /// assert_batches_eq!( + /// &[ + /// "+---+", + /// "| a |", + /// "+---+", + /// "| 1 |", + /// "+---+", + /// ], + /// &results + /// ); /// # Ok(()) /// # } /// ``` - pub fn with_param_values(self, param_values: Vec) -> Result { - let plan = self.plan.with_param_values(param_values)?; + pub fn with_param_values(self, query_values: impl Into) -> Result { + let plan = self.plan.with_param_values(query_values)?; Ok(Self::new(self.session_state, plan)) } @@ -1250,11 +1276,12 @@ impl DataFrame { /// ``` pub async fn cache(self) -> Result { let context = SessionContext::new_with_state(self.session_state.clone()); - let mem_table = MemTable::try_new( - SchemaRef::from(self.schema().clone()), - self.collect_partitioned().await?, - )?; - + // The schema is consistent with the output + let plan = self.clone().create_physical_plan().await?; + let schema = plan.schema(); + let task_ctx = Arc::new(self.task_ctx()); + let partitions = collect_partitioned(plan, task_ctx).await?; + let mem_table = MemTable::try_new(schema, partitions)?; context.read_table(Arc::new(mem_table)) } } @@ -1321,24 +1348,144 @@ impl TableProvider for DataFrameTableProvider { mod tests { use std::vec; - use arrow::array::Int32Array; - use arrow::datatypes::DataType; + use super::*; + use crate::execution::context::SessionConfig; + use crate::physical_plan::{ColumnarValue, Partitioning, PhysicalExpr}; + use crate::test_util::{register_aggregate_csv, test_table, test_table_with_name}; + use crate::{assert_batches_sorted_eq, execution::context::SessionContext}; + use arrow::array::{self, Int32Array}; + use arrow::datatypes::DataType; + use datafusion_common::{Constraint, Constraints}; use datafusion_expr::{ avg, cast, count, count_distinct, create_udf, expr, lit, max, min, sum, BuiltInWindowFunction, ScalarFunctionImplementation, Volatility, WindowFrame, - WindowFunction, + WindowFunctionDefinition, }; use datafusion_physical_expr::expressions::Column; + use datafusion_physical_plan::get_plan_string; - use crate::execution::context::SessionConfig; - use crate::physical_plan::ColumnarValue; - use crate::physical_plan::Partitioning; - use crate::physical_plan::PhysicalExpr; - use crate::test_util::{register_aggregate_csv, test_table, test_table_with_name}; - use crate::{assert_batches_sorted_eq, execution::context::SessionContext}; + // Get string representation of the plan + async fn assert_physical_plan(df: &DataFrame, expected: Vec<&str>) { + let physical_plan = df + .clone() + .create_physical_plan() + .await + .expect("Error creating physical plan"); - use super::*; + let actual = get_plan_string(&physical_plan); + assert_eq!( + expected, actual, + "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + } + + pub fn table_with_constraints() -> Arc { + let dual_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + ])); + let batch = RecordBatch::try_new( + dual_schema.clone(), + vec![ + Arc::new(array::Int32Array::from(vec![1])), + Arc::new(array::StringArray::from(vec!["a"])), + ], + ) + .unwrap(); + let provider = MemTable::try_new(dual_schema, vec![vec![batch]]) + .unwrap() + .with_constraints(Constraints::new_unverified(vec![Constraint::PrimaryKey( + vec![0], + )])); + Arc::new(provider) + } + + async fn assert_logical_expr_schema_eq_physical_expr_schema( + df: DataFrame, + ) -> Result<()> { + let logical_expr_dfschema = df.schema(); + let logical_expr_schema = SchemaRef::from(logical_expr_dfschema.to_owned()); + let batches = df.collect().await?; + let physical_expr_schema = batches[0].schema(); + assert_eq!(logical_expr_schema, physical_expr_schema); + Ok(()) + } + + #[tokio::test] + async fn test_array_agg_ord_schema() -> Result<()> { + let ctx = SessionContext::new(); + + let create_table_query = r#" + CREATE TABLE test_table ( + "double_field" DOUBLE, + "string_field" VARCHAR + ) AS VALUES + (1.0, 'a'), + (2.0, 'b'), + (3.0, 'c') + "#; + ctx.sql(create_table_query).await?; + + let query = r#"SELECT + array_agg("double_field" ORDER BY "string_field") as "double_field", + array_agg("string_field" ORDER BY "string_field") as "string_field" + FROM test_table"#; + + let result = ctx.sql(query).await?; + assert_logical_expr_schema_eq_physical_expr_schema(result).await?; + Ok(()) + } + + #[tokio::test] + async fn test_array_agg_schema() -> Result<()> { + let ctx = SessionContext::new(); + + let create_table_query = r#" + CREATE TABLE test_table ( + "double_field" DOUBLE, + "string_field" VARCHAR + ) AS VALUES + (1.0, 'a'), + (2.0, 'b'), + (3.0, 'c') + "#; + ctx.sql(create_table_query).await?; + + let query = r#"SELECT + array_agg("double_field") as "double_field", + array_agg("string_field") as "string_field" + FROM test_table"#; + + let result = ctx.sql(query).await?; + assert_logical_expr_schema_eq_physical_expr_schema(result).await?; + Ok(()) + } + + #[tokio::test] + async fn test_array_agg_distinct_schema() -> Result<()> { + let ctx = SessionContext::new(); + + let create_table_query = r#" + CREATE TABLE test_table ( + "double_field" DOUBLE, + "string_field" VARCHAR + ) AS VALUES + (1.0, 'a'), + (2.0, 'b'), + (2.0, 'a') + "#; + ctx.sql(create_table_query).await?; + + let query = r#"SELECT + array_agg(distinct "double_field") as "double_field", + array_agg(distinct "string_field") as "string_field" + FROM test_table"#; + + let result = ctx.sql(query).await?; + assert_logical_expr_schema_eq_physical_expr_schema(result).await?; + Ok(()) + } #[tokio::test] async fn select_columns() -> Result<()> { @@ -1378,7 +1525,9 @@ mod tests { // build plan using Table API let t = test_table().await?; let first_row = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::FirstValue, + ), vec![col("aggregate_test_100.c1")], vec![col("aggregate_test_100.c2")], vec![], @@ -1449,6 +1598,223 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_aggregate_with_pk() -> Result<()> { + // create the dataframe + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(config); + + let df = ctx.read_table(table_with_constraints())?; + + // GROUP BY id + let group_expr = vec![col("id")]; + let aggr_expr = vec![]; + let df = df.aggregate(group_expr, aggr_expr)?; + + // Since id and name are functionally dependant, we can use name among + // expression even if it is not part of the group by expression and can + // select "name" column even though it wasn't explicitly grouped + let df = df.select(vec![col("id"), col("name")])?; + assert_physical_plan( + &df, + vec![ + "AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]", + " MemoryExec: partitions=1, partition_sizes=[1]", + ], + ) + .await; + + let df_results = df.collect().await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!([ + "+----+------+", + "| id | name |", + "+----+------+", + "| 1 | a |", + "+----+------+" + ], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn test_aggregate_with_pk2() -> Result<()> { + // create the dataframe + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(config); + + let df = ctx.read_table(table_with_constraints())?; + + // GROUP BY id + let group_expr = vec![col("id")]; + let aggr_expr = vec![]; + let df = df.aggregate(group_expr, aggr_expr)?; + + // Predicate refers to id, and name fields: + // id = 1 AND name = 'a' + let predicate = col("id").eq(lit(1i32)).and(col("name").eq(lit("a"))); + let df = df.filter(predicate)?; + assert_physical_plan( + &df, + vec![ + "CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: id@0 = 1 AND name@1 = a", + " AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]", + " MemoryExec: partitions=1, partition_sizes=[1]", + ], + ) + .await; + + // Since id and name are functionally dependant, we can use name among expression + // even if it is not part of the group by expression. + let df_results = df.collect().await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!( + ["+----+------+", + "| id | name |", + "+----+------+", + "| 1 | a |", + "+----+------+",], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn test_aggregate_with_pk3() -> Result<()> { + // create the dataframe + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(config); + + let df = ctx.read_table(table_with_constraints())?; + + // GROUP BY id + let group_expr = vec![col("id")]; + let aggr_expr = vec![]; + // group by id, + let df = df.aggregate(group_expr, aggr_expr)?; + + // Predicate refers to id field + // id = 1 + let predicate = col("id").eq(lit(1i32)); + let df = df.filter(predicate)?; + // Select expression refers to id, and name columns. + // id, name + let df = df.select(vec![col("id"), col("name")])?; + assert_physical_plan( + &df, + vec![ + "CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: id@0 = 1", + " AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]", + " MemoryExec: partitions=1, partition_sizes=[1]", + ], + ) + .await; + + // Since id and name are functionally dependant, we can use name among expression + // even if it is not part of the group by expression. + let df_results = df.collect().await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!( + ["+----+------+", + "| id | name |", + "+----+------+", + "| 1 | a |", + "+----+------+",], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn test_aggregate_with_pk4() -> Result<()> { + // create the dataframe + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(config); + + let df = ctx.read_table(table_with_constraints())?; + + // GROUP BY id + let group_expr = vec![col("id")]; + let aggr_expr = vec![]; + let df = df.aggregate(group_expr, aggr_expr)?; + + // Predicate refers to id field + // id = 1 + let predicate = col("id").eq(lit(1i32)); + let df = df.filter(predicate)?; + // Select expression refers to id column. + // id + let df = df.select(vec![col("id")])?; + + // In this case aggregate shouldn't be expanded, since these + // columns are not used. + assert_physical_plan( + &df, + vec![ + "CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: id@0 = 1", + " AggregateExec: mode=Single, gby=[id@0 as id], aggr=[]", + " MemoryExec: partitions=1, partition_sizes=[1]", + ], + ) + .await; + + let df_results = df.collect().await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!([ + "+----+", + "| id |", + "+----+", + "| 1 |", + "+----+",], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn test_aggregate_alias() -> Result<()> { + let df = test_table().await?; + + let df = df + // GROUP BY `c2 + 1` + .aggregate(vec![col("c2") + lit(1)], vec![])? + // SELECT `c2 + 1` as c2 + .select(vec![(col("c2") + lit(1)).alias("c2")])? + // GROUP BY c2 as "c2" (alias in expr is not supported by SQL) + .aggregate(vec![col("c2").alias("c2")], vec![])?; + + let df_results = df.collect().await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!([ + "+----+", + "| c2 |", + "+----+", + "| 2 |", + "| 3 |", + "| 4 |", + "| 5 |", + "| 6 |", + "+----+", + ], + &df_results + ); + + Ok(()) + } + #[tokio::test] async fn test_distinct() -> Result<()> { let t = test_table().await?; @@ -2251,6 +2617,17 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_cache_mismatch() -> Result<()> { + let ctx = SessionContext::new(); + let df = ctx + .sql("SELECT CASE WHEN true THEN NULL ELSE 1 END") + .await?; + let cache_df = df.cache().await; + assert!(cache_df.is_ok()); + Ok(()) + } + #[tokio::test] async fn cache_test() -> Result<()> { let df = test_table() diff --git a/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs b/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs index fd91ea1cc538d..a16c1ae3333fb 100644 --- a/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs +++ b/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs @@ -45,6 +45,7 @@ use arrow::array::{BinaryArray, FixedSizeBinaryArray, GenericListArray}; use arrow::datatypes::{Fields, SchemaRef}; use arrow::error::ArrowError::SchemaError; use arrow::error::Result as ArrowResult; +use datafusion_common::arrow_err; use num_traits::NumCast; use std::collections::BTreeMap; use std::io::Read; @@ -86,9 +87,9 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { } Ok(lookup) } - _ => Err(DataFusionError::ArrowError(SchemaError( + _ => arrow_err!(SchemaError( "expected avro schema to be a record".to_string(), - ))), + )), } } @@ -1536,12 +1537,10 @@ mod test { .unwrap() .resolve(&schema) .unwrap(); - let r4 = apache_avro::to_value(serde_json::json!({ - "col1": null - })) - .unwrap() - .resolve(&schema) - .unwrap(); + let r4 = apache_avro::to_value(serde_json::json!({ "col1": null })) + .unwrap() + .resolve(&schema) + .unwrap(); let mut w = apache_avro::Writer::new(&schema, vec![]); w.append(r1).unwrap(); @@ -1600,12 +1599,10 @@ mod test { }"#, ) .unwrap(); - let r1 = apache_avro::to_value(serde_json::json!({ - "col1": null - })) - .unwrap() - .resolve(&schema) - .unwrap(); + let r1 = apache_avro::to_value(serde_json::json!({ "col1": null })) + .unwrap() + .resolve(&schema) + .unwrap(); let r2 = apache_avro::to_value(serde_json::json!({ "col1": { "col2": "hello" diff --git a/datafusion/core/src/datasource/default_table_source.rs b/datafusion/core/src/datasource/default_table_source.rs index 00a9c123ceeec..fadf01c74c5d4 100644 --- a/datafusion/core/src/datasource/default_table_source.rs +++ b/datafusion/core/src/datasource/default_table_source.rs @@ -73,6 +73,10 @@ impl TableSource for DefaultTableSource { fn get_logical_plan(&self) -> Option<&datafusion_expr::LogicalPlan> { self.table_provider.get_logical_plan() } + + fn get_column_default(&self, column: &str) -> Option<&Expr> { + self.table_provider.get_column_default(column) + } } /// Wrap TableProvider in TableSource diff --git a/datafusion/core/src/datasource/empty.rs b/datafusion/core/src/datasource/empty.rs index 77160aa5d1c0c..5100987520ee1 100644 --- a/datafusion/core/src/datasource/empty.rs +++ b/datafusion/core/src/datasource/empty.rs @@ -77,7 +77,7 @@ impl TableProvider for EmptyTable { // even though there is no data, projections apply let projected_schema = project_schema(&self.schema, projection)?; Ok(Arc::new( - EmptyExec::new(false, projected_schema).with_partitions(self.partitions), + EmptyExec::new(projected_schema).with_partitions(self.partitions), )) } } diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index a9bd7d0e27bb8..650f8c844eda3 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -15,16 +15,19 @@ // specific language governing permissions and limitations // under the License. -//! Apache Arrow format abstractions +//! [`ArrowFormat`]: Apache Arrow [`FileFormat`] abstractions //! //! Works with files following the [Arrow IPC format](https://arrow.apache.org/docs/format/Columnar.html#ipc-file-format) use std::any::Any; use std::borrow::Cow; +use std::fmt::{self, Debug}; use std::sync::Arc; use crate::datasource::file_format::FileFormat; -use crate::datasource::physical_plan::{ArrowExec, FileScanConfig}; +use crate::datasource::physical_plan::{ + ArrowExec, FileGroupDisplay, FileScanConfig, FileSinkConfig, +}; use crate::error::Result; use crate::execution::context::SessionState; use crate::physical_plan::ExecutionPlan; @@ -32,16 +35,35 @@ use crate::physical_plan::ExecutionPlan; use arrow::ipc::convert::fb_to_schema; use arrow::ipc::reader::FileReader; use arrow::ipc::root_as_message; +use arrow_ipc::writer::IpcWriteOptions; +use arrow_ipc::CompressionType; use arrow_schema::{ArrowError, Schema, SchemaRef}; use bytes::Bytes; -use datafusion_common::{FileType, Statistics}; -use datafusion_physical_expr::PhysicalExpr; +use datafusion_common::{not_impl_err, DataFusionError, FileType, Statistics}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; +use crate::physical_plan::{DisplayAs, DisplayFormatType}; use async_trait::async_trait; +use datafusion_physical_plan::insert::{DataSink, FileSinkExec}; +use datafusion_physical_plan::metrics::MetricsSet; use futures::stream::BoxStream; use futures::StreamExt; use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; +use tokio::io::AsyncWriteExt; +use tokio::task::JoinSet; + +use super::file_compression_type::FileCompressionType; +use super::write::demux::start_demuxer_task; +use super::write::{create_writer, SharedBuffer}; + +/// Initial writing buffer size. Note this is just a size hint for efficiency. It +/// will grow beyond the set value if needed. +const INITIAL_BUFFER_BYTES: usize = 1048576; + +/// If the buffered Arrow data exceeds this size, it is flushed to object store +const BUFFER_FLUSH_BYTES: usize = 1024000; /// Arrow `FileFormat` implementation. #[derive(Default, Debug)] @@ -97,11 +119,197 @@ impl FileFormat for ArrowFormat { Ok(Arc::new(exec)) } + async fn create_writer_physical_plan( + &self, + input: Arc, + _state: &SessionState, + conf: FileSinkConfig, + order_requirements: Option>, + ) -> Result> { + if conf.overwrite { + return not_impl_err!("Overwrites are not implemented yet for Arrow format"); + } + + let sink_schema = conf.output_schema().clone(); + let sink = Arc::new(ArrowFileSink::new(conf)); + + Ok(Arc::new(FileSinkExec::new( + input, + sink, + sink_schema, + order_requirements, + )) as _) + } + fn file_type(&self) -> FileType { FileType::ARROW } } +/// Implements [`DataSink`] for writing to arrow_ipc files +struct ArrowFileSink { + config: FileSinkConfig, +} + +impl ArrowFileSink { + fn new(config: FileSinkConfig) -> Self { + Self { config } + } + + /// Converts table schema to writer schema, which may differ in the case + /// of hive style partitioning where some columns are removed from the + /// underlying files. + fn get_writer_schema(&self) -> Arc { + if !self.config.table_partition_cols.is_empty() { + let schema = self.config.output_schema(); + let partition_names: Vec<_> = self + .config + .table_partition_cols + .iter() + .map(|(s, _)| s) + .collect(); + Arc::new(Schema::new( + schema + .fields() + .iter() + .filter(|f| !partition_names.contains(&f.name())) + .map(|f| (**f).clone()) + .collect::>(), + )) + } else { + self.config.output_schema().clone() + } + } +} + +impl Debug for ArrowFileSink { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ArrowFileSink").finish() + } +} + +impl DisplayAs for ArrowFileSink { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "ArrowFileSink(file_groups=",)?; + FileGroupDisplay(&self.config.file_groups).fmt_as(t, f)?; + write!(f, ")") + } + } + } +} + +#[async_trait] +impl DataSink for ArrowFileSink { + fn as_any(&self) -> &dyn Any { + self + } + + fn metrics(&self) -> Option { + None + } + + async fn write_all( + &self, + data: SendableRecordBatchStream, + context: &Arc, + ) -> Result { + // No props are supported yet, but can be by updating FileTypeWriterOptions + // to populate this struct and use those options to initialize the arrow_ipc::writer::FileWriter + // https://github.com/apache/arrow-datafusion/issues/8635 + let _arrow_props = self.config.file_type_writer_options.try_into_arrow()?; + + let object_store = context + .runtime_env() + .object_store(&self.config.object_store_url)?; + + let part_col = if !self.config.table_partition_cols.is_empty() { + Some(self.config.table_partition_cols.clone()) + } else { + None + }; + + let (demux_task, mut file_stream_rx) = start_demuxer_task( + data, + context, + part_col, + self.config.table_paths[0].clone(), + "arrow".into(), + self.config.single_file_output, + ); + + let mut file_write_tasks: JoinSet> = + JoinSet::new(); + + let ipc_options = + IpcWriteOptions::try_new(64, false, arrow_ipc::MetadataVersion::V5)? + .try_with_compression(Some(CompressionType::LZ4_FRAME))?; + while let Some((path, mut rx)) = file_stream_rx.recv().await { + let shared_buffer = SharedBuffer::new(INITIAL_BUFFER_BYTES); + let mut arrow_writer = arrow_ipc::writer::FileWriter::try_new_with_options( + shared_buffer.clone(), + &self.get_writer_schema(), + ipc_options.clone(), + )?; + let mut object_store_writer = create_writer( + FileCompressionType::UNCOMPRESSED, + &path, + object_store.clone(), + ) + .await?; + file_write_tasks.spawn(async move { + let mut row_count = 0; + while let Some(batch) = rx.recv().await { + row_count += batch.num_rows(); + arrow_writer.write(&batch)?; + let mut buff_to_flush = shared_buffer.buffer.try_lock().unwrap(); + if buff_to_flush.len() > BUFFER_FLUSH_BYTES { + object_store_writer + .write_all(buff_to_flush.as_slice()) + .await?; + buff_to_flush.clear(); + } + } + arrow_writer.finish()?; + let final_buff = shared_buffer.buffer.try_lock().unwrap(); + + object_store_writer.write_all(final_buff.as_slice()).await?; + object_store_writer.shutdown().await?; + Ok(row_count) + }); + } + + let mut row_count = 0; + while let Some(result) = file_write_tasks.join_next().await { + match result { + Ok(r) => { + row_count += r?; + } + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } + } + + match demux_task.await { + Ok(r) => r?, + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } + Ok(row_count as u64) + } +} + const ARROW_MAGIC: [u8; 6] = [b'A', b'R', b'R', b'O', b'W', b'1']; const CONTINUATION_MARKER: [u8; 4] = [0xff; 4]; @@ -214,6 +422,7 @@ mod tests { last_modified: DateTime::default(), size: usize::MAX, e_tag: None, + version: None, }; let arrow_format = ArrowFormat {}; @@ -256,6 +465,7 @@ mod tests { last_modified: DateTime::default(), size: usize::MAX, e_tag: None, + version: None, }; let arrow_format = ArrowFormat {}; diff --git a/datafusion/core/src/datasource/file_format/avro.rs b/datafusion/core/src/datasource/file_format/avro.rs index a24a28ad6fdd4..6d424bf0b28f3 100644 --- a/datafusion/core/src/datasource/file_format/avro.rs +++ b/datafusion/core/src/datasource/file_format/avro.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Apache Avro format abstractions +//! [`AvroFormat`] Apache Avro [`FileFormat`] abstractions use std::any::Any; use std::sync::Arc; diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 5f2084bc80a85..7a0af3ff0809b 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -15,29 +15,17 @@ // specific language governing permissions and limitations // under the License. -//! CSV format abstractions +//! [`CsvFormat`], Comma Separated Value (CSV) [`FileFormat`] abstractions use std::any::Any; use std::collections::HashSet; -use std::fmt; -use std::fmt::Debug; +use std::fmt::{self, Debug}; use std::sync::Arc; -use arrow_array::RecordBatch; -use datafusion_common::{exec_err, not_impl_err, DataFusionError, FileType}; -use datafusion_execution::TaskContext; -use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; - -use bytes::{Buf, Bytes}; -use datafusion_physical_plan::metrics::MetricsSet; -use futures::stream::BoxStream; -use futures::{pin_mut, Stream, StreamExt, TryStreamExt}; -use object_store::{delimited::newline_delimited_stream, ObjectMeta, ObjectStore}; - -use super::write::orchestration::{stateless_append_all, stateless_multipart_put}; +use super::write::orchestration::stateless_multipart_put; use super::{FileFormat, DEFAULT_SCHEMA_INFER_MAX_RECORD}; use crate::datasource::file_format::file_compression_type::FileCompressionType; -use crate::datasource::file_format::write::{BatchSerializer, FileWriterMode}; +use crate::datasource::file_format::write::BatchSerializer; use crate::datasource::physical_plan::{ CsvExec, FileGroupDisplay, FileScanConfig, FileSinkConfig, }; @@ -47,11 +35,20 @@ use crate::physical_plan::insert::{DataSink, FileSinkExec}; use crate::physical_plan::{DisplayAs, DisplayFormatType, Statistics}; use crate::physical_plan::{ExecutionPlan, SendableRecordBatchStream}; +use arrow::array::RecordBatch; use arrow::csv::WriterBuilder; use arrow::datatypes::{DataType, Field, Fields, Schema}; use arrow::{self, datatypes::SchemaRef}; +use datafusion_common::{exec_err, not_impl_err, DataFusionError, FileType}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; +use datafusion_physical_plan::metrics::MetricsSet; use async_trait::async_trait; +use bytes::{Buf, Bytes}; +use futures::stream::BoxStream; +use futures::{pin_mut, Stream, StreamExt, TryStreamExt}; +use object_store::{delimited::newline_delimited_stream, ObjectMeta, ObjectStore}; /// Character Separated Value `FileFormat` implementation. #[derive(Debug)] @@ -400,8 +397,6 @@ impl Default for CsvSerializer { pub struct CsvSerializer { // CSV writer builder builder: WriterBuilder, - // Inner buffer for avoiding reallocation - buffer: Vec, // Flag to indicate whether there will be a header header: bool, } @@ -412,7 +407,6 @@ impl CsvSerializer { Self { builder: WriterBuilder::new(), header: true, - buffer: Vec::with_capacity(4096), } } @@ -431,26 +425,19 @@ impl CsvSerializer { #[async_trait] impl BatchSerializer for CsvSerializer { - async fn serialize(&mut self, batch: RecordBatch) -> Result { + async fn serialize(&self, batch: RecordBatch, initial: bool) -> Result { + let mut buffer = Vec::with_capacity(4096); let builder = self.builder.clone(); - let mut writer = builder.with_header(self.header).build(&mut self.buffer); + let header = self.header && initial; + let mut writer = builder.with_header(header).build(&mut buffer); writer.write(&batch)?; drop(writer); - self.header = false; - Ok(Bytes::from(self.buffer.drain(..).collect::>())) - } - - fn duplicate(&mut self) -> Result> { - let new_self = CsvSerializer::new() - .with_builder(self.builder.clone()) - .with_header(self.header); - self.header = false; - Ok(Box::new(new_self)) + Ok(Bytes::from(buffer)) } } /// Implements [`DataSink`] for writing to a CSV file. -struct CsvSink { +pub struct CsvSink { /// Config options for writing data config: FileSinkConfig, } @@ -465,11 +452,7 @@ impl DisplayAs for CsvSink { fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!( - f, - "CsvSink(writer_mode={:?}, file_groups=", - self.config.writer_mode - )?; + write!(f, "CsvSink(file_groups=",)?; FileGroupDisplay(&self.config.file_groups).fmt_as(t, f)?; write!(f, ")") } @@ -478,56 +461,14 @@ impl DisplayAs for CsvSink { } impl CsvSink { - fn new(config: FileSinkConfig) -> Self { + /// Create from config. + pub fn new(config: FileSinkConfig) -> Self { Self { config } } - async fn append_all( - &self, - data: SendableRecordBatchStream, - context: &Arc, - ) -> Result { - if !self.config.table_partition_cols.is_empty() { - return Err(DataFusionError::NotImplemented("Inserting in append mode to hive style partitioned tables is not supported".into())); - } - let writer_options = self.config.file_type_writer_options.try_into_csv()?; - let (builder, compression) = - (&writer_options.writer_options, &writer_options.compression); - let compression = FileCompressionType::from(*compression); - - let object_store = context - .runtime_env() - .object_store(&self.config.object_store_url)?; - let file_groups = &self.config.file_groups; - - let builder_clone = builder.clone(); - let options_clone = writer_options.clone(); - let get_serializer = move |file_size| { - let inner_clone = builder_clone.clone(); - // In append mode, consider has_header flag only when file is empty (at the start). - // For other modes, use has_header flag as is. - let serializer: Box = Box::new(if file_size > 0 { - CsvSerializer::new() - .with_builder(inner_clone) - .with_header(false) - } else { - CsvSerializer::new() - .with_builder(inner_clone) - .with_header(options_clone.writer_options.header()) - }); - serializer - }; - - stateless_append_all( - data, - context, - object_store, - file_groups, - self.config.unbounded_input, - compression, - Box::new(get_serializer), - ) - .await + /// Retrieve the inner [`FileSinkConfig`]. + pub fn config(&self) -> &FileSinkConfig { + &self.config } async fn multipartput_all( @@ -541,13 +482,11 @@ impl CsvSink { let builder_clone = builder.clone(); let options_clone = writer_options.clone(); let get_serializer = move || { - let inner_clone = builder_clone.clone(); - let serializer: Box = Box::new( + Arc::new( CsvSerializer::new() - .with_builder(inner_clone) + .with_builder(builder_clone.clone()) .with_header(options_clone.writer_options.header()), - ); - serializer + ) as _ }; stateless_multipart_put( @@ -577,19 +516,8 @@ impl DataSink for CsvSink { data: SendableRecordBatchStream, context: &Arc, ) -> Result { - match self.config.writer_mode { - FileWriterMode::Append => { - let total_count = self.append_all(data, context).await?; - Ok(total_count) - } - FileWriterMode::PutMultipart => { - let total_count = self.multipartput_all(data, context).await?; - Ok(total_count) - } - FileWriterMode::Put => { - return not_impl_err!("FileWriterMode::Put is not supported yet!") - } - } + let total_count = self.multipartput_all(data, context).await?; + Ok(total_count) } } @@ -605,15 +533,15 @@ mod tests { use crate::physical_plan::collect; use crate::prelude::{CsvReadOptions, SessionConfig, SessionContext}; use crate::test_util::arrow_test_data; + use arrow::compute::concat_batches; - use bytes::Bytes; - use chrono::DateTime; use datafusion_common::cast::as_string_array; - use datafusion_common::internal_err; use datafusion_common::stats::Precision; - use datafusion_common::FileType; - use datafusion_common::GetExt; + use datafusion_common::{internal_err, FileType, GetExt}; use datafusion_expr::{col, lit}; + + use bytes::Bytes; + use chrono::DateTime; use futures::StreamExt; use object_store::local::LocalFileSystem; use object_store::path::Path; @@ -737,6 +665,7 @@ mod tests { last_modified: DateTime::default(), size: usize::MAX, e_tag: None, + version: None, }; let num_rows_to_read = 100; @@ -899,8 +828,8 @@ mod tests { .collect() .await?; let batch = concat_batches(&batches[0].schema(), &batches)?; - let mut serializer = CsvSerializer::new(); - let bytes = serializer.serialize(batch).await?; + let serializer = CsvSerializer::new(); + let bytes = serializer.serialize(batch, true).await?; assert_eq!( "c2,c3\n2,1\n5,-40\n1,29\n1,-85\n5,-82\n4,-111\n3,104\n3,13\n1,38\n4,-38\n", String::from_utf8(bytes.into()).unwrap() @@ -923,8 +852,8 @@ mod tests { .collect() .await?; let batch = concat_batches(&batches[0].schema(), &batches)?; - let mut serializer = CsvSerializer::new().with_header(false); - let bytes = serializer.serialize(batch).await?; + let serializer = CsvSerializer::new().with_header(false); + let bytes = serializer.serialize(batch, true).await?; assert_eq!( "2,1\n5,-40\n1,29\n1,-85\n5,-82\n4,-111\n3,104\n3,13\n1,38\n4,-38\n", String::from_utf8(bytes.into()).unwrap() diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index 70cfd1836efec..8c02955ad363a 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Line delimited JSON format abstractions +//! [`JsonFormat`]: Line delimited JSON [`FileFormat`] abstractions use std::any::Any; use std::fmt; @@ -23,40 +23,34 @@ use std::fmt::Debug; use std::io::BufReader; use std::sync::Arc; +use super::write::orchestration::stateless_multipart_put; use super::{FileFormat, FileScanConfig}; -use arrow::datatypes::Schema; -use arrow::datatypes::SchemaRef; -use arrow::json; -use arrow::json::reader::infer_json_schema_from_iterator; -use arrow::json::reader::ValueIter; -use arrow_array::RecordBatch; -use async_trait::async_trait; -use bytes::Buf; - -use bytes::Bytes; -use datafusion_physical_expr::PhysicalExpr; -use datafusion_physical_expr::PhysicalSortRequirement; -use datafusion_physical_plan::ExecutionPlan; -use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; - -use crate::datasource::physical_plan::FileGroupDisplay; -use crate::physical_plan::insert::DataSink; -use crate::physical_plan::insert::FileSinkExec; -use crate::physical_plan::SendableRecordBatchStream; -use crate::physical_plan::{DisplayAs, DisplayFormatType, Statistics}; - -use super::write::orchestration::{stateless_append_all, stateless_multipart_put}; - use crate::datasource::file_format::file_compression_type::FileCompressionType; -use crate::datasource::file_format::write::{BatchSerializer, FileWriterMode}; +use crate::datasource::file_format::write::BatchSerializer; use crate::datasource::file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD; +use crate::datasource::physical_plan::FileGroupDisplay; use crate::datasource::physical_plan::{FileSinkConfig, NdJsonExec}; use crate::error::Result; use crate::execution::context::SessionState; +use crate::physical_plan::insert::{DataSink, FileSinkExec}; +use crate::physical_plan::{ + DisplayAs, DisplayFormatType, SendableRecordBatchStream, Statistics, +}; +use arrow::datatypes::Schema; +use arrow::datatypes::SchemaRef; +use arrow::json; +use arrow::json::reader::{infer_json_schema_from_iterator, ValueIter}; +use arrow_array::RecordBatch; use datafusion_common::{not_impl_err, DataFusionError, FileType}; use datafusion_execution::TaskContext; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use datafusion_physical_plan::metrics::MetricsSet; +use datafusion_physical_plan::ExecutionPlan; + +use async_trait::async_trait; +use bytes::{Buf, Bytes}; +use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; /// New line delimited JSON `FileFormat` implementation. #[derive(Debug)] @@ -201,36 +195,27 @@ impl Default for JsonSerializer { } /// Define a struct for serializing Json records to a stream -pub struct JsonSerializer { - // Inner buffer for avoiding reallocation - buffer: Vec, -} +pub struct JsonSerializer {} impl JsonSerializer { /// Constructor for the JsonSerializer object pub fn new() -> Self { - Self { - buffer: Vec::with_capacity(4096), - } + Self {} } } #[async_trait] impl BatchSerializer for JsonSerializer { - async fn serialize(&mut self, batch: RecordBatch) -> Result { - let mut writer = json::LineDelimitedWriter::new(&mut self.buffer); + async fn serialize(&self, batch: RecordBatch, _initial: bool) -> Result { + let mut buffer = Vec::with_capacity(4096); + let mut writer = json::LineDelimitedWriter::new(&mut buffer); writer.write(&batch)?; - //drop(writer); - Ok(Bytes::from(self.buffer.drain(..).collect::>())) - } - - fn duplicate(&mut self) -> Result> { - Ok(Box::new(JsonSerializer::new())) + Ok(Bytes::from(buffer)) } } /// Implements [`DataSink`] for writing to a Json file. -struct JsonSink { +pub struct JsonSink { /// Config options for writing data config: FileSinkConfig, } @@ -245,11 +230,7 @@ impl DisplayAs for JsonSink { fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!( - f, - "JsonSink(writer_mode={:?}, file_groups=", - self.config.writer_mode - )?; + write!(f, "JsonSink(file_groups=",)?; FileGroupDisplay(&self.config.file_groups).fmt_as(t, f)?; write!(f, ")") } @@ -258,42 +239,14 @@ impl DisplayAs for JsonSink { } impl JsonSink { - fn new(config: FileSinkConfig) -> Self { + /// Create from config. + pub fn new(config: FileSinkConfig) -> Self { Self { config } } - async fn append_all( - &self, - data: SendableRecordBatchStream, - context: &Arc, - ) -> Result { - if !self.config.table_partition_cols.is_empty() { - return Err(DataFusionError::NotImplemented("Inserting in append mode to hive style partitioned tables is not supported".into())); - } - - let writer_options = self.config.file_type_writer_options.try_into_json()?; - let compression = &writer_options.compression; - - let object_store = context - .runtime_env() - .object_store(&self.config.object_store_url)?; - let file_groups = &self.config.file_groups; - - let get_serializer = move |_| { - let serializer: Box = Box::new(JsonSerializer::new()); - serializer - }; - - stateless_append_all( - data, - context, - object_store, - file_groups, - self.config.unbounded_input, - (*compression).into(), - Box::new(get_serializer), - ) - .await + /// Retrieve the inner [`FileSinkConfig`]. + pub fn config(&self) -> &FileSinkConfig { + &self.config } async fn multipartput_all( @@ -304,10 +257,7 @@ impl JsonSink { let writer_options = self.config.file_type_writer_options.try_into_json()?; let compression = &writer_options.compression; - let get_serializer = move || { - let serializer: Box = Box::new(JsonSerializer::new()); - serializer - }; + let get_serializer = move || Arc::new(JsonSerializer::new()) as _; stateless_multipart_put( data, @@ -336,31 +286,25 @@ impl DataSink for JsonSink { data: SendableRecordBatchStream, context: &Arc, ) -> Result { - match self.config.writer_mode { - FileWriterMode::Append => { - let total_count = self.append_all(data, context).await?; - Ok(total_count) - } - FileWriterMode::PutMultipart => { - let total_count = self.multipartput_all(data, context).await?; - Ok(total_count) - } - FileWriterMode::Put => { - return not_impl_err!("FileWriterMode::Put is not supported yet!") - } - } + let total_count = self.multipartput_all(data, context).await?; + Ok(total_count) } } #[cfg(test)] mod tests { use super::super::test_util::scan_format; + use arrow::util::pretty; use datafusion_common::cast::as_int64_array; use datafusion_common::stats::Precision; + use datafusion_common::{assert_batches_eq, internal_err}; use futures::StreamExt; use object_store::local::LocalFileSystem; + use regex::Regex; + use rstest::rstest; use super::*; + use crate::execution::options::NdJsonReadOptions; use crate::physical_plan::collect; use crate::prelude::{SessionConfig, SessionContext}; use crate::test::object_store::local_unpartitioned_file; @@ -484,4 +428,94 @@ mod tests { .collect::>(); assert_eq!(vec!["a: Int64", "b: Float64", "c: Boolean"], fields); } + + async fn count_num_partitions(ctx: &SessionContext, query: &str) -> Result { + let result = ctx + .sql(&format!("EXPLAIN {query}")) + .await? + .collect() + .await?; + + let plan = format!("{}", &pretty::pretty_format_batches(&result)?); + + let re = Regex::new(r"file_groups=\{(\d+) group").unwrap(); + + if let Some(captures) = re.captures(&plan) { + if let Some(match_) = captures.get(1) { + let count = match_.as_str().parse::().unwrap(); + return Ok(count); + } + } + + internal_err!("Query contains no Exec: file_groups") + } + + #[rstest(n_partitions, case(1), case(2), case(3), case(4))] + #[tokio::test] + async fn it_can_read_ndjson_in_parallel(n_partitions: usize) -> Result<()> { + let config = SessionConfig::new() + .with_repartition_file_scans(true) + .with_repartition_file_min_size(0) + .with_target_partitions(n_partitions); + + let ctx = SessionContext::new_with_config(config); + + let table_path = "tests/data/1.json"; + let options = NdJsonReadOptions::default(); + + ctx.register_json("json_parallel", table_path, options) + .await?; + + let query = "SELECT SUM(a) FROM json_parallel;"; + + let result = ctx.sql(query).await?.collect().await?; + let actual_partitions = count_num_partitions(&ctx, query).await?; + + #[rustfmt::skip] + let expected = [ + "+----------------------+", + "| SUM(json_parallel.a) |", + "+----------------------+", + "| -7 |", + "+----------------------+" + ]; + + assert_batches_eq!(expected, &result); + assert_eq!(n_partitions, actual_partitions); + + Ok(()) + } + + #[rstest(n_partitions, case(1), case(2), case(3), case(4))] + #[tokio::test] + async fn it_can_read_empty_ndjson_in_parallel(n_partitions: usize) -> Result<()> { + let config = SessionConfig::new() + .with_repartition_file_scans(true) + .with_repartition_file_min_size(0) + .with_target_partitions(n_partitions); + + let ctx = SessionContext::new_with_config(config); + + let table_path = "tests/data/empty.json"; + let options = NdJsonReadOptions::default(); + + ctx.register_json("json_parallel_empty", table_path, options) + .await?; + + let query = "SELECT * FROM json_parallel_empty WHERE random() > 0.5;"; + + let result = ctx.sql(query).await?.collect().await?; + let actual_partitions = count_num_partitions(&ctx, query).await?; + + #[rustfmt::skip] + let expected = [ + "++", + "++", + ]; + + assert_batches_eq!(expected, &result); + assert_eq!(1, actual_partitions); + + Ok(()) + } } diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index b541e2a1d44c1..12c9fb91adb1a 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -124,7 +124,8 @@ pub(crate) mod test_util { use object_store::local::LocalFileSystem; use object_store::path::Path; use object_store::{ - GetOptions, GetResult, GetResultPayload, ListResult, MultipartId, + GetOptions, GetResult, GetResultPayload, ListResult, MultipartId, PutOptions, + PutResult, }; use tokio::io::AsyncWrite; @@ -164,7 +165,6 @@ pub(crate) mod test_util { limit, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, ) @@ -189,7 +189,12 @@ pub(crate) mod test_util { #[async_trait] impl ObjectStore for VariableStream { - async fn put(&self, _location: &Path, _bytes: Bytes) -> object_store::Result<()> { + async fn put_opts( + &self, + _location: &Path, + _bytes: Bytes, + _opts: PutOptions, + ) -> object_store::Result { unimplemented!() } @@ -228,6 +233,7 @@ pub(crate) mod test_util { last_modified: Default::default(), size: range.end, e_tag: None, + version: None, }, range: Default::default(), }) @@ -257,11 +263,10 @@ pub(crate) mod test_util { unimplemented!() } - async fn list( + fn list( &self, _prefix: Option<&Path>, - ) -> object_store::Result>> - { + ) -> BoxStream<'_, object_store::Result> { unimplemented!() } diff --git a/datafusion/core/src/datasource/file_format/options.rs b/datafusion/core/src/datasource/file_format/options.rs index 41a70e6d2f8fb..d389137785ff8 100644 --- a/datafusion/core/src/datasource/file_format/options.rs +++ b/datafusion/core/src/datasource/file_format/options.rs @@ -21,14 +21,13 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Schema, SchemaRef}; use async_trait::async_trait; -use datafusion_common::{plan_err, DataFusionError}; use crate::datasource::file_format::arrow::ArrowFormat; use crate::datasource::file_format::file_compression_type::FileCompressionType; #[cfg(feature = "parquet")] use crate::datasource::file_format::parquet::ParquetFormat; use crate::datasource::file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD; -use crate::datasource::listing::{ListingTableInsertMode, ListingTableUrl}; +use crate::datasource::listing::ListingTableUrl; use crate::datasource::{ file_format::{avro::AvroFormat, csv::CsvFormat, json::JsonFormat}, listing::ListingOptions, @@ -72,12 +71,8 @@ pub struct CsvReadOptions<'a> { pub table_partition_cols: Vec<(String, DataType)>, /// File compression type pub file_compression_type: FileCompressionType, - /// Flag indicating whether this file may be unbounded (as in a FIFO file). - pub infinite: bool, /// Indicates how the file is sorted pub file_sort_order: Vec>, - /// Setting controls how inserts to this file should be handled - pub insert_mode: ListingTableInsertMode, } impl<'a> Default for CsvReadOptions<'a> { @@ -99,9 +94,7 @@ impl<'a> CsvReadOptions<'a> { file_extension: DEFAULT_CSV_EXTENSION, table_partition_cols: vec![], file_compression_type: FileCompressionType::UNCOMPRESSED, - infinite: false, file_sort_order: vec![], - insert_mode: ListingTableInsertMode::AppendToFile, } } @@ -111,12 +104,6 @@ impl<'a> CsvReadOptions<'a> { self } - /// Configure mark_infinite setting - pub fn mark_infinite(mut self, infinite: bool) -> Self { - self.infinite = infinite; - self - } - /// Specify delimiter to use for CSV read pub fn delimiter(mut self, delimiter: u8) -> Self { self.delimiter = delimiter; @@ -184,12 +171,6 @@ impl<'a> CsvReadOptions<'a> { self.file_sort_order = file_sort_order; self } - - /// Configure how insertions to this table should be handled - pub fn insert_mode(mut self, insert_mode: ListingTableInsertMode) -> Self { - self.insert_mode = insert_mode; - self - } } /// Options that control the reading of Parquet files. @@ -219,8 +200,6 @@ pub struct ParquetReadOptions<'a> { pub schema: Option<&'a Schema>, /// Indicates how the file is sorted pub file_sort_order: Vec>, - /// Setting controls how inserts to this file should be handled - pub insert_mode: ListingTableInsertMode, } impl<'a> Default for ParquetReadOptions<'a> { @@ -232,7 +211,6 @@ impl<'a> Default for ParquetReadOptions<'a> { skip_metadata: None, schema: None, file_sort_order: vec![], - insert_mode: ListingTableInsertMode::AppendNewFiles, } } } @@ -272,12 +250,6 @@ impl<'a> ParquetReadOptions<'a> { self.file_sort_order = file_sort_order; self } - - /// Configure how insertions to this table should be handled - pub fn insert_mode(mut self, insert_mode: ListingTableInsertMode) -> Self { - self.insert_mode = insert_mode; - self - } } /// Options that control the reading of ARROW files. @@ -342,8 +314,6 @@ pub struct AvroReadOptions<'a> { pub file_extension: &'a str, /// Partition Columns pub table_partition_cols: Vec<(String, DataType)>, - /// Flag indicating whether this file may be unbounded (as in a FIFO file). - pub infinite: bool, } impl<'a> Default for AvroReadOptions<'a> { @@ -352,7 +322,6 @@ impl<'a> Default for AvroReadOptions<'a> { schema: None, file_extension: DEFAULT_AVRO_EXTENSION, table_partition_cols: vec![], - infinite: false, } } } @@ -367,12 +336,6 @@ impl<'a> AvroReadOptions<'a> { self } - /// Configure mark_infinite setting - pub fn mark_infinite(mut self, infinite: bool) -> Self { - self.infinite = infinite; - self - } - /// Specify schema to use for AVRO read pub fn schema(mut self, schema: &'a Schema) -> Self { self.schema = Some(schema); @@ -403,8 +366,6 @@ pub struct NdJsonReadOptions<'a> { pub infinite: bool, /// Indicates how the file is sorted pub file_sort_order: Vec>, - /// Setting controls how inserts to this file should be handled - pub insert_mode: ListingTableInsertMode, } impl<'a> Default for NdJsonReadOptions<'a> { @@ -417,7 +378,6 @@ impl<'a> Default for NdJsonReadOptions<'a> { file_compression_type: FileCompressionType::UNCOMPRESSED, infinite: false, file_sort_order: vec![], - insert_mode: ListingTableInsertMode::AppendToFile, } } } @@ -464,12 +424,6 @@ impl<'a> NdJsonReadOptions<'a> { self.file_sort_order = file_sort_order; self } - - /// Configure how insertions to this table should be handled - pub fn insert_mode(mut self, insert_mode: ListingTableInsertMode) -> Self { - self.insert_mode = insert_mode; - self - } } #[async_trait] @@ -493,21 +447,17 @@ pub trait ReadOptions<'a> { state: SessionState, table_path: ListingTableUrl, schema: Option<&'a Schema>, - infinite: bool, ) -> Result where 'a: 'async_trait, { - match (schema, infinite) { - (Some(s), _) => Ok(Arc::new(s.to_owned())), - (None, false) => Ok(self - .to_listing_options(config) - .infer_schema(&state, &table_path) - .await?), - (None, true) => { - plan_err!("Schema inference for infinite data sources is not supported.") - } + if let Some(s) = schema { + return Ok(Arc::new(s.to_owned())); } + + self.to_listing_options(config) + .infer_schema(&state, &table_path) + .await } } @@ -527,8 +477,6 @@ impl ReadOptions<'_> for CsvReadOptions<'_> { .with_target_partitions(config.target_partitions()) .with_table_partition_cols(self.table_partition_cols.clone()) .with_file_sort_order(self.file_sort_order.clone()) - .with_infinite_source(self.infinite) - .with_insert_mode(self.insert_mode.clone()) } async fn get_resolved_schema( @@ -537,7 +485,7 @@ impl ReadOptions<'_> for CsvReadOptions<'_> { state: SessionState, table_path: ListingTableUrl, ) -> Result { - self._get_resolved_schema(config, state, table_path, self.schema, self.infinite) + self._get_resolved_schema(config, state, table_path, self.schema) .await } } @@ -555,7 +503,6 @@ impl ReadOptions<'_> for ParquetReadOptions<'_> { .with_target_partitions(config.target_partitions()) .with_table_partition_cols(self.table_partition_cols.clone()) .with_file_sort_order(self.file_sort_order.clone()) - .with_insert_mode(self.insert_mode.clone()) } async fn get_resolved_schema( @@ -564,7 +511,7 @@ impl ReadOptions<'_> for ParquetReadOptions<'_> { state: SessionState, table_path: ListingTableUrl, ) -> Result { - self._get_resolved_schema(config, state, table_path, self.schema, false) + self._get_resolved_schema(config, state, table_path, self.schema) .await } } @@ -580,9 +527,7 @@ impl ReadOptions<'_> for NdJsonReadOptions<'_> { .with_file_extension(self.file_extension) .with_target_partitions(config.target_partitions()) .with_table_partition_cols(self.table_partition_cols.clone()) - .with_infinite_source(self.infinite) .with_file_sort_order(self.file_sort_order.clone()) - .with_insert_mode(self.insert_mode.clone()) } async fn get_resolved_schema( @@ -591,7 +536,7 @@ impl ReadOptions<'_> for NdJsonReadOptions<'_> { state: SessionState, table_path: ListingTableUrl, ) -> Result { - self._get_resolved_schema(config, state, table_path, self.schema, self.infinite) + self._get_resolved_schema(config, state, table_path, self.schema) .await } } @@ -605,7 +550,6 @@ impl ReadOptions<'_> for AvroReadOptions<'_> { .with_file_extension(self.file_extension) .with_target_partitions(config.target_partitions()) .with_table_partition_cols(self.table_partition_cols.clone()) - .with_infinite_source(self.infinite) } async fn get_resolved_schema( @@ -614,7 +558,7 @@ impl ReadOptions<'_> for AvroReadOptions<'_> { state: SessionState, table_path: ListingTableUrl, ) -> Result { - self._get_resolved_schema(config, state, table_path, self.schema, self.infinite) + self._get_resolved_schema(config, state, table_path, self.schema) .await } } @@ -636,7 +580,7 @@ impl ReadOptions<'_> for ArrowReadOptions<'_> { state: SessionState, table_path: ListingTableUrl, ) -> Result { - self._get_resolved_schema(config, state, table_path, self.schema, false) + self._get_resolved_schema(config, state, table_path, self.schema) .await } } diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 2cba474e559e7..9729bfa163af1 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Parquet format abstractions +//! [`ParquetFormat`]: Parquet [`FileFormat`] abstractions use arrow_array::RecordBatch; use async_trait::async_trait; @@ -29,7 +29,6 @@ use parquet::file::writer::SerializedFileWriter; use std::any::Any; use std::fmt; use std::fmt::Debug; -use std::io::Write; use std::sync::Arc; use tokio::io::{AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc::{self, Receiver, Sender}; @@ -40,11 +39,12 @@ use crate::datasource::statistics::{create_max_min_accs, get_col_stats}; use arrow::datatypes::SchemaRef; use arrow::datatypes::{Fields, Schema}; use bytes::{BufMut, BytesMut}; -use datafusion_common::{exec_err, not_impl_err, plan_err, DataFusionError, FileType}; +use datafusion_common::{exec_err, not_impl_err, DataFusionError, FileType}; use datafusion_execution::TaskContext; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use futures::{StreamExt, TryStreamExt}; use hashbrown::HashMap; +use object_store::path::Path; use object_store::{ObjectMeta, ObjectStore}; use parquet::arrow::{ arrow_to_parquet_schema, parquet_to_arrow_schema, AsyncArrowWriter, @@ -55,7 +55,7 @@ use parquet::file::properties::WriterProperties; use parquet::file::statistics::Statistics as ParquetStatistics; use super::write::demux::start_demuxer_task; -use super::write::{create_writer, AbortableWrite, FileWriterMode}; +use super::write::{create_writer, AbortableWrite, SharedBuffer}; use super::{FileFormat, FileScanConfig}; use crate::arrow::array::{ BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, @@ -64,7 +64,7 @@ use crate::arrow::datatypes::DataType; use crate::config::ConfigOptions; use crate::datasource::physical_plan::{ - FileGroupDisplay, FileMeta, FileSinkConfig, ParquetExec, SchemaAdapter, + FileGroupDisplay, FileSinkConfig, ParquetExec, SchemaAdapter, }; use crate::error::Result; use crate::execution::context::SessionState; @@ -75,6 +75,17 @@ use crate::physical_plan::{ Statistics, }; +/// Size of the buffer for [`AsyncArrowWriter`]. +const PARQUET_WRITER_BUFFER_SIZE: usize = 10485760; + +/// Initial writing buffer size. Note this is just a size hint for efficiency. It +/// will grow beyond the set value if needed. +const INITIAL_BUFFER_BYTES: usize = 1048576; + +/// When writing parquet files in parallel, if the buffered Parquet data exceeds +/// this size, it is flushed to object store +const BUFFER_FLUSH_BYTES: usize = 1024000; + /// The Apache Parquet `FileFormat` implementation /// /// Note it is recommended these are instead configured on the [`ConfigOptions`] @@ -163,6 +174,16 @@ fn clear_metadata( }) } +async fn fetch_schema_with_location( + store: &dyn ObjectStore, + file: &ObjectMeta, + metadata_size_hint: Option, +) -> Result<(Path, Schema)> { + let loc_path = file.location.clone(); + let schema = fetch_schema(store, file, metadata_size_hint).await?; + Ok((loc_path, schema)) +} + #[async_trait] impl FileFormat for ParquetFormat { fn as_any(&self) -> &dyn Any { @@ -175,13 +196,32 @@ impl FileFormat for ParquetFormat { store: &Arc, objects: &[ObjectMeta], ) -> Result { - let schemas: Vec<_> = futures::stream::iter(objects) - .map(|object| fetch_schema(store.as_ref(), object, self.metadata_size_hint)) + let mut schemas: Vec<_> = futures::stream::iter(objects) + .map(|object| { + fetch_schema_with_location( + store.as_ref(), + object, + self.metadata_size_hint, + ) + }) .boxed() // Workaround https://github.com/rust-lang/rust/issues/64552 .buffered(state.config_options().execution.meta_fetch_concurrency) .try_collect() .await?; + // Schema inference adds fields based the order they are seen + // which depends on the order the files are processed. For some + // object stores (like local file systems) the order returned from list + // is not deterministic. Thus, to ensure deterministic schema inference + // sort the files first. + // https://github.com/apache/arrow-datafusion/pull/6629 + schemas.sort_by(|(location1, _), (location2, _)| location1.cmp(location2)); + + let schemas = schemas + .into_iter() + .map(|(_, schema)| schema) + .collect::>(); + let schema = if self.skip_metadata(state.config_options()) { Schema::try_merge(clear_metadata(schemas)) } else { @@ -581,7 +621,7 @@ async fn fetch_statistics( } /// Implements [`DataSink`] for writing to a parquet file. -struct ParquetSink { +pub struct ParquetSink { /// Config options for writing data config: FileSinkConfig, } @@ -596,11 +636,7 @@ impl DisplayAs for ParquetSink { fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!( - f, - "ParquetSink(writer_mode={:?}, file_groups=", - self.config.writer_mode - )?; + write!(f, "ParquetSink(file_groups=",)?; FileGroupDisplay(&self.config.file_groups).fmt_as(t, f)?; write!(f, ")") } @@ -609,10 +645,15 @@ impl DisplayAs for ParquetSink { } impl ParquetSink { - fn new(config: FileSinkConfig) -> Self { + /// Create from config. + pub fn new(config: FileSinkConfig) -> Self { Self { config } } + /// Retrieve the inner [`FileSinkConfig`]. + pub fn config(&self) -> &FileSinkConfig { + &self.config + } /// Converts table schema to writer schema, which may differ in the case /// of hive style partitioning where some columns are removed from the /// underlying files. @@ -642,36 +683,23 @@ impl ParquetSink { /// AsyncArrowWriters are used when individual parquet file serialization is not parallelized async fn create_async_arrow_writer( &self, - file_meta: FileMeta, + location: &Path, object_store: Arc, parquet_props: WriterProperties, ) -> Result< AsyncArrowWriter>, > { - let object = &file_meta.object_meta; - match self.config.writer_mode { - FileWriterMode::Append => { - plan_err!( - "Appending to Parquet files is not supported by the file format!" - ) - } - FileWriterMode::Put => { - not_impl_err!("FileWriterMode::Put is not implemented for ParquetSink") - } - FileWriterMode::PutMultipart => { - let (_, multipart_writer) = object_store - .put_multipart(&object.location) - .await - .map_err(DataFusionError::ObjectStore)?; - let writer = AsyncArrowWriter::try_new( - multipart_writer, - self.get_writer_schema(), - 10485760, - Some(parquet_props), - )?; - Ok(writer) - } - } + let (_, multipart_writer) = object_store + .put_multipart(location) + .await + .map_err(DataFusionError::ObjectStore)?; + let writer = AsyncArrowWriter::try_new( + multipart_writer, + self.get_writer_schema(), + PARQUET_WRITER_BUFFER_SIZE, + Some(parquet_props), + )?; + Ok(writer) } } @@ -730,13 +758,7 @@ impl DataSink for ParquetSink { if !allow_single_file_parallelism { let mut writer = self .create_async_arrow_writer( - ObjectMeta { - location: path, - last_modified: chrono::offset::Utc::now(), - size: 0, - e_tag: None, - } - .into(), + &path, object_store.clone(), parquet_props.clone(), ) @@ -752,17 +774,10 @@ impl DataSink for ParquetSink { }); } else { let writer = create_writer( - FileWriterMode::PutMultipart, // Parquet files as a whole are never compressed, since they // manage compressed blocks themselves. FileCompressionType::UNCOMPRESSED, - ObjectMeta { - location: path, - last_modified: chrono::offset::Utc::now(), - size: 0, - e_tag: None, - } - .into(), + &path, object_store.clone(), ) .await?; @@ -1005,7 +1020,7 @@ async fn concatenate_parallel_row_groups( writer_props: Arc, mut object_store_writer: AbortableWrite>, ) -> Result { - let merged_buff = SharedBuffer::new(1048576); + let merged_buff = SharedBuffer::new(INITIAL_BUFFER_BYTES); let schema_desc = arrow_to_parquet_schema(schema.as_ref())?; let mut parquet_writer = SerializedFileWriter::new( @@ -1026,7 +1041,7 @@ async fn concatenate_parallel_row_groups( for chunk in serialized_columns { chunk.append_to_row_group(&mut rg_out)?; let mut buff_to_flush = merged_buff.buffer.try_lock().unwrap(); - if buff_to_flush.len() > 1024000 { + if buff_to_flush.len() > BUFFER_FLUSH_BYTES { object_store_writer .write_all(buff_to_flush.as_slice()) .await?; @@ -1101,37 +1116,6 @@ async fn output_single_parquet_file_parallelized( Ok(row_count) } -/// A buffer with interior mutability shared by the SerializedFileWriter and -/// ObjectStore writer -#[derive(Clone)] -struct SharedBuffer { - /// The inner buffer for reading and writing - /// - /// The lock is used to obtain internal mutability, so no worry about the - /// lock contention. - buffer: Arc>>, -} - -impl SharedBuffer { - pub fn new(capacity: usize) -> Self { - Self { - buffer: Arc::new(futures::lock::Mutex::new(Vec::with_capacity(capacity))), - } - } -} - -impl Write for SharedBuffer { - fn write(&mut self, buf: &[u8]) -> std::io::Result { - let mut buffer = self.buffer.try_lock().unwrap(); - Write::write(&mut *buffer, buf) - } - - fn flush(&mut self) -> std::io::Result<()> { - let mut buffer = self.buffer.try_lock().unwrap(); - Write::flush(&mut *buffer) - } -} - #[cfg(test)] pub(crate) mod test_util { use super::*; @@ -1153,12 +1137,21 @@ pub(crate) mod test_util { batches: Vec, multi_page: bool, ) -> Result<(Vec, Vec)> { + // we need the tmp files to be sorted as some tests rely on the how the returning files are ordered + // https://github.com/apache/arrow-datafusion/pull/6629 + let tmp_files = { + let mut tmp_files: Vec<_> = (0..batches.len()) + .map(|_| NamedTempFile::new().expect("creating temp file")) + .collect(); + tmp_files.sort_by(|a, b| a.path().cmp(b.path())); + tmp_files + }; + // Each batch writes to their own file let files: Vec<_> = batches .into_iter() - .map(|batch| { - let mut output = NamedTempFile::new().expect("creating temp file"); - + .zip(tmp_files.into_iter()) + .map(|(batch, mut output)| { let builder = WriterProperties::builder(); let props = if multi_page { builder.set_data_page_row_count_limit(ROWS_PER_PAGE) @@ -1184,6 +1177,7 @@ pub(crate) mod test_util { .collect(); let meta: Vec<_> = files.iter().map(local_unpartitioned_file).collect(); + Ok((meta, files)) } @@ -1228,7 +1222,9 @@ mod tests { use log::error; use object_store::local::LocalFileSystem; use object_store::path::Path; - use object_store::{GetOptions, GetResult, ListResult, MultipartId}; + use object_store::{ + GetOptions, GetResult, ListResult, MultipartId, PutOptions, PutResult, + }; use parquet::arrow::arrow_reader::ArrowReaderOptions; use parquet::arrow::ParquetRecordBatchStreamBuilder; use parquet::file::metadata::{ParquetColumnIndex, ParquetOffsetIndex}; @@ -1281,6 +1277,42 @@ mod tests { Ok(()) } + #[tokio::test] + async fn is_schema_stable() -> Result<()> { + let c1: ArrayRef = + Arc::new(StringArray::from(vec![Some("Foo"), None, Some("bar")])); + + let c2: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(2), None])); + + let batch1 = + RecordBatch::try_from_iter(vec![("a", c1.clone()), ("b", c1.clone())]) + .unwrap(); + let batch2 = + RecordBatch::try_from_iter(vec![("c", c2.clone()), ("d", c2.clone())]) + .unwrap(); + + let store = Arc::new(LocalFileSystem::new()) as _; + let (meta, _files) = store_parquet(vec![batch1, batch2], false).await?; + + let session = SessionContext::new(); + let ctx = session.state(); + let format = ParquetFormat::default(); + let schema = format.infer_schema(&ctx, &store, &meta).await.unwrap(); + + let order: Vec<_> = ["a", "b", "c", "d"] + .into_iter() + .map(|i| i.to_string()) + .collect(); + let coll: Vec<_> = schema + .all_fields() + .into_iter() + .map(|i| i.name().to_string()) + .collect(); + assert_eq!(coll, order); + + Ok(()) + } + #[derive(Debug)] struct RequestCountingObjectStore { inner: Arc, @@ -1312,7 +1344,12 @@ mod tests { #[async_trait] impl ObjectStore for RequestCountingObjectStore { - async fn put(&self, _location: &Path, _bytes: Bytes) -> object_store::Result<()> { + async fn put_opts( + &self, + _location: &Path, + _bytes: Bytes, + _opts: PutOptions, + ) -> object_store::Result { Err(object_store::Error::NotImplemented) } @@ -1349,12 +1386,13 @@ mod tests { Err(object_store::Error::NotImplemented) } - async fn list( + fn list( &self, _prefix: Option<&Path>, - ) -> object_store::Result>> - { - Err(object_store::Error::NotImplemented) + ) -> BoxStream<'_, object_store::Result> { + Box::pin(futures::stream::once(async { + Err(object_store::Error::NotImplemented) + })) } async fn list_with_delimiter( @@ -1824,8 +1862,8 @@ mod tests { // there is only one row group in one file. assert_eq!(page_index.len(), 1); assert_eq!(offset_index.len(), 1); - let page_index = page_index.get(0).unwrap(); - let offset_index = offset_index.get(0).unwrap(); + let page_index = page_index.first().unwrap(); + let offset_index = offset_index.first().unwrap(); // 13 col in one row group assert_eq!(page_index.len(), 13); diff --git a/datafusion/core/src/datasource/file_format/write/demux.rs b/datafusion/core/src/datasource/file_format/write/demux.rs index 27c65dd459ec4..dbfeb67eaeb96 100644 --- a/datafusion/core/src/datasource/file_format/write/demux.rs +++ b/datafusion/core/src/datasource/file_format/write/demux.rs @@ -264,12 +264,9 @@ async fn hive_style_partitions_demuxer( // TODO: upstream RecordBatch::take to arrow-rs let take_indices = builder.finish(); let struct_array: StructArray = rb.clone().into(); - let parted_batch = RecordBatch::try_from( + let parted_batch = RecordBatch::from( arrow::compute::take(&struct_array, &take_indices, None)?.as_struct(), - ) - .map_err(|_| { - DataFusionError::Internal("Unexpected error partitioning batch!".into()) - })?; + ); // Get or create channel for this batch let part_tx = match value_map.get_mut(&part_key) { @@ -386,7 +383,7 @@ fn compute_take_arrays( fn remove_partition_by_columns( parted_batch: &RecordBatch, - partition_by: &Vec<(String, DataType)>, + partition_by: &[(String, DataType)], ) -> Result { let end_idx = parted_batch.num_columns() - partition_by.len(); let non_part_cols = &parted_batch.columns()[..end_idx]; @@ -408,7 +405,7 @@ fn remove_partition_by_columns( } fn compute_hive_style_file_path( - part_key: &Vec, + part_key: &[String], partition_by: &[(String, DataType)], write_id: &str, file_extension: &str, diff --git a/datafusion/core/src/datasource/file_format/write/mod.rs b/datafusion/core/src/datasource/file_format/write/mod.rs index 770c7a49c326d..c481f2accf199 100644 --- a/datafusion/core/src/datasource/file_format/write/mod.rs +++ b/datafusion/core/src/datasource/file_format/write/mod.rs @@ -18,129 +18,60 @@ //! Module containing helper methods/traits related to enabling //! write support for the various file formats -use std::io::Error; -use std::mem; +use std::io::{Error, Write}; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use crate::datasource::file_format::file_compression_type::FileCompressionType; - -use crate::datasource::physical_plan::FileMeta; use crate::error::Result; use arrow_array::RecordBatch; - -use datafusion_common::{exec_err, DataFusionError}; +use datafusion_common::DataFusionError; use async_trait::async_trait; use bytes::Bytes; - use futures::future::BoxFuture; -use futures::ready; -use futures::FutureExt; use object_store::path::Path; -use object_store::{MultipartId, ObjectMeta, ObjectStore}; - +use object_store::{MultipartId, ObjectStore}; use tokio::io::AsyncWrite; pub(crate) mod demux; pub(crate) mod orchestration; -/// `AsyncPutWriter` is an object that facilitates asynchronous writing to object stores. -/// It is specifically designed for the `object_store` crate's `put` method and sends -/// whole bytes at once when the buffer is flushed. -pub struct AsyncPutWriter { - /// Object metadata - object_meta: ObjectMeta, - /// A shared reference to the object store - store: Arc, - /// A buffer that stores the bytes to be sent - current_buffer: Vec, - /// Used for async handling in flush method - inner_state: AsyncPutState, +/// A buffer with interior mutability shared by the SerializedFileWriter and +/// ObjectStore writer +#[derive(Clone)] +pub(crate) struct SharedBuffer { + /// The inner buffer for reading and writing + /// + /// The lock is used to obtain internal mutability, so no worry about the + /// lock contention. + pub(crate) buffer: Arc>>, } -impl AsyncPutWriter { - /// Constructor for the `AsyncPutWriter` object - pub fn new(object_meta: ObjectMeta, store: Arc) -> Self { +impl SharedBuffer { + pub fn new(capacity: usize) -> Self { Self { - object_meta, - store, - current_buffer: vec![], - // The writer starts out in buffering mode - inner_state: AsyncPutState::Buffer, - } - } - - /// Separate implementation function that unpins the [`AsyncPutWriter`] so - /// that partial borrows work correctly - fn poll_shutdown_inner( - &mut self, - cx: &mut Context<'_>, - ) -> Poll> { - loop { - match &mut self.inner_state { - AsyncPutState::Buffer => { - // Convert the current buffer to bytes and take ownership of it - let bytes = Bytes::from(mem::take(&mut self.current_buffer)); - // Set the inner state to Put variant with the bytes - self.inner_state = AsyncPutState::Put { bytes } - } - AsyncPutState::Put { bytes } => { - // Send the bytes to the object store's put method - return Poll::Ready( - ready!(self - .store - .put(&self.object_meta.location, bytes.clone()) - .poll_unpin(cx)) - .map_err(Error::from), - ); - } - } + buffer: Arc::new(futures::lock::Mutex::new(Vec::with_capacity(capacity))), } } } -/// An enum that represents the inner state of AsyncPut -enum AsyncPutState { - /// Building Bytes struct in this state - Buffer, - /// Data in the buffer is being sent to the object store - Put { bytes: Bytes }, -} - -impl AsyncWrite for AsyncPutWriter { - // Define the implementation of the AsyncWrite trait for the `AsyncPutWriter` struct - fn poll_write( - mut self: Pin<&mut Self>, - _: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - // Extend the current buffer with the incoming buffer - self.current_buffer.extend_from_slice(buf); - // Return a ready poll with the length of the incoming buffer - Poll::Ready(Ok(buf.len())) +impl Write for SharedBuffer { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let mut buffer = self.buffer.try_lock().unwrap(); + Write::write(&mut *buffer, buf) } - fn poll_flush( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll> { - // Return a ready poll with an empty result - Poll::Ready(Ok(())) - } - - fn poll_shutdown( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - // Call the poll_shutdown_inner method to handle the actual sending of data to the object store - self.poll_shutdown_inner(cx) + fn flush(&mut self) -> std::io::Result<()> { + let mut buffer = self.buffer.try_lock().unwrap(); + Write::flush(&mut *buffer) } } /// Stores data needed during abortion of MultiPart writers +#[derive(Clone)] pub(crate) struct MultiPart { /// A shared reference to the object store store: Arc, @@ -163,45 +94,28 @@ impl MultiPart { } } -pub(crate) enum AbortMode { - Put, - Append, - MultiPart(MultiPart), -} - /// A wrapper struct with abort method and writer pub(crate) struct AbortableWrite { writer: W, - mode: AbortMode, + multipart: MultiPart, } impl AbortableWrite { /// Create a new `AbortableWrite` instance with the given writer, and write mode. - pub(crate) fn new(writer: W, mode: AbortMode) -> Self { - Self { writer, mode } + pub(crate) fn new(writer: W, multipart: MultiPart) -> Self { + Self { writer, multipart } } /// handling of abort for different write modes pub(crate) fn abort_writer(&self) -> Result>> { - match &self.mode { - AbortMode::Put => Ok(async { Ok(()) }.boxed()), - AbortMode::Append => exec_err!("Cannot abort in append mode"), - AbortMode::MultiPart(MultiPart { - store, - multipart_id, - location, - }) => { - let location = location.clone(); - let multipart_id = multipart_id.clone(); - let store = store.clone(); - Ok(Box::pin(async move { - store - .abort_multipart(&location, &multipart_id) - .await - .map_err(DataFusionError::ObjectStore) - })) - } - } + let multi = self.multipart.clone(); + Ok(Box::pin(async move { + multi + .store + .abort_multipart(&multi.location, &multi.multipart_id) + .await + .map_err(DataFusionError::ObjectStore) + })) } } @@ -229,77 +143,28 @@ impl AsyncWrite for AbortableWrite { } } -/// An enum that defines different file writer modes. -#[derive(Debug, Clone, Copy)] -pub enum FileWriterMode { - /// Data is appended to an existing file. - Append, - /// Data is written to a new file. - Put, - /// Data is written to a new file in multiple parts. - PutMultipart, -} /// A trait that defines the methods required for a RecordBatch serializer. #[async_trait] -pub trait BatchSerializer: Unpin + Send { +pub trait BatchSerializer: Sync + Send { /// Asynchronously serializes a `RecordBatch` and returns the serialized bytes. - async fn serialize(&mut self, batch: RecordBatch) -> Result; - /// Duplicates self to support serializing multiple batches in parallel on multiple cores - fn duplicate(&mut self) -> Result> { - Err(DataFusionError::NotImplemented( - "Parallel serialization is not implemented for this file type".into(), - )) - } + /// Parameter `initial` signals whether the given batch is the first batch. + /// This distinction is important for certain serializers (like CSV). + async fn serialize(&self, batch: RecordBatch, initial: bool) -> Result; } /// Returns an [`AbortableWrite`] which writes to the given object store location /// with the specified compression pub(crate) async fn create_writer( - writer_mode: FileWriterMode, file_compression_type: FileCompressionType, - file_meta: FileMeta, + location: &Path, object_store: Arc, ) -> Result>> { - let object = &file_meta.object_meta; - match writer_mode { - // If the mode is append, call the store's append method and return wrapped in - // a boxed trait object. - FileWriterMode::Append => { - let writer = object_store - .append(&object.location) - .await - .map_err(DataFusionError::ObjectStore)?; - let writer = AbortableWrite::new( - file_compression_type.convert_async_writer(writer)?, - AbortMode::Append, - ); - Ok(writer) - } - // If the mode is put, create a new AsyncPut writer and return it wrapped in - // a boxed trait object - FileWriterMode::Put => { - let writer = Box::new(AsyncPutWriter::new(object.clone(), object_store)); - let writer = AbortableWrite::new( - file_compression_type.convert_async_writer(writer)?, - AbortMode::Put, - ); - Ok(writer) - } - // If the mode is put multipart, call the store's put_multipart method and - // return the writer wrapped in a boxed trait object. - FileWriterMode::PutMultipart => { - let (multipart_id, writer) = object_store - .put_multipart(&object.location) - .await - .map_err(DataFusionError::ObjectStore)?; - Ok(AbortableWrite::new( - file_compression_type.convert_async_writer(writer)?, - AbortMode::MultiPart(MultiPart::new( - object_store, - multipart_id, - object.location.clone(), - )), - )) - } - } + let (multipart_id, writer) = object_store + .put_multipart(location) + .await + .map_err(DataFusionError::ObjectStore)?; + Ok(AbortableWrite::new( + file_compression_type.convert_async_writer(writer)?, + MultiPart::new(object_store, multipart_id, location.clone()), + )) } diff --git a/datafusion/core/src/datasource/file_format/write/orchestration.rs b/datafusion/core/src/datasource/file_format/write/orchestration.rs index f84baa9ac2252..9b820a15b280c 100644 --- a/datafusion/core/src/datasource/file_format/write/orchestration.rs +++ b/datafusion/core/src/datasource/file_format/write/orchestration.rs @@ -21,33 +21,25 @@ use std::sync::Arc; +use super::demux::start_demuxer_task; +use super::{create_writer, AbortableWrite, BatchSerializer}; use crate::datasource::file_format::file_compression_type::FileCompressionType; -use crate::datasource::listing::PartitionedFile; use crate::datasource::physical_plan::FileSinkConfig; use crate::error::Result; use crate::physical_plan::SendableRecordBatchStream; use arrow_array::RecordBatch; - -use datafusion_common::DataFusionError; - -use bytes::Bytes; +use datafusion_common::{internal_datafusion_err, internal_err, DataFusionError}; use datafusion_execution::TaskContext; -use futures::StreamExt; - -use object_store::{ObjectMeta, ObjectStore}; - +use bytes::Bytes; use tokio::io::{AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc::{self, Receiver}; use tokio::task::{JoinHandle, JoinSet}; use tokio::try_join; -use super::demux::start_demuxer_task; -use super::{create_writer, AbortableWrite, BatchSerializer, FileWriterMode}; - type WriterType = AbortableWrite>; -type SerializerType = Box; +type SerializerType = Arc; /// Serializes a single data stream in parallel and writes to an ObjectStore /// concurrently. Data order is preserved. In the event of an error, @@ -55,37 +47,28 @@ type SerializerType = Box; /// so that the caller may handle aborting failed writes. pub(crate) async fn serialize_rb_stream_to_object_store( mut data_rx: Receiver, - mut serializer: Box, + serializer: Arc, mut writer: AbortableWrite>, - unbounded_input: bool, ) -> std::result::Result<(WriterType, u64), (WriterType, DataFusionError)> { let (tx, mut rx) = mpsc::channel::>>(100); - let serialize_task = tokio::spawn(async move { + // Some serializers (like CSV) handle the first batch differently than + // subsequent batches, so we track that here. + let mut initial = true; while let Some(batch) = data_rx.recv().await { - match serializer.duplicate() { - Ok(mut serializer_clone) => { - let handle = tokio::spawn(async move { - let num_rows = batch.num_rows(); - let bytes = serializer_clone.serialize(batch).await?; - Ok((num_rows, bytes)) - }); - tx.send(handle).await.map_err(|_| { - DataFusionError::Internal( - "Unknown error writing to object store".into(), - ) - })?; - if unbounded_input { - tokio::task::yield_now().await; - } - } - Err(_) => { - return Err(DataFusionError::Internal( - "Unknown error writing to object store".into(), - )) - } + let serializer_clone = serializer.clone(); + let handle = tokio::spawn(async move { + let num_rows = batch.num_rows(); + let bytes = serializer_clone.serialize(batch, initial).await?; + Ok((num_rows, bytes)) + }); + if initial { + initial = false; } + tx.send(handle).await.map_err(|_| { + internal_datafusion_err!("Unknown error writing to object store") + })?; } Ok(()) }); @@ -129,7 +112,7 @@ pub(crate) async fn serialize_rb_stream_to_object_store( Err(_) => { return Err(( writer, - DataFusionError::Internal("Unknown error writing to object store".into()), + internal_datafusion_err!("Unknown error writing to object store"), )) } }; @@ -145,7 +128,6 @@ type FileWriteBundle = (Receiver, SerializerType, WriterType); pub(crate) async fn stateless_serialize_and_write_files( mut rx: Receiver, tx: tokio::sync::oneshot::Sender, - unbounded_input: bool, ) -> Result<()> { let mut row_count = 0; // tracks if any writers encountered an error triggering the need to abort @@ -158,13 +140,7 @@ pub(crate) async fn stateless_serialize_and_write_files( let mut join_set = JoinSet::new(); while let Some((data_rx, serializer, writer)) = rx.recv().await { join_set.spawn(async move { - serialize_rb_stream_to_object_store( - data_rx, - serializer, - writer, - unbounded_input, - ) - .await + serialize_rb_stream_to_object_store(data_rx, serializer, writer).await }); } let mut finished_writers = Vec::new(); @@ -187,9 +163,9 @@ pub(crate) async fn stateless_serialize_and_write_files( // this thread, so we cannot clean it up (hence any_abort_errors is true) any_errors = true; any_abort_errors = true; - triggering_error = Some(DataFusionError::Internal(format!( + triggering_error = Some(internal_datafusion_err!( "Unexpected join error while serializing file {e}" - ))); + )); } } } @@ -206,24 +182,24 @@ pub(crate) async fn stateless_serialize_and_write_files( false => { writer.shutdown() .await - .map_err(|_| DataFusionError::Internal("Error encountered while finalizing writes! Partial results may have been written to ObjectStore!".into()))?; + .map_err(|_| internal_datafusion_err!("Error encountered while finalizing writes! Partial results may have been written to ObjectStore!"))?; } } } if any_errors { match any_abort_errors{ - true => return Err(DataFusionError::Internal("Error encountered during writing to ObjectStore and failed to abort all writers. Partial result may have been written.".into())), + true => return internal_err!("Error encountered during writing to ObjectStore and failed to abort all writers. Partial result may have been written."), false => match triggering_error { Some(e) => return Err(e), - None => return Err(DataFusionError::Internal("Unknown Error encountered during writing to ObjectStore. All writers succesfully aborted.".into())) + None => return internal_err!("Unknown Error encountered during writing to ObjectStore. All writers succesfully aborted.") } } } tx.send(row_count).map_err(|_| { - DataFusionError::Internal( - "Error encountered while sending row count back to file sink!".into(), + internal_datafusion_err!( + "Error encountered while sending row count back to file sink!" ) })?; Ok(()) @@ -236,7 +212,7 @@ pub(crate) async fn stateless_multipart_put( data: SendableRecordBatchStream, context: &Arc, file_extension: String, - get_serializer: Box Box + Send>, + get_serializer: Box Arc + Send>, config: &FileSinkConfig, compression: FileCompressionType, ) -> Result { @@ -246,7 +222,6 @@ pub(crate) async fn stateless_multipart_put( let single_file_output = config.single_file_output; let base_output_path = &config.table_paths[0]; - let unbounded_input = config.unbounded_input; let part_cols = if !config.table_partition_cols.is_empty() { Some(config.table_partition_cols.clone()) } else { @@ -271,31 +246,18 @@ pub(crate) async fn stateless_multipart_put( let (tx_file_bundle, rx_file_bundle) = tokio::sync::mpsc::channel(rb_buffer_size / 2); let (tx_row_cnt, rx_row_cnt) = tokio::sync::oneshot::channel(); let write_coordinater_task = tokio::spawn(async move { - stateless_serialize_and_write_files(rx_file_bundle, tx_row_cnt, unbounded_input) - .await + stateless_serialize_and_write_files(rx_file_bundle, tx_row_cnt).await }); - while let Some((output_location, rb_stream)) = file_stream_rx.recv().await { + while let Some((location, rb_stream)) = file_stream_rx.recv().await { let serializer = get_serializer(); - let object_meta = ObjectMeta { - location: output_location, - last_modified: chrono::offset::Utc::now(), - size: 0, - e_tag: None, - }; - let writer = create_writer( - FileWriterMode::PutMultipart, - compression, - object_meta.into(), - object_store.clone(), - ) - .await?; + let writer = create_writer(compression, &location, object_store.clone()).await?; tx_file_bundle .send((rb_stream, serializer, writer)) .await .map_err(|_| { - DataFusionError::Internal( - "Writer receive file bundle channel closed unexpectedly!".into(), + internal_datafusion_err!( + "Writer receive file bundle channel closed unexpectedly!" ) })?; } @@ -318,98 +280,8 @@ pub(crate) async fn stateless_multipart_put( } let total_count = rx_row_cnt.await.map_err(|_| { - DataFusionError::Internal( - "Did not receieve row count from write coordinater".into(), - ) - })?; - - Ok(total_count) -} - -/// Orchestrates append_all for any statelessly serialized file type. Appends to all files provided -/// in a round robin fashion. -pub(crate) async fn stateless_append_all( - mut data: SendableRecordBatchStream, - context: &Arc, - object_store: Arc, - file_groups: &Vec, - unbounded_input: bool, - compression: FileCompressionType, - get_serializer: Box Box + Send>, -) -> Result { - let rb_buffer_size = &context - .session_config() - .options() - .execution - .max_buffered_batches_per_output_file; - - let (tx_file_bundle, rx_file_bundle) = tokio::sync::mpsc::channel(file_groups.len()); - let mut send_channels = vec![]; - for file_group in file_groups { - let serializer = get_serializer(file_group.object_meta.size); - - let file = file_group.clone(); - let writer = create_writer( - FileWriterMode::Append, - compression, - file.object_meta.clone().into(), - object_store.clone(), - ) - .await?; - - let (tx, rx) = tokio::sync::mpsc::channel(rb_buffer_size / 2); - send_channels.push(tx); - tx_file_bundle - .send((rx, serializer, writer)) - .await - .map_err(|_| { - DataFusionError::Internal( - "Writer receive file bundle channel closed unexpectedly!".into(), - ) - })?; - } - - let (tx_row_cnt, rx_row_cnt) = tokio::sync::oneshot::channel(); - let write_coordinater_task = tokio::spawn(async move { - stateless_serialize_and_write_files(rx_file_bundle, tx_row_cnt, unbounded_input) - .await - }); - - // Append to file groups in round robin - let mut next_file_idx = 0; - while let Some(rb) = data.next().await.transpose()? { - send_channels[next_file_idx].send(rb).await.map_err(|_| { - DataFusionError::Internal( - "Recordbatch file append stream closed unexpectedly!".into(), - ) - })?; - next_file_idx = (next_file_idx + 1) % send_channels.len(); - if unbounded_input { - tokio::task::yield_now().await; - } - } - // Signal to the write coordinater that no more files are coming - drop(tx_file_bundle); - drop(send_channels); - - let total_count = rx_row_cnt.await.map_err(|_| { - DataFusionError::Internal( - "Did not receieve row count from write coordinater".into(), - ) + internal_datafusion_err!("Did not receieve row count from write coordinater") })?; - match try_join!(write_coordinater_task) { - Ok(r1) => { - r1.0?; - } - Err(e) => { - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()); - } else { - unreachable!(); - } - } - } - Ok(total_count) } diff --git a/datafusion/core/src/datasource/function.rs b/datafusion/core/src/datasource/function.rs new file mode 100644 index 0000000000000..2fd352ee4eb31 --- /dev/null +++ b/datafusion/core/src/datasource/function.rs @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A table that uses a function to generate data + +use super::TableProvider; + +use datafusion_common::Result; +use datafusion_expr::Expr; + +use std::sync::Arc; + +/// A trait for table function implementations +pub trait TableFunctionImpl: Sync + Send { + /// Create a table provider + fn call(&self, args: &[Expr]) -> Result>; +} + +/// A table that uses a function to generate data +pub struct TableFunction { + /// Name of the table function + name: String, + /// Function implementation + fun: Arc, +} + +impl TableFunction { + /// Create a new table function + pub fn new(name: String, fun: Arc) -> Self { + Self { name, fun } + } + + /// Get the name of the table function + pub fn name(&self) -> &str { + &self.name + } + + /// Get the function implementation and generate a table + pub fn create_table_provider(&self, args: &[Expr]) -> Result> { + self.fun.call(args) + } +} diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index d6a0add9b2537..68de55e1a4108 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -38,9 +38,8 @@ use super::PartitionedFile; use crate::datasource::listing::ListingTableUrl; use crate::execution::context::SessionState; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; -use datafusion_common::{Column, DFField, DFSchema, DataFusionError}; -use datafusion_expr::expr::ScalarUDF; -use datafusion_expr::{Expr, Volatility}; +use datafusion_common::{internal_err, Column, DFField, DFSchema, DataFusionError}; +use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility}; use datafusion_physical_expr::create_physical_expr; use datafusion_physical_expr::execution_props::ExecutionProps; use object_store::path::Path; @@ -54,13 +53,13 @@ use object_store::{ObjectMeta, ObjectStore}; pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { let mut is_applicable = true; expr.apply(&mut |expr| { - Ok(match expr { + match expr { Expr::Column(Column { ref name, .. }) => { is_applicable &= col_names.contains(name); if is_applicable { - VisitRecursion::Skip + Ok(VisitRecursion::Skip) } else { - VisitRecursion::Stop + Ok(VisitRecursion::Stop) } } Expr::Literal(_) @@ -89,25 +88,32 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { | Expr::ScalarSubquery(_) | Expr::GetIndexedField { .. } | Expr::GroupingSet(_) - | Expr::Case { .. } => VisitRecursion::Continue, + | Expr::Case { .. } => Ok(VisitRecursion::Continue), Expr::ScalarFunction(scalar_function) => { - match scalar_function.fun.volatility() { - Volatility::Immutable => VisitRecursion::Continue, - // TODO: Stable functions could be `applicable`, but that would require access to the context - Volatility::Stable | Volatility::Volatile => { - is_applicable = false; - VisitRecursion::Stop + match &scalar_function.func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + match fun.volatility() { + Volatility::Immutable => Ok(VisitRecursion::Continue), + // TODO: Stable functions could be `applicable`, but that would require access to the context + Volatility::Stable | Volatility::Volatile => { + is_applicable = false; + Ok(VisitRecursion::Stop) + } + } } - } - } - Expr::ScalarUDF(ScalarUDF { fun, .. }) => { - match fun.signature.volatility { - Volatility::Immutable => VisitRecursion::Continue, - // TODO: Stable functions could be `applicable`, but that would require access to the context - Volatility::Stable | Volatility::Volatile => { - is_applicable = false; - VisitRecursion::Stop + ScalarFunctionDefinition::UDF(fun) => { + match fun.signature().volatility { + Volatility::Immutable => Ok(VisitRecursion::Continue), + // TODO: Stable functions could be `applicable`, but that would require access to the context + Volatility::Stable | Volatility::Volatile => { + is_applicable = false; + Ok(VisitRecursion::Stop) + } + } + } + ScalarFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") } } } @@ -116,17 +122,15 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { // - AGGREGATE, WINDOW and SORT should not end up in filter conditions, except maybe in some edge cases // - Can `Wildcard` be considered as a `Literal`? // - ScalarVariable could be `applicable`, but that would require access to the context - Expr::AggregateUDF { .. } - | Expr::AggregateFunction { .. } + Expr::AggregateFunction { .. } | Expr::Sort { .. } | Expr::WindowFunction { .. } - | Expr::Wildcard - | Expr::QualifiedWildcard { .. } + | Expr::Wildcard { .. } | Expr::Placeholder(_) => { is_applicable = false; - VisitRecursion::Stop + Ok(VisitRecursion::Stop) } - }) + } }) .unwrap(); is_applicable @@ -137,12 +141,18 @@ const CONCURRENCY_LIMIT: usize = 100; /// Partition the list of files into `n` groups pub fn split_files( - partitioned_files: Vec, + mut partitioned_files: Vec, n: usize, ) -> Vec> { if partitioned_files.is_empty() { return vec![]; } + + // ObjectStore::list does not guarantee any consistent order and for some + // implementations such as LocalFileSystem, it may be inconsistent. Thus + // Sort files by path to ensure consistent plans when run more than once. + partitioned_files.sort_by(|a, b| a.path().cmp(b.path())); + // effectively this is div with rounding up instead of truncating let chunk_size = (partitioned_files.len() + n - 1) / n; partitioned_files @@ -276,7 +286,10 @@ async fn prune_partitions( // Applies `filter` to `batch` returning `None` on error let do_filter = |filter| -> Option { let expr = create_physical_expr(filter, &df_schema, &schema, &props).ok()?; - Some(expr.evaluate(&batch).ok()?.into_array(partitions.len())) + expr.evaluate(&batch) + .ok()? + .into_array(partitions.len()) + .ok() }; //.Compute the conjunction of the filters, ignoring errors @@ -359,14 +372,13 @@ pub async fn pruned_partition_list<'a>( Some(files) => files, None => { trace!("Recursively listing partition {}", partition.path); - let s = store.list(Some(&partition.path)).await?; - s.try_collect().await? + store.list(Some(&partition.path)).try_collect().await? } }; - let files = files.into_iter().filter(move |o| { let extension_match = o.location.as_ref().ends_with(file_extension); - let glob_match = table_path.contains(&o.location); + // here need to scan subdirectories(`listing_table_ignore_subdirectory` = false) + let glob_match = table_path.contains(&o.location, false); extension_match && glob_match }); @@ -520,19 +532,13 @@ mod tests { f1.object_meta.location.as_ref(), "tablepath/mypartition=val1/file.parquet" ); - assert_eq!( - &f1.partition_values, - &[ScalarValue::Utf8(Some(String::from("val1"))),] - ); + assert_eq!(&f1.partition_values, &[ScalarValue::from("val1")]); let f2 = &pruned[1]; assert_eq!( f2.object_meta.location.as_ref(), "tablepath/mypartition=val1/other=val3/file.parquet" ); - assert_eq!( - f2.partition_values, - &[ScalarValue::Utf8(Some(String::from("val1"))),] - ); + assert_eq!(f2.partition_values, &[ScalarValue::from("val1"),]); } #[tokio::test] @@ -573,10 +579,7 @@ mod tests { ); assert_eq!( &f1.partition_values, - &[ - ScalarValue::Utf8(Some(String::from("p1v2"))), - ScalarValue::Utf8(Some(String::from("p2v1"))) - ] + &[ScalarValue::from("p1v2"), ScalarValue::from("p2v1"),] ); let f2 = &pruned[1]; assert_eq!( @@ -585,10 +588,7 @@ mod tests { ); assert_eq!( &f2.partition_values, - &[ - ScalarValue::Utf8(Some(String::from("p1v2"))), - ScalarValue::Utf8(Some(String::from("p2v1"))) - ] + &[ScalarValue::from("p1v2"), ScalarValue::from("p2v1")] ); } diff --git a/datafusion/core/src/datasource/listing/mod.rs b/datafusion/core/src/datasource/listing/mod.rs index 8b0f021f02777..e7583501f9d90 100644 --- a/datafusion/core/src/datasource/listing/mod.rs +++ b/datafusion/core/src/datasource/listing/mod.rs @@ -31,9 +31,7 @@ use std::pin::Pin; use std::sync::Arc; pub use self::url::ListingTableUrl; -pub use table::{ - ListingOptions, ListingTable, ListingTableConfig, ListingTableInsertMode, -}; +pub use table::{ListingOptions, ListingTable, ListingTableConfig}; /// Stream of files get listed from object store pub type PartitionedFileStream = @@ -42,7 +40,7 @@ pub type PartitionedFileStream = /// Only scan a subset of Row Groups from the Parquet file whose data "midpoint" /// lies within the [start, end) byte offsets. This option can be used to scan non-overlapping /// sections of a Parquet file in parallel. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Hash, Eq, PartialOrd, Ord)] pub struct FileRange { /// Range start pub start: i64, @@ -72,16 +70,16 @@ pub struct PartitionedFile { /// An optional field for user defined per object metadata pub extensions: Option>, } - impl PartitionedFile { /// Create a simple file without metadata or partition - pub fn new(path: String, size: u64) -> Self { + pub fn new(path: impl Into, size: u64) -> Self { Self { object_meta: ObjectMeta { - location: Path::from(path), + location: Path::from(path.into()), last_modified: chrono::Utc.timestamp_nanos(0), size: size as usize, e_tag: None, + version: None, }, partition_values: vec![], range: None, @@ -97,11 +95,13 @@ impl PartitionedFile { last_modified: chrono::Utc.timestamp_nanos(0), size: size as usize, e_tag: None, + version: None, }, partition_values: vec![], - range: Some(FileRange { start, end }), + range: None, extensions: None, } + .with_range(start, end) } /// Return a file reference from the given path @@ -109,6 +109,17 @@ impl PartitionedFile { let size = std::fs::metadata(path.clone())?.len(); Ok(Self::new(path, size)) } + + /// Return the path of this partitioned file + pub fn path(&self) -> &Path { + &self.object_meta.location + } + + /// Update the file to only scan the specified range (in bytes) + pub fn with_range(mut self, start: i64, end: i64) -> Self { + self.range = Some(FileRange { start, end }); + self + } } impl From for PartitionedFile { diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index d26d417bd8b2f..a7af1bf1be28a 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -17,6 +17,7 @@ //! The table implementation. +use std::collections::HashMap; use std::str::FromStr; use std::{any::Any, sync::Arc}; @@ -26,6 +27,7 @@ use super::PartitionedFile; #[cfg(feature = "parquet")] use crate::datasource::file_format::parquet::ParquetFormat; use crate::datasource::{ + create_ordering, file_format::{ arrow::ArrowFormat, avro::AvroFormat, @@ -36,19 +38,16 @@ use crate::datasource::{ }, get_statistics_with_limit, listing::ListingTableUrl, - physical_plan::{is_plan_streaming, FileScanConfig, FileSinkConfig}, + physical_plan::{FileScanConfig, FileSinkConfig}, TableProvider, TableType, }; -use crate::logical_expr::TableProviderFilterPushDown; -use crate::physical_plan; use crate::{ error::{DataFusionError, Result}, execution::context::SessionState, - logical_expr::Expr, + logical_expr::{utils::conjunction, Expr, TableProviderFilterPushDown}, physical_plan::{empty::EmptyExec, ExecutionPlan, Statistics}, }; -use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, SchemaBuilder, SchemaRef}; use arrow_schema::Schema; use datafusion_common::{ @@ -57,10 +56,8 @@ use datafusion_common::{ }; use datafusion_execution::cache::cache_manager::FileStatisticsCache; use datafusion_execution::cache::cache_unit::DefaultFileStatisticsCache; -use datafusion_expr::expr::Sort; -use datafusion_optimizer::utils::conjunction; use datafusion_physical_expr::{ - create_physical_expr, LexOrdering, PhysicalSortExpr, PhysicalSortRequirement, + create_physical_expr, LexOrdering, PhysicalSortRequirement, }; use async_trait::async_trait; @@ -160,7 +157,7 @@ impl ListingTableConfig { /// Infer `ListingOptions` based on `table_path` suffix. pub async fn infer_options(self, state: &SessionState) -> Result { - let store = if let Some(url) = self.table_paths.get(0) { + let store = if let Some(url) = self.table_paths.first() { state.runtime_env().object_store(url)? } else { return Ok(self); @@ -168,7 +165,7 @@ impl ListingTableConfig { let file = self .table_paths - .get(0) + .first() .unwrap() .list_all_files(state, store.as_ref(), "") .await? @@ -194,7 +191,7 @@ impl ListingTableConfig { pub async fn infer_schema(self, state: &SessionState) -> Result { match self.options { Some(options) => { - let schema = if let Some(url) = self.table_paths.get(0) { + let schema = if let Some(url) = self.table_paths.first() { options.infer_schema(state, url).await? } else { Arc::new(Schema::empty()) @@ -216,33 +213,6 @@ impl ListingTableConfig { } } -#[derive(Debug, Clone)] -///controls how new data should be inserted to a ListingTable -pub enum ListingTableInsertMode { - ///Data should be appended to an existing file - AppendToFile, - ///Data is appended as new files in existing TablePaths - AppendNewFiles, - ///Throw an error if insert into is attempted on this table - Error, -} - -impl FromStr for ListingTableInsertMode { - type Err = DataFusionError; - fn from_str(s: &str) -> Result { - let s_lower = s.to_lowercase(); - match s_lower.as_str() { - "append_to_file" => Ok(ListingTableInsertMode::AppendToFile), - "append_new_files" => Ok(ListingTableInsertMode::AppendNewFiles), - "error" => Ok(ListingTableInsertMode::Error), - _ => plan_err!( - "Unknown or unsupported insert mode {s}. Supported options are \ - append_to_file, append_new_files, and error." - ), - } - } -} - /// Options for creating a [`ListingTable`] #[derive(Clone, Debug)] pub struct ListingOptions { @@ -276,16 +246,6 @@ pub struct ListingOptions { /// multiple equivalent orderings, the outer `Vec` will have a /// single element. pub file_sort_order: Vec>, - /// Infinite source means that the input is not guaranteed to end. - /// Currently, CSV, JSON, and AVRO formats are supported. - /// In order to support infinite inputs, DataFusion may adjust query - /// plans (e.g. joins) to run the given query in full pipelining mode. - pub infinite_source: bool, - /// This setting controls how inserts to this table should be handled - pub insert_mode: ListingTableInsertMode, - /// This setting when true indicates that the table is backed by a single file. - /// Any inserts to the table may only append to this existing file. - pub single_file: bool, /// This setting holds file format specific options which should be used /// when inserting into this table. pub file_type_write_options: Option, @@ -306,31 +266,10 @@ impl ListingOptions { collect_stat: true, target_partitions: 1, file_sort_order: vec![], - infinite_source: false, - insert_mode: ListingTableInsertMode::AppendToFile, - single_file: false, file_type_write_options: None, } } - /// Set unbounded assumption on [`ListingOptions`] and returns self. - /// - /// ``` - /// use std::sync::Arc; - /// use datafusion::datasource::{listing::ListingOptions, file_format::csv::CsvFormat}; - /// use datafusion::prelude::SessionContext; - /// let ctx = SessionContext::new(); - /// let listing_options = ListingOptions::new(Arc::new( - /// CsvFormat::default() - /// )).with_infinite_source(true); - /// - /// assert_eq!(listing_options.infinite_source, true); - /// ``` - pub fn with_infinite_source(mut self, infinite_source: bool) -> Self { - self.infinite_source = infinite_source; - self - } - /// Set file extension on [`ListingOptions`] and returns self. /// /// ``` @@ -478,18 +417,6 @@ impl ListingOptions { self } - /// Configure how insertions to this table should be handled. - pub fn with_insert_mode(mut self, insert_mode: ListingTableInsertMode) -> Self { - self.insert_mode = insert_mode; - self - } - - /// Configure if this table is backed by a sigle file - pub fn with_single_file(mut self, single_file: bool) -> Self { - self.single_file = single_file; - self - } - /// Configure file format specific writing options. pub fn with_write_options( mut self, @@ -529,7 +456,7 @@ impl ListingOptions { /// /// # Features /// -/// 1. Merges schemas if the files have compatible but not indentical schemas +/// 1. Merges schemas if the files have compatible but not identical schemas /// /// 2. Hive-style partitioning support, where a path such as /// `/files/date=1/1/2022/data.parquet` is injected as a `date` column. @@ -596,8 +523,8 @@ pub struct ListingTable { options: ListingOptions, definition: Option, collected_statistics: FileStatisticsCache, - infinite_source: bool, constraints: Constraints, + column_defaults: HashMap, } impl ListingTable { @@ -625,7 +552,6 @@ impl ListingTable { for (part_col_name, part_col_type) in &options.table_partition_cols { builder.push(Field::new(part_col_name, part_col_type.clone(), false)); } - let infinite_source = options.infinite_source; let table = Self { table_paths: config.table_paths, @@ -634,8 +560,8 @@ impl ListingTable { options, definition: None, collected_statistics: Arc::new(DefaultFileStatisticsCache::default()), - infinite_source, constraints: Constraints::empty(), + column_defaults: HashMap::new(), }; Ok(table) @@ -647,6 +573,15 @@ impl ListingTable { self } + /// Assign column defaults + pub fn with_column_defaults( + mut self, + column_defaults: HashMap, + ) -> Self { + self.column_defaults = column_defaults; + self + } + /// Set the [`FileStatisticsCache`] used to cache parquet file statistics. /// /// Setting a statistics cache on the `SessionContext` can avoid refetching statistics @@ -677,34 +612,7 @@ impl ListingTable { /// If file_sort_order is specified, creates the appropriate physical expressions fn try_create_output_ordering(&self) -> Result> { - let mut all_sort_orders = vec![]; - - for exprs in &self.options.file_sort_order { - // Construct PhsyicalSortExpr objects from Expr objects: - let sort_exprs = exprs - .iter() - .map(|expr| { - if let Expr::Sort(Sort { expr, asc, nulls_first }) = expr { - if let Expr::Column(col) = expr.as_ref() { - let expr = physical_plan::expressions::col(&col.name, self.table_schema.as_ref())?; - Ok(PhysicalSortExpr { - expr, - options: SortOptions { - descending: !asc, - nulls_first: *nulls_first, - }, - }) - } else { - plan_err!("Expected single column references in output_ordering, got {expr}") - } - } else { - plan_err!("Expected Expr::Sort in output_ordering, but got {expr}") - } - }) - .collect::>>()?; - all_sort_orders.push(sort_exprs); - } - Ok(all_sort_orders) + create_ordering(&self.table_schema, &self.options.file_sort_order) } } @@ -740,7 +648,7 @@ impl TableProvider for ListingTable { if partitioned_file_lists.is_empty() { let schema = self.schema(); let projected_schema = project_schema(&schema, projection)?; - return Ok(Arc::new(EmptyExec::new(false, projected_schema))); + return Ok(Arc::new(EmptyExec::new(projected_schema))); } // extract types of partition columns @@ -765,10 +673,10 @@ impl TableProvider for ListingTable { None }; - let object_store_url = if let Some(url) = self.table_paths.get(0) { + let object_store_url = if let Some(url) = self.table_paths.first() { url.object_store() } else { - return Ok(Arc::new(EmptyExec::new(false, Arc::new(Schema::empty())))); + return Ok(Arc::new(EmptyExec::new(Arc::new(Schema::empty())))); }; // create the execution plan self.options @@ -784,7 +692,6 @@ impl TableProvider for ListingTable { limit, output_ordering: self.try_create_output_ordering()?, table_partition_cols, - infinite_source: self.infinite_source, }, filters.as_ref(), ) @@ -835,6 +742,13 @@ impl TableProvider for ListingTable { } let table_path = &self.table_paths()[0]; + if !table_path.is_collection() { + return plan_err!( + "Inserting into a ListingTable backed by a single file is not supported, URL is possibly missing a trailing `/`. \ + To append to an existing file use StreamTable, e.g. by using CREATE UNBOUNDED EXTERNAL TABLE" + ); + } + // Get the object store for the table path. let store = state.runtime_env().object_store(table_path)?; @@ -849,31 +763,6 @@ impl TableProvider for ListingTable { .await?; let file_groups = file_list_stream.try_collect::>().await?; - //if we are writing a single output_partition to a table backed by a single file - //we can append to that file. Otherwise, we can write new files into the directory - //adding new files to the listing table in order to insert to the table. - let input_partitions = input.output_partitioning().partition_count(); - let writer_mode = match self.options.insert_mode { - ListingTableInsertMode::AppendToFile => { - if input_partitions > file_groups.len() { - return plan_err!( - "Cannot append {input_partitions} partitions to {} files!", - file_groups.len() - ); - } - - crate::datasource::file_format::write::FileWriterMode::Append - } - ListingTableInsertMode::AppendNewFiles => { - crate::datasource::file_format::write::FileWriterMode::PutMultipart - } - ListingTableInsertMode::Error => { - return plan_err!( - "Invalid plan attempting write to table with TableWriteMode::Error!" - ); - } - }; - let file_format = self.options().format.as_ref(); let file_type_writer_options = match &self.options().file_type_write_options { @@ -891,33 +780,17 @@ impl TableProvider for ListingTable { file_groups, output_schema: self.schema(), table_partition_cols: self.options.table_partition_cols.clone(), - writer_mode, - // A plan can produce finite number of rows even if it has unbounded sources, like LIMIT - // queries. Thus, we can check if the plan is streaming to ensure file sink input is - // unbounded. When `unbounded_input` flag is `true` for sink, we occasionally call `yield_now` - // to consume data at the input. When `unbounded_input` flag is `false` (e.g non-streaming data), - // all of the data at the input is sink after execution finishes. See discussion for rationale: - // https://github.com/apache/arrow-datafusion/pull/7610#issuecomment-1728979918 - unbounded_input: is_plan_streaming(&input)?, - single_file_output: self.options.single_file, + single_file_output: false, overwrite, file_type_writer_options, }; let unsorted: Vec> = vec![]; let order_requirements = if self.options().file_sort_order != unsorted { - if matches!( - self.options().insert_mode, - ListingTableInsertMode::AppendToFile - ) { - return plan_err!( - "Cannot insert into a sorted ListingTable with mode append!" - ); - } // Multiple sort orders in outer vec are equivalent, so we pass only the first one let ordering = self .try_create_output_ordering()? - .get(0) + .first() .ok_or(DataFusionError::Internal( "Expected ListingTable to have a sort order, but none found!".into(), ))? @@ -938,6 +811,10 @@ impl TableProvider for ListingTable { .create_writer_physical_plan(input, state, config, order_requirements) .await } + + fn get_column_default(&self, column: &str) -> Option<&Expr> { + self.column_defaults.get(column) + } } impl ListingTable { @@ -950,7 +827,7 @@ impl ListingTable { filters: &'a [Expr], limit: Option, ) -> Result<(Vec>, Statistics)> { - let store = if let Some(url) = self.table_paths.get(0) { + let store = if let Some(url) = self.table_paths.first() { ctx.runtime_env().object_store(url)? } else { return Ok((vec![], Statistics::new_unknown(&self.file_schema))); @@ -1021,7 +898,6 @@ impl ListingTable { #[cfg(test)] mod tests { use std::collections::HashMap; - use std::fs::File; use super::*; #[cfg(feature = "parquet")] @@ -1032,48 +908,20 @@ mod tests { use crate::prelude::*; use crate::{ assert_batches_eq, - datasource::file_format::{avro::AvroFormat, file_compression_type::FileTypeExt}, - execution::options::ReadOptions, + datasource::file_format::avro::AvroFormat, logical_expr::{col, lit}, test::{columns, object_store::register_test_store}, }; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; + use arrow_schema::SortOptions; use datafusion_common::stats::Precision; use datafusion_common::{assert_contains, GetExt, ScalarValue}; use datafusion_expr::{BinaryExpr, LogicalPlanBuilder, Operator}; - use rstest::*; + use datafusion_physical_expr::PhysicalSortExpr; use tempfile::TempDir; - /// It creates dummy file and checks if it can create unbounded input executors. - async fn unbounded_table_helper( - file_type: FileType, - listing_option: ListingOptions, - infinite_data: bool, - ) -> Result<()> { - let ctx = SessionContext::new(); - register_test_store( - &ctx, - &[(&format!("table/file{}", file_type.get_ext()), 100)], - ); - - let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); - - let table_path = ListingTableUrl::parse("test:///table/").unwrap(); - let config = ListingTableConfig::new(table_path) - .with_listing_options(listing_option) - .with_schema(Arc::new(schema)); - // Create a table - let table = ListingTable::try_new(config)?; - // Create executor from table - let source_exec = table.scan(&ctx.state(), None, &[], None).await?; - - assert_eq!(source_exec.unbounded_output(&[])?, infinite_data); - - Ok(()) - } - #[tokio::test] async fn read_single_file() -> Result<()> { let ctx = SessionContext::new(); @@ -1281,99 +1129,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn unbounded_csv_table_without_schema() -> Result<()> { - let tmp_dir = TempDir::new()?; - let file_path = tmp_dir.path().join("dummy.csv"); - File::create(file_path)?; - let ctx = SessionContext::new(); - let error = ctx - .register_csv( - "test", - tmp_dir.path().to_str().unwrap(), - CsvReadOptions::new().mark_infinite(true), - ) - .await - .unwrap_err(); - match error { - DataFusionError::Plan(_) => Ok(()), - val => Err(val), - } - } - - #[tokio::test] - async fn unbounded_json_table_without_schema() -> Result<()> { - let tmp_dir = TempDir::new()?; - let file_path = tmp_dir.path().join("dummy.json"); - File::create(file_path)?; - let ctx = SessionContext::new(); - let error = ctx - .register_json( - "test", - tmp_dir.path().to_str().unwrap(), - NdJsonReadOptions::default().mark_infinite(true), - ) - .await - .unwrap_err(); - match error { - DataFusionError::Plan(_) => Ok(()), - val => Err(val), - } - } - - #[tokio::test] - async fn unbounded_avro_table_without_schema() -> Result<()> { - let tmp_dir = TempDir::new()?; - let file_path = tmp_dir.path().join("dummy.avro"); - File::create(file_path)?; - let ctx = SessionContext::new(); - let error = ctx - .register_avro( - "test", - tmp_dir.path().to_str().unwrap(), - AvroReadOptions::default().mark_infinite(true), - ) - .await - .unwrap_err(); - match error { - DataFusionError::Plan(_) => Ok(()), - val => Err(val), - } - } - - #[rstest] - #[tokio::test] - async fn unbounded_csv_table( - #[values(true, false)] infinite_data: bool, - ) -> Result<()> { - let config = CsvReadOptions::new().mark_infinite(infinite_data); - let session_config = SessionConfig::new().with_target_partitions(1); - let listing_options = config.to_listing_options(&session_config); - unbounded_table_helper(FileType::CSV, listing_options, infinite_data).await - } - - #[rstest] - #[tokio::test] - async fn unbounded_json_table( - #[values(true, false)] infinite_data: bool, - ) -> Result<()> { - let config = NdJsonReadOptions::default().mark_infinite(infinite_data); - let session_config = SessionConfig::new().with_target_partitions(1); - let listing_options = config.to_listing_options(&session_config); - unbounded_table_helper(FileType::JSON, listing_options, infinite_data).await - } - - #[rstest] - #[tokio::test] - async fn unbounded_avro_table( - #[values(true, false)] infinite_data: bool, - ) -> Result<()> { - let config = AvroReadOptions::default().mark_infinite(infinite_data); - let session_config = SessionConfig::new().with_target_partitions(1); - let listing_options = config.to_listing_options(&session_config); - unbounded_table_helper(FileType::AVRO, listing_options, infinite_data).await - } - #[tokio::test] async fn test_assert_list_files_for_scan_grouping() -> Result<()> { // more expected partitions than files @@ -1594,17 +1349,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_insert_into_append_to_json_file() -> Result<()> { - helper_test_insert_into_append_to_existing_files( - FileType::JSON, - FileCompressionType::UNCOMPRESSED, - None, - ) - .await?; - Ok(()) - } - #[tokio::test] async fn test_insert_into_append_new_json_files() -> Result<()> { let mut config_map: HashMap = HashMap::new(); @@ -1623,17 +1367,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_insert_into_append_to_csv_file() -> Result<()> { - helper_test_insert_into_append_to_existing_files( - FileType::CSV, - FileCompressionType::UNCOMPRESSED, - None, - ) - .await?; - Ok(()) - } - #[tokio::test] async fn test_insert_into_append_new_csv_files() -> Result<()> { let mut config_map: HashMap = HashMap::new(); @@ -1690,13 +1423,8 @@ mod tests { #[tokio::test] async fn test_insert_into_sql_csv_defaults() -> Result<()> { - helper_test_insert_into_sql( - "csv", - FileCompressionType::UNCOMPRESSED, - "OPTIONS (insert_mode 'append_new_files')", - None, - ) - .await?; + helper_test_insert_into_sql("csv", FileCompressionType::UNCOMPRESSED, "", None) + .await?; Ok(()) } @@ -1705,8 +1433,7 @@ mod tests { helper_test_insert_into_sql( "csv", FileCompressionType::UNCOMPRESSED, - "WITH HEADER ROW \ - OPTIONS (insert_mode 'append_new_files')", + "WITH HEADER ROW", None, ) .await?; @@ -1715,13 +1442,8 @@ mod tests { #[tokio::test] async fn test_insert_into_sql_json_defaults() -> Result<()> { - helper_test_insert_into_sql( - "json", - FileCompressionType::UNCOMPRESSED, - "OPTIONS (insert_mode 'append_new_files')", - None, - ) - .await?; + helper_test_insert_into_sql("json", FileCompressionType::UNCOMPRESSED, "", None) + .await?; Ok(()) } @@ -1906,211 +1628,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_insert_into_append_to_parquet_file_fails() -> Result<()> { - let maybe_err = helper_test_insert_into_append_to_existing_files( - FileType::PARQUET, - FileCompressionType::UNCOMPRESSED, - None, - ) - .await; - let _err = - maybe_err.expect_err("Appending to existing parquet file did not fail!"); - Ok(()) - } - - fn load_empty_schema_table( - schema: SchemaRef, - temp_path: &str, - insert_mode: ListingTableInsertMode, - file_format: Arc, - ) -> Result> { - File::create(temp_path)?; - let table_path = ListingTableUrl::parse(temp_path).unwrap(); - - let listing_options = - ListingOptions::new(file_format.clone()).with_insert_mode(insert_mode); - - let config = ListingTableConfig::new(table_path) - .with_listing_options(listing_options) - .with_schema(schema); - - let table = ListingTable::try_new(config)?; - Ok(Arc::new(table)) - } - - /// Logic of testing inserting into listing table by Appending to existing files - /// is the same for all formats/options which support this. This helper allows - /// passing different options to execute the same test with different settings. - async fn helper_test_insert_into_append_to_existing_files( - file_type: FileType, - file_compression_type: FileCompressionType, - session_config_map: Option>, - ) -> Result<()> { - // Create the initial context, schema, and batch. - let session_ctx = match session_config_map { - Some(cfg) => { - let config = SessionConfig::from_string_hash_map(cfg)?; - SessionContext::new_with_config(config) - } - None => SessionContext::new(), - }; - // Create a new schema with one field called "a" of type Int32 - let schema = Arc::new(Schema::new(vec![Field::new( - "column1", - DataType::Int32, - false, - )])); - - // Create a new batch of data to insert into the table - let batch = RecordBatch::try_new( - schema.clone(), - vec![Arc::new(arrow_array::Int32Array::from(vec![1, 2, 3]))], - )?; - - // Filename with extension - let filename = format!( - "path{}", - file_type - .to_owned() - .get_ext_with_compression(file_compression_type) - .unwrap() - ); - - // Create a temporary directory and a CSV file within it. - let tmp_dir = TempDir::new()?; - let path = tmp_dir.path().join(filename); - - let file_format: Arc = match file_type { - FileType::CSV => Arc::new( - CsvFormat::default().with_file_compression_type(file_compression_type), - ), - FileType::JSON => Arc::new( - JsonFormat::default().with_file_compression_type(file_compression_type), - ), - FileType::PARQUET => Arc::new(ParquetFormat::default()), - FileType::AVRO => Arc::new(AvroFormat {}), - FileType::ARROW => Arc::new(ArrowFormat {}), - }; - - let initial_table = load_empty_schema_table( - schema.clone(), - path.to_str().unwrap(), - ListingTableInsertMode::AppendToFile, - file_format, - )?; - session_ctx.register_table("t", initial_table)?; - // Create and register the source table with the provided schema and inserted data - let source_table = Arc::new(MemTable::try_new( - schema.clone(), - vec![vec![batch.clone(), batch.clone()]], - )?); - session_ctx.register_table("source", source_table.clone())?; - // Convert the source table into a provider so that it can be used in a query - let source = provider_as_source(source_table); - // Create a table scan logical plan to read from the source table - let scan_plan = LogicalPlanBuilder::scan("source", source, None)?.build()?; - // Create an insert plan to insert the source data into the initial table - let insert_into_table = - LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, false)?.build()?; - // Create a physical plan from the insert plan - let plan = session_ctx - .state() - .create_physical_plan(&insert_into_table) - .await?; - - // Execute the physical plan and collect the results - let res = collect(plan, session_ctx.task_ctx()).await?; - // Insert returns the number of rows written, in our case this would be 6. - let expected = [ - "+-------+", - "| count |", - "+-------+", - "| 6 |", - "+-------+", - ]; - - // Assert that the batches read from the file match the expected result. - assert_batches_eq!(expected, &res); - - // Read the records in the table - let batches = session_ctx.sql("select * from t").await?.collect().await?; - - // Define the expected result as a vector of strings. - let expected = [ - "+---------+", - "| column1 |", - "+---------+", - "| 1 |", - "| 2 |", - "| 3 |", - "| 1 |", - "| 2 |", - "| 3 |", - "+---------+", - ]; - - // Assert that the batches read from the file match the expected result. - assert_batches_eq!(expected, &batches); - - // Assert that only 1 file was added to the table - let num_files = tmp_dir.path().read_dir()?.count(); - assert_eq!(num_files, 1); - - // Create a physical plan from the insert plan - let plan = session_ctx - .state() - .create_physical_plan(&insert_into_table) - .await?; - - // Again, execute the physical plan and collect the results - let res = collect(plan, session_ctx.task_ctx()).await?; - // Insert returns the number of rows written, in our case this would be 6. - let expected = [ - "+-------+", - "| count |", - "+-------+", - "| 6 |", - "+-------+", - ]; - - // Assert that the batches read from the file match the expected result. - assert_batches_eq!(expected, &res); - - // Open the CSV file, read its contents as a record batch, and collect the batches into a vector. - let batches = session_ctx.sql("select * from t").await?.collect().await?; - - // Define the expected result after the second append. - let expected = vec![ - "+---------+", - "| column1 |", - "+---------+", - "| 1 |", - "| 2 |", - "| 3 |", - "| 1 |", - "| 2 |", - "| 3 |", - "| 1 |", - "| 2 |", - "| 3 |", - "| 1 |", - "| 2 |", - "| 3 |", - "+---------+", - ]; - - // Assert that the batches read from the file after the second append match the expected result. - assert_batches_eq!(expected, &batches); - - // Assert that no additional files were added to the table - let num_files = tmp_dir.path().read_dir()?.count(); - assert_eq!(num_files, 1); - - // Return Ok if the function - Ok(()) - } - async fn helper_test_append_new_files_to_table( file_type: FileType, file_compression_type: FileCompressionType, @@ -2156,7 +1673,6 @@ mod tests { "t", tmp_dir.path().to_str().unwrap(), CsvReadOptions::new() - .insert_mode(ListingTableInsertMode::AppendNewFiles) .schema(schema.as_ref()) .file_compression_type(file_compression_type), ) @@ -2168,7 +1684,6 @@ mod tests { "t", tmp_dir.path().to_str().unwrap(), NdJsonReadOptions::default() - .insert_mode(ListingTableInsertMode::AppendNewFiles) .schema(schema.as_ref()) .file_compression_type(file_compression_type), ) @@ -2179,9 +1694,7 @@ mod tests { .register_parquet( "t", tmp_dir.path().to_str().unwrap(), - ParquetReadOptions::default() - .insert_mode(ListingTableInsertMode::AppendNewFiles) - .schema(schema.as_ref()), + ParquetReadOptions::default().schema(schema.as_ref()), ) .await?; } @@ -2190,10 +1703,7 @@ mod tests { .register_avro( "t", tmp_dir.path().to_str().unwrap(), - AvroReadOptions::default() - // TODO implement insert_mode for avro - //.insert_mode(ListingTableInsertMode::AppendNewFiles) - .schema(schema.as_ref()), + AvroReadOptions::default().schema(schema.as_ref()), ) .await?; } @@ -2202,10 +1712,7 @@ mod tests { .register_arrow( "t", tmp_dir.path().to_str().unwrap(), - ArrowReadOptions::default() - // TODO implement insert_mode for arrow - //.insert_mode(ListingTableInsertMode::AppendNewFiles) - .schema(schema.as_ref()), + ArrowReadOptions::default().schema(schema.as_ref()), ) .await?; } diff --git a/datafusion/core/src/datasource/listing/url.rs b/datafusion/core/src/datasource/listing/url.rs index 9197e37adbd5d..766dee7de9010 100644 --- a/datafusion/core/src/datasource/listing/url.rs +++ b/datafusion/core/src/datasource/listing/url.rs @@ -20,6 +20,7 @@ use std::fs; use crate::datasource::object_store::ObjectStoreUrl; use crate::execution::context::SessionState; use datafusion_common::{DataFusionError, Result}; +use datafusion_optimizer::OptimizerConfig; use futures::stream::BoxStream; use futures::{StreamExt, TryStreamExt}; use glob::Pattern; @@ -45,6 +46,17 @@ pub struct ListingTableUrl { impl ListingTableUrl { /// Parse a provided string as a `ListingTableUrl` /// + /// A URL can either refer to a single object, or a collection of objects with a + /// common prefix, with the presence of a trailing `/` indicating a collection. + /// + /// For example, `file:///foo.txt` refers to the file at `/foo.txt`, whereas + /// `file:///foo/` refers to all the files under the directory `/foo` and its + /// subdirectories. + /// + /// Similarly `s3://BUCKET/blob.csv` refers to `blob.csv` in the S3 bucket `BUCKET`, + /// wherease `s3://BUCKET/foo/` refers to all objects with the prefix `foo/` in the + /// S3 bucket `BUCKET` + /// /// # URL Encoding /// /// URL paths are expected to be URL-encoded. That is, the URL for a file named `bar%2Efoo` @@ -58,19 +70,21 @@ impl ListingTableUrl { /// # Paths without a Scheme /// /// If no scheme is provided, or the string is an absolute filesystem path - /// as determined [`std::path::Path::is_absolute`], the string will be + /// as determined by [`std::path::Path::is_absolute`], the string will be /// interpreted as a path on the local filesystem using the operating /// system's standard path delimiter, i.e. `\` on Windows, `/` on Unix. /// /// If the path contains any of `'?', '*', '['`, it will be considered /// a glob expression and resolved as described in the section below. /// - /// Otherwise, the path will be resolved to an absolute path, returning - /// an error if it does not exist, and converted to a [file URI] + /// Otherwise, the path will be resolved to an absolute path based on the current + /// working directory, and converted to a [file URI]. /// - /// If you wish to specify a path that does not exist on the local - /// machine you must provide it as a fully-qualified [file URI] - /// e.g. `file:///myfile.txt` + /// If the path already exists in the local filesystem this will be used to determine if this + /// [`ListingTableUrl`] refers to a collection or a single object, otherwise the presence + /// of a trailing path delimiter will be used to indicate a directory. For the avoidance + /// of ambiguity it is recommended users always include trailing `/` when intending to + /// refer to a directory. /// /// ## Glob File Paths /// @@ -78,9 +92,7 @@ impl ListingTableUrl { /// be resolved as follows. /// /// The string up to the first path segment containing a glob expression will be extracted, - /// and resolved in the same manner as a normal scheme-less path. That is, resolved to - /// an absolute path on the local filesystem, returning an error if it does not exist, - /// and converted to a [file URI] + /// and resolved in the same manner as a normal scheme-less path above. /// /// The remaining string will be interpreted as a [`glob::Pattern`] and used as a /// filter when listing files from object storage @@ -105,6 +117,7 @@ impl ListingTableUrl { /// Get object store for specified input_url /// if input_url is actually not a url, we assume it is a local file path /// if we have a local path, create it if not exists so ListingTableUrl::parse works + #[deprecated(note = "Use parse")] pub fn parse_create_local_if_not_exists( s: impl AsRef, is_directory: bool, @@ -120,6 +133,10 @@ impl ListingTableUrl { if is_directory { fs::create_dir_all(path)?; } else { + // ensure parent directory exists + if let Some(parent) = path.parent() { + fs::create_dir_all(parent)?; + } fs::File::create(path)?; } } @@ -130,7 +147,7 @@ impl ListingTableUrl { /// Creates a new [`ListingTableUrl`] interpreting `s` as a filesystem path fn parse_path(s: &str) -> Result { - let (prefix, glob) = match split_glob_expression(s) { + let (path, glob) = match split_glob_expression(s) { Some((prefix, glob)) => { let glob = Pattern::new(glob) .map_err(|e| DataFusionError::External(Box::new(e)))?; @@ -139,15 +156,12 @@ impl ListingTableUrl { None => (s, None), }; - let path = std::path::Path::new(prefix).canonicalize()?; - let url = if path.is_dir() { - Url::from_directory_path(path) - } else { - Url::from_file_path(path) - } - .map_err(|_| DataFusionError::Internal(format!("Can not open path: {s}")))?; - // TODO: Currently we do not have an IO-related error variant that accepts () - // or a string. Once we have such a variant, change the error type above. + let url = url_from_filesystem_path(path).ok_or_else(|| { + DataFusionError::External( + format!("Failed to convert path to URL: {path}").into(), + ) + })?; + Self::try_new(url, glob) } @@ -162,25 +176,46 @@ impl ListingTableUrl { self.url.scheme() } - /// Return the prefix from which to list files + /// Return the URL path not excluding any glob expression + /// + /// If [`Self::is_collection`], this is the listing prefix + /// Otherwise, this is the path to the object pub fn prefix(&self) -> &Path { &self.prefix } /// Returns `true` if `path` matches this [`ListingTableUrl`] - pub fn contains(&self, path: &Path) -> bool { + pub fn contains(&self, path: &Path, ignore_subdirectory: bool) -> bool { match self.strip_prefix(path) { Some(mut segments) => match &self.glob { Some(glob) => { - let stripped = segments.join("/"); - glob.matches(&stripped) + if ignore_subdirectory { + segments + .next() + .map_or(false, |file_name| glob.matches(file_name)) + } else { + let stripped = segments.join("/"); + glob.matches(&stripped) + } + } + None => { + if ignore_subdirectory { + let has_subdirectory = segments.collect::>().len() > 1; + !has_subdirectory + } else { + true + } } - None => true, }, None => false, } } + /// Returns `true` if `path` refers to a collection of objects + pub fn is_collection(&self) -> bool { + self.url.as_str().ends_with('/') + } + /// Strips the prefix of this [`ListingTableUrl`] from the provided path, returning /// an iterator of the remaining path segments pub(crate) fn strip_prefix<'a, 'b: 'a>( @@ -202,21 +237,20 @@ impl ListingTableUrl { store: &'a dyn ObjectStore, file_extension: &'a str, ) -> Result>> { + let exec_options = &ctx.options().execution; + let ignore_subdirectory = exec_options.listing_table_ignore_subdirectory; // If the prefix is a file, use a head request, otherwise list - let is_dir = self.url.as_str().ends_with('/'); - let list = match is_dir { + let list = match self.is_collection() { true => match ctx.runtime_env().cache_manager.get_list_files_cache() { - None => futures::stream::once(store.list(Some(&self.prefix))) - .try_flatten() - .boxed(), + None => store.list(Some(&self.prefix)), Some(cache) => { if let Some(res) = cache.get(&self.prefix) { debug!("Hit list all files cache"); futures::stream::iter(res.as_ref().clone().into_iter().map(Ok)) .boxed() } else { - let list_res = store.list(Some(&self.prefix)).await; - let vec = list_res?.try_collect::>().await?; + let list_res = store.list(Some(&self.prefix)); + let vec = list_res.try_collect::>().await?; cache.put(&self.prefix, Arc::new(vec.clone())); futures::stream::iter(vec.into_iter().map(Ok)).boxed() } @@ -228,7 +262,7 @@ impl ListingTableUrl { .try_filter(move |meta| { let path = &meta.location; let extension_match = path.as_ref().ends_with(file_extension); - let glob_match = self.contains(path); + let glob_match = self.contains(path, ignore_subdirectory); futures::future::ready(extension_match && glob_match) }) .map_err(DataFusionError::ObjectStore) @@ -247,6 +281,34 @@ impl ListingTableUrl { } } +/// Creates a file URL from a potentially relative filesystem path +fn url_from_filesystem_path(s: &str) -> Option { + let path = std::path::Path::new(s); + let is_dir = match path.exists() { + true => path.is_dir(), + // Fallback to inferring from trailing separator + false => std::path::is_separator(s.chars().last()?), + }; + + let from_absolute_path = |p| { + let first = match is_dir { + true => Url::from_directory_path(p).ok(), + false => Url::from_file_path(p).ok(), + }?; + + // By default from_*_path preserve relative path segments + // We therefore parse the URL again to resolve these + Url::parse(first.as_str()).ok() + }; + + if path.is_absolute() { + return from_absolute_path(path); + } + + let absolute = std::env::current_dir().ok()?.join(path); + from_absolute_path(&absolute) +} + impl AsRef for ListingTableUrl { fn as_ref(&self) -> &str { self.url.as_ref() @@ -326,8 +388,8 @@ mod tests { let url = ListingTableUrl::parse("file:///foo/bar?").unwrap(); assert_eq!(url.prefix.as_ref(), "foo/bar"); - let err = ListingTableUrl::parse("file:///foo/😺").unwrap_err(); - assert_eq!(err.to_string(), "Object Store error: Encountered object with invalid path: Error parsing Path \"/foo/😺\": Encountered illegal character sequence \"😺\" whilst parsing path segment \"😺\""); + let url = ListingTableUrl::parse("file:///foo/😺").unwrap(); + assert_eq!(url.prefix.as_ref(), "foo/😺"); let url = ListingTableUrl::parse("file:///foo/bar%2Efoo").unwrap(); assert_eq!(url.prefix.as_ref(), "foo/bar.foo"); @@ -347,6 +409,37 @@ mod tests { let url = ListingTableUrl::parse(path.to_str().unwrap()).unwrap(); assert!(url.prefix.as_ref().ends_with("bar%2Ffoo"), "{}", url.prefix); + + let url = ListingTableUrl::parse("file:///foo/../a%252Fb.txt").unwrap(); + assert_eq!(url.prefix.as_ref(), "a%2Fb.txt"); + + let url = + ListingTableUrl::parse("file:///foo/./bar/../../baz/./test.txt").unwrap(); + assert_eq!(url.prefix.as_ref(), "baz/test.txt"); + + let workdir = std::env::current_dir().unwrap(); + let t = workdir.join("non-existent"); + let a = ListingTableUrl::parse(t.to_str().unwrap()).unwrap(); + let b = ListingTableUrl::parse("non-existent").unwrap(); + assert_eq!(a, b); + assert!(a.prefix.as_ref().ends_with("non-existent")); + + let t = workdir.parent().unwrap(); + let a = ListingTableUrl::parse(t.to_str().unwrap()).unwrap(); + let b = ListingTableUrl::parse("..").unwrap(); + assert_eq!(a, b); + + let t = t.join("bar"); + let a = ListingTableUrl::parse(t.to_str().unwrap()).unwrap(); + let b = ListingTableUrl::parse("../bar").unwrap(); + assert_eq!(a, b); + assert!(a.prefix.as_ref().ends_with("bar")); + + let t = t.join(".").join("foo").join("..").join("baz"); + let a = ListingTableUrl::parse(t.to_str().unwrap()).unwrap(); + let b = ListingTableUrl::parse("../bar/./foo/../baz").unwrap(); + assert_eq!(a, b); + assert!(a.prefix.as_ref().ends_with("bar/baz")); } #[test] diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index 26f40518979a2..e8ffece320d7d 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -21,8 +21,6 @@ use std::path::Path; use std::str::FromStr; use std::sync::Arc; -use super::listing::ListingTableInsertMode; - #[cfg(feature = "parquet")] use crate::datasource::file_format::parquet::ParquetFormat; use crate::datasource::file_format::{ @@ -38,24 +36,19 @@ use crate::execution::context::SessionState; use arrow::datatypes::{DataType, SchemaRef}; use datafusion_common::file_options::{FileTypeWriterOptions, StatementOptions}; -use datafusion_common::{DataFusionError, FileType}; +use datafusion_common::{arrow_datafusion_err, plan_err, DataFusionError, FileType}; use datafusion_expr::CreateExternalTable; use async_trait::async_trait; /// A `TableProviderFactory` capable of creating new `ListingTable`s +#[derive(Debug, Default)] pub struct ListingTableFactory {} impl ListingTableFactory { /// Creates a new `ListingTableFactory` pub fn new() -> Self { - Self {} - } -} - -impl Default for ListingTableFactory { - fn default() -> Self { - Self::new() + Self::default() } } @@ -74,12 +67,20 @@ impl TableProviderFactory for ListingTableFactory { let file_extension = get_extension(cmd.location.as_str()); let file_format: Arc = match file_type { - FileType::CSV => Arc::new( - CsvFormat::default() + FileType::CSV => { + let mut statement_options = StatementOptions::from(&cmd.options); + let mut csv_format = CsvFormat::default() .with_has_header(cmd.has_header) .with_delimiter(cmd.delimiter as u8) - .with_file_compression_type(file_compression_type), - ), + .with_file_compression_type(file_compression_type); + if let Some(quote) = statement_options.take_str_option("quote") { + csv_format = csv_format.with_quote(quote.as_bytes()[0]) + } + if let Some(escape) = statement_options.take_str_option("escape") { + csv_format = csv_format.with_escape(Some(escape.as_bytes()[0])) + } + Arc::new(csv_format) + } #[cfg(feature = "parquet")] FileType::PARQUET => Arc::new(ParquetFormat::default()), FileType::AVRO => Arc::new(AvroFormat), @@ -113,7 +114,7 @@ impl TableProviderFactory for ListingTableFactory { .map(|col| { schema .field_with_name(col) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) }) .collect::>>()? .into_iter() @@ -132,40 +133,17 @@ impl TableProviderFactory for ListingTableFactory { (Some(schema), table_partition_cols) }; - // look for 'infinite' as an option - let infinite_source = cmd.unbounded; - let mut statement_options = StatementOptions::from(&cmd.options); - // Extract ListingTable specific options if present or set default - let unbounded = if infinite_source { - statement_options.take_str_option("unbounded"); - infinite_source - } else { - statement_options - .take_bool_option("unbounded")? - .unwrap_or(false) - }; - - let create_local_path = statement_options - .take_bool_option("create_local_path")? - .unwrap_or(false); - let single_file = statement_options - .take_bool_option("single_file")? - .unwrap_or(false); - - let explicit_insert_mode = statement_options.take_str_option("insert_mode"); - let insert_mode = match explicit_insert_mode { - Some(mode) => ListingTableInsertMode::from_str(mode.as_str()), - None => match file_type { - FileType::CSV => Ok(ListingTableInsertMode::AppendToFile), - #[cfg(feature = "parquet")] - FileType::PARQUET => Ok(ListingTableInsertMode::AppendNewFiles), - FileType::AVRO => Ok(ListingTableInsertMode::AppendNewFiles), - FileType::JSON => Ok(ListingTableInsertMode::AppendToFile), - FileType::ARROW => Ok(ListingTableInsertMode::AppendNewFiles), - }, - }?; + // Backwards compatibility (#8547), discard deprecated options + statement_options.take_bool_option("single_file")?; + if let Some(s) = statement_options.take_str_option("insert_mode") { + if !s.eq_ignore_ascii_case("append_new_files") { + return plan_err!("Unknown or unsupported insert mode {s}. Only append_new_files supported"); + } + } + statement_options.take_bool_option("create_local_path")?; + statement_options.take_str_option("unbounded"); let file_type = file_format.file_type(); @@ -205,13 +183,7 @@ impl TableProviderFactory for ListingTableFactory { FileType::AVRO => file_type_writer_options, }; - let table_path = match create_local_path { - true => ListingTableUrl::parse_create_local_if_not_exists( - &cmd.location, - !single_file, - ), - false => ListingTableUrl::parse(&cmd.location), - }?; + let table_path = ListingTableUrl::parse(&cmd.location)?; let options = ListingOptions::new(file_format) .with_collect_stat(state.config().collect_statistics()) @@ -219,10 +191,7 @@ impl TableProviderFactory for ListingTableFactory { .with_target_partitions(state.config().target_partitions()) .with_table_partition_cols(table_partition_cols) .with_file_sort_order(cmd.order_exprs.clone()) - .with_insert_mode(insert_mode) - .with_single_file(single_file) - .with_write_options(file_type_writer_options) - .with_infinite_source(unbounded); + .with_write_options(file_type_writer_options); let resolved_schema = match provided_schema { None => options.infer_schema(state, &table_path).await?, @@ -235,7 +204,8 @@ impl TableProviderFactory for ListingTableFactory { .with_cache(state.runtime_env().cache_manager.get_file_statistic_cache()); let table = provider .with_definition(cmd.definition.clone()) - .with_constraints(cmd.constraints.clone()); + .with_constraints(cmd.constraints.clone()) + .with_column_defaults(cmd.column_defaults.clone()); Ok(Arc::new(table)) } } @@ -286,6 +256,7 @@ mod tests { unbounded: false, options: HashMap::new(), constraints: Constraints::empty(), + column_defaults: HashMap::new(), }; let table_provider = factory.create(&state, &cmd).await.unwrap(); let listing_table = table_provider diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index 6bcaa97a408fc..7c61cc5368608 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -21,6 +21,7 @@ use datafusion_physical_plan::metrics::MetricsSet; use futures::StreamExt; use log::debug; use std::any::Any; +use std::collections::HashMap; use std::fmt::{self, Debug}; use std::sync::Arc; @@ -56,6 +57,7 @@ pub struct MemTable { schema: SchemaRef, pub(crate) batches: Vec, constraints: Constraints, + column_defaults: HashMap, } impl MemTable { @@ -79,6 +81,7 @@ impl MemTable { .map(|e| Arc::new(RwLock::new(e))) .collect::>(), constraints: Constraints::empty(), + column_defaults: HashMap::new(), }) } @@ -88,6 +91,15 @@ impl MemTable { self } + /// Assign column defaults + pub fn with_column_defaults( + mut self, + column_defaults: HashMap, + ) -> Self { + self.column_defaults = column_defaults; + self + } + /// Create a mem table by reading from another data source pub async fn load( t: Arc, @@ -228,6 +240,10 @@ impl TableProvider for MemTable { None, ))) } + + fn get_column_default(&self, column: &str) -> Option<&Expr> { + self.column_defaults.get(column) + } } /// Implements for writing to a [`MemTable`] @@ -407,7 +423,7 @@ mod tests { .scan(&session_ctx.state(), Some(&projection), &[], None) .await { - Err(DataFusionError::ArrowError(ArrowError::SchemaError(e))) => { + Err(DataFusionError::ArrowError(ArrowError::SchemaError(e), _)) => { assert_eq!( "\"project index 4 out of bounds, max field 3\"", format!("{e:?}") diff --git a/datafusion/core/src/datasource/mod.rs b/datafusion/core/src/datasource/mod.rs index 48e9d6992124d..2e516cc36a01d 100644 --- a/datafusion/core/src/datasource/mod.rs +++ b/datafusion/core/src/datasource/mod.rs @@ -23,12 +23,14 @@ pub mod avro_to_arrow; pub mod default_table_source; pub mod empty; pub mod file_format; +pub mod function; pub mod listing; pub mod listing_table_factory; pub mod memory; pub mod physical_plan; pub mod provider; mod statistics; +pub mod stream; pub mod streaming; pub mod view; @@ -43,3 +45,46 @@ pub use self::provider::TableProvider; pub use self::view::ViewTable; pub use crate::logical_expr::TableType; pub use statistics::get_statistics_with_limit; + +use arrow_schema::{Schema, SortOptions}; +use datafusion_common::{plan_err, DataFusionError, Result}; +use datafusion_expr::Expr; +use datafusion_physical_expr::{expressions, LexOrdering, PhysicalSortExpr}; + +fn create_ordering( + schema: &Schema, + sort_order: &[Vec], +) -> Result> { + let mut all_sort_orders = vec![]; + + for exprs in sort_order { + // Construct PhysicalSortExpr objects from Expr objects: + let mut sort_exprs = vec![]; + for expr in exprs { + match expr { + Expr::Sort(sort) => match sort.expr.as_ref() { + Expr::Column(col) => match expressions::col(&col.name, schema) { + Ok(expr) => { + sort_exprs.push(PhysicalSortExpr { + expr, + options: SortOptions { + descending: !sort.asc, + nulls_first: sort.nulls_first, + }, + }); + } + // Cannot find expression in the projected_schema, stop iterating + // since rest of the orderings are violated + Err(_) => break, + } + expr => return plan_err!("Expected single column references in output_ordering, got {expr}"), + } + expr => return plan_err!("Expected Expr::Sort in output_ordering, but got {expr}"), + } + } + if !sort_exprs.is_empty() { + all_sort_orders.push(sort_exprs); + } + } + Ok(all_sort_orders) +} diff --git a/datafusion/core/src/datasource/physical_plan/arrow_file.rs b/datafusion/core/src/datasource/physical_plan/arrow_file.rs index 30b55db284918..ae1e879d0da1c 100644 --- a/datafusion/core/src/datasource/physical_plan/arrow_file.rs +++ b/datafusion/core/src/datasource/physical_plan/arrow_file.rs @@ -93,10 +93,6 @@ impl ExecutionPlan for ArrowExec { Partitioning::UnknownPartitioning(self.base_config.file_groups.len()) } - fn unbounded_output(&self, _: &[bool]) -> Result { - Ok(self.base_config().infinite_source) - } - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { self.projected_output_ordering .first() diff --git a/datafusion/core/src/datasource/physical_plan/avro.rs b/datafusion/core/src/datasource/physical_plan/avro.rs index b97f162fd2f5c..e448bf39f4272 100644 --- a/datafusion/core/src/datasource/physical_plan/avro.rs +++ b/datafusion/core/src/datasource/physical_plan/avro.rs @@ -89,10 +89,6 @@ impl ExecutionPlan for AvroExec { Partitioning::UnknownPartitioning(self.base_config.file_groups.len()) } - fn unbounded_output(&self, _: &[bool]) -> Result { - Ok(self.base_config().infinite_source) - } - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { self.projected_output_ordering .first() @@ -276,7 +272,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }); assert_eq!(avro_exec.output_partitioning().partition_count(), 1); let mut results = avro_exec @@ -348,7 +343,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }); assert_eq!(avro_exec.output_partitioning().partition_count(), 1); @@ -406,8 +400,7 @@ mod tests { .await?; let mut partitioned_file = PartitionedFile::from(meta); - partitioned_file.partition_values = - vec![ScalarValue::Utf8(Some("2021-10-26".to_owned()))]; + partitioned_file.partition_values = vec![ScalarValue::from("2021-10-26")]; let avro_exec = AvroExec::new(FileScanConfig { // select specific columns of the files as well as the partitioning @@ -420,7 +413,6 @@ mod tests { limit: None, table_partition_cols: vec![Field::new("date", DataType::Utf8, false)], output_ordering: vec![], - infinite_source: false, }); assert_eq!(avro_exec.output_partitioning().partition_count(), 1); diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 75aa343ffbfc6..b28bc7d566882 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -19,11 +19,10 @@ use std::any::Any; use std::io::{Read, Seek, SeekFrom}; -use std::ops::Range; use std::sync::Arc; use std::task::Poll; -use super::FileScanConfig; +use super::{calculate_range, FileGroupPartitioner, FileScanConfig, RangeCalculation}; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::listing::{FileRange, ListingTableUrl}; use crate::datasource::physical_plan::file_stream::{ @@ -146,10 +145,6 @@ impl ExecutionPlan for CsvExec { Partitioning::UnknownPartitioning(self.base_config.file_groups.len()) } - fn unbounded_output(&self, _: &[bool]) -> Result { - Ok(self.base_config().infinite_source) - } - /// See comments on `impl ExecutionPlan for ParquetExec`: output order can't be fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { self.projected_output_ordering @@ -177,7 +172,7 @@ impl ExecutionPlan for CsvExec { } /// Redistribute files across partitions according to their size - /// See comments on `repartition_file_groups()` for more detail. + /// See comments on [`FileGroupPartitioner`] for more detail. /// /// Return `None` if can't get repartitioned(empty/compressed file). fn repartitioned( @@ -191,11 +186,11 @@ impl ExecutionPlan for CsvExec { return Ok(None); } - let repartitioned_file_groups_option = FileScanConfig::repartition_file_groups( - self.base_config.file_groups.clone(), - target_partitions, - repartition_file_min_size, - ); + let repartitioned_file_groups_option = FileGroupPartitioner::new() + .with_target_partitions(target_partitions) + .with_preserve_order_within_groups(self.output_ordering().is_some()) + .with_repartition_file_min_size(repartition_file_min_size) + .repartition_file_groups(&self.base_config.file_groups); if let Some(repartitioned_file_groups) = repartitioned_file_groups_option { let mut new_plan = self.clone(); @@ -322,47 +317,6 @@ impl CsvOpener { } } -/// Returns the offset of the first newline in the object store range [start, end), or the end offset if no newline is found. -async fn find_first_newline( - object_store: &Arc, - location: &object_store::path::Path, - start_byte: usize, - end_byte: usize, -) -> Result { - let options = GetOptions { - range: Some(Range { - start: start_byte, - end: end_byte, - }), - ..Default::default() - }; - - let r = object_store.get_opts(location, options).await?; - let mut input = r.into_stream(); - - let mut buffered = Bytes::new(); - let mut index = 0; - - loop { - if buffered.is_empty() { - match input.next().await { - Some(Ok(b)) => buffered = b, - Some(Err(e)) => return Err(e.into()), - None => return Ok(index), - }; - } - - for byte in &buffered { - if *byte == b'\n' { - return Ok(index); - } - index += 1; - } - - buffered.advance(buffered.len()); - } -} - impl FileOpener for CsvOpener { /// Open a partitioned CSV file. /// @@ -412,44 +366,20 @@ impl FileOpener for CsvOpener { ); } + let store = self.config.object_store.clone(); + Ok(Box::pin(async move { - let file_size = file_meta.object_meta.size; // Current partition contains bytes [start_byte, end_byte) (might contain incomplete lines at boundaries) - let range = match file_meta.range { - None => None, - Some(FileRange { start, end }) => { - let (start, end) = (start as usize, end as usize); - // Partition byte range is [start, end), the boundary might be in the middle of - // some line. Need to find out the exact line boundaries. - let start_delta = if start != 0 { - find_first_newline( - &config.object_store, - file_meta.location(), - start - 1, - file_size, - ) - .await? - } else { - 0 - }; - let end_delta = if end != file_size { - find_first_newline( - &config.object_store, - file_meta.location(), - end - 1, - file_size, - ) - .await? - } else { - 0 - }; - let range = start + start_delta..end + end_delta; - if range.start == range.end { - return Ok( - futures::stream::poll_fn(move |_| Poll::Ready(None)).boxed() - ); - } - Some(range) + + let calculated_range = calculate_range(&file_meta, &store).await?; + + let range = match calculated_range { + RangeCalculation::Range(None) => None, + RangeCalculation::Range(Some(range)) => Some(range), + RangeCalculation::TerminateEarly => { + return Ok( + futures::stream::poll_fn(move |_| Poll::Ready(None)).boxed() + ) } }; @@ -457,10 +387,8 @@ impl FileOpener for CsvOpener { range, ..Default::default() }; - let result = config - .object_store - .get_opts(file_meta.location(), options) - .await?; + + let result = store.get_opts(file_meta.location(), options).await?; match result.payload { GetResultPayload::File(mut file, _) => { @@ -872,8 +800,7 @@ mod tests { // Add partition columns config.table_partition_cols = vec![Field::new("date", DataType::Utf8, false)]; - config.file_groups[0][0].partition_values = - vec![ScalarValue::Utf8(Some("2021-10-26".to_owned()))]; + config.file_groups[0][0].partition_values = vec![ScalarValue::from("2021-10-26")]; // We should be able to project on the partition column // Which is supposed to be after the file fields diff --git a/datafusion/core/src/datasource/physical_plan/file_groups.rs b/datafusion/core/src/datasource/physical_plan/file_groups.rs new file mode 100644 index 0000000000000..6456bd5c72766 --- /dev/null +++ b/datafusion/core/src/datasource/physical_plan/file_groups.rs @@ -0,0 +1,826 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Logic for managing groups of [`PartitionedFile`]s in DataFusion + +use crate::datasource::listing::{FileRange, PartitionedFile}; +use itertools::Itertools; +use std::cmp::min; +use std::collections::BinaryHeap; +use std::iter::repeat_with; + +/// Repartition input files into `target_partitions` partitions, if total file size exceed +/// `repartition_file_min_size` +/// +/// This partitions evenly by file byte range, and does not have any knowledge +/// of how data is laid out in specific files. The specific `FileOpener` are +/// responsible for the actual partitioning on specific data source type. (e.g. +/// the `CsvOpener` will read lines overlap with byte range as well as +/// handle boundaries to ensure all lines will be read exactly once) +/// +/// # Example +/// +/// For example, if there are two files `A` and `B` that we wish to read with 4 +/// partitions (with 4 threads) they will be divided as follows: +/// +/// ```text +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// ┌─────────────────┐ +/// │ │ │ │ +/// │ File A │ +/// │ │ Range: 0-2MB │ │ +/// │ │ +/// │ └─────────────────┘ │ +/// ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// ┌─────────────────┐ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// │ │ ┌─────────────────┐ +/// │ │ │ │ │ │ +/// │ │ │ File A │ +/// │ │ │ │ Range 2-4MB │ │ +/// │ │ │ │ +/// │ │ │ └─────────────────┘ │ +/// │ File A (7MB) │ ────────▶ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// │ │ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// │ │ ┌─────────────────┐ +/// │ │ │ │ │ │ +/// │ │ │ File A │ +/// │ │ │ │ Range: 4-6MB │ │ +/// │ │ │ │ +/// │ │ │ └─────────────────┘ │ +/// └─────────────────┘ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// ┌─────────────────┐ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// │ File B (1MB) │ ┌─────────────────┐ +/// │ │ │ │ File A │ │ +/// └─────────────────┘ │ Range: 6-7MB │ +/// │ └─────────────────┘ │ +/// ┌─────────────────┐ +/// │ │ File B (1MB) │ │ +/// │ │ +/// │ └─────────────────┘ │ +/// ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// +/// If target_partitions = 4, +/// divides into 4 groups +/// ``` +/// +/// # Maintaining Order +/// +/// Within each group files are read sequentially. Thus, if the overall order of +/// tuples must be preserved, multiple files can not be mixed in the same group. +/// +/// In this case, the code will split the largest files evenly into any +/// available empty groups, but the overall distribution may not not be as even +/// as as even as if the order did not need to be preserved. +/// +/// ```text +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// ┌─────────────────┐ +/// │ │ │ │ +/// │ File A │ +/// │ │ Range: 0-2MB │ │ +/// │ │ +/// ┌─────────────────┐ │ └─────────────────┘ │ +/// │ │ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// │ │ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// │ │ ┌─────────────────┐ +/// │ │ │ │ │ │ +/// │ │ │ File A │ +/// │ │ │ │ Range 2-4MB │ │ +/// │ File A (6MB) │ ────────▶ │ │ +/// │ (ordered) │ │ └─────────────────┘ │ +/// │ │ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// │ │ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// │ │ ┌─────────────────┐ +/// │ │ │ │ │ │ +/// │ │ │ File A │ +/// │ │ │ │ Range: 4-6MB │ │ +/// └─────────────────┘ │ │ +/// ┌─────────────────┐ │ └─────────────────┘ │ +/// │ File B (1MB) │ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// │ (ordered) │ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// └─────────────────┘ ┌─────────────────┐ +/// │ │ File B (1MB) │ │ +/// │ │ +/// │ └─────────────────┘ │ +/// ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// +/// If target_partitions = 4, +/// divides into 4 groups +/// ``` +#[derive(Debug, Clone, Copy)] +pub struct FileGroupPartitioner { + /// how many partitions should be created + target_partitions: usize, + /// the minimum size for a file to be repartitioned. + repartition_file_min_size: usize, + /// if the order when reading the files must be preserved + preserve_order_within_groups: bool, +} + +impl Default for FileGroupPartitioner { + fn default() -> Self { + Self::new() + } +} + +impl FileGroupPartitioner { + /// Creates a new [`FileGroupPartitioner`] with default values: + /// 1. `target_partitions = 1` + /// 2. `repartition_file_min_size = 10MB` + /// 3. `preserve_order_within_groups = false` + pub fn new() -> Self { + Self { + target_partitions: 1, + repartition_file_min_size: 10 * 1024 * 1024, + preserve_order_within_groups: false, + } + } + + /// Set the target partitions + pub fn with_target_partitions(mut self, target_partitions: usize) -> Self { + self.target_partitions = target_partitions; + self + } + + /// Set the minimum size at which to repartition a file + pub fn with_repartition_file_min_size( + mut self, + repartition_file_min_size: usize, + ) -> Self { + self.repartition_file_min_size = repartition_file_min_size; + self + } + + /// Set whether the order of tuples within a file must be preserved + pub fn with_preserve_order_within_groups( + mut self, + preserve_order_within_groups: bool, + ) -> Self { + self.preserve_order_within_groups = preserve_order_within_groups; + self + } + + /// Repartition input files according to the settings on this [`FileGroupPartitioner`]. + /// + /// If no repartitioning is needed or possible, return `None`. + pub fn repartition_file_groups( + &self, + file_groups: &[Vec], + ) -> Option>> { + if file_groups.is_empty() { + return None; + } + + // Perform redistribution only in case all files should be read from beginning to end + let has_ranges = file_groups.iter().flatten().any(|f| f.range.is_some()); + if has_ranges { + return None; + } + + // special case when order must be preserved + if self.preserve_order_within_groups { + self.repartition_preserving_order(file_groups) + } else { + self.repartition_evenly_by_size(file_groups) + } + } + + /// Evenly repartition files across partitions by size, ignoring any + /// existing grouping / ordering + fn repartition_evenly_by_size( + &self, + file_groups: &[Vec], + ) -> Option>> { + let target_partitions = self.target_partitions; + let repartition_file_min_size = self.repartition_file_min_size; + let flattened_files = file_groups.iter().flatten().collect::>(); + + let total_size = flattened_files + .iter() + .map(|f| f.object_meta.size as i64) + .sum::(); + if total_size < (repartition_file_min_size as i64) || total_size == 0 { + return None; + } + + let target_partition_size = + (total_size as usize + (target_partitions) - 1) / (target_partitions); + + let current_partition_index: usize = 0; + let current_partition_size: usize = 0; + + // Partition byte range evenly for all `PartitionedFile`s + let repartitioned_files = flattened_files + .into_iter() + .scan( + (current_partition_index, current_partition_size), + |state, source_file| { + let mut produced_files = vec![]; + let mut range_start = 0; + while range_start < source_file.object_meta.size { + let range_end = min( + range_start + (target_partition_size - state.1), + source_file.object_meta.size, + ); + + let mut produced_file = source_file.clone(); + produced_file.range = Some(FileRange { + start: range_start as i64, + end: range_end as i64, + }); + produced_files.push((state.0, produced_file)); + + if state.1 + (range_end - range_start) >= target_partition_size { + state.0 += 1; + state.1 = 0; + } else { + state.1 += range_end - range_start; + } + range_start = range_end; + } + Some(produced_files) + }, + ) + .flatten() + .group_by(|(partition_idx, _)| *partition_idx) + .into_iter() + .map(|(_, group)| group.map(|(_, vals)| vals).collect_vec()) + .collect_vec(); + + Some(repartitioned_files) + } + + /// Redistribute file groups across size preserving order + fn repartition_preserving_order( + &self, + file_groups: &[Vec], + ) -> Option>> { + // Can't repartition and preserve order if there are more groups + // than partitions + if file_groups.len() >= self.target_partitions { + return None; + } + let num_new_groups = self.target_partitions - file_groups.len(); + + // If there is only a single file + if file_groups.len() == 1 && file_groups[0].len() == 1 { + return self.repartition_evenly_by_size(file_groups); + } + + // Find which files could be split (single file groups) + let mut heap: BinaryHeap<_> = file_groups + .iter() + .enumerate() + .filter_map(|(group_index, group)| { + // ignore groups that do not have exactly 1 file + if group.len() == 1 { + Some(ToRepartition { + source_index: group_index, + file_size: group[0].object_meta.size, + new_groups: vec![group_index], + }) + } else { + None + } + }) + .collect(); + + // No files can be redistributed + if heap.is_empty() { + return None; + } + + // Add new empty groups to which we will redistribute ranges of existing files + let mut file_groups: Vec<_> = file_groups + .iter() + .cloned() + .chain(repeat_with(Vec::new).take(num_new_groups)) + .collect(); + + // Divide up empty groups + for (group_index, group) in file_groups.iter().enumerate() { + if !group.is_empty() { + continue; + } + // Pick the file that has the largest ranges to read so far + let mut largest_group = heap.pop().unwrap(); + largest_group.new_groups.push(group_index); + heap.push(largest_group); + } + + // Distribute files to their newly assigned groups + while let Some(to_repartition) = heap.pop() { + let range_size = to_repartition.range_size() as i64; + let ToRepartition { + source_index, + file_size, + new_groups, + } = to_repartition; + assert_eq!(file_groups[source_index].len(), 1); + let original_file = file_groups[source_index].pop().unwrap(); + + let last_group = new_groups.len() - 1; + let mut range_start: i64 = 0; + let mut range_end: i64 = range_size; + for (i, group_index) in new_groups.into_iter().enumerate() { + let target_group = &mut file_groups[group_index]; + assert!(target_group.is_empty()); + + // adjust last range to include the entire file + if i == last_group { + range_end = file_size as i64; + } + target_group + .push(original_file.clone().with_range(range_start, range_end)); + range_start = range_end; + range_end += range_size; + } + } + + Some(file_groups) + } +} + +/// Tracks how a individual file will be repartitioned +#[derive(Debug, Clone, PartialEq, Eq)] +struct ToRepartition { + /// the index from which the original file will be taken + source_index: usize, + /// the size of the original file + file_size: usize, + /// indexes of which group(s) will this be distributed to (including `source_index`) + new_groups: Vec, +} + +impl ToRepartition { + // how big will each file range be when this file is read in its new groups? + fn range_size(&self) -> usize { + self.file_size / self.new_groups.len() + } +} + +impl PartialOrd for ToRepartition { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +/// Order based on individual range +impl Ord for ToRepartition { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.range_size().cmp(&other.range_size()) + } +} + +#[cfg(test)] +mod test { + use super::*; + + /// Empty file won't get partitioned + #[test] + fn repartition_empty_file_only() { + let partitioned_file_empty = pfile("empty", 0); + let file_group = vec![vec![partitioned_file_empty.clone()]]; + + let partitioned_files = FileGroupPartitioner::new() + .with_target_partitions(4) + .with_repartition_file_min_size(0) + .repartition_file_groups(&file_group); + + assert_partitioned_files(None, partitioned_files); + } + + /// Repartition when there is a empty file in file groups + #[test] + fn repartition_empty_files() { + let pfile_a = pfile("a", 10); + let pfile_b = pfile("b", 10); + let pfile_empty = pfile("empty", 0); + + let empty_first = vec![ + vec![pfile_empty.clone()], + vec![pfile_a.clone()], + vec![pfile_b.clone()], + ]; + let empty_middle = vec![ + vec![pfile_a.clone()], + vec![pfile_empty.clone()], + vec![pfile_b.clone()], + ]; + let empty_last = vec![vec![pfile_a], vec![pfile_b], vec![pfile_empty]]; + + // Repartition file groups into x partitions + let expected_2 = vec![ + vec![pfile("a", 10).with_range(0, 10)], + vec![pfile("b", 10).with_range(0, 10)], + ]; + let expected_3 = vec![ + vec![pfile("a", 10).with_range(0, 7)], + vec![ + pfile("a", 10).with_range(7, 10), + pfile("b", 10).with_range(0, 4), + ], + vec![pfile("b", 10).with_range(4, 10)], + ]; + + let file_groups_tests = [empty_first, empty_middle, empty_last]; + + for fg in file_groups_tests { + let all_expected = [(2, expected_2.clone()), (3, expected_3.clone())]; + for (n_partition, expected) in all_expected { + let actual = FileGroupPartitioner::new() + .with_target_partitions(n_partition) + .with_repartition_file_min_size(10) + .repartition_file_groups(&fg); + + assert_partitioned_files(Some(expected), actual); + } + } + } + + #[test] + fn repartition_single_file() { + // Single file, single partition into multiple partitions + let single_partition = vec![vec![pfile("a", 123)]]; + + let actual = FileGroupPartitioner::new() + .with_target_partitions(4) + .with_repartition_file_min_size(10) + .repartition_file_groups(&single_partition); + + let expected = Some(vec![ + vec![pfile("a", 123).with_range(0, 31)], + vec![pfile("a", 123).with_range(31, 62)], + vec![pfile("a", 123).with_range(62, 93)], + vec![pfile("a", 123).with_range(93, 123)], + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_too_much_partitions() { + // Single file, single partition into 96 partitions + let partitioned_file = pfile("a", 8); + let single_partition = vec![vec![partitioned_file]]; + + let actual = FileGroupPartitioner::new() + .with_target_partitions(96) + .with_repartition_file_min_size(5) + .repartition_file_groups(&single_partition); + + let expected = Some(vec![ + vec![pfile("a", 8).with_range(0, 1)], + vec![pfile("a", 8).with_range(1, 2)], + vec![pfile("a", 8).with_range(2, 3)], + vec![pfile("a", 8).with_range(3, 4)], + vec![pfile("a", 8).with_range(4, 5)], + vec![pfile("a", 8).with_range(5, 6)], + vec![pfile("a", 8).with_range(6, 7)], + vec![pfile("a", 8).with_range(7, 8)], + ]); + + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_multiple_partitions() { + // Multiple files in single partition after redistribution + let source_partitions = vec![vec![pfile("a", 40)], vec![pfile("b", 60)]]; + + let actual = FileGroupPartitioner::new() + .with_target_partitions(3) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + vec![pfile("a", 40).with_range(0, 34)], + vec![ + pfile("a", 40).with_range(34, 40), + pfile("b", 60).with_range(0, 28), + ], + vec![pfile("b", 60).with_range(28, 60)], + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_same_num_partitions() { + // "Rebalance" files across partitions + let source_partitions = vec![vec![pfile("a", 40)], vec![pfile("b", 60)]]; + + let actual = FileGroupPartitioner::new() + .with_target_partitions(2) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + vec![ + pfile("a", 40).with_range(0, 40), + pfile("b", 60).with_range(0, 10), + ], + vec![pfile("b", 60).with_range(10, 60)], + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_no_action_ranges() { + // No action due to Some(range) in second file + let source_partitions = vec![ + vec![pfile("a", 123)], + vec![pfile("b", 144).with_range(1, 50)], + ]; + + let actual = FileGroupPartitioner::new() + .with_target_partitions(65) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + assert_partitioned_files(None, actual) + } + + #[test] + fn repartition_no_action_min_size() { + // No action due to target_partition_size + let single_partition = vec![vec![pfile("a", 123)]]; + + let actual = FileGroupPartitioner::new() + .with_target_partitions(65) + .with_repartition_file_min_size(500) + .repartition_file_groups(&single_partition); + + assert_partitioned_files(None, actual) + } + + #[test] + fn repartition_no_action_zero_files() { + // No action due to no files + let empty_partition = vec![]; + + let partitioner = FileGroupPartitioner::new() + .with_target_partitions(65) + .with_repartition_file_min_size(500); + + assert_partitioned_files(None, repartition_test(partitioner, empty_partition)) + } + + #[test] + fn repartition_ordered_no_action_too_few_partitions() { + // No action as there are no new groups to redistribute to + let input_partitions = vec![vec![pfile("a", 100)], vec![pfile("b", 200)]]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(2) + .with_repartition_file_min_size(10) + .repartition_file_groups(&input_partitions); + + assert_partitioned_files(None, actual) + } + + #[test] + fn repartition_ordered_no_action_file_too_small() { + // No action as there are no new groups to redistribute to + let single_partition = vec![vec![pfile("a", 100)]]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(2) + // file is too small to repartition + .with_repartition_file_min_size(1000) + .repartition_file_groups(&single_partition); + + assert_partitioned_files(None, actual) + } + + #[test] + fn repartition_ordered_one_large_file() { + // "Rebalance" the single large file across partitions + let source_partitions = vec![vec![pfile("a", 100)]]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(3) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + vec![pfile("a", 100).with_range(0, 34)], + vec![pfile("a", 100).with_range(34, 68)], + vec![pfile("a", 100).with_range(68, 100)], + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_ordered_one_large_one_small_file() { + // "Rebalance" the single large file across empty partitions, but can't split + // small file + let source_partitions = vec![vec![pfile("a", 100)], vec![pfile("b", 30)]]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(4) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + // scan first third of "a" + vec![pfile("a", 100).with_range(0, 33)], + // only b in this group (can't do this) + vec![pfile("b", 30).with_range(0, 30)], + // second third of "a" + vec![pfile("a", 100).with_range(33, 66)], + // final third of "a" + vec![pfile("a", 100).with_range(66, 100)], + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_ordered_two_large_files() { + // "Rebalance" two large files across empty partitions, but can't mix them + let source_partitions = vec![vec![pfile("a", 100)], vec![pfile("b", 100)]]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(4) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + // scan first half of "a" + vec![pfile("a", 100).with_range(0, 50)], + // scan first half of "b" + vec![pfile("b", 100).with_range(0, 50)], + // second half of "a" + vec![pfile("a", 100).with_range(50, 100)], + // second half of "b" + vec![pfile("b", 100).with_range(50, 100)], + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_ordered_two_large_one_small_files() { + // "Rebalance" two large files and one small file across empty partitions + let source_partitions = vec![ + vec![pfile("a", 100)], + vec![pfile("b", 100)], + vec![pfile("c", 30)], + ]; + + let partitioner = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_repartition_file_min_size(10); + + // with 4 partitions, can only split the first large file "a" + let actual = partitioner + .with_target_partitions(4) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + // scan first half of "a" + vec![pfile("a", 100).with_range(0, 50)], + // All of "b" + vec![pfile("b", 100).with_range(0, 100)], + // All of "c" + vec![pfile("c", 30).with_range(0, 30)], + // second half of "a" + vec![pfile("a", 100).with_range(50, 100)], + ]); + assert_partitioned_files(expected, actual); + + // With 5 partitions, we can split both "a" and "b", but they can't be intermixed + let actual = partitioner + .with_target_partitions(5) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + // scan first half of "a" + vec![pfile("a", 100).with_range(0, 50)], + // scan first half of "b" + vec![pfile("b", 100).with_range(0, 50)], + // All of "c" + vec![pfile("c", 30).with_range(0, 30)], + // second half of "a" + vec![pfile("a", 100).with_range(50, 100)], + // second half of "b" + vec![pfile("b", 100).with_range(50, 100)], + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_ordered_one_large_one_small_existing_empty() { + // "Rebalance" files using existing empty partition + let source_partitions = + vec![vec![pfile("a", 100)], vec![], vec![pfile("b", 40)], vec![]]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(5) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + // Of the three available groups (2 original empty and 1 new from the + // target partitions), assign two to "a" and one to "b" + let expected = Some(vec![ + // Scan of "a" across three groups + vec![pfile("a", 100).with_range(0, 33)], + vec![pfile("a", 100).with_range(33, 66)], + // scan first half of "b" + vec![pfile("b", 40).with_range(0, 20)], + // final third of "a" + vec![pfile("a", 100).with_range(66, 100)], + // second half of "b" + vec![pfile("b", 40).with_range(20, 40)], + ]); + assert_partitioned_files(expected, actual); + } + #[test] + fn repartition_ordered_existing_group_multiple_files() { + // groups with multiple files in a group can not be changed, but can divide others + let source_partitions = vec![ + // two files in an existing partition + vec![pfile("a", 100), pfile("b", 100)], + vec![pfile("c", 40)], + ]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(3) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + // Of the three available groups (2 original empty and 1 new from the + // target partitions), assign two to "a" and one to "b" + let expected = Some(vec![ + // don't try and rearrange files in the existing partition + // assuming that the caller had a good reason to put them that way. + // (it is technically possible to split off ranges from the files if desired) + vec![pfile("a", 100), pfile("b", 100)], + // first half of "c" + vec![pfile("c", 40).with_range(0, 20)], + // second half of "c" + vec![pfile("c", 40).with_range(20, 40)], + ]); + assert_partitioned_files(expected, actual); + } + + /// Asserts that the two groups of `ParititonedFile` are the same + /// (PartitionedFile doesn't implement PartialEq) + fn assert_partitioned_files( + expected: Option>>, + actual: Option>>, + ) { + match (expected, actual) { + (None, None) => {} + (Some(_), None) => panic!("Expected Some, got None"), + (None, Some(_)) => panic!("Expected None, got Some"), + (Some(expected), Some(actual)) => { + let expected_string = format!("{:#?}", expected); + let actual_string = format!("{:#?}", actual); + assert_eq!(expected_string, actual_string); + } + } + } + + /// returns a partitioned file with the specified path and size + fn pfile(path: impl Into, file_size: u64) -> PartitionedFile { + PartitionedFile::new(path, file_size) + } + + /// repartition the file groups both with and without preserving order + /// asserting they return the same value and returns that value + fn repartition_test( + partitioner: FileGroupPartitioner, + file_groups: Vec>, + ) -> Option>> { + let repartitioned = partitioner.repartition_file_groups(&file_groups); + + let repartitioned_preserving_sort = partitioner + .with_preserve_order_within_groups(true) + .repartition_file_groups(&file_groups); + + assert_partitioned_files( + repartitioned.clone(), + repartitioned_preserving_sort.clone(), + ); + repartitioned + } +} diff --git a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs index 3efb0df9df7cc..516755e4d2939 100644 --- a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs +++ b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs @@ -19,15 +19,11 @@ //! file sources. use std::{ - borrow::Cow, cmp::min, collections::HashMap, fmt::Debug, marker::PhantomData, - sync::Arc, vec, + borrow::Cow, collections::HashMap, fmt::Debug, marker::PhantomData, sync::Arc, vec, }; -use super::get_projected_output_ordering; -use crate::datasource::{ - listing::{FileRange, PartitionedFile}, - object_store::ObjectStoreUrl, -}; +use super::{get_projected_output_ordering, FileGroupPartitioner}; +use crate::datasource::{listing::PartitionedFile, object_store::ObjectStoreUrl}; use crate::{ error::{DataFusionError, Result}, scalar::ScalarValue, @@ -42,7 +38,6 @@ use datafusion_common::stats::Precision; use datafusion_common::{exec_err, ColumnStatistics, Statistics}; use datafusion_physical_expr::LexOrdering; -use itertools::Itertools; use log::warn; /// Convert type to a type suitable for use as a [`ListingTable`] @@ -104,8 +99,6 @@ pub struct FileScanConfig { pub table_partition_cols: Vec, /// All equivalent lexicographical orderings that describe the schema. pub output_ordering: Vec, - /// Indicates whether this plan may produce an infinite stream of records. - pub infinite_source: bool, } impl FileScanConfig { @@ -176,79 +169,17 @@ impl FileScanConfig { }) } - /// Repartition all input files into `target_partitions` partitions, if total file size exceed - /// `repartition_file_min_size` - /// `target_partitions` and `repartition_file_min_size` directly come from configuration. - /// - /// This function only try to partition file byte range evenly, and let specific `FileOpener` to - /// do actual partition on specific data source type. (e.g. `CsvOpener` will only read lines - /// overlap with byte range but also handle boundaries to ensure all lines will be read exactly once) + #[allow(missing_docs)] + #[deprecated(since = "33.0.0", note = "Use SessionContext::new_with_config")] pub fn repartition_file_groups( file_groups: Vec>, target_partitions: usize, repartition_file_min_size: usize, ) -> Option>> { - let flattened_files = file_groups.iter().flatten().collect::>(); - - // Perform redistribution only in case all files should be read from beginning to end - let has_ranges = flattened_files.iter().any(|f| f.range.is_some()); - if has_ranges { - return None; - } - - let total_size = flattened_files - .iter() - .map(|f| f.object_meta.size as i64) - .sum::(); - if total_size < (repartition_file_min_size as i64) || total_size == 0 { - return None; - } - - let target_partition_size = - (total_size as usize + (target_partitions) - 1) / (target_partitions); - - let current_partition_index: usize = 0; - let current_partition_size: usize = 0; - - // Partition byte range evenly for all `PartitionedFile`s - let repartitioned_files = flattened_files - .into_iter() - .scan( - (current_partition_index, current_partition_size), - |state, source_file| { - let mut produced_files = vec![]; - let mut range_start = 0; - while range_start < source_file.object_meta.size { - let range_end = min( - range_start + (target_partition_size - state.1), - source_file.object_meta.size, - ); - - let mut produced_file = source_file.clone(); - produced_file.range = Some(FileRange { - start: range_start as i64, - end: range_end as i64, - }); - produced_files.push((state.0, produced_file)); - - if state.1 + (range_end - range_start) >= target_partition_size { - state.0 += 1; - state.1 = 0; - } else { - state.1 += range_end - range_start; - } - range_start = range_end; - } - Some(produced_files) - }, - ) - .flatten() - .group_by(|(partition_idx, _)| *partition_idx) - .into_iter() - .map(|(_, group)| group.map(|(_, vals)| vals).collect_vec()) - .collect_vec(); - - Some(repartitioned_files) + FileGroupPartitioner::new() + .with_target_partitions(target_partitions) + .with_repartition_file_min_size(repartition_file_min_size) + .repartition_file_groups(&file_groups) } } @@ -336,7 +267,7 @@ impl PartitionColumnProjector { &mut self.key_buffer_cache, partition_value.as_ref(), file_batch.num_rows(), - ), + )?, ) } @@ -396,11 +327,11 @@ fn create_dict_array( dict_val: &ScalarValue, len: usize, data_type: DataType, -) -> ArrayRef +) -> Result where T: ArrowNativeType, { - let dict_vals = dict_val.to_array(); + let dict_vals = dict_val.to_array()?; let sliced_key_buffer = buffer_gen.get_buffer(len); @@ -409,16 +340,16 @@ where .len(len) .add_buffer(sliced_key_buffer); builder = builder.add_child_data(dict_vals.to_data()); - Arc::new(DictionaryArray::::from( + Ok(Arc::new(DictionaryArray::::from( builder.build().unwrap(), - )) + ))) } fn create_output_array( key_buffer_cache: &mut ZeroBufferGenerators, val: &ScalarValue, len: usize, -) -> ArrayRef { +) -> Result { if let ScalarValue::Dictionary(key_type, dict_val) = &val { match key_type.as_ref() { DataType::Int8 => { @@ -654,15 +585,9 @@ mod tests { // file_batch is ok here because we kept all the file cols in the projection file_batch, &[ - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "2021".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "10".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "26".to_owned(), - ))), + wrap_partition_value_in_dict(ScalarValue::from("2021")), + wrap_partition_value_in_dict(ScalarValue::from("10")), + wrap_partition_value_in_dict(ScalarValue::from("26")), ], ) .expect("Projection of partition columns into record batch failed"); @@ -688,15 +613,9 @@ mod tests { // file_batch is ok here because we kept all the file cols in the projection file_batch, &[ - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "2021".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "10".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "27".to_owned(), - ))), + wrap_partition_value_in_dict(ScalarValue::from("2021")), + wrap_partition_value_in_dict(ScalarValue::from("10")), + wrap_partition_value_in_dict(ScalarValue::from("27")), ], ) .expect("Projection of partition columns into record batch failed"); @@ -724,15 +643,9 @@ mod tests { // file_batch is ok here because we kept all the file cols in the projection file_batch, &[ - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "2021".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "10".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "28".to_owned(), - ))), + wrap_partition_value_in_dict(ScalarValue::from("2021")), + wrap_partition_value_in_dict(ScalarValue::from("10")), + wrap_partition_value_in_dict(ScalarValue::from("28")), ], ) .expect("Projection of partition columns into record batch failed"); @@ -758,9 +671,9 @@ mod tests { // file_batch is ok here because we kept all the file cols in the projection file_batch, &[ - ScalarValue::Utf8(Some("2021".to_owned())), - ScalarValue::Utf8(Some("10".to_owned())), - ScalarValue::Utf8(Some("26".to_owned())), + ScalarValue::from("2021"), + ScalarValue::from("10"), + ScalarValue::from("26"), ], ) .expect("Projection of partition columns into record batch failed"); @@ -792,7 +705,6 @@ mod tests { statistics, table_partition_cols, output_ordering: vec![], - infinite_source: false, } } diff --git a/datafusion/core/src/datasource/physical_plan/file_stream.rs b/datafusion/core/src/datasource/physical_plan/file_stream.rs index a715f6e8e3cde..bb4c8313642cd 100644 --- a/datafusion/core/src/datasource/physical_plan/file_stream.rs +++ b/datafusion/core/src/datasource/physical_plan/file_stream.rs @@ -518,10 +518,8 @@ impl RecordBatchStream for FileStream { #[cfg(test)] mod tests { - use arrow_schema::Schema; - use datafusion_common::internal_err; - use datafusion_common::DataFusionError; - use datafusion_common::Statistics; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; use super::*; use crate::datasource::file_format::write::BatchSerializer; @@ -534,8 +532,8 @@ mod tests { test::{make_partition, object_store::register_test_store}, }; - use std::sync::atomic::{AtomicUsize, Ordering}; - use std::sync::Arc; + use arrow_schema::Schema; + use datafusion_common::{internal_err, DataFusionError, Statistics}; use async_trait::async_trait; use bytes::Bytes; @@ -667,7 +665,6 @@ mod tests { limit: self.limit, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }; let metrics_set = ExecutionPlanMetricsSet::new(); let file_stream = FileStream::new(&config, 0, self.opener, &metrics_set) @@ -994,7 +991,7 @@ mod tests { #[async_trait] impl BatchSerializer for TestSerializer { - async fn serialize(&mut self, _batch: RecordBatch) -> Result { + async fn serialize(&self, _batch: RecordBatch, _initial: bool) -> Result { Ok(self.bytes.clone()) } } diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index 73dcb32ac81f7..529632dab85a8 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -18,11 +18,11 @@ //! Execution plan for reading line-delimited JSON files use std::any::Any; -use std::io::BufReader; +use std::io::{BufReader, Read, Seek, SeekFrom}; use std::sync::Arc; use std::task::Poll; -use super::FileScanConfig; +use super::{calculate_range, FileGroupPartitioner, FileScanConfig, RangeCalculation}; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::listing::ListingTableUrl; use crate::datasource::physical_plan::file_stream::{ @@ -43,8 +43,8 @@ use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; use bytes::{Buf, Bytes}; -use futures::{ready, stream, StreamExt, TryStreamExt}; -use object_store; +use futures::{ready, StreamExt, TryStreamExt}; +use object_store::{self, GetOptions}; use object_store::{GetResultPayload, ObjectStore}; use tokio::io::AsyncWriteExt; use tokio::task::JoinSet; @@ -110,10 +110,6 @@ impl ExecutionPlan for NdJsonExec { Partitioning::UnknownPartitioning(self.base_config.file_groups.len()) } - fn unbounded_output(&self, _: &[bool]) -> Result { - Ok(self.base_config.infinite_source) - } - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { self.projected_output_ordering .first() @@ -138,6 +134,30 @@ impl ExecutionPlan for NdJsonExec { Ok(self) } + fn repartitioned( + &self, + target_partitions: usize, + config: &datafusion_common::config::ConfigOptions, + ) -> Result>> { + let repartition_file_min_size = config.optimizer.repartition_file_min_size; + let preserve_order_within_groups = self.output_ordering().is_some(); + let file_groups = &self.base_config.file_groups; + + let repartitioned_file_groups_option = FileGroupPartitioner::new() + .with_target_partitions(target_partitions) + .with_preserve_order_within_groups(preserve_order_within_groups) + .with_repartition_file_min_size(repartition_file_min_size) + .repartition_file_groups(file_groups); + + if let Some(repartitioned_file_groups) = repartitioned_file_groups_option { + let mut new_plan = self.clone(); + new_plan.base_config.file_groups = repartitioned_file_groups; + return Ok(Some(Arc::new(new_plan))); + } + + Ok(None) + } + fn execute( &self, partition: usize, @@ -197,54 +217,89 @@ impl JsonOpener { } impl FileOpener for JsonOpener { + /// Open a partitioned NDJSON file. + /// + /// If `file_meta.range` is `None`, the entire file is opened. + /// Else `file_meta.range` is `Some(FileRange{start, end})`, which corresponds to the byte range [start, end) within the file. + /// + /// Note: `start` or `end` might be in the middle of some lines. In such cases, the following rules + /// are applied to determine which lines to read: + /// 1. The first line of the partition is the line in which the index of the first character >= `start`. + /// 2. The last line of the partition is the line in which the byte at position `end - 1` resides. + /// + /// See [`CsvOpener`](super::CsvOpener) for an example. fn open(&self, file_meta: FileMeta) -> Result { let store = self.object_store.clone(); let schema = self.projected_schema.clone(); let batch_size = self.batch_size; - let file_compression_type = self.file_compression_type.to_owned(); + Ok(Box::pin(async move { - let r = store.get(file_meta.location()).await?; - match r.payload { - GetResultPayload::File(file, _) => { - let bytes = file_compression_type.convert_read(file)?; + let calculated_range = calculate_range(&file_meta, &store).await?; + + let range = match calculated_range { + RangeCalculation::Range(None) => None, + RangeCalculation::Range(Some(range)) => Some(range), + RangeCalculation::TerminateEarly => { + return Ok( + futures::stream::poll_fn(move |_| Poll::Ready(None)).boxed() + ) + } + }; + + let options = GetOptions { + range, + ..Default::default() + }; + + let result = store.get_opts(file_meta.location(), options).await?; + + match result.payload { + GetResultPayload::File(mut file, _) => { + let bytes = match file_meta.range { + None => file_compression_type.convert_read(file)?, + Some(_) => { + file.seek(SeekFrom::Start(result.range.start as _))?; + let limit = result.range.end - result.range.start; + file_compression_type.convert_read(file.take(limit as u64))? + } + }; + let reader = ReaderBuilder::new(schema) .with_batch_size(batch_size) .build(BufReader::new(bytes))?; + Ok(futures::stream::iter(reader).boxed()) } GetResultPayload::Stream(s) => { + let s = s.map_err(DataFusionError::from); + let mut decoder = ReaderBuilder::new(schema) .with_batch_size(batch_size) .build_decoder()?; - - let s = s.map_err(DataFusionError::from); let mut input = file_compression_type.convert_stream(s.boxed())?.fuse(); - let mut buffered = Bytes::new(); + let mut buffer = Bytes::new(); - let s = stream::poll_fn(move |cx| { + let s = futures::stream::poll_fn(move |cx| { loop { - if buffered.is_empty() { - buffered = match ready!(input.poll_next_unpin(cx)) { - Some(Ok(b)) => b, + if buffer.is_empty() { + match ready!(input.poll_next_unpin(cx)) { + Some(Ok(b)) => buffer = b, Some(Err(e)) => { return Poll::Ready(Some(Err(e.into()))) } - None => break, + None => {} }; } - let read = buffered.len(); - let decoded = match decoder.decode(buffered.as_ref()) { + let decoded = match decoder.decode(buffer.as_ref()) { + Ok(0) => break, Ok(decoded) => decoded, Err(e) => return Poll::Ready(Some(Err(e))), }; - buffered.advance(decoded); - if decoded != read { - break; - } + buffer.advance(decoded); } Poll::Ready(decoder.flush().transpose()) @@ -357,9 +412,9 @@ mod tests { ) .unwrap(); let meta = file_groups - .get(0) + .first() .unwrap() - .get(0) + .first() .unwrap() .clone() .object_meta; @@ -391,9 +446,9 @@ mod tests { ) .unwrap(); let path = file_groups - .get(0) + .first() .unwrap() - .get(0) + .first() .unwrap() .object_meta .location @@ -462,7 +517,6 @@ mod tests { limit: Some(3), table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, file_compression_type.to_owned(), ); @@ -541,7 +595,6 @@ mod tests { limit: Some(3), table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, file_compression_type.to_owned(), ); @@ -589,7 +642,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, file_compression_type.to_owned(), ); @@ -642,7 +694,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, file_compression_type.to_owned(), ); diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index ea0a9698ff5ca..d7be017a18682 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -20,11 +20,14 @@ mod arrow_file; mod avro; mod csv; +mod file_groups; mod file_scan_config; mod file_stream; mod json; #[cfg(feature = "parquet")] pub mod parquet; +pub use file_groups::FileGroupPartitioner; +use futures::StreamExt; pub(crate) use self::csv::plan_to_csv; pub use self::csv::{CsvConfig, CsvExec, CsvOpener}; @@ -43,16 +46,14 @@ pub use json::{JsonOpener, NdJsonExec}; use std::{ fmt::{Debug, Formatter, Result as FmtResult}, + ops::Range, sync::Arc, vec, }; use super::listing::ListingTableUrl; use crate::error::{DataFusionError, Result}; -use crate::{ - datasource::file_format::write::FileWriterMode, - physical_plan::{DisplayAs, DisplayFormatType}, -}; +use crate::physical_plan::{DisplayAs, DisplayFormatType}; use crate::{ datasource::{ listing::{FileRange, PartitionedFile}, @@ -73,8 +74,8 @@ use datafusion_physical_expr::PhysicalSortExpr; use datafusion_physical_plan::ExecutionPlan; use log::debug; -use object_store::path::Path; use object_store::ObjectMeta; +use object_store::{path::Path, GetOptions, ObjectStore}; /// The base configurations to provide when creating a physical plan for /// writing to any given file format. @@ -90,14 +91,10 @@ pub struct FileSinkConfig { /// A vector of column names and their corresponding data types, /// representing the partitioning columns for the file pub table_partition_cols: Vec<(String, DataType)>, - /// A writer mode that determines how data is written to the file - pub writer_mode: FileWriterMode, /// If true, it is assumed there is a single table_path which is a file to which all data should be written /// regardless of input partitioning. Otherwise, each table path is assumed to be a directory /// to which each output partition is written to its own output file. pub single_file_output: bool, - /// If input is unbounded, tokio tasks need to yield to not block execution forever - pub unbounded_input: bool, /// Controls whether existing data should be overwritten by this sink pub overwrite: bool, /// Contains settings specific to writing a given FileType, e.g. parquet max_row_group_size @@ -136,13 +133,24 @@ impl DisplayAs for FileScanConfig { write!(f, ", limit={limit}")?; } - if self.infinite_source { - write!(f, ", infinite_source=true")?; - } - if let Some(ordering) = orderings.first() { if !ordering.is_empty() { - write!(f, ", output_ordering={}", OutputOrderingDisplay(ordering))?; + let start = if orderings.len() == 1 { + ", output_ordering=" + } else { + ", output_orderings=[" + }; + write!(f, "{}", start)?; + for (idx, ordering) in + orderings.iter().enumerate().filter(|(_, o)| !o.is_empty()) + { + match idx { + 0 => write!(f, "{}", OutputOrderingDisplay(ordering))?, + _ => write!(f, ", {}", OutputOrderingDisplay(ordering))?, + } + } + let end = if orderings.len() == 1 { "" } else { "]" }; + write!(f, "{}", end)?; } } @@ -502,9 +510,9 @@ fn get_projected_output_ordering( all_orderings } -// Get output (un)boundedness information for the given `plan`. -pub(crate) fn is_plan_streaming(plan: &Arc) -> Result { - let result = if plan.children().is_empty() { +/// Get output (un)boundedness information for the given `plan`. +pub fn is_plan_streaming(plan: &Arc) -> Result { + if plan.children().is_empty() { plan.unbounded_output(&[]) } else { let children_unbounded_output = plan @@ -513,8 +521,110 @@ pub(crate) fn is_plan_streaming(plan: &Arc) -> Result { .map(is_plan_streaming) .collect::>>(); plan.unbounded_output(&children_unbounded_output?) + } +} + +/// Represents the possible outcomes of a range calculation. +/// +/// This enum is used to encapsulate the result of calculating the range of +/// bytes to read from an object (like a file) in an object store. +/// +/// Variants: +/// - `Range(Option>)`: +/// Represents a range of bytes to be read. It contains an `Option` wrapping a +/// `Range`. `None` signifies that the entire object should be read, +/// while `Some(range)` specifies the exact byte range to read. +/// - `TerminateEarly`: +/// Indicates that the range calculation determined no further action is +/// necessary, possibly because the calculated range is empty or invalid. +enum RangeCalculation { + Range(Option>), + TerminateEarly, +} + +/// Calculates an appropriate byte range for reading from an object based on the +/// provided metadata. +/// +/// This asynchronous function examines the `FileMeta` of an object in an object store +/// and determines the range of bytes to be read. The range calculation may adjust +/// the start and end points to align with meaningful data boundaries (like newlines). +/// +/// Returns a `Result` wrapping a `RangeCalculation`, which is either a calculated byte range or an indication to terminate early. +/// +/// Returns an `Error` if any part of the range calculation fails, such as issues in reading from the object store or invalid range boundaries. +async fn calculate_range( + file_meta: &FileMeta, + store: &Arc, +) -> Result { + let location = file_meta.location(); + let file_size = file_meta.object_meta.size; + + match file_meta.range { + None => Ok(RangeCalculation::Range(None)), + Some(FileRange { start, end }) => { + let (start, end) = (start as usize, end as usize); + + let start_delta = if start != 0 { + find_first_newline(store, location, start - 1, file_size).await? + } else { + 0 + }; + + let end_delta = if end != file_size { + find_first_newline(store, location, end - 1, file_size).await? + } else { + 0 + }; + + let range = start + start_delta..end + end_delta; + + if range.start == range.end { + return Ok(RangeCalculation::TerminateEarly); + } + + Ok(RangeCalculation::Range(Some(range))) + } + } +} + +/// Asynchronously finds the position of the first newline character in a specified byte range +/// within an object, such as a file, in an object store. +/// +/// This function scans the contents of the object starting from the specified `start` position +/// up to the `end` position, looking for the first occurrence of a newline (`'\n'`) character. +/// It returns the position of the first newline relative to the start of the range. +/// +/// Returns a `Result` wrapping a `usize` that represents the position of the first newline character found within the specified range. If no newline is found, it returns the length of the scanned data, effectively indicating the end of the range. +/// +/// The function returns an `Error` if any issues arise while reading from the object store or processing the data stream. +/// +async fn find_first_newline( + object_store: &Arc, + location: &Path, + start: usize, + end: usize, +) -> Result { + let range = Some(Range { start, end }); + + let options = GetOptions { + range, + ..Default::default() }; - result + + let result = object_store.get_opts(location, options).await?; + let mut result_stream = result.into_stream(); + + let mut index = 0; + + while let Some(chunk) = result_stream.next().await.transpose()? { + if let Some(position) = chunk.iter().position(|&byte| byte == b'\n') { + return Ok(index + position); + } + + index += chunk.len(); + } + + Ok(index) } #[cfg(test)] @@ -527,7 +637,6 @@ mod tests { }; use arrow_schema::Field; use chrono::Utc; - use datafusion_common::config::ConfigOptions; use crate::physical_plan::{DefaultDisplay, VerboseDisplay}; @@ -789,6 +898,7 @@ mod tests { last_modified: Utc::now(), size: 42, e_tag: None, + version: None, }; PartitionedFile { @@ -798,345 +908,4 @@ mod tests { extensions: None, } } - - /// Unit tests for `repartition_file_groups()` - #[cfg(feature = "parquet")] - mod repartition_file_groups_test { - use datafusion_common::Statistics; - use itertools::Itertools; - - use super::*; - - /// Empty file won't get partitioned - #[tokio::test] - async fn repartition_empty_file_only() { - let partitioned_file_empty = PartitionedFile::new("empty".to_string(), 0); - let file_group = vec![vec![partitioned_file_empty]]; - - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: file_group, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::new_unknown(&Schema::empty()), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let partitioned_file = repartition_with_size(&parquet_exec, 4, 0); - - assert!(partitioned_file[0][0].range.is_none()); - } - - // Repartition when there is a empty file in file groups - #[tokio::test] - async fn repartition_empty_files() { - let partitioned_file_a = PartitionedFile::new("a".to_string(), 10); - let partitioned_file_b = PartitionedFile::new("b".to_string(), 10); - let partitioned_file_empty = PartitionedFile::new("empty".to_string(), 0); - - let empty_first = vec![ - vec![partitioned_file_empty.clone()], - vec![partitioned_file_a.clone()], - vec![partitioned_file_b.clone()], - ]; - let empty_middle = vec![ - vec![partitioned_file_a.clone()], - vec![partitioned_file_empty.clone()], - vec![partitioned_file_b.clone()], - ]; - let empty_last = vec![ - vec![partitioned_file_a], - vec![partitioned_file_b], - vec![partitioned_file_empty], - ]; - - // Repartition file groups into x partitions - let expected_2 = - vec![(0, "a".to_string(), 0, 10), (1, "b".to_string(), 0, 10)]; - let expected_3 = vec![ - (0, "a".to_string(), 0, 7), - (1, "a".to_string(), 7, 10), - (1, "b".to_string(), 0, 4), - (2, "b".to_string(), 4, 10), - ]; - - //let file_groups_testset = [empty_first, empty_middle, empty_last]; - let file_groups_testset = [empty_first, empty_middle, empty_last]; - - for fg in file_groups_testset { - for (n_partition, expected) in [(2, &expected_2), (3, &expected_3)] { - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: fg.clone(), - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::new_unknown(&Arc::new( - Schema::empty(), - )), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = - repartition_with_size_to_vec(&parquet_exec, n_partition, 10); - - assert_eq!(expected, &actual); - } - } - } - - #[tokio::test] - async fn repartition_single_file() { - // Single file, single partition into multiple partitions - let partitioned_file = PartitionedFile::new("a".to_string(), 123); - let single_partition = vec![vec![partitioned_file]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: single_partition, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::new_unknown(&Schema::empty()), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = repartition_with_size_to_vec(&parquet_exec, 4, 10); - let expected = vec![ - (0, "a".to_string(), 0, 31), - (1, "a".to_string(), 31, 62), - (2, "a".to_string(), 62, 93), - (3, "a".to_string(), 93, 123), - ]; - assert_eq!(expected, actual); - } - - #[tokio::test] - async fn repartition_too_much_partitions() { - // Single file, single parittion into 96 partitions - let partitioned_file = PartitionedFile::new("a".to_string(), 8); - let single_partition = vec![vec![partitioned_file]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: single_partition, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::new_unknown(&Schema::empty()), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = repartition_with_size_to_vec(&parquet_exec, 96, 5); - let expected = vec![ - (0, "a".to_string(), 0, 1), - (1, "a".to_string(), 1, 2), - (2, "a".to_string(), 2, 3), - (3, "a".to_string(), 3, 4), - (4, "a".to_string(), 4, 5), - (5, "a".to_string(), 5, 6), - (6, "a".to_string(), 6, 7), - (7, "a".to_string(), 7, 8), - ]; - assert_eq!(expected, actual); - } - - #[tokio::test] - async fn repartition_multiple_partitions() { - // Multiple files in single partition after redistribution - let partitioned_file_1 = PartitionedFile::new("a".to_string(), 40); - let partitioned_file_2 = PartitionedFile::new("b".to_string(), 60); - let source_partitions = - vec![vec![partitioned_file_1], vec![partitioned_file_2]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: source_partitions, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::new_unknown(&Schema::empty()), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = repartition_with_size_to_vec(&parquet_exec, 3, 10); - let expected = vec![ - (0, "a".to_string(), 0, 34), - (1, "a".to_string(), 34, 40), - (1, "b".to_string(), 0, 28), - (2, "b".to_string(), 28, 60), - ]; - assert_eq!(expected, actual); - } - - #[tokio::test] - async fn repartition_same_num_partitions() { - // "Rebalance" files across partitions - let partitioned_file_1 = PartitionedFile::new("a".to_string(), 40); - let partitioned_file_2 = PartitionedFile::new("b".to_string(), 60); - let source_partitions = - vec![vec![partitioned_file_1], vec![partitioned_file_2]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: source_partitions, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::new_unknown(&Schema::empty()), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = repartition_with_size_to_vec(&parquet_exec, 2, 10); - let expected = vec![ - (0, "a".to_string(), 0, 40), - (0, "b".to_string(), 0, 10), - (1, "b".to_string(), 10, 60), - ]; - assert_eq!(expected, actual); - } - - #[tokio::test] - async fn repartition_no_action_ranges() { - // No action due to Some(range) in second file - let partitioned_file_1 = PartitionedFile::new("a".to_string(), 123); - let mut partitioned_file_2 = PartitionedFile::new("b".to_string(), 144); - partitioned_file_2.range = Some(FileRange { start: 1, end: 50 }); - - let source_partitions = - vec![vec![partitioned_file_1], vec![partitioned_file_2]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: source_partitions, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::new_unknown(&Schema::empty()), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = repartition_with_size(&parquet_exec, 65, 10); - assert_eq!(2, actual.len()); - } - - #[tokio::test] - async fn repartition_no_action_min_size() { - // No action due to target_partition_size - let partitioned_file = PartitionedFile::new("a".to_string(), 123); - let single_partition = vec![vec![partitioned_file]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: single_partition, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::new_unknown(&Schema::empty()), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = repartition_with_size(&parquet_exec, 65, 500); - assert_eq!(1, actual.len()); - } - - /// Calls `ParquetExec.repartitioned` with the specified - /// `target_partitions` and `repartition_file_min_size`, returning the - /// resulting `PartitionedFile`s - fn repartition_with_size( - parquet_exec: &ParquetExec, - target_partitions: usize, - repartition_file_min_size: usize, - ) -> Vec> { - let mut config = ConfigOptions::new(); - config.optimizer.repartition_file_min_size = repartition_file_min_size; - - parquet_exec - .repartitioned(target_partitions, &config) - .unwrap() // unwrap Result - .unwrap() // unwrap Option - .as_any() - .downcast_ref::() - .unwrap() - .base_config() - .file_groups - .clone() - } - - /// Calls `repartition_with_size` and returns a tuple for each output `PartitionedFile`: - /// - /// `(partition index, file path, start, end)` - fn repartition_with_size_to_vec( - parquet_exec: &ParquetExec, - target_partitions: usize, - repartition_file_min_size: usize, - ) -> Vec<(usize, String, i64, i64)> { - let file_groups = repartition_with_size( - parquet_exec, - target_partitions, - repartition_file_min_size, - ); - - file_groups - .iter() - .enumerate() - .flat_map(|(part_idx, files)| { - files - .iter() - .map(|f| { - ( - part_idx, - f.object_meta.location.to_string(), - f.range.as_ref().unwrap().start, - f.range.as_ref().unwrap().end, - ) - }) - .collect_vec() - }) - .collect_vec() - } - } } diff --git a/datafusion/core/src/datasource/physical_plan/parquet.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs similarity index 91% rename from datafusion/core/src/datasource/physical_plan/parquet.rs rename to datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 960b2ec7337de..9d81d8d083c28 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -26,8 +26,8 @@ use crate::datasource::physical_plan::file_stream::{ FileOpenFuture, FileOpener, FileStream, }; use crate::datasource::physical_plan::{ - parquet::page_filter::PagePruningPredicate, DisplayAs, FileMeta, FileScanConfig, - SchemaAdapter, + parquet::page_filter::PagePruningPredicate, DisplayAs, FileGroupPartitioner, + FileMeta, FileScanConfig, SchemaAdapter, }; use crate::{ config::ConfigOptions, @@ -66,6 +66,7 @@ mod metrics; pub mod page_filter; mod row_filter; mod row_groups; +mod statistics; pub use metrics::ParquetFileMetrics; @@ -329,18 +330,18 @@ impl ExecutionPlan for ParquetExec { } /// Redistribute files across partitions according to their size - /// See comments on `get_file_groups_repartitioned()` for more detail. + /// See comments on [`FileGroupPartitioner`] for more detail. fn repartitioned( &self, target_partitions: usize, config: &ConfigOptions, ) -> Result>> { let repartition_file_min_size = config.optimizer.repartition_file_min_size; - let repartitioned_file_groups_option = FileScanConfig::repartition_file_groups( - self.base_config.file_groups.clone(), - target_partitions, - repartition_file_min_size, - ); + let repartitioned_file_groups_option = FileGroupPartitioner::new() + .with_target_partitions(target_partitions) + .with_repartition_file_min_size(repartition_file_min_size) + .with_preserve_order_within_groups(self.output_ordering().is_some()) + .repartition_file_groups(&self.base_config.file_groups); let mut new_plan = self.clone(); if let Some(repartitioned_file_groups) = repartitioned_file_groups_option { @@ -467,8 +468,10 @@ impl FileOpener for ParquetOpener { ParquetRecordBatchStreamBuilder::new_with_options(reader, options) .await?; + let file_schema = builder.schema().clone(); + let (schema_mapping, adapted_projections) = - schema_adapter.map_schema(builder.schema())?; + schema_adapter.map_schema(&file_schema)?; // let predicate = predicate.map(|p| reassign_predicate_columns(p, builder.schema(), true)).transpose()?; let mask = ProjectionMask::roots( @@ -480,8 +483,8 @@ impl FileOpener for ParquetOpener { if let Some(predicate) = pushdown_filters.then_some(predicate).flatten() { let row_filter = row_filter::build_row_filter( &predicate, - builder.schema().as_ref(), - table_schema.as_ref(), + &file_schema, + &table_schema, builder.metadata(), reorder_predicates, &file_metrics, @@ -506,6 +509,8 @@ impl FileOpener for ParquetOpener { let file_metadata = builder.metadata().clone(); let predicate = pruning_predicate.as_ref().map(|p| p.as_ref()); let mut row_groups = row_groups::prune_row_groups_by_statistics( + &file_schema, + builder.parquet_schema(), file_metadata.row_groups(), file_range, predicate, @@ -517,6 +522,7 @@ impl FileOpener for ParquetOpener { if enable_bloom_filter && !row_groups.is_empty() { if let Some(predicate) = predicate { row_groups = row_groups::prune_row_groups_by_bloom_filters( + &file_schema, &mut builder, &row_groups, file_metadata.row_groups(), @@ -718,28 +724,6 @@ pub async fn plan_to_parquet( Ok(()) } -// Copy from the arrow-rs -// https://github.com/apache/arrow-rs/blob/733b7e7fd1e8c43a404c3ce40ecf741d493c21b4/parquet/src/arrow/buffer/bit_util.rs#L55 -// Convert the byte slice to fixed length byte array with the length of 16 -fn sign_extend_be(b: &[u8]) -> [u8; 16] { - assert!(b.len() <= 16, "Array too large, expected less than 16"); - let is_negative = (b[0] & 128u8) == 128u8; - let mut result = if is_negative { [255u8; 16] } else { [0u8; 16] }; - for (d, s) in result.iter_mut().skip(16 - b.len()).zip(b) { - *d = *s; - } - result -} - -// Convert the bytes array to i128. -// The endian of the input bytes array must be big-endian. -pub(crate) fn from_bytes_to_i128(b: &[u8]) -> i128 { - // The bytes array are from parquet file and must be the big-endian. - // The endian is defined by parquet format, and the reference document - // https://github.com/apache/parquet-format/blob/54e53e5d7794d383529dd30746378f19a12afd58/src/main/thrift/parquet.thrift#L66 - i128::from_be_bytes(sign_extend_be(b)) -} - // Convert parquet column schema to arrow data type, and just consider the // decimal data type. pub(crate) fn parquet_to_arrow_decimal_type( @@ -769,7 +753,7 @@ mod tests { use crate::datasource::file_format::options::CsvReadOptions; use crate::datasource::file_format::parquet::test_util::store_parquet; use crate::datasource::file_format::test_util::scan_format; - use crate::datasource::listing::{FileRange, PartitionedFile}; + use crate::datasource::listing::{FileRange, ListingOptions, PartitionedFile}; use crate::datasource::object_store::ObjectStoreUrl; use crate::execution::context::SessionState; use crate::physical_plan::displayable; @@ -789,8 +773,8 @@ mod tests { }; use arrow_array::Date64Array; use chrono::{TimeZone, Utc}; - use datafusion_common::ScalarValue; use datafusion_common::{assert_contains, ToDFSchema}; + use datafusion_common::{FileType, GetExt, ScalarValue}; use datafusion_expr::{col, lit, when, Expr}; use datafusion_physical_expr::create_physical_expr; use datafusion_physical_expr::execution_props::ExecutionProps; @@ -899,7 +883,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, predicate, None, @@ -1556,7 +1539,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, None, @@ -1623,11 +1605,11 @@ mod tests { let partitioned_file = PartitionedFile { object_meta: meta, partition_values: vec![ - ScalarValue::Utf8(Some("2021".to_owned())), + ScalarValue::from("2021"), ScalarValue::UInt8(Some(10)), ScalarValue::Dictionary( Box::new(DataType::UInt16), - Box::new(ScalarValue::Utf8(Some("26".to_owned()))), + Box::new(ScalarValue::from("26")), ), ], range: None, @@ -1671,7 +1653,6 @@ mod tests { ), ], output_ordering: vec![], - infinite_source: false, }, None, None, @@ -1718,6 +1699,7 @@ mod tests { last_modified: Utc.timestamp_nanos(0), size: 1337, e_tag: None, + version: None, }, partition_values: vec![], range: None, @@ -1734,7 +1716,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, None, @@ -1787,8 +1768,9 @@ mod tests { ); } - #[tokio::test] - async fn parquet_exec_metrics() { + /// Returns a string array with contents: + /// "[Foo, null, bar, bar, bar, bar, zzz]" + fn string_batch() -> RecordBatch { let c1: ArrayRef = Arc::new(StringArray::from(vec![ Some("Foo"), None, @@ -1800,9 +1782,15 @@ mod tests { ])); // batch1: c1(string) - let batch1 = create_batch(vec![("c1", c1.clone())]); + create_batch(vec![("c1", c1.clone())]) + } + + #[tokio::test] + async fn parquet_exec_metrics() { + // batch1: c1(string) + let batch1 = string_batch(); - // on + // c1 != 'bar' let filter = col("c1").not_eq(lit("bar")); // read/write them files: @@ -1831,20 +1819,10 @@ mod tests { #[tokio::test] async fn parquet_exec_display() { - let c1: ArrayRef = Arc::new(StringArray::from(vec![ - Some("Foo"), - None, - Some("bar"), - Some("bar"), - Some("bar"), - Some("bar"), - Some("zzz"), - ])); - // batch1: c1(string) - let batch1 = create_batch(vec![("c1", c1.clone())]); + let batch1 = string_batch(); - // on + // c1 != 'bar' let filter = col("c1").not_eq(lit("bar")); let rt = RoundTrip::new() @@ -1873,21 +1851,15 @@ mod tests { } #[tokio::test] - async fn parquet_exec_skip_empty_pruning() { - let c1: ArrayRef = Arc::new(StringArray::from(vec![ - Some("Foo"), - None, - Some("bar"), - Some("bar"), - Some("bar"), - Some("bar"), - Some("zzz"), - ])); - + async fn parquet_exec_has_no_pruning_predicate_if_can_not_prune() { // batch1: c1(string) - let batch1 = create_batch(vec![("c1", c1.clone())]); + let batch1 = string_batch(); - // filter is too complicated for pruning + // filter is too complicated for pruning (PruningPredicate code does not + // handle case expressions), so the pruning predicate will always be + // "true" + + // WHEN c1 != bar THEN true ELSE false END let filter = when(col("c1").not_eq(lit("bar")), lit(true)) .otherwise(lit(false)) .unwrap(); @@ -1898,7 +1870,7 @@ mod tests { .round_trip(vec![batch1]) .await; - // Should not contain a pruning predicate + // Should not contain a pruning predicate (since nothing can be pruned) let pruning_predicate = &rt.parquet_exec.pruning_predicate; assert!( pruning_predicate.is_none(), @@ -1911,6 +1883,33 @@ mod tests { assert_eq!(predicate.unwrap().to_string(), filter_phys.to_string()); } + #[tokio::test] + async fn parquet_exec_has_pruning_predicate_for_guarantees() { + // batch1: c1(string) + let batch1 = string_batch(); + + // part of the filter is too complicated for pruning (PruningPredicate code does not + // handle case expressions), but part (c1 = 'foo') can be used for bloom filtering, so + // should still have the pruning predicate. + + // c1 = 'foo' AND (WHEN c1 != bar THEN true ELSE false END) + let filter = col("c1").eq(lit("foo")).and( + when(col("c1").not_eq(lit("bar")), lit(true)) + .otherwise(lit(false)) + .unwrap(), + ); + + let rt = RoundTrip::new() + .with_predicate(filter.clone()) + .with_pushdown_predicate() + .round_trip(vec![batch1]) + .await; + + // Should have a pruning predicate + let pruning_predicate = &rt.parquet_exec.pruning_predicate; + assert!(pruning_predicate.is_some()); + } + /// returns the sum of all the metrics with the specified name /// the returned set. /// @@ -1957,6 +1956,96 @@ mod tests { Ok(schema) } + #[tokio::test] + async fn write_table_results() -> Result<()> { + // create partitioned input file and context + let tmp_dir = TempDir::new()?; + // let mut ctx = create_ctx(&tmp_dir, 4).await?; + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_target_partitions(8), + ); + let schema = populate_csv_partitions(&tmp_dir, 4, ".csv")?; + // register csv file with the execution context + ctx.register_csv( + "test", + tmp_dir.path().to_str().unwrap(), + CsvReadOptions::new().schema(&schema), + ) + .await?; + + // register a local file system object store for /tmp directory + let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?); + let local_url = Url::parse("file://local").unwrap(); + ctx.runtime_env().register_object_store(&local_url, local); + + // Configure listing options + let file_format = ParquetFormat::default().with_enable_pruning(Some(true)); + let listing_options = ListingOptions::new(Arc::new(file_format)) + .with_file_extension(FileType::PARQUET.get_ext()); + + // execute a simple query and write the results to parquet + let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out"; + std::fs::create_dir(&out_dir).unwrap(); + let df = ctx.sql("SELECT c1, c2 FROM test").await?; + let schema: Schema = df.schema().into(); + // Register a listing table - this will use all files in the directory as data sources + // for the query + ctx.register_listing_table( + "my_table", + &out_dir, + listing_options, + Some(Arc::new(schema)), + None, + ) + .await + .unwrap(); + df.write_table("my_table", DataFrameWriteOptions::new()) + .await?; + + // create a new context and verify that the results were saved to a partitioned parquet file + let ctx = SessionContext::new(); + + // get write_id + let mut paths = fs::read_dir(&out_dir).unwrap(); + let path = paths.next(); + let name = path + .unwrap()? + .path() + .file_name() + .expect("Should be a file name") + .to_str() + .expect("Should be a str") + .to_owned(); + let (parsed_id, _) = name.split_once('_').expect("File should contain _ !"); + let write_id = parsed_id.to_owned(); + + // register each partition as well as the top level dir + ctx.register_parquet( + "part0", + &format!("{out_dir}/{write_id}_0.parquet"), + ParquetReadOptions::default(), + ) + .await?; + + ctx.register_parquet("allparts", &out_dir, ParquetReadOptions::default()) + .await?; + + let part0 = ctx.sql("SELECT c1, c2 FROM part0").await?.collect().await?; + let allparts = ctx + .sql("SELECT c1, c2 FROM allparts") + .await? + .collect() + .await?; + + let allparts_count: usize = allparts.iter().map(|batch| batch.num_rows()).sum(); + + assert_eq!(part0[0].schema(), allparts[0].schema()); + + assert_eq!(allparts_count, 40); + + Ok(()) + } + #[tokio::test] async fn write_parquet_results() -> Result<()> { // create partitioned input file and context @@ -2001,7 +2090,6 @@ mod tests { .to_str() .expect("Should be a str") .to_owned(); - println!("{name}"); let (parsed_id, _) = name.split_once('_').expect("File should contain _ !"); let write_id = parsed_id.to_owned(); diff --git a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs index b5b5f154f7a0f..a0637f3796106 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs @@ -23,7 +23,7 @@ use arrow::array::{ }; use arrow::datatypes::DataType; use arrow::{array::ArrayRef, datatypes::SchemaRef, error::ArrowError}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; use log::{debug, trace}; @@ -37,11 +37,11 @@ use parquet::{ }, format::PageLocation, }; +use std::collections::HashSet; use std::sync::Arc; -use crate::datasource::physical_plan::parquet::{ - from_bytes_to_i128, parquet_to_arrow_decimal_type, -}; +use crate::datasource::physical_plan::parquet::parquet_to_arrow_decimal_type; +use crate::datasource::physical_plan::parquet::statistics::from_bytes_to_i128; use crate::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; use super::metrics::ParquetFileMetrics; @@ -372,7 +372,7 @@ fn prune_pages_in_one_row_group( } fn create_row_count_in_each_page( - location: &Vec, + location: &[PageLocation], num_rows: usize, ) -> Vec { let mut vec = Vec::with_capacity(location.len()); @@ -555,4 +555,12 @@ impl<'a> PruningStatistics for PagesPruningStatistics<'a> { ))), } } + + fn contained( + &self, + _column: &datafusion_common::Column, + _values: &HashSet, + ) -> Option { + None + } } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs index 0f4b09caeded5..151ab5f657b15 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs @@ -21,7 +21,7 @@ use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; use datafusion_common::cast::as_boolean_array; use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{arrow_err, DataFusionError, Result, ScalarValue}; use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::utils::reassign_predicate_columns; use std::collections::BTreeSet; @@ -126,7 +126,7 @@ impl ArrowPredicate for DatafusionArrowPredicate { match self .physical_expr .evaluate(&batch) - .map(|v| v.into_array(batch.num_rows())) + .and_then(|v| v.into_array(batch.num_rows())) { Ok(array) => { let bool_arr = as_boolean_array(&array)?.clone(); @@ -243,7 +243,7 @@ impl<'a> TreeNodeRewriter for FilterCandidateBuilder<'a> { } Err(e) => { // If the column is not in the table schema, should throw the error - Err(DataFusionError::ArrowError(e)) + arrow_err!(e) } }; } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 91bceed916027..24c65423dd4ca 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -15,30 +15,24 @@ // specific language governing permissions and limitations // under the License. -use arrow::{ - array::ArrayRef, - datatypes::{DataType, Schema}, -}; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; -use datafusion_common::{Column, DataFusionError, Result, ScalarValue}; +use arrow::{array::ArrayRef, datatypes::Schema}; +use arrow_array::BooleanArray; +use arrow_schema::FieldRef; +use datafusion_common::{Column, ScalarValue}; +use parquet::file::metadata::ColumnChunkMetaData; +use parquet::schema::types::SchemaDescriptor; use parquet::{ arrow::{async_reader::AsyncFileReader, ParquetRecordBatchStreamBuilder}, bloom_filter::Sbbf, - file::{metadata::RowGroupMetaData, statistics::Statistics as ParquetStatistics}, -}; -use std::{ - collections::{HashMap, HashSet}, - sync::Arc, + file::metadata::RowGroupMetaData, }; +use std::collections::{HashMap, HashSet}; -use crate::datasource::{ - listing::FileRange, - physical_plan::parquet::{from_bytes_to_i128, parquet_to_arrow_decimal_type}, +use crate::datasource::listing::FileRange; +use crate::datasource::physical_plan::parquet::statistics::{ + max_statistics, min_statistics, parquet_column, }; -use crate::logical_expr::Operator; -use crate::physical_expr::expressions as phys_expr; use crate::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; -use crate::physical_plan::PhysicalExpr; use super::ParquetFileMetrics; @@ -51,7 +45,12 @@ use super::ParquetFileMetrics; /// /// If an index IS present in the returned Vec it means the predicate /// did not filter out that row group. +/// +/// Note: This method currently ignores ColumnOrder +/// pub(crate) fn prune_row_groups_by_statistics( + arrow_schema: &Schema, + parquet_schema: &SchemaDescriptor, groups: &[RowGroupMetaData], range: Option, predicate: Option<&PruningPredicate>, @@ -74,8 +73,9 @@ pub(crate) fn prune_row_groups_by_statistics( if let Some(predicate) = predicate { let pruning_stats = RowGroupPruningStatistics { + parquet_schema, row_group_metadata: metadata, - parquet_schema: predicate.schema().as_ref(), + arrow_schema, }; match predicate.prune(&pruning_stats) { Ok(values) => { @@ -111,331 +111,159 @@ pub(crate) fn prune_row_groups_by_statistics( pub(crate) async fn prune_row_groups_by_bloom_filters< T: AsyncFileReader + Send + 'static, >( + arrow_schema: &Schema, builder: &mut ParquetRecordBatchStreamBuilder, row_groups: &[usize], groups: &[RowGroupMetaData], predicate: &PruningPredicate, metrics: &ParquetFileMetrics, ) -> Vec { - let bf_predicates = match BloomFilterPruningPredicate::try_new(predicate.orig_expr()) - { - Ok(predicates) => predicates, - Err(_) => { - return row_groups.to_vec(); - } - }; let mut filtered = Vec::with_capacity(groups.len()); for idx in row_groups { - let rg_metadata = &groups[*idx]; - // get all columns bloom filter - let mut column_sbbf = - HashMap::with_capacity(bf_predicates.required_columns.len()); - for column_name in bf_predicates.required_columns.iter() { - let column_idx = match rg_metadata - .columns() - .iter() - .enumerate() - .find(|(_, column)| column.column_path().string().eq(column_name)) - { - Some((column_idx, _)) => column_idx, - None => continue, + // get all columns in the predicate that we could use a bloom filter with + let literal_columns = predicate.literal_columns(); + let mut column_sbbf = HashMap::with_capacity(literal_columns.len()); + + for column_name in literal_columns { + let Some((column_idx, _field)) = + parquet_column(builder.parquet_schema(), arrow_schema, &column_name) + else { + continue; }; + let bf = match builder .get_row_group_column_bloom_filter(*idx, column_idx) .await { - Ok(bf) => match bf { - Some(bf) => bf, - None => { - continue; - } - }, + Ok(Some(bf)) => bf, + Ok(None) => continue, // no bloom filter for this column Err(e) => { - log::error!("Error evaluating row group predicate values when using BloomFilterPruningPredicate {e}"); + log::debug!("Ignoring error reading bloom filter: {e}"); metrics.predicate_evaluation_errors.add(1); continue; } }; - column_sbbf.insert(column_name.to_owned(), bf); + column_sbbf.insert(column_name.to_string(), bf); } - if bf_predicates.prune(&column_sbbf) { + + let stats = BloomFilterStatistics { column_sbbf }; + + // Can this group be pruned? + let prune_group = match predicate.prune(&stats) { + Ok(values) => !values[0], + Err(e) => { + log::debug!("Error evaluating row group predicate on bloom filter: {e}"); + metrics.predicate_evaluation_errors.add(1); + false + } + }; + + if prune_group { metrics.row_groups_pruned.add(1); - continue; + } else { + filtered.push(*idx); } - filtered.push(*idx); } filtered } -struct BloomFilterPruningPredicate { - /// Actual pruning predicate - predicate_expr: Option, - /// The statistics required to evaluate this predicate - required_columns: Vec, +/// Implements `PruningStatistics` for Parquet Split Block Bloom Filters (SBBF) +struct BloomFilterStatistics { + /// Maps column name to the parquet bloom filter + column_sbbf: HashMap, } -impl BloomFilterPruningPredicate { - fn try_new(expr: &Arc) -> Result { - let binary_expr = expr.as_any().downcast_ref::(); - match binary_expr { - Some(binary_expr) => { - let columns = Self::get_predicate_columns(expr); - Ok(Self { - predicate_expr: Some(binary_expr.clone()), - required_columns: columns.into_iter().collect(), - }) - } - None => Err(DataFusionError::Execution( - "BloomFilterPruningPredicate only support binary expr".to_string(), - )), - } +impl PruningStatistics for BloomFilterStatistics { + fn min_values(&self, _column: &Column) -> Option { + None } - fn prune(&self, column_sbbf: &HashMap) -> bool { - Self::prune_expr_with_bloom_filter(self.predicate_expr.as_ref(), column_sbbf) + fn max_values(&self, _column: &Column) -> Option { + None } - /// Return true if the `expr` can be proved not `true` - /// based on the bloom filter. - /// - /// We only checked `BinaryExpr` but it also support `InList`, - /// Because of the `optimizer` will convert `InList` to `BinaryExpr`. - fn prune_expr_with_bloom_filter( - expr: Option<&phys_expr::BinaryExpr>, - column_sbbf: &HashMap, - ) -> bool { - let Some(expr) = expr else { - // unsupported predicate - return false; - }; - match expr.op() { - Operator::And | Operator::Or => { - let left = Self::prune_expr_with_bloom_filter( - expr.left().as_any().downcast_ref::(), - column_sbbf, - ); - let right = Self::prune_expr_with_bloom_filter( - expr.right() - .as_any() - .downcast_ref::(), - column_sbbf, - ); - match expr.op() { - Operator::And => left || right, - Operator::Or => left && right, - _ => false, - } - } - Operator::Eq => { - if let Some((col, val)) = Self::check_expr_is_col_equal_const(expr) { - if let Some(sbbf) = column_sbbf.get(col.name()) { - match val { - ScalarValue::Utf8(Some(v)) => !sbbf.check(&v.as_str()), - ScalarValue::Boolean(Some(v)) => !sbbf.check(&v), - ScalarValue::Float64(Some(v)) => !sbbf.check(&v), - ScalarValue::Float32(Some(v)) => !sbbf.check(&v), - ScalarValue::Int64(Some(v)) => !sbbf.check(&v), - ScalarValue::Int32(Some(v)) => !sbbf.check(&v), - ScalarValue::Int16(Some(v)) => !sbbf.check(&v), - ScalarValue::Int8(Some(v)) => !sbbf.check(&v), - _ => false, - } - } else { - false - } - } else { - false - } - } - _ => false, - } + fn num_containers(&self) -> usize { + 1 } - fn get_predicate_columns(expr: &Arc) -> HashSet { - let mut columns = HashSet::new(); - expr.apply(&mut |expr| { - if let Some(binary_expr) = - expr.as_any().downcast_ref::() - { - if let Some((column, _)) = - Self::check_expr_is_col_equal_const(binary_expr) - { - columns.insert(column.name().to_string()); - } - } - Ok(VisitRecursion::Continue) - }) - // no way to fail as only Ok(VisitRecursion::Continue) is returned - .unwrap(); - - columns + fn null_counts(&self, _column: &Column) -> Option { + None } - fn check_expr_is_col_equal_const( - exr: &phys_expr::BinaryExpr, - ) -> Option<(phys_expr::Column, ScalarValue)> { - if Operator::Eq.ne(exr.op()) { - return None; - } + /// Use bloom filters to determine if we are sure this column can not + /// possibly contain `values` + /// + /// The `contained` API returns false if the bloom filters knows that *ALL* + /// of the values in a column are not present. + fn contained( + &self, + column: &Column, + values: &HashSet, + ) -> Option { + let sbbf = self.column_sbbf.get(column.name.as_str())?; + + // Bloom filters are probabilistic data structures that can return false + // positives (i.e. it might return true even if the value is not + // present) however, the bloom filter will return `false` if the value is + // definitely not present. + + let known_not_present = values + .iter() + .map(|value| match value { + ScalarValue::Utf8(Some(v)) => sbbf.check(&v.as_str()), + ScalarValue::Boolean(Some(v)) => sbbf.check(v), + ScalarValue::Float64(Some(v)) => sbbf.check(v), + ScalarValue::Float32(Some(v)) => sbbf.check(v), + ScalarValue::Int64(Some(v)) => sbbf.check(v), + ScalarValue::Int32(Some(v)) => sbbf.check(v), + ScalarValue::Int16(Some(v)) => sbbf.check(v), + ScalarValue::Int8(Some(v)) => sbbf.check(v), + _ => true, + }) + // The row group doesn't contain any of the values if + // all the checks are false + .all(|v| !v); + + let contains = if known_not_present { + Some(false) + } else { + // Given the bloom filter is probabilistic, we can't be sure that + // the row group actually contains the values. Return `None` to + // indicate this uncertainty + None + }; - let left_any = exr.left().as_any(); - let right_any = exr.right().as_any(); - if let (Some(col), Some(liter)) = ( - left_any.downcast_ref::(), - right_any.downcast_ref::(), - ) { - return Some((col.clone(), liter.value().clone())); - } - if let (Some(liter), Some(col)) = ( - left_any.downcast_ref::(), - right_any.downcast_ref::(), - ) { - return Some((col.clone(), liter.value().clone())); - } - None + Some(BooleanArray::from(vec![contains])) } } -/// Wraps parquet statistics in a way -/// that implements [`PruningStatistics`] +/// Wraps [`RowGroupMetaData`] in a way that implements [`PruningStatistics`] +/// +/// Note: This should be implemented for an array of [`RowGroupMetaData`] instead +/// of per row-group struct RowGroupPruningStatistics<'a> { + parquet_schema: &'a SchemaDescriptor, row_group_metadata: &'a RowGroupMetaData, - parquet_schema: &'a Schema, -} - -/// Extract the min/max statistics from a `ParquetStatistics` object -macro_rules! get_statistic { - ($column_statistics:expr, $func:ident, $bytes_func:ident, $target_arrow_type:expr) => {{ - if !$column_statistics.has_min_max_set() { - return None; - } - match $column_statistics { - ParquetStatistics::Boolean(s) => Some(ScalarValue::Boolean(Some(*s.$func()))), - ParquetStatistics::Int32(s) => { - match $target_arrow_type { - // int32 to decimal with the precision and scale - Some(DataType::Decimal128(precision, scale)) => { - Some(ScalarValue::Decimal128( - Some(*s.$func() as i128), - precision, - scale, - )) - } - _ => Some(ScalarValue::Int32(Some(*s.$func()))), - } - } - ParquetStatistics::Int64(s) => { - match $target_arrow_type { - // int64 to decimal with the precision and scale - Some(DataType::Decimal128(precision, scale)) => { - Some(ScalarValue::Decimal128( - Some(*s.$func() as i128), - precision, - scale, - )) - } - _ => Some(ScalarValue::Int64(Some(*s.$func()))), - } - } - // 96 bit ints not supported - ParquetStatistics::Int96(_) => None, - ParquetStatistics::Float(s) => Some(ScalarValue::Float32(Some(*s.$func()))), - ParquetStatistics::Double(s) => Some(ScalarValue::Float64(Some(*s.$func()))), - ParquetStatistics::ByteArray(s) => { - match $target_arrow_type { - // decimal data type - Some(DataType::Decimal128(precision, scale)) => { - Some(ScalarValue::Decimal128( - Some(from_bytes_to_i128(s.$bytes_func())), - precision, - scale, - )) - } - _ => { - let s = std::str::from_utf8(s.$bytes_func()) - .map(|s| s.to_string()) - .ok(); - Some(ScalarValue::Utf8(s)) - } - } - } - // type not supported yet - ParquetStatistics::FixedLenByteArray(s) => { - match $target_arrow_type { - // just support the decimal data type - Some(DataType::Decimal128(precision, scale)) => { - Some(ScalarValue::Decimal128( - Some(from_bytes_to_i128(s.$bytes_func())), - precision, - scale, - )) - } - _ => None, - } - } - } - }}; + arrow_schema: &'a Schema, } -// Extract the min or max value calling `func` or `bytes_func` on the ParquetStatistics as appropriate -macro_rules! get_min_max_values { - ($self:expr, $column:expr, $func:ident, $bytes_func:ident) => {{ - let (_column_index, field) = - if let Some((v, f)) = $self.parquet_schema.column_with_name(&$column.name) { - (v, f) - } else { - // Named column was not present - return None; - }; - - let data_type = field.data_type(); - // The result may be None, because DataFusion doesn't have support for ScalarValues of the column type - let null_scalar: ScalarValue = data_type.try_into().ok()?; - - $self.row_group_metadata - .columns() - .iter() - .find(|c| c.column_descr().name() == &$column.name) - .and_then(|c| if c.statistics().is_some() {Some((c.statistics().unwrap(), c.column_descr()))} else {None}) - .map(|(stats, column_descr)| - { - let target_data_type = parquet_to_arrow_decimal_type(column_descr); - get_statistic!(stats, $func, $bytes_func, target_data_type) - }) - .flatten() - // column either didn't have statistics at all or didn't have min/max values - .or_else(|| Some(null_scalar.clone())) - .map(|s| s.to_array()) - }} -} - -// Extract the null count value on the ParquetStatistics -macro_rules! get_null_count_values { - ($self:expr, $column:expr) => {{ - let value = ScalarValue::UInt64( - if let Some(col) = $self - .row_group_metadata - .columns() - .iter() - .find(|c| c.column_descr().name() == &$column.name) - { - col.statistics().map(|s| s.null_count()) - } else { - Some($self.row_group_metadata.num_rows() as u64) - }, - ); - - Some(value.to_array()) - }}; +impl<'a> RowGroupPruningStatistics<'a> { + /// Lookups up the parquet column by name + fn column(&self, name: &str) -> Option<(&ColumnChunkMetaData, &FieldRef)> { + let (idx, field) = parquet_column(self.parquet_schema, self.arrow_schema, name)?; + Some((self.row_group_metadata.column(idx), field)) + } } impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { fn min_values(&self, column: &Column) -> Option { - get_min_max_values!(self, column, min, min_bytes) + let (column, field) = self.column(&column.name)?; + min_statistics(field.data_type(), std::iter::once(column.statistics())).ok() } fn max_values(&self, column: &Column) -> Option { - get_min_max_values!(self, column, max, max_bytes) + let (column, field) = self.column(&column.name)?; + max_statistics(field.data_type(), std::iter::once(column.statistics())).ok() } fn num_containers(&self) -> usize { @@ -443,7 +271,17 @@ impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { } fn null_counts(&self, column: &Column) -> Option { - get_null_count_values!(self, column) + let (c, _) = self.column(&column.name)?; + let scalar = ScalarValue::UInt64(Some(c.statistics()?.null_count())); + scalar.to_array().ok() + } + + fn contained( + &self, + _column: &Column, + _values: &HashSet, + ) -> Option { + None } } @@ -455,14 +293,11 @@ mod tests { use arrow::datatypes::DataType::Decimal128; use arrow::datatypes::Schema; use arrow::datatypes::{DataType, Field}; - use datafusion_common::{config::ConfigOptions, TableReference, ToDFSchema}; - use datafusion_expr::{ - builder::LogicalTableSource, cast, col, lit, AggregateUDF, Expr, ScalarUDF, - TableSource, WindowUDF, - }; + use datafusion_common::{Result, ToDFSchema}; + use datafusion_expr::{cast, col, lit, Expr}; use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::{create_physical_expr, PhysicalExpr}; - use datafusion_sql::planner::ContextProvider; + use parquet::arrow::arrow_to_parquet_schema; use parquet::arrow::async_reader::ParquetObjectReader; use parquet::basic::LogicalType; use parquet::data_type::{ByteArray, FixedLenByteArray}; @@ -520,11 +355,11 @@ mod tests { fn row_group_pruning_predicate_simple_expr() { use datafusion_expr::{col, lit}; // int > 1 => c1_max > 1 - let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); let expr = col("c1").gt(lit(15)); let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let field = PrimitiveTypeField::new("c1", PhysicalType::INT32); let schema_descr = get_test_schema_descr(vec![field]); @@ -540,6 +375,8 @@ mod tests { let metrics = parquet_file_metrics(); assert_eq!( prune_row_groups_by_statistics( + &schema, + &schema_descr, &[rgm1, rgm2], None, Some(&pruning_predicate), @@ -553,11 +390,11 @@ mod tests { fn row_group_pruning_predicate_missing_stats() { use datafusion_expr::{col, lit}; // int > 1 => c1_max > 1 - let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); let expr = col("c1").gt(lit(15)); let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let field = PrimitiveTypeField::new("c1", PhysicalType::INT32); let schema_descr = get_test_schema_descr(vec![field]); @@ -574,6 +411,8 @@ mod tests { // is null / undefined so the first row group can't be filtered out assert_eq!( prune_row_groups_by_statistics( + &schema, + &schema_descr, &[rgm1, rgm2], None, Some(&pruning_predicate), @@ -621,6 +460,8 @@ mod tests { // when conditions are joined using AND assert_eq!( prune_row_groups_by_statistics( + &schema, + &schema_descr, groups, None, Some(&pruning_predicate), @@ -633,12 +474,14 @@ mod tests { // this bypasses the entire predicate expression and no row groups are filtered out let expr = col("c1").gt(lit(15)).or(col("c2").rem(lit(2)).eq(lit(0))); let expr = logical2physical(&expr, &schema); - let pruning_predicate = PruningPredicate::try_new(expr, schema).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); // if conditions in predicate are joined with OR and an unsupported expression is used // this bypasses the entire predicate expression and no row groups are filtered out assert_eq!( prune_row_groups_by_statistics( + &schema, + &schema_descr, groups, None, Some(&pruning_predicate), @@ -648,6 +491,64 @@ mod tests { ); } + #[test] + fn row_group_pruning_predicate_file_schema() { + use datafusion_expr::{col, lit}; + // test row group predicate when file schema is different than table schema + // c1 > 0 + let table_schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + ])); + let expr = col("c1").gt(lit(0)); + let expr = logical2physical(&expr, &table_schema); + let pruning_predicate = + PruningPredicate::try_new(expr, table_schema.clone()).unwrap(); + + // Model a file schema's column order c2 then c1, which is the opposite + // of the table schema + let file_schema = Arc::new(Schema::new(vec![ + Field::new("c2", DataType::Int32, false), + Field::new("c1", DataType::Int32, false), + ])); + let schema_descr = get_test_schema_descr(vec![ + PrimitiveTypeField::new("c2", PhysicalType::INT32), + PrimitiveTypeField::new("c1", PhysicalType::INT32), + ]); + // rg1 has c2 less than zero, c1 greater than zero + let rgm1 = get_row_group_meta_data( + &schema_descr, + vec![ + ParquetStatistics::int32(Some(-10), Some(-1), None, 0, false), // c2 + ParquetStatistics::int32(Some(1), Some(10), None, 0, false), + ], + ); + // rg1 has c2 greater than zero, c1 less than zero + let rgm2 = get_row_group_meta_data( + &schema_descr, + vec![ + ParquetStatistics::int32(Some(1), Some(10), None, 0, false), + ParquetStatistics::int32(Some(-10), Some(-1), None, 0, false), + ], + ); + + let metrics = parquet_file_metrics(); + let groups = &[rgm1, rgm2]; + // the first row group should be left because c1 is greater than zero + // the second should be filtered out because c1 is less than zero + assert_eq!( + prune_row_groups_by_statistics( + &file_schema, // NB must be file schema, not table_schema + &schema_descr, + groups, + None, + Some(&pruning_predicate), + &metrics + ), + vec![0] + ); + } + fn gen_row_group_meta_data_for_pruning_predicate() -> Vec { let schema_descr = get_test_schema_descr(vec![ PrimitiveTypeField::new("c1", PhysicalType::INT32), @@ -678,15 +579,18 @@ mod tests { Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Boolean, false), ])); + let schema_descr = arrow_to_parquet_schema(&schema).unwrap(); let expr = col("c1").gt(lit(15)).and(col("c2").is_null()); let expr = logical2physical(&expr, &schema); - let pruning_predicate = PruningPredicate::try_new(expr, schema).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let groups = gen_row_group_meta_data_for_pruning_predicate(); let metrics = parquet_file_metrics(); // First row group was filtered out because it contains no null value on "c2". assert_eq!( prune_row_groups_by_statistics( + &schema, + &schema_descr, &groups, None, Some(&pruning_predicate), @@ -706,11 +610,12 @@ mod tests { Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Boolean, false), ])); + let schema_descr = arrow_to_parquet_schema(&schema).unwrap(); let expr = col("c1") .gt(lit(15)) .and(col("c2").eq(lit(ScalarValue::Boolean(None)))); let expr = logical2physical(&expr, &schema); - let pruning_predicate = PruningPredicate::try_new(expr, schema).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let groups = gen_row_group_meta_data_for_pruning_predicate(); let metrics = parquet_file_metrics(); @@ -718,6 +623,8 @@ mod tests { // pass predicates. Ideally these should both be false assert_eq!( prune_row_groups_by_statistics( + &schema, + &schema_descr, &groups, None, Some(&pruning_predicate), @@ -735,8 +642,11 @@ mod tests { // INT32: c1 > 5, the c1 is decimal(9,2) // The type of scalar value if decimal(9,2), don't need to do cast - let schema = - Schema::new(vec![Field::new("c1", DataType::Decimal128(9, 2), false)]); + let schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Decimal128(9, 2), + false, + )])); let field = PrimitiveTypeField::new("c1", PhysicalType::INT32) .with_logical_type(LogicalType::Decimal { scale: 2, @@ -747,8 +657,7 @@ mod tests { let schema_descr = get_test_schema_descr(vec![field]); let expr = col("c1").gt(lit(ScalarValue::Decimal128(Some(500), 9, 2))); let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let rgm1 = get_row_group_meta_data( &schema_descr, // [1.00, 6.00] @@ -776,6 +685,8 @@ mod tests { let metrics = parquet_file_metrics(); assert_eq!( prune_row_groups_by_statistics( + &schema, + &schema_descr, &[rgm1, rgm2, rgm3], None, Some(&pruning_predicate), @@ -788,8 +699,11 @@ mod tests { // The c1 type is decimal(9,0) in the parquet file, and the type of scalar is decimal(5,2). // We should convert all type to the coercion type, which is decimal(11,2) // The decimal of arrow is decimal(5,2), the decimal of parquet is decimal(9,0) - let schema = - Schema::new(vec![Field::new("c1", DataType::Decimal128(9, 0), false)]); + let schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Decimal128(9, 0), + false, + )])); let field = PrimitiveTypeField::new("c1", PhysicalType::INT32) .with_logical_type(LogicalType::Decimal { @@ -804,8 +718,7 @@ mod tests { Decimal128(11, 2), )); let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let rgm1 = get_row_group_meta_data( &schema_descr, // [100, 600] @@ -839,6 +752,8 @@ mod tests { let metrics = parquet_file_metrics(); assert_eq!( prune_row_groups_by_statistics( + &schema, + &schema_descr, &[rgm1, rgm2, rgm3, rgm4], None, Some(&pruning_predicate), @@ -848,8 +763,11 @@ mod tests { ); // INT64: c1 < 5, the c1 is decimal(18,2) - let schema = - Schema::new(vec![Field::new("c1", DataType::Decimal128(18, 2), false)]); + let schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Decimal128(18, 2), + false, + )])); let field = PrimitiveTypeField::new("c1", PhysicalType::INT64) .with_logical_type(LogicalType::Decimal { scale: 2, @@ -860,8 +778,7 @@ mod tests { let schema_descr = get_test_schema_descr(vec![field]); let expr = col("c1").lt(lit(ScalarValue::Decimal128(Some(500), 18, 2))); let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let rgm1 = get_row_group_meta_data( &schema_descr, // [6.00, 8.00] @@ -886,6 +803,8 @@ mod tests { let metrics = parquet_file_metrics(); assert_eq!( prune_row_groups_by_statistics( + &schema, + &schema_descr, &[rgm1, rgm2, rgm3], None, Some(&pruning_predicate), @@ -896,8 +815,11 @@ mod tests { // FIXED_LENGTH_BYTE_ARRAY: c1 = decimal128(100000, 28, 3), the c1 is decimal(18,2) // the type of parquet is decimal(18,2) - let schema = - Schema::new(vec![Field::new("c1", DataType::Decimal128(18, 2), false)]); + let schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Decimal128(18, 2), + false, + )])); let field = PrimitiveTypeField::new("c1", PhysicalType::FIXED_LEN_BYTE_ARRAY) .with_logical_type(LogicalType::Decimal { scale: 2, @@ -911,8 +833,7 @@ mod tests { let left = cast(col("c1"), DataType::Decimal128(28, 3)); let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3))); let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); // we must use the big-endian when encode the i128 to bytes or vec[u8]. let rgm1 = get_row_group_meta_data( &schema_descr, @@ -956,6 +877,8 @@ mod tests { let metrics = parquet_file_metrics(); assert_eq!( prune_row_groups_by_statistics( + &schema, + &schema_descr, &[rgm1, rgm2, rgm3], None, Some(&pruning_predicate), @@ -966,8 +889,11 @@ mod tests { // BYTE_ARRAY: c1 = decimal128(100000, 28, 3), the c1 is decimal(18,2) // the type of parquet is decimal(18,2) - let schema = - Schema::new(vec![Field::new("c1", DataType::Decimal128(18, 2), false)]); + let schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Decimal128(18, 2), + false, + )])); let field = PrimitiveTypeField::new("c1", PhysicalType::BYTE_ARRAY) .with_logical_type(LogicalType::Decimal { scale: 2, @@ -981,8 +907,7 @@ mod tests { let left = cast(col("c1"), DataType::Decimal128(28, 3)); let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3))); let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); // we must use the big-endian when encode the i128 to bytes or vec[u8]. let rgm1 = get_row_group_meta_data( &schema_descr, @@ -1015,6 +940,8 @@ mod tests { let metrics = parquet_file_metrics(); assert_eq!( prune_row_groups_by_statistics( + &schema, + &schema_descr, &[rgm1, rgm2, rgm3], None, Some(&pruning_predicate), @@ -1028,7 +955,6 @@ mod tests { schema_descr: &SchemaDescPtr, column_statistics: Vec, ) -> RowGroupMetaData { - use parquet::file::metadata::ColumnChunkMetaData; let mut columns = vec![]; for (i, s) in column_statistics.iter().enumerate() { let column = ColumnChunkMetaData::builder(schema_descr.column(i)) @@ -1046,7 +972,7 @@ mod tests { } fn get_test_schema_descr(fields: Vec) -> SchemaDescPtr { - use parquet::schema::types::{SchemaDescriptor, Type as SchemaType}; + use parquet::schema::types::Type as SchemaType; let schema_fields = fields .iter() .map(|field| { @@ -1089,33 +1015,30 @@ mod tests { #[tokio::test] async fn test_row_group_bloom_filter_pruning_predicate_simple_expr() { - // load parquet file - let testdata = datafusion_common::test_util::parquet_test_data(); - let file_name = "data_index_bloom_encoding_stats.parquet"; - let path = format!("{testdata}/{file_name}"); - let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - - // generate pruning predicate - let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); - let expr = col(r#""String""#).eq(lit("Hello_Not_Exists")); - let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); - - let row_groups = vec![0]; - let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( - file_name, - data, - &pruning_predicate, - &row_groups, - ) - .await - .unwrap(); - assert!(pruned_row_groups.is_empty()); + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_all_pruned() + // generate pruning predicate `(String = "Hello_Not_exists")` + .run(col(r#""String""#).eq(lit("Hello_Not_Exists"))) + .await } #[tokio::test] async fn test_row_group_bloom_filter_pruning_predicate_mutiple_expr() { + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_all_pruned() + // generate pruning predicate `(String = "Hello_Not_exists" OR String = "Hello_Not_exists2")` + .run( + lit("1").eq(lit("1")).and( + col(r#""String""#) + .eq(lit("Hello_Not_Exists")) + .or(col(r#""String""#).eq(lit("Hello_Not_Exists2"))), + ), + ) + .await + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_sql_in() { // load parquet file let testdata = datafusion_common::test_util::parquet_test_data(); let file_name = "data_index_bloom_encoding_stats.parquet"; @@ -1124,10 +1047,15 @@ mod tests { // generate pruning predicate let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); - let expr = lit("1").eq(lit("1")).and( - col(r#""String""#) - .eq(lit("Hello_Not_Exists")) - .or(col(r#""String""#).eq(lit("Hello_Not_Exists2"))), + + let expr = col(r#""String""#).in_list( + vec![ + lit("Hello_Not_Exists"), + lit("Hello_Not_Exists2"), + lit("Hello_Not_Exists3"), + lit("Hello_Not_Exist4"), + ], + false, ); let expr = logical2physical(&expr, &schema); let pruning_predicate = @@ -1146,88 +1074,162 @@ mod tests { } #[tokio::test] - async fn test_row_group_bloom_filter_pruning_predicate_sql_in() { - // load parquet file - let testdata = datafusion_common::test_util::parquet_test_data(); - let file_name = "data_index_bloom_encoding_stats.parquet"; - let path = format!("{testdata}/{file_name}"); - let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - - // generate pruning predicate - let schema = Schema::new(vec![ - Field::new("String", DataType::Utf8, false), - Field::new("String3", DataType::Utf8, false), - ]); - let sql = - "SELECT * FROM tbl WHERE \"String\" IN ('Hello_Not_Exists', 'Hello_Not_Exists2')"; - let expr = sql_to_physical_plan(sql).unwrap(); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); - - let row_groups = vec![0]; - let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( - file_name, - data, - &pruning_predicate, - &row_groups, - ) - .await - .unwrap(); - assert!(pruned_row_groups.is_empty()); + async fn test_row_group_bloom_filter_pruning_predicate_with_exists_value() { + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_none_pruned() + // generate pruning predicate `(String = "Hello")` + .run(col(r#""String""#).eq(lit("Hello"))) + .await } #[tokio::test] - async fn test_row_group_bloom_filter_pruning_predicate_with_exists_value() { - // load parquet file - let testdata = datafusion_common::test_util::parquet_test_data(); - let file_name = "data_index_bloom_encoding_stats.parquet"; - let path = format!("{testdata}/{file_name}"); - let data = bytes::Bytes::from(std::fs::read(path).unwrap()); + async fn test_row_group_bloom_filter_pruning_predicate_with_exists_2_values() { + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_none_pruned() + // generate pruning predicate `(String = "Hello") OR (String = "the quick")` + .run( + col(r#""String""#) + .eq(lit("Hello")) + .or(col(r#""String""#).eq(lit("the quick"))), + ) + .await + } - // generate pruning predicate - let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); - let expr = col(r#""String""#).eq(lit("Hello")); - let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_with_exists_3_values() { + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_none_pruned() + // generate pruning predicate `(String = "Hello") OR (String = "the quick") OR (String = "are you")` + .run( + col(r#""String""#) + .eq(lit("Hello")) + .or(col(r#""String""#).eq(lit("the quick"))) + .or(col(r#""String""#).eq(lit("are you"))), + ) + .await + } - let row_groups = vec![0]; - let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( - file_name, - data, - &pruning_predicate, - &row_groups, - ) - .await - .unwrap(); - assert_eq!(pruned_row_groups, row_groups); + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_with_or_not_eq() { + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_none_pruned() + // generate pruning predicate `(String = "foo") OR (String != "bar")` + .run( + col(r#""String""#) + .not_eq(lit("foo")) + .or(col(r#""String""#).not_eq(lit("bar"))), + ) + .await } #[tokio::test] async fn test_row_group_bloom_filter_pruning_predicate_without_bloom_filter() { - // load parquet file - let testdata = datafusion_common::test_util::parquet_test_data(); - let file_name = "alltypes_plain.parquet"; - let path = format!("{testdata}/{file_name}"); - let data = bytes::Bytes::from(std::fs::read(path).unwrap()); + // generate pruning predicate on a column without a bloom filter + BloomFilterTest::new_all_types() + .with_expect_none_pruned() + .run(col(r#""string_col""#).eq(lit("0"))) + .await + } - // generate pruning predicate - let schema = Schema::new(vec![Field::new("string_col", DataType::Utf8, false)]); - let expr = col(r#""string_col""#).eq(lit("0")); - let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + struct BloomFilterTest { + file_name: String, + schema: Schema, + // which row groups should be attempted to prune + row_groups: Vec, + // which row groups are expected to be left after pruning. Must be set + // otherwise will panic on run() + post_pruning_row_groups: Option>, + } - let row_groups = vec![0]; - let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( - file_name, - data, - &pruning_predicate, - &row_groups, - ) - .await - .unwrap(); - assert_eq!(pruned_row_groups, row_groups); + impl BloomFilterTest { + /// Return a test for data_index_bloom_encoding_stats.parquet + /// Note the values in the `String` column are: + /// ```sql + /// ❯ select * from './parquet-testing/data/data_index_bloom_encoding_stats.parquet'; + /// +-----------+ + /// | String | + /// +-----------+ + /// | Hello | + /// | This is | + /// | a | + /// | test | + /// | How | + /// | are you | + /// | doing | + /// | today | + /// | the quick | + /// | brown fox | + /// | jumps | + /// | over | + /// | the lazy | + /// | dog | + /// +-----------+ + /// ``` + fn new_data_index_bloom_encoding_stats() -> Self { + Self { + file_name: String::from("data_index_bloom_encoding_stats.parquet"), + schema: Schema::new(vec![Field::new("String", DataType::Utf8, false)]), + row_groups: vec![0], + post_pruning_row_groups: None, + } + } + + // Return a test for alltypes_plain.parquet + fn new_all_types() -> Self { + Self { + file_name: String::from("alltypes_plain.parquet"), + schema: Schema::new(vec![Field::new( + "string_col", + DataType::Utf8, + false, + )]), + row_groups: vec![0], + post_pruning_row_groups: None, + } + } + + /// Expect all row groups to be pruned + pub fn with_expect_all_pruned(mut self) -> Self { + self.post_pruning_row_groups = Some(vec![]); + self + } + + /// Expect all row groups not to be pruned + pub fn with_expect_none_pruned(mut self) -> Self { + self.post_pruning_row_groups = Some(self.row_groups.clone()); + self + } + + /// Prune this file using the specified expression and check that the expected row groups are left + async fn run(self, expr: Expr) { + let Self { + file_name, + schema, + row_groups, + post_pruning_row_groups, + } = self; + + let post_pruning_row_groups = + post_pruning_row_groups.expect("post_pruning_row_groups must be set"); + + let testdata = datafusion_common::test_util::parquet_test_data(); + let path = format!("{testdata}/{file_name}"); + let data = bytes::Bytes::from(std::fs::read(path).unwrap()); + + let expr = logical2physical(&expr, &schema); + let pruning_predicate = + PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + + let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( + &file_name, + data, + &pruning_predicate, + &row_groups, + ) + .await + .unwrap(); + assert_eq!(pruned_row_groups, post_pruning_row_groups); + } } async fn test_row_group_bloom_filter_pruning_predicate( @@ -1243,6 +1245,7 @@ mod tests { last_modified: chrono::DateTime::from(std::time::SystemTime::now()), size: data.len(), e_tag: None, + version: None, }; let in_memory = object_store::memory::InMemory::new(); in_memory @@ -1261,6 +1264,7 @@ mod tests { let metadata = builder.metadata().clone(); let pruned_row_group = prune_row_groups_by_bloom_filters( + pruning_predicate.schema(), &mut builder, row_groups, metadata.row_groups(), @@ -1271,97 +1275,4 @@ mod tests { Ok(pruned_row_group) } - - fn sql_to_physical_plan(sql: &str) -> Result> { - use datafusion_optimizer::{ - analyzer::Analyzer, optimizer::Optimizer, OptimizerConfig, OptimizerContext, - }; - use datafusion_sql::{ - planner::SqlToRel, - sqlparser::{ast::Statement, parser::Parser}, - }; - use sqlparser::dialect::GenericDialect; - - // parse the SQL - let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... - let ast: Vec = Parser::parse_sql(&dialect, sql).unwrap(); - let statement = &ast[0]; - - // create a logical query plan - let schema_provider = TestSchemaProvider::new(); - let sql_to_rel = SqlToRel::new(&schema_provider); - let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); - - // hard code the return value of now() - let config = OptimizerContext::new().with_skip_failing_rules(false); - let analyzer = Analyzer::new(); - let optimizer = Optimizer::new(); - // analyze and optimize the logical plan - let plan = analyzer.execute_and_check(&plan, config.options(), |_, _| {})?; - let plan = optimizer.optimize(&plan, &config, |_, _| {})?; - // convert the logical plan into a physical plan - let exprs = plan.expressions(); - let expr = &exprs[0]; - let df_schema = plan.schema().as_ref().to_owned(); - let tb_schema: Schema = df_schema.clone().into(); - let execution_props = ExecutionProps::new(); - create_physical_expr(expr, &df_schema, &tb_schema, &execution_props) - } - - struct TestSchemaProvider { - options: ConfigOptions, - tables: HashMap>, - } - - impl TestSchemaProvider { - pub fn new() -> Self { - let mut tables = HashMap::new(); - tables.insert( - "tbl".to_string(), - create_table_source(vec![Field::new( - "String".to_string(), - DataType::Utf8, - false, - )]), - ); - - Self { - options: Default::default(), - tables, - } - } - } - - impl ContextProvider for TestSchemaProvider { - fn get_table_source(&self, name: TableReference) -> Result> { - match self.tables.get(name.table()) { - Some(table) => Ok(table.clone()), - _ => datafusion_common::plan_err!("Table not found: {}", name.table()), - } - } - - fn get_function_meta(&self, _name: &str) -> Option> { - None - } - - fn get_aggregate_meta(&self, _name: &str) -> Option> { - None - } - - fn get_variable_type(&self, _variable_names: &[String]) -> Option { - None - } - - fn options(&self) -> &ConfigOptions { - &self.options - } - - fn get_window_meta(&self, _name: &str) -> Option> { - None - } - } - - fn create_table_source(fields: Vec) -> Arc { - Arc::new(LogicalTableSource::new(Arc::new(Schema::new(fields)))) - } } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs new file mode 100644 index 0000000000000..4e472606da515 --- /dev/null +++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs @@ -0,0 +1,899 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`min_statistics`] and [`max_statistics`] convert statistics in parquet format to arrow [`ArrayRef`]. + +// TODO: potentially move this to arrow-rs: https://github.com/apache/arrow-rs/issues/4328 + +use arrow::{array::ArrayRef, datatypes::DataType}; +use arrow_array::new_empty_array; +use arrow_schema::{FieldRef, Schema}; +use datafusion_common::{Result, ScalarValue}; +use parquet::file::statistics::Statistics as ParquetStatistics; +use parquet::schema::types::SchemaDescriptor; + +// Convert the bytes array to i128. +// The endian of the input bytes array must be big-endian. +pub(crate) fn from_bytes_to_i128(b: &[u8]) -> i128 { + // The bytes array are from parquet file and must be the big-endian. + // The endian is defined by parquet format, and the reference document + // https://github.com/apache/parquet-format/blob/54e53e5d7794d383529dd30746378f19a12afd58/src/main/thrift/parquet.thrift#L66 + i128::from_be_bytes(sign_extend_be(b)) +} + +// Copy from arrow-rs +// https://github.com/apache/arrow-rs/blob/733b7e7fd1e8c43a404c3ce40ecf741d493c21b4/parquet/src/arrow/buffer/bit_util.rs#L55 +// Convert the byte slice to fixed length byte array with the length of 16 +fn sign_extend_be(b: &[u8]) -> [u8; 16] { + assert!(b.len() <= 16, "Array too large, expected less than 16"); + let is_negative = (b[0] & 128u8) == 128u8; + let mut result = if is_negative { [255u8; 16] } else { [0u8; 16] }; + for (d, s) in result.iter_mut().skip(16 - b.len()).zip(b) { + *d = *s; + } + result +} + +/// Extract a single min/max statistics from a [`ParquetStatistics`] object +/// +/// * `$column_statistics` is the `ParquetStatistics` object +/// * `$func is the function` (`min`/`max`) to call to get the value +/// * `$bytes_func` is the function (`min_bytes`/`max_bytes`) to call to get the value as bytes +/// * `$target_arrow_type` is the [`DataType`] of the target statistics +macro_rules! get_statistic { + ($column_statistics:expr, $func:ident, $bytes_func:ident, $target_arrow_type:expr) => {{ + if !$column_statistics.has_min_max_set() { + return None; + } + match $column_statistics { + ParquetStatistics::Boolean(s) => Some(ScalarValue::Boolean(Some(*s.$func()))), + ParquetStatistics::Int32(s) => { + match $target_arrow_type { + // int32 to decimal with the precision and scale + Some(DataType::Decimal128(precision, scale)) => { + Some(ScalarValue::Decimal128( + Some(*s.$func() as i128), + *precision, + *scale, + )) + } + _ => Some(ScalarValue::Int32(Some(*s.$func()))), + } + } + ParquetStatistics::Int64(s) => { + match $target_arrow_type { + // int64 to decimal with the precision and scale + Some(DataType::Decimal128(precision, scale)) => { + Some(ScalarValue::Decimal128( + Some(*s.$func() as i128), + *precision, + *scale, + )) + } + _ => Some(ScalarValue::Int64(Some(*s.$func()))), + } + } + // 96 bit ints not supported + ParquetStatistics::Int96(_) => None, + ParquetStatistics::Float(s) => Some(ScalarValue::Float32(Some(*s.$func()))), + ParquetStatistics::Double(s) => Some(ScalarValue::Float64(Some(*s.$func()))), + ParquetStatistics::ByteArray(s) => { + match $target_arrow_type { + // decimal data type + Some(DataType::Decimal128(precision, scale)) => { + Some(ScalarValue::Decimal128( + Some(from_bytes_to_i128(s.$bytes_func())), + *precision, + *scale, + )) + } + _ => { + let s = std::str::from_utf8(s.$bytes_func()) + .map(|s| s.to_string()) + .ok(); + Some(ScalarValue::Utf8(s)) + } + } + } + // type not supported yet + ParquetStatistics::FixedLenByteArray(s) => { + match $target_arrow_type { + // just support the decimal data type + Some(DataType::Decimal128(precision, scale)) => { + Some(ScalarValue::Decimal128( + Some(from_bytes_to_i128(s.$bytes_func())), + *precision, + *scale, + )) + } + _ => None, + } + } + } + }}; +} + +/// Lookups up the parquet column by name +/// +/// Returns the parquet column index and the corresponding arrow field +pub(crate) fn parquet_column<'a>( + parquet_schema: &SchemaDescriptor, + arrow_schema: &'a Schema, + name: &str, +) -> Option<(usize, &'a FieldRef)> { + let (root_idx, field) = arrow_schema.fields.find(name)?; + if field.data_type().is_nested() { + // Nested fields are not supported and require non-trivial logic + // to correctly walk the parquet schema accounting for the + // logical type rules - + // + // For example a ListArray could correspond to anything from 1 to 3 levels + // in the parquet schema + return None; + } + + // This could be made more efficient (#TBD) + let parquet_idx = (0..parquet_schema.columns().len()) + .find(|x| parquet_schema.get_column_root_idx(*x) == root_idx)?; + Some((parquet_idx, field)) +} + +/// Extracts the min statistics from an iterator of [`ParquetStatistics`] to an [`ArrayRef`] +pub(crate) fn min_statistics<'a, I: Iterator>>( + data_type: &DataType, + iterator: I, +) -> Result { + let scalars = iterator + .map(|x| x.and_then(|s| get_statistic!(s, min, min_bytes, Some(data_type)))); + collect_scalars(data_type, scalars) +} + +/// Extracts the max statistics from an iterator of [`ParquetStatistics`] to an [`ArrayRef`] +pub(crate) fn max_statistics<'a, I: Iterator>>( + data_type: &DataType, + iterator: I, +) -> Result { + let scalars = iterator + .map(|x| x.and_then(|s| get_statistic!(s, max, max_bytes, Some(data_type)))); + collect_scalars(data_type, scalars) +} + +/// Builds an array from an iterator of ScalarValue +fn collect_scalars>>( + data_type: &DataType, + iterator: I, +) -> Result { + let mut scalars = iterator.peekable(); + match scalars.peek().is_none() { + true => Ok(new_empty_array(data_type)), + false => { + let null = ScalarValue::try_from(data_type)?; + ScalarValue::iter_to_array(scalars.map(|x| x.unwrap_or_else(|| null.clone()))) + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use arrow_array::{ + new_null_array, Array, BinaryArray, BooleanArray, Decimal128Array, Float32Array, + Float64Array, Int32Array, Int64Array, RecordBatch, StringArray, StructArray, + TimestampNanosecondArray, + }; + use arrow_schema::{Field, SchemaRef}; + use bytes::Bytes; + use datafusion_common::test_util::parquet_test_data; + use parquet::arrow::arrow_reader::ArrowReaderBuilder; + use parquet::arrow::arrow_writer::ArrowWriter; + use parquet::file::metadata::{ParquetMetaData, RowGroupMetaData}; + use parquet::file::properties::{EnabledStatistics, WriterProperties}; + use std::path::PathBuf; + use std::sync::Arc; + + // TODO error cases (with parquet statistics that are mismatched in expected type) + + #[test] + fn roundtrip_empty() { + let empty_bool_array = new_empty_array(&DataType::Boolean); + Test { + input: empty_bool_array.clone(), + expected_min: empty_bool_array.clone(), + expected_max: empty_bool_array.clone(), + } + .run() + } + + #[test] + fn roundtrip_bool() { + Test { + input: bool_array([ + // row group 1 + Some(true), + None, + Some(true), + // row group 2 + Some(true), + Some(false), + None, + // row group 3 + None, + None, + None, + ]), + expected_min: bool_array([Some(true), Some(false), None]), + expected_max: bool_array([Some(true), Some(true), None]), + } + .run() + } + + #[test] + fn roundtrip_int32() { + Test { + input: i32_array([ + // row group 1 + Some(1), + None, + Some(3), + // row group 2 + Some(0), + Some(5), + None, + // row group 3 + None, + None, + None, + ]), + expected_min: i32_array([Some(1), Some(0), None]), + expected_max: i32_array([Some(3), Some(5), None]), + } + .run() + } + + #[test] + fn roundtrip_int64() { + Test { + input: i64_array([ + // row group 1 + Some(1), + None, + Some(3), + // row group 2 + Some(0), + Some(5), + None, + // row group 3 + None, + None, + None, + ]), + expected_min: i64_array([Some(1), Some(0), None]), + expected_max: i64_array(vec![Some(3), Some(5), None]), + } + .run() + } + + #[test] + fn roundtrip_f32() { + Test { + input: f32_array([ + // row group 1 + Some(1.0), + None, + Some(3.0), + // row group 2 + Some(-1.0), + Some(5.0), + None, + // row group 3 + None, + None, + None, + ]), + expected_min: f32_array([Some(1.0), Some(-1.0), None]), + expected_max: f32_array([Some(3.0), Some(5.0), None]), + } + .run() + } + + #[test] + fn roundtrip_f64() { + Test { + input: f64_array([ + // row group 1 + Some(1.0), + None, + Some(3.0), + // row group 2 + Some(-1.0), + Some(5.0), + None, + // row group 3 + None, + None, + None, + ]), + expected_min: f64_array([Some(1.0), Some(-1.0), None]), + expected_max: f64_array([Some(3.0), Some(5.0), None]), + } + .run() + } + + #[test] + #[should_panic( + expected = "Inconsistent types in ScalarValue::iter_to_array. Expected Int64, got TimestampNanosecond(NULL, None)" + )] + // Due to https://github.com/apache/arrow-datafusion/issues/8295 + fn roundtrip_timestamp() { + Test { + input: timestamp_array([ + // row group 1 + Some(1), + None, + Some(3), + // row group 2 + Some(9), + Some(5), + None, + // row group 3 + None, + None, + None, + ]), + expected_min: timestamp_array([Some(1), Some(5), None]), + expected_max: timestamp_array([Some(3), Some(9), None]), + } + .run() + } + + #[test] + fn roundtrip_decimal() { + Test { + input: Arc::new( + Decimal128Array::from(vec![ + // row group 1 + Some(100), + None, + Some(22000), + // row group 2 + Some(500000), + Some(330000), + None, + // row group 3 + None, + None, + None, + ]) + .with_precision_and_scale(9, 2) + .unwrap(), + ), + expected_min: Arc::new( + Decimal128Array::from(vec![Some(100), Some(330000), None]) + .with_precision_and_scale(9, 2) + .unwrap(), + ), + expected_max: Arc::new( + Decimal128Array::from(vec![Some(22000), Some(500000), None]) + .with_precision_and_scale(9, 2) + .unwrap(), + ), + } + .run() + } + + #[test] + fn roundtrip_utf8() { + Test { + input: utf8_array([ + // row group 1 + Some("A"), + None, + Some("Q"), + // row group 2 + Some("ZZ"), + Some("AA"), + None, + // row group 3 + None, + None, + None, + ]), + expected_min: utf8_array([Some("A"), Some("AA"), None]), + expected_max: utf8_array([Some("Q"), Some("ZZ"), None]), + } + .run() + } + + #[test] + fn roundtrip_struct() { + let mut test = Test { + input: struct_array(vec![ + // row group 1 + (Some(true), Some(1)), + (None, None), + (Some(true), Some(3)), + // row group 2 + (Some(true), Some(0)), + (Some(false), Some(5)), + (None, None), + // row group 3 + (None, None), + (None, None), + (None, None), + ]), + expected_min: struct_array(vec![ + (Some(true), Some(1)), + (Some(true), Some(0)), + (None, None), + ]), + + expected_max: struct_array(vec![ + (Some(true), Some(3)), + (Some(true), Some(0)), + (None, None), + ]), + }; + // Due to https://github.com/apache/arrow-datafusion/issues/8334, + // statistics for struct arrays are not supported + test.expected_min = + new_null_array(test.input.data_type(), test.expected_min.len()); + test.expected_max = + new_null_array(test.input.data_type(), test.expected_min.len()); + test.run() + } + + #[test] + #[should_panic( + expected = "Inconsistent types in ScalarValue::iter_to_array. Expected Utf8, got Binary(NULL)" + )] + // Due to https://github.com/apache/arrow-datafusion/issues/8295 + fn roundtrip_binary() { + Test { + input: Arc::new(BinaryArray::from_opt_vec(vec![ + // row group 1 + Some(b"A"), + None, + Some(b"Q"), + // row group 2 + Some(b"ZZ"), + Some(b"AA"), + None, + // row group 3 + None, + None, + None, + ])), + expected_min: Arc::new(BinaryArray::from_opt_vec(vec![ + Some(b"A"), + Some(b"AA"), + None, + ])), + expected_max: Arc::new(BinaryArray::from_opt_vec(vec![ + Some(b"Q"), + Some(b"ZZ"), + None, + ])), + } + .run() + } + + #[test] + fn struct_and_non_struct() { + // Ensures that statistics for an array that appears *after* a struct + // array are not wrong + let struct_col = struct_array(vec![ + // row group 1 + (Some(true), Some(1)), + (None, None), + (Some(true), Some(3)), + ]); + let int_col = i32_array([Some(100), Some(200), Some(300)]); + let expected_min = i32_array([Some(100)]); + let expected_max = i32_array(vec![Some(300)]); + + // use a name that shadows a name in the struct column + match struct_col.data_type() { + DataType::Struct(fields) => { + assert_eq!(fields.get(1).unwrap().name(), "int_col") + } + _ => panic!("unexpected data type for struct column"), + }; + + let input_batch = RecordBatch::try_from_iter([ + ("struct_col", struct_col), + ("int_col", int_col), + ]) + .unwrap(); + + let schema = input_batch.schema(); + + let metadata = parquet_metadata(schema.clone(), input_batch); + let parquet_schema = metadata.file_metadata().schema_descr(); + + // read the int_col statistics + let (idx, _) = parquet_column(parquet_schema, &schema, "int_col").unwrap(); + assert_eq!(idx, 2); + + let row_groups = metadata.row_groups(); + let iter = row_groups.iter().map(|x| x.column(idx).statistics()); + + let min = min_statistics(&DataType::Int32, iter.clone()).unwrap(); + assert_eq!( + &min, + &expected_min, + "Min. Statistics\n\n{}\n\n", + DisplayStats(row_groups) + ); + + let max = max_statistics(&DataType::Int32, iter).unwrap(); + assert_eq!( + &max, + &expected_max, + "Max. Statistics\n\n{}\n\n", + DisplayStats(row_groups) + ); + } + + #[test] + fn nan_in_stats() { + // /parquet-testing/data/nan_in_stats.parquet + // row_groups: 1 + // "x": Double({min: Some(1.0), max: Some(NaN), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + + TestFile::new("nan_in_stats.parquet") + .with_column(ExpectedColumn { + name: "x", + expected_min: Arc::new(Float64Array::from(vec![Some(1.0)])), + expected_max: Arc::new(Float64Array::from(vec![Some(f64::NAN)])), + }) + .run(); + } + + #[test] + fn alltypes_plain() { + // /parquet-testing/data/datapage_v1-snappy-compressed-checksum.parquet + // row_groups: 1 + // (has no statistics) + TestFile::new("alltypes_plain.parquet") + // No column statistics should be read as NULL, but with the right type + .with_column(ExpectedColumn { + name: "id", + expected_min: i32_array([None]), + expected_max: i32_array([None]), + }) + .with_column(ExpectedColumn { + name: "bool_col", + expected_min: bool_array([None]), + expected_max: bool_array([None]), + }) + .run(); + } + + #[test] + fn alltypes_tiny_pages() { + // /parquet-testing/data/alltypes_tiny_pages.parquet + // row_groups: 1 + // "id": Int32({min: Some(0), max: Some(7299), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "bool_col": Boolean({min: Some(false), max: Some(true), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "tinyint_col": Int32({min: Some(0), max: Some(9), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "smallint_col": Int32({min: Some(0), max: Some(9), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "int_col": Int32({min: Some(0), max: Some(9), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "bigint_col": Int64({min: Some(0), max: Some(90), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "float_col": Float({min: Some(0.0), max: Some(9.9), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "double_col": Double({min: Some(0.0), max: Some(90.89999999999999), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "date_string_col": ByteArray({min: Some(ByteArray { data: "01/01/09" }), max: Some(ByteArray { data: "12/31/10" }), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "string_col": ByteArray({min: Some(ByteArray { data: "0" }), max: Some(ByteArray { data: "9" }), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "timestamp_col": Int96({min: None, max: None, distinct_count: None, null_count: 0, min_max_deprecated: true, min_max_backwards_compatible: true}) + // "year": Int32({min: Some(2009), max: Some(2010), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "month": Int32({min: Some(1), max: Some(12), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + TestFile::new("alltypes_tiny_pages.parquet") + .with_column(ExpectedColumn { + name: "id", + expected_min: i32_array([Some(0)]), + expected_max: i32_array([Some(7299)]), + }) + .with_column(ExpectedColumn { + name: "bool_col", + expected_min: bool_array([Some(false)]), + expected_max: bool_array([Some(true)]), + }) + .with_column(ExpectedColumn { + name: "tinyint_col", + expected_min: i32_array([Some(0)]), + expected_max: i32_array([Some(9)]), + }) + .with_column(ExpectedColumn { + name: "smallint_col", + expected_min: i32_array([Some(0)]), + expected_max: i32_array([Some(9)]), + }) + .with_column(ExpectedColumn { + name: "int_col", + expected_min: i32_array([Some(0)]), + expected_max: i32_array([Some(9)]), + }) + .with_column(ExpectedColumn { + name: "bigint_col", + expected_min: i64_array([Some(0)]), + expected_max: i64_array([Some(90)]), + }) + .with_column(ExpectedColumn { + name: "float_col", + expected_min: f32_array([Some(0.0)]), + expected_max: f32_array([Some(9.9)]), + }) + .with_column(ExpectedColumn { + name: "double_col", + expected_min: f64_array([Some(0.0)]), + expected_max: f64_array([Some(90.89999999999999)]), + }) + .with_column(ExpectedColumn { + name: "date_string_col", + expected_min: utf8_array([Some("01/01/09")]), + expected_max: utf8_array([Some("12/31/10")]), + }) + .with_column(ExpectedColumn { + name: "string_col", + expected_min: utf8_array([Some("0")]), + expected_max: utf8_array([Some("9")]), + }) + // File has no min/max for timestamp_col + .with_column(ExpectedColumn { + name: "timestamp_col", + expected_min: timestamp_array([None]), + expected_max: timestamp_array([None]), + }) + .with_column(ExpectedColumn { + name: "year", + expected_min: i32_array([Some(2009)]), + expected_max: i32_array([Some(2010)]), + }) + .with_column(ExpectedColumn { + name: "month", + expected_min: i32_array([Some(1)]), + expected_max: i32_array([Some(12)]), + }) + .run(); + } + + #[test] + fn fixed_length_decimal_legacy() { + // /parquet-testing/data/fixed_length_decimal_legacy.parquet + // row_groups: 1 + // "value": FixedLenByteArray({min: Some(FixedLenByteArray(ByteArray { data: Some(ByteBufferPtr { data: b"\0\0\0\0\0\xc8" }) })), max: Some(FixedLenByteArray(ByteArray { data: "\0\0\0\0\t`" })), distinct_count: None, null_count: 0, min_max_deprecated: true, min_max_backwards_compatible: true}) + + TestFile::new("fixed_length_decimal_legacy.parquet") + .with_column(ExpectedColumn { + name: "value", + expected_min: Arc::new( + Decimal128Array::from(vec![Some(200)]) + .with_precision_and_scale(13, 2) + .unwrap(), + ), + expected_max: Arc::new( + Decimal128Array::from(vec![Some(2400)]) + .with_precision_and_scale(13, 2) + .unwrap(), + ), + }) + .run(); + } + + const ROWS_PER_ROW_GROUP: usize = 3; + + /// Writes the input batch into a parquet file, with every every three rows as + /// their own row group, and compares the min/maxes to the expected values + struct Test { + input: ArrayRef, + expected_min: ArrayRef, + expected_max: ArrayRef, + } + + impl Test { + fn run(self) { + let Self { + input, + expected_min, + expected_max, + } = self; + + let input_batch = RecordBatch::try_from_iter([("c1", input)]).unwrap(); + + let schema = input_batch.schema(); + + let metadata = parquet_metadata(schema.clone(), input_batch); + let parquet_schema = metadata.file_metadata().schema_descr(); + + let row_groups = metadata.row_groups(); + + for field in schema.fields() { + if field.data_type().is_nested() { + let lookup = parquet_column(parquet_schema, &schema, field.name()); + assert_eq!(lookup, None); + continue; + } + + let (idx, f) = + parquet_column(parquet_schema, &schema, field.name()).unwrap(); + assert_eq!(f, field); + + let iter = row_groups.iter().map(|x| x.column(idx).statistics()); + let min = min_statistics(f.data_type(), iter.clone()).unwrap(); + assert_eq!( + &min, + &expected_min, + "Min. Statistics\n\n{}\n\n", + DisplayStats(row_groups) + ); + + let max = max_statistics(f.data_type(), iter).unwrap(); + assert_eq!( + &max, + &expected_max, + "Max. Statistics\n\n{}\n\n", + DisplayStats(row_groups) + ); + } + } + } + + /// Write the specified batches out as parquet and return the metadata + fn parquet_metadata(schema: SchemaRef, batch: RecordBatch) -> Arc { + let props = WriterProperties::builder() + .set_statistics_enabled(EnabledStatistics::Chunk) + .set_max_row_group_size(ROWS_PER_ROW_GROUP) + .build(); + + let mut buffer = Vec::new(); + let mut writer = ArrowWriter::try_new(&mut buffer, schema, Some(props)).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + + let reader = ArrowReaderBuilder::try_new(Bytes::from(buffer)).unwrap(); + reader.metadata().clone() + } + + /// Formats the statistics nicely for display + struct DisplayStats<'a>(&'a [RowGroupMetaData]); + impl<'a> std::fmt::Display for DisplayStats<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let row_groups = self.0; + writeln!(f, " row_groups: {}", row_groups.len())?; + for rg in row_groups { + for col in rg.columns() { + if let Some(statistics) = col.statistics() { + writeln!(f, " {}: {:?}", col.column_path(), statistics)?; + } + } + } + Ok(()) + } + } + + struct ExpectedColumn { + name: &'static str, + expected_min: ArrayRef, + expected_max: ArrayRef, + } + + /// Reads statistics out of the specified, and compares them to the expected values + struct TestFile { + file_name: &'static str, + expected_columns: Vec, + } + + impl TestFile { + fn new(file_name: &'static str) -> Self { + Self { + file_name, + expected_columns: Vec::new(), + } + } + + fn with_column(mut self, column: ExpectedColumn) -> Self { + self.expected_columns.push(column); + self + } + + /// Reads the specified parquet file and validates that the exepcted min/max + /// values for the specified columns are as expected. + fn run(self) { + let path = PathBuf::from(parquet_test_data()).join(self.file_name); + let file = std::fs::File::open(path).unwrap(); + let reader = ArrowReaderBuilder::try_new(file).unwrap(); + let arrow_schema = reader.schema(); + let metadata = reader.metadata(); + let row_groups = metadata.row_groups(); + let parquet_schema = metadata.file_metadata().schema_descr(); + + for expected_column in self.expected_columns { + let ExpectedColumn { + name, + expected_min, + expected_max, + } = expected_column; + + let (idx, field) = + parquet_column(parquet_schema, arrow_schema, name).unwrap(); + + let iter = row_groups.iter().map(|x| x.column(idx).statistics()); + let actual_min = min_statistics(field.data_type(), iter.clone()).unwrap(); + assert_eq!(&expected_min, &actual_min, "column {name}"); + + let actual_max = max_statistics(field.data_type(), iter).unwrap(); + assert_eq!(&expected_max, &actual_max, "column {name}"); + } + } + } + + fn bool_array(input: impl IntoIterator>) -> ArrayRef { + let array: BooleanArray = input.into_iter().collect(); + Arc::new(array) + } + + fn i32_array(input: impl IntoIterator>) -> ArrayRef { + let array: Int32Array = input.into_iter().collect(); + Arc::new(array) + } + + fn i64_array(input: impl IntoIterator>) -> ArrayRef { + let array: Int64Array = input.into_iter().collect(); + Arc::new(array) + } + + fn f32_array(input: impl IntoIterator>) -> ArrayRef { + let array: Float32Array = input.into_iter().collect(); + Arc::new(array) + } + + fn f64_array(input: impl IntoIterator>) -> ArrayRef { + let array: Float64Array = input.into_iter().collect(); + Arc::new(array) + } + + fn timestamp_array(input: impl IntoIterator>) -> ArrayRef { + let array: TimestampNanosecondArray = input.into_iter().collect(); + Arc::new(array) + } + + fn utf8_array<'a>(input: impl IntoIterator>) -> ArrayRef { + let array: StringArray = input + .into_iter() + .map(|s| s.map(|s| s.to_string())) + .collect(); + Arc::new(array) + } + + // returns a struct array with columns "bool_col" and "int_col" with the specified values + fn struct_array(input: Vec<(Option, Option)>) -> ArrayRef { + let boolean: BooleanArray = input.iter().map(|(b, _i)| b).collect(); + let int: Int32Array = input.iter().map(|(_b, i)| i).collect(); + + let nullable = true; + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("bool_col", DataType::Boolean, nullable)), + Arc::new(boolean) as ArrayRef, + ), + ( + Arc::new(Field::new("int_col", DataType::Int32, nullable)), + Arc::new(int) as ArrayRef, + ), + ]); + Arc::new(struct_array) + } +} diff --git a/datafusion/core/src/datasource/provider.rs b/datafusion/core/src/datasource/provider.rs index 7d9f9e86d6030..c1cee849fe5cd 100644 --- a/datafusion/core/src/datasource/provider.rs +++ b/datafusion/core/src/datasource/provider.rs @@ -26,6 +26,8 @@ use datafusion_expr::{CreateExternalTable, LogicalPlan}; pub use datafusion_expr::{TableProviderFilterPushDown, TableType}; use crate::arrow::datatypes::SchemaRef; +use crate::datasource::listing_table_factory::ListingTableFactory; +use crate::datasource::stream::StreamTableFactory; use crate::error::Result; use crate::execution::context::SessionState; use crate::logical_expr::Expr; @@ -64,6 +66,11 @@ pub trait TableProvider: Sync + Send { None } + /// Get the default value for a column, if available. + fn get_column_default(&self, _column: &str) -> Option<&Expr> { + None + } + /// Create an [`ExecutionPlan`] for scanning the table with optionally /// specified `projection`, `filter` and `limit`, described below. /// @@ -134,7 +141,11 @@ pub trait TableProvider: Sync + Send { /// (though it may return more). Like Projection Pushdown and Filter /// Pushdown, DataFusion pushes `LIMIT`s as far down in the plan as /// possible, called "Limit Pushdown" as some sources can use this - /// information to improve their performance. + /// information to improve their performance. Note that if there are any + /// Inexact filters pushed down, the LIMIT cannot be pushed down. This is + /// because inexact filters do not guarentee that every filtered row is + /// removed, so applying the limit could lead to too few rows being available + /// to return as a final result. async fn scan( &self, state: &SessionState, @@ -214,3 +225,41 @@ pub trait TableProviderFactory: Sync + Send { cmd: &CreateExternalTable, ) -> Result>; } + +/// The default [`TableProviderFactory`] +/// +/// If [`CreateExternalTable`] is unbounded calls [`StreamTableFactory::create`], +/// otherwise calls [`ListingTableFactory::create`] +#[derive(Debug, Default)] +pub struct DefaultTableFactory { + stream: StreamTableFactory, + listing: ListingTableFactory, +} + +impl DefaultTableFactory { + /// Creates a new [`DefaultTableFactory`] + pub fn new() -> Self { + Self::default() + } +} + +#[async_trait] +impl TableProviderFactory for DefaultTableFactory { + async fn create( + &self, + state: &SessionState, + cmd: &CreateExternalTable, + ) -> Result> { + let mut unbounded = cmd.unbounded; + for (k, v) in &cmd.options { + if k.eq_ignore_ascii_case("unbounded") && v.eq_ignore_ascii_case("true") { + unbounded = true + } + } + + match unbounded { + true => self.stream.create(state, cmd).await, + false => self.listing.create(state, cmd).await, + } + } +} diff --git a/datafusion/core/src/datasource/statistics.rs b/datafusion/core/src/datasource/statistics.rs index 3d8248dfdeb28..695e139517cff 100644 --- a/datafusion/core/src/datasource/statistics.rs +++ b/datafusion/core/src/datasource/statistics.rs @@ -70,7 +70,11 @@ pub async fn get_statistics_with_limit( // files. This only applies when we know the number of rows. It also // currently ignores tables that have no statistics regarding the // number of rows. - if num_rows.get_value().unwrap_or(&usize::MIN) <= &limit.unwrap_or(usize::MAX) { + let conservative_num_rows = match num_rows { + Precision::Exact(nr) => nr, + _ => usize::MIN, + }; + if conservative_num_rows <= limit.unwrap_or(usize::MAX) { while let Some(current) = all_files.next().await { let (file, file_stats) = current?; result_files.push(file); diff --git a/datafusion/core/src/datasource/stream.rs b/datafusion/core/src/datasource/stream.rs new file mode 100644 index 0000000000000..830cd7a07e460 --- /dev/null +++ b/datafusion/core/src/datasource/stream.rs @@ -0,0 +1,365 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! TableProvider for stream sources, such as FIFO files + +use std::any::Any; +use std::fmt::Formatter; +use std::fs::{File, OpenOptions}; +use std::io::BufReader; +use std::path::PathBuf; +use std::str::FromStr; +use std::sync::Arc; + +use arrow_array::{RecordBatch, RecordBatchReader, RecordBatchWriter}; +use arrow_schema::SchemaRef; +use async_trait::async_trait; +use futures::StreamExt; +use tokio::task::spawn_blocking; + +use datafusion_common::{plan_err, Constraints, DataFusionError, Result}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_expr::{CreateExternalTable, Expr, TableType}; +use datafusion_physical_plan::common::AbortOnDropSingle; +use datafusion_physical_plan::insert::{DataSink, FileSinkExec}; +use datafusion_physical_plan::metrics::MetricsSet; +use datafusion_physical_plan::stream::RecordBatchReceiverStreamBuilder; +use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; +use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; + +use crate::datasource::provider::TableProviderFactory; +use crate::datasource::{create_ordering, TableProvider}; +use crate::execution::context::SessionState; + +/// A [`TableProviderFactory`] for [`StreamTable`] +#[derive(Debug, Default)] +pub struct StreamTableFactory {} + +#[async_trait] +impl TableProviderFactory for StreamTableFactory { + async fn create( + &self, + state: &SessionState, + cmd: &CreateExternalTable, + ) -> Result> { + let schema: SchemaRef = Arc::new(cmd.schema.as_ref().into()); + let location = cmd.location.clone(); + let encoding = cmd.file_type.parse()?; + + let config = StreamConfig::new_file(schema, location.into()) + .with_encoding(encoding) + .with_order(cmd.order_exprs.clone()) + .with_header(cmd.has_header) + .with_batch_size(state.config().batch_size()) + .with_constraints(cmd.constraints.clone()); + + Ok(Arc::new(StreamTable(Arc::new(config)))) + } +} + +/// The data encoding for [`StreamTable`] +#[derive(Debug, Clone)] +pub enum StreamEncoding { + /// CSV records + Csv, + /// Newline-delimited JSON records + Json, +} + +impl FromStr for StreamEncoding { + type Err = DataFusionError; + + fn from_str(s: &str) -> std::result::Result { + match s.to_ascii_lowercase().as_str() { + "csv" => Ok(Self::Csv), + "json" => Ok(Self::Json), + _ => plan_err!("Unrecognised StreamEncoding {}", s), + } + } +} + +/// The configuration for a [`StreamTable`] +#[derive(Debug)] +pub struct StreamConfig { + schema: SchemaRef, + location: PathBuf, + batch_size: usize, + encoding: StreamEncoding, + header: bool, + order: Vec>, + constraints: Constraints, +} + +impl StreamConfig { + /// Stream data from the file at `location` + /// + /// * Data will be read sequentially from the provided `location` + /// * New data will be appended to the end of the file + /// + /// The encoding can be configured with [`Self::with_encoding`] and + /// defaults to [`StreamEncoding::Csv`] + pub fn new_file(schema: SchemaRef, location: PathBuf) -> Self { + Self { + schema, + location, + batch_size: 1024, + encoding: StreamEncoding::Csv, + order: vec![], + header: false, + constraints: Constraints::empty(), + } + } + + /// Specify a sort order for the stream + pub fn with_order(mut self, order: Vec>) -> Self { + self.order = order; + self + } + + /// Specify the batch size + pub fn with_batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = batch_size; + self + } + + /// Specify whether the file has a header (only applicable for [`StreamEncoding::Csv`]) + pub fn with_header(mut self, header: bool) -> Self { + self.header = header; + self + } + + /// Specify an encoding for the stream + pub fn with_encoding(mut self, encoding: StreamEncoding) -> Self { + self.encoding = encoding; + self + } + + /// Assign constraints + pub fn with_constraints(mut self, constraints: Constraints) -> Self { + self.constraints = constraints; + self + } + + fn reader(&self) -> Result> { + let file = File::open(&self.location)?; + let schema = self.schema.clone(); + match &self.encoding { + StreamEncoding::Csv => { + let reader = arrow::csv::ReaderBuilder::new(schema) + .with_header(self.header) + .with_batch_size(self.batch_size) + .build(file)?; + + Ok(Box::new(reader)) + } + StreamEncoding::Json => { + let reader = arrow::json::ReaderBuilder::new(schema) + .with_batch_size(self.batch_size) + .build(BufReader::new(file))?; + + Ok(Box::new(reader)) + } + } + } + + fn writer(&self) -> Result> { + match &self.encoding { + StreamEncoding::Csv => { + let header = self.header && !self.location.exists(); + let file = OpenOptions::new() + .create(true) + .append(true) + .open(&self.location)?; + let writer = arrow::csv::WriterBuilder::new() + .with_header(header) + .build(file); + + Ok(Box::new(writer)) + } + StreamEncoding::Json => { + let file = OpenOptions::new() + .create(true) + .append(true) + .open(&self.location)?; + Ok(Box::new(arrow::json::LineDelimitedWriter::new(file))) + } + } + } +} + +/// A [`TableProvider`] for an unbounded stream source +/// +/// Currently only reading from / appending to a single file in-place is supported, but +/// other stream sources and sinks may be added in future. +/// +/// Applications looking to read/write datasets comprising multiple files, e.g. [Hadoop]-style +/// data stored in object storage, should instead consider [`ListingTable`]. +/// +/// [Hadoop]: https://hadoop.apache.org/ +/// [`ListingTable`]: crate::datasource::listing::ListingTable +pub struct StreamTable(Arc); + +impl StreamTable { + /// Create a new [`StreamTable`] for the given [`StreamConfig`] + pub fn new(config: Arc) -> Self { + Self(config) + } +} + +#[async_trait] +impl TableProvider for StreamTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.0.schema.clone() + } + + fn constraints(&self) -> Option<&Constraints> { + Some(&self.0.constraints) + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &SessionState, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + let projected_schema = match projection { + Some(p) => { + let projected = self.0.schema.project(p)?; + create_ordering(&projected, &self.0.order)? + } + None => create_ordering(self.0.schema.as_ref(), &self.0.order)?, + }; + + Ok(Arc::new(StreamingTableExec::try_new( + self.0.schema.clone(), + vec![Arc::new(StreamRead(self.0.clone())) as _], + projection, + projected_schema, + true, + )?)) + } + + async fn insert_into( + &self, + _state: &SessionState, + input: Arc, + _overwrite: bool, + ) -> Result> { + let ordering = match self.0.order.first() { + Some(x) => { + let schema = self.0.schema.as_ref(); + let orders = create_ordering(schema, std::slice::from_ref(x))?; + let ordering = orders.into_iter().next().unwrap(); + Some(ordering.into_iter().map(Into::into).collect()) + } + None => None, + }; + + Ok(Arc::new(FileSinkExec::new( + input, + Arc::new(StreamWrite(self.0.clone())), + self.0.schema.clone(), + ordering, + ))) + } +} + +struct StreamRead(Arc); + +impl PartitionStream for StreamRead { + fn schema(&self) -> &SchemaRef { + &self.0.schema + } + + fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { + let config = self.0.clone(); + let schema = self.0.schema.clone(); + let mut builder = RecordBatchReceiverStreamBuilder::new(schema, 2); + let tx = builder.tx(); + builder.spawn_blocking(move || { + let reader = config.reader()?; + for b in reader { + if tx.blocking_send(b.map_err(Into::into)).is_err() { + break; + } + } + Ok(()) + }); + builder.build() + } +} + +#[derive(Debug)] +struct StreamWrite(Arc); + +impl DisplayAs for StreamWrite { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + f.debug_struct("StreamWrite") + .field("location", &self.0.location) + .field("batch_size", &self.0.batch_size) + .field("encoding", &self.0.encoding) + .field("header", &self.0.header) + .finish_non_exhaustive() + } +} + +#[async_trait] +impl DataSink for StreamWrite { + fn as_any(&self) -> &dyn Any { + self + } + + fn metrics(&self) -> Option { + None + } + + async fn write_all( + &self, + mut data: SendableRecordBatchStream, + _context: &Arc, + ) -> Result { + let config = self.0.clone(); + let (sender, mut receiver) = tokio::sync::mpsc::channel::(2); + // Note: FIFO Files support poll so this could use AsyncFd + let write = AbortOnDropSingle::new(spawn_blocking(move || { + let mut count = 0_u64; + let mut writer = config.writer()?; + while let Some(batch) = receiver.blocking_recv() { + count += batch.num_rows() as u64; + writer.write(&batch)?; + } + Ok(count) + })); + + while let Some(b) = data.next().await.transpose()? { + if sender.send(b).await.is_err() { + break; + } + } + drop(sender); + write.await.unwrap() + } +} diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 9c500ec07293b..d6b7f046f3e3f 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -26,8 +26,8 @@ mod parquet; use crate::{ catalog::{CatalogList, MemoryCatalogList}, datasource::{ + function::{TableFunction, TableFunctionImpl}, listing::{ListingOptions, ListingTable}, - listing_table_factory::ListingTableFactory, provider::TableProviderFactory, }, datasource::{MemTable, ViewTable}, @@ -43,7 +43,7 @@ use datafusion_common::{ use datafusion_execution::registry::SerializerRegistry; use datafusion_expr::{ logical_plan::{DdlStatement, Statement}, - StringifiedPlan, UserDefinedLogicalNode, WindowUDF, + Expr, StringifiedPlan, UserDefinedLogicalNode, WindowUDF, }; pub use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::var_provider::is_system_variables; @@ -111,6 +111,7 @@ use datafusion_sql::planner::object_name_to_table_reference; use uuid::Uuid; // backwards compatibility +use crate::datasource::provider::DefaultTableFactory; use crate::execution::options::ArrowReadOptions; pub use datafusion_execution::config::SessionConfig; pub use datafusion_execution::TaskContext; @@ -529,6 +530,7 @@ impl SessionContext { if_not_exists, or_replace, constraints, + column_defaults, } = cmd; let input = Arc::try_unwrap(input).unwrap_or_else(|e| e.as_ref().clone()); @@ -542,7 +544,12 @@ impl SessionContext { let physical = DataFrame::new(self.state(), input); let batches: Vec<_> = physical.collect_partitioned().await?; - let table = Arc::new(MemTable::try_new(schema, batches)?); + let table = Arc::new( + // pass constraints and column defaults to the mem table. + MemTable::try_new(schema, batches)? + .with_constraints(constraints) + .with_column_defaults(column_defaults.into_iter().collect()), + ); self.register_table(&name, table)?; self.return_empty_dataframe() @@ -557,8 +564,10 @@ impl SessionContext { let batches: Vec<_> = physical.collect_partitioned().await?; let table = Arc::new( - // pass constraints to the mem table. - MemTable::try_new(schema, batches)?.with_constraints(constraints), + // pass constraints and column defaults to the mem table. + MemTable::try_new(schema, batches)? + .with_constraints(constraints) + .with_column_defaults(column_defaults.into_iter().collect()), ); self.register_table(&name, table)?; @@ -795,6 +804,14 @@ impl SessionContext { .add_var_provider(variable_type, provider); } + /// Register a table UDF with this context + pub fn register_udtf(&self, name: &str, fun: Arc) { + self.state.write().table_functions.insert( + name.to_owned(), + Arc::new(TableFunction::new(name.to_owned(), fun)), + ); + } + /// Registers a scalar UDF within this context. /// /// Note in SQL queries, function names are looked up using @@ -802,11 +819,18 @@ impl SessionContext { /// /// - `SELECT MY_FUNC(x)...` will look for a function named `"my_func"` /// - `SELECT "my_FUNC"(x)` will look for a function named `"my_FUNC"` + /// Any functions registered with the udf name or its aliases will be overwritten with this new function pub fn register_udf(&self, f: ScalarUDF) { - self.state - .write() + let mut state = self.state.write(); + let aliases = f.aliases(); + for alias in aliases { + state + .scalar_functions + .insert(alias.to_string(), Arc::new(f.clone())); + } + state .scalar_functions - .insert(f.name.clone(), Arc::new(f)); + .insert(f.name().to_string(), Arc::new(f)); } /// Registers an aggregate UDF within this context. @@ -820,7 +844,7 @@ impl SessionContext { self.state .write() .aggregate_functions - .insert(f.name.clone(), Arc::new(f)); + .insert(f.name().to_string(), Arc::new(f)); } /// Registers a window UDF within this context. @@ -834,7 +858,7 @@ impl SessionContext { self.state .write() .window_functions - .insert(f.name.clone(), Arc::new(f)); + .insert(f.name().to_string(), Arc::new(f)); } /// Creates a [`DataFrame`] for reading a data source. @@ -858,10 +882,12 @@ impl SessionContext { // check if the file extension matches the expected extension for path in &table_paths { - let file_name = path.prefix().filename().unwrap_or_default(); - if !path.as_str().ends_with(&option_extension) && file_name.contains('.') { + let file_path = path.as_str(); + if !file_path.ends_with(option_extension.clone().as_str()) + && !path.is_collection() + { return exec_err!( - "File '{file_name}' does not match the expected extension '{option_extension}'" + "File path '{file_path}' does not match the expected extension '{option_extension}'" ); } } @@ -938,14 +964,9 @@ impl SessionContext { sql_definition: Option, ) -> Result<()> { let table_path = ListingTableUrl::parse(table_path)?; - let resolved_schema = match (provided_schema, options.infinite_source) { - (Some(s), _) => s, - (None, false) => options.infer_schema(&self.state(), &table_path).await?, - (None, true) => { - return plan_err!( - "Schema inference for infinite data sources is not supported." - ) - } + let resolved_schema = match provided_schema { + Some(s) => s, + None => options.infer_schema(&self.state(), &table_path).await?, }; let config = ListingTableConfig::new(table_path) .with_listing_options(options) @@ -1224,6 +1245,8 @@ pub struct SessionState { query_planner: Arc, /// Collection of catalogs containing schemas and ultimately TableProviders catalog_list: Arc, + /// Table Functions + table_functions: HashMap>, /// Scalar functions that are registered with the context scalar_functions: HashMap>, /// Aggregate functions registered in the context @@ -1285,12 +1308,12 @@ impl SessionState { let mut table_factories: HashMap> = HashMap::new(); #[cfg(feature = "parquet")] - table_factories.insert("PARQUET".into(), Arc::new(ListingTableFactory::new())); - table_factories.insert("CSV".into(), Arc::new(ListingTableFactory::new())); - table_factories.insert("JSON".into(), Arc::new(ListingTableFactory::new())); - table_factories.insert("NDJSON".into(), Arc::new(ListingTableFactory::new())); - table_factories.insert("AVRO".into(), Arc::new(ListingTableFactory::new())); - table_factories.insert("ARROW".into(), Arc::new(ListingTableFactory::new())); + table_factories.insert("PARQUET".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("CSV".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("JSON".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("NDJSON".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("AVRO".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("ARROW".into(), Arc::new(DefaultTableFactory::new())); if config.create_default_catalog_and_schema() { let default_catalog = MemoryCatalogProvider::new(); @@ -1322,6 +1345,7 @@ impl SessionState { physical_optimizers: PhysicalOptimizer::new(), query_planner: Arc::new(DefaultQueryPlanner {}), catalog_list, + table_functions: HashMap::new(), scalar_functions: HashMap::new(), aggregate_functions: HashMap::new(), window_functions: HashMap::new(), @@ -1597,9 +1621,6 @@ impl SessionState { .0 .insert(ObjectName(vec![Ident::from(table.name.as_str())])); } - DFStatement::DescribeTableStmt(table) => { - visitor.insert(&table.table_name) - } DFStatement::CopyTo(CopyToStatement { source, target: _, @@ -1698,7 +1719,7 @@ impl SessionState { let mut stringified_plans = e.stringified_plans.clone(); // analyze & capture output of each rule - let analyzed_plan = match self.analyzer.execute_and_check( + let analyzer_result = self.analyzer.execute_and_check( e.plan.as_ref(), self.options(), |analyzed_plan, analyzer| { @@ -1706,7 +1727,8 @@ impl SessionState { let plan_type = PlanType::AnalyzedLogicalPlan { analyzer_name }; stringified_plans.push(analyzed_plan.to_stringified(plan_type)); }, - ) { + ); + let analyzed_plan = match analyzer_result { Ok(plan) => plan, Err(DataFusionError::Context(analyzer_name, err)) => { let plan_type = PlanType::AnalyzedLogicalPlan { analyzer_name }; @@ -1729,7 +1751,7 @@ impl SessionState { .push(analyzed_plan.to_stringified(PlanType::FinalAnalyzedLogicalPlan)); // optimize the child plan, capturing the output of each optimizer - let (plan, logical_optimization_succeeded) = match self.optimizer.optimize( + let optimized_plan = self.optimizer.optimize( &analyzed_plan, self, |optimized_plan, optimizer| { @@ -1737,7 +1759,8 @@ impl SessionState { let plan_type = PlanType::OptimizedLogicalPlan { optimizer_name }; stringified_plans.push(optimized_plan.to_stringified(plan_type)); }, - ) { + ); + let (plan, logical_optimization_succeeded) = match optimized_plan { Ok(plan) => (Arc::new(plan), true), Err(DataFusionError::Context(optimizer_name, err)) => { let plan_type = PlanType::OptimizedLogicalPlan { optimizer_name }; @@ -1860,6 +1883,22 @@ impl<'a> ContextProvider for SessionContextProvider<'a> { .ok_or_else(|| plan_datafusion_err!("table '{name}' not found")) } + fn get_table_function_source( + &self, + name: &str, + args: Vec, + ) -> Result> { + let tbl_func = self + .state + .table_functions + .get(name) + .cloned() + .ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?; + let provider = tbl_func.create_table_provider(&args)?; + + Ok(provider_as_source(provider)) + } + fn get_function_meta(&self, name: &str) -> Option> { self.state.scalar_functions().get(name).cloned() } diff --git a/datafusion/core/src/execution/context/parquet.rs b/datafusion/core/src/execution/context/parquet.rs index ef1f0143543d8..7825d9b882979 100644 --- a/datafusion/core/src/execution/context/parquet.rs +++ b/datafusion/core/src/execution/context/parquet.rs @@ -80,6 +80,8 @@ mod tests { use crate::dataframe::DataFrameWriteOptions; use crate::parquet::basic::Compression; use crate::test_util::parquet_test_data; + use datafusion_execution::config::SessionConfig; + use tempfile::tempdir; use super::*; @@ -102,8 +104,12 @@ mod tests { #[tokio::test] async fn read_with_glob_path_issue_2465() -> Result<()> { - let ctx = SessionContext::new(); - + let config = + SessionConfig::from_string_hash_map(std::collections::HashMap::from([( + "datafusion.execution.listing_table_ignore_subdirectory".to_owned(), + "false".to_owned(), + )]))?; + let ctx = SessionContext::new_with_config(config); let df = ctx .read_parquet( // it was reported that when a path contains // (two consecutive separator) no files were found @@ -140,6 +146,7 @@ mod tests { #[tokio::test] async fn read_from_different_file_extension() -> Result<()> { let ctx = SessionContext::new(); + let sep = std::path::MAIN_SEPARATOR.to_string(); // Make up a new dataframe. let write_df = ctx.read_batch(RecordBatch::try_new( @@ -155,11 +162,48 @@ mod tests { ], )?)?; + let temp_dir = tempdir()?; + let temp_dir_path = temp_dir.path(); + let path1 = temp_dir_path + .join("output1.parquet") + .to_str() + .unwrap() + .to_string(); + let path2 = temp_dir_path + .join("output2.parquet.snappy") + .to_str() + .unwrap() + .to_string(); + let path3 = temp_dir_path + .join("output3.parquet.snappy.parquet") + .to_str() + .unwrap() + .to_string(); + + let path4 = temp_dir_path + .join("output4.parquet".to_owned() + &sep) + .to_str() + .unwrap() + .to_string(); + + let path5 = temp_dir_path + .join("bbb..bbb") + .join("filename.parquet") + .to_str() + .unwrap() + .to_string(); + let dir = temp_dir_path + .join("bbb..bbb".to_owned() + &sep) + .to_str() + .unwrap() + .to_string(); + std::fs::create_dir(dir).expect("create dir failed"); + // Write the dataframe to a parquet file named 'output1.parquet' write_df .clone() .write_parquet( - "output1.parquet", + &path1, DataFrameWriteOptions::new().with_single_file_output(true), Some( WriterProperties::builder() @@ -173,7 +217,7 @@ mod tests { write_df .clone() .write_parquet( - "output2.parquet.snappy", + &path2, DataFrameWriteOptions::new().with_single_file_output(true), Some( WriterProperties::builder() @@ -185,8 +229,22 @@ mod tests { // Write the dataframe to a parquet file named 'output3.parquet.snappy.parquet' write_df + .clone() .write_parquet( - "output3.parquet.snappy.parquet", + &path3, + DataFrameWriteOptions::new().with_single_file_output(true), + Some( + WriterProperties::builder() + .set_compression(Compression::SNAPPY) + .build(), + ), + ) + .await?; + + // Write the dataframe to a parquet file named 'bbb..bbb/filename.parquet' + write_df + .write_parquet( + &path5, DataFrameWriteOptions::new().with_single_file_output(true), Some( WriterProperties::builder() @@ -199,7 +257,7 @@ mod tests { // Read the dataframe from 'output1.parquet' with the default file extension. let read_df = ctx .read_parquet( - "output1.parquet", + &path1, ParquetReadOptions { ..Default::default() }, @@ -213,7 +271,7 @@ mod tests { // Read the dataframe from 'output2.parquet.snappy' with the correct file extension. let read_df = ctx .read_parquet( - "output2.parquet.snappy", + &path2, ParquetReadOptions { file_extension: "snappy", ..Default::default() @@ -227,22 +285,52 @@ mod tests { // Read the dataframe from 'output3.parquet.snappy.parquet' with the wrong file extension. let read_df = ctx .read_parquet( - "output2.parquet.snappy", + &path2, ParquetReadOptions { ..Default::default() }, ) .await; - + let binding = DataFilePaths::to_urls(&path2).unwrap(); + let expexted_path = binding[0].as_str(); assert_eq!( read_df.unwrap_err().strip_backtrace(), - "Execution error: File 'output2.parquet.snappy' does not match the expected extension '.parquet'" + format!("Execution error: File path '{}' does not match the expected extension '.parquet'", expexted_path) ); // Read the dataframe from 'output3.parquet.snappy.parquet' with the correct file extension. let read_df = ctx .read_parquet( - "output3.parquet.snappy.parquet", + &path3, + ParquetReadOptions { + ..Default::default() + }, + ) + .await?; + + let results = read_df.collect().await?; + let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); + assert_eq!(total_rows, 5); + + // Read the dataframe from 'output4/' + std::fs::create_dir(&path4)?; + let read_df = ctx + .read_parquet( + &path4, + ParquetReadOptions { + ..Default::default() + }, + ) + .await?; + + let results = read_df.collect().await?; + let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); + assert_eq!(total_rows, 0); + + // Read the datafram from doule dot folder; + let read_df = ctx + .read_parquet( + &path5, ParquetReadOptions { ..Default::default() }, diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index bf9a4abf4f2d1..b3ebbc6e3637e 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -283,12 +283,20 @@ //! //! ## Plan Representations //! -//! Logical planning yields [`LogicalPlan`]s nodes and [`Expr`] +//! ### Logical Plans +//! Logical planning yields [`LogicalPlan`] nodes and [`Expr`] //! expressions which are [`Schema`] aware and represent statements //! independent of how they are physically executed. //! A [`LogicalPlan`] is a Directed Acyclic Graph (DAG) of other //! [`LogicalPlan`]s, each potentially containing embedded [`Expr`]s. //! +//! Examples of working with and executing `Expr`s can be found in the +//! [`expr_api`.rs] example +//! +//! [`expr_api`.rs]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/expr_api.rs +//! +//! ### Physical Plans +//! //! An [`ExecutionPlan`] (sometimes referred to as a "physical plan") //! is a plan that can be executed against data. It a DAG of other //! [`ExecutionPlan`]s each potentially containing expressions of the diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index 43def5d73f73d..86a8cdb7b3d4d 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -22,7 +22,6 @@ use super::optimizer::PhysicalOptimizerRule; use crate::config::ConfigOptions; use crate::error::Result; use crate::physical_plan::aggregates::AggregateExec; -use crate::physical_plan::empty::EmptyExec; use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::{expressions, AggregateExpr, ExecutionPlan, Statistics}; use crate::scalar::ScalarValue; @@ -30,6 +29,7 @@ use crate::scalar::ScalarValue; use datafusion_common::stats::Precision; use datafusion_common::tree_node::TreeNode; use datafusion_expr::utils::COUNT_STAR_EXPANSION; +use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; /// Optimizer that uses available statistics for aggregate functions #[derive(Default)] @@ -82,7 +82,7 @@ impl PhysicalOptimizerRule for AggregateStatistics { // input can be entirely removed Ok(Arc::new(ProjectionExec::try_new( projections, - Arc::new(EmptyExec::new(true, plan.schema())), + Arc::new(PlaceholderRowExec::new(plan.schema())), )?)) } else { plan.map_children(|child| self.optimize(child, _config)) @@ -241,7 +241,7 @@ fn take_optimizable_max( } #[cfg(test)] -mod tests { +pub(crate) mod tests { use std::sync::Arc; use super::*; @@ -334,7 +334,7 @@ mod tests { } /// Describe the type of aggregate being tested - enum TestAggregate { + pub(crate) enum TestAggregate { /// Testing COUNT(*) type aggregates CountStar, @@ -343,7 +343,7 @@ mod tests { } impl TestAggregate { - fn new_count_star() -> Self { + pub(crate) fn new_count_star() -> Self { Self::CountStar } @@ -352,7 +352,7 @@ mod tests { } /// Return appropriate expr depending if COUNT is for col or table (*) - fn count_expr(&self) -> Arc { + pub(crate) fn count_expr(&self) -> Arc { Arc::new(Count::new( self.column(), self.column_name(), @@ -397,7 +397,6 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], source, Arc::clone(&schema), )?; @@ -407,7 +406,6 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], Arc::new(partial_agg), Arc::clone(&schema), )?; @@ -429,7 +427,6 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], source, Arc::clone(&schema), )?; @@ -439,7 +436,6 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], Arc::new(partial_agg), Arc::clone(&schema), )?; @@ -460,7 +456,6 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], source, Arc::clone(&schema), )?; @@ -473,7 +468,6 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], Arc::new(coalesce), Arc::clone(&schema), )?; @@ -494,7 +488,6 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], source, Arc::clone(&schema), )?; @@ -507,7 +500,6 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], Arc::new(coalesce), Arc::clone(&schema), )?; @@ -539,7 +531,6 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], filter, Arc::clone(&schema), )?; @@ -549,7 +540,6 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], Arc::new(partial_agg), Arc::clone(&schema), )?; @@ -586,7 +576,6 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], filter, Arc::clone(&schema), )?; @@ -596,7 +585,6 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], Arc::new(partial_agg), Arc::clone(&schema), )?; diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index 2c4e929788df9..7359a6463059f 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -91,10 +91,12 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate { input_agg_exec.group_by().clone(), input_agg_exec.aggr_expr().to_vec(), input_agg_exec.filter_expr().to_vec(), - input_agg_exec.order_by_expr().to_vec(), input_agg_exec.input().clone(), input_agg_exec.input_schema(), ) + .map(|combined_agg| { + combined_agg.with_limit(agg_exec.limit()) + }) .ok() .map(Arc::new) } else { @@ -255,7 +257,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, None, @@ -274,7 +275,6 @@ mod tests { group_by, aggr_expr, vec![], - vec![], input, schema, ) @@ -294,7 +294,6 @@ mod tests { group_by, aggr_expr, vec![], - vec![], input, schema, ) @@ -428,4 +427,48 @@ mod tests { assert_optimized!(expected, plan); Ok(()) } + + #[test] + fn aggregations_with_limit_combined() -> Result<()> { + let schema = schema(); + let aggr_expr = vec![]; + + let groups: Vec<(Arc, String)> = + vec![(col("c", &schema)?, "c".to_string())]; + + let partial_group_by = PhysicalGroupBy::new_single(groups); + let partial_agg = partial_aggregate_exec( + parquet_exec(&schema), + partial_group_by, + aggr_expr.clone(), + ); + + let groups: Vec<(Arc, String)> = + vec![(col("c", &partial_agg.schema())?, "c".to_string())]; + let final_group_by = PhysicalGroupBy::new_single(groups); + + let schema = partial_agg.schema(); + let final_agg = Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + final_group_by, + aggr_expr, + vec![], + partial_agg, + schema, + ) + .unwrap() + .with_limit(Some(5)), + ); + let plan: Arc = final_agg; + // should combine the Partial/Final AggregateExecs to a Single AggregateExec + // with the final limit preserved + let expected = &[ + "AggregateExec: mode=Single, gby=[c@2 as c], aggr=[], lim=[5]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c]", + ]; + + assert_optimized!(expected, plan); + Ok(()) + } } diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index ee6e11bd271ab..bf5aa7d02272d 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -21,15 +21,16 @@ //! according to the configuration), this rule increases partition counts in //! the physical plan. +use std::borrow::Cow; use std::fmt; use std::fmt::Formatter; use std::sync::Arc; +use super::output_requirements::OutputRequirementExec; use crate::config::ConfigOptions; use crate::error::Result; use crate::physical_optimizer::utils::{ - add_sort_above, get_children_exectrees, get_plan_string, is_coalesce_partitions, - is_repartition, is_sort_preserving_merge, ExecTree, + is_coalesce_partitions, is_repartition, is_sort_preserving_merge, }; use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; @@ -47,15 +48,17 @@ use crate::physical_plan::{ }; use arrow::compute::SortOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_expr::logical_plan::JoinType; use datafusion_physical_expr::expressions::{Column, NoOp}; use datafusion_physical_expr::utils::map_columns_before_projection; use datafusion_physical_expr::{ - physical_exprs_equal, EquivalenceProperties, PhysicalExpr, + physical_exprs_equal, EquivalenceProperties, LexRequirementRef, PhysicalExpr, + PhysicalSortRequirement, }; -use datafusion_physical_plan::unbounded_output; +use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::windows::{get_best_fitting_window, BoundedWindowAggExec}; +use datafusion_physical_plan::{get_plan_string, unbounded_output}; use itertools::izip; @@ -256,7 +259,7 @@ impl PhysicalOptimizerRule for EnforceDistribution { /// 1) If the current plan is Partitioned HashJoin, SortMergeJoin, check whether the requirements can be satisfied by adjusting join keys ordering: /// Requirements can not be satisfied, clear the current requirements, generate new requirements(to pushdown) based on the current join keys, return the unchanged plan. /// Requirements is already satisfied, clear the current requirements, generate new requirements(to pushdown) based on the current join keys, return the unchanged plan. -/// Requirements can be satisfied by adjusting keys ordering, clear the current requiements, generate new requirements(to pushdown) based on the adjusted join keys, return the changed plan. +/// Requirements can be satisfied by adjusting keys ordering, clear the current requirements, generate new requirements(to pushdown) based on the adjusted join keys, return the changed plan. /// /// 2) If the current plan is Aggregation, check whether the requirements can be satisfied by adjusting group by keys ordering: /// Requirements can not be satisfied, clear all the requirements, return the unchanged plan. @@ -268,11 +271,12 @@ impl PhysicalOptimizerRule for EnforceDistribution { /// 5) For other types of operators, by default, pushdown the parent requirements to children. /// fn adjust_input_keys_ordering( - requirements: PlanWithKeyRequirements, + mut requirements: PlanWithKeyRequirements, ) -> Result> { let parent_required = requirements.required_key_ordering.clone(); let plan_any = requirements.plan.as_any(); - let transformed = if let Some(HashJoinExec { + + if let Some(HashJoinExec { left, right, on, @@ -287,7 +291,7 @@ fn adjust_input_keys_ordering( PartitionMode::Partitioned => { let join_constructor = |new_conditions: (Vec<(Column, Column)>, Vec)| { - Ok(Arc::new(HashJoinExec::try_new( + HashJoinExec::try_new( left.clone(), right.clone(), new_conditions.0, @@ -295,15 +299,17 @@ fn adjust_input_keys_ordering( join_type, PartitionMode::Partitioned, *null_equals_null, - )?) as Arc) + ) + .map(|e| Arc::new(e) as _) }; - Some(reorder_partitioned_join_keys( + reorder_partitioned_join_keys( requirements.plan.clone(), &parent_required, on, vec![], &join_constructor, - )?) + ) + .map(Transformed::Yes) } PartitionMode::CollectLeft => { let new_right_request = match join_type { @@ -321,15 +327,15 @@ fn adjust_input_keys_ordering( }; // Push down requirements to the right side - Some(PlanWithKeyRequirements { - plan: requirements.plan.clone(), - required_key_ordering: vec![], - request_key_ordering: vec![None, new_right_request], - }) + requirements.children[1].required_key_ordering = + new_right_request.unwrap_or(vec![]); + Ok(Transformed::Yes(requirements)) } PartitionMode::Auto => { // Can not satisfy, clear the current requirements and generate new empty requirements - Some(PlanWithKeyRequirements::new(requirements.plan.clone())) + Ok(Transformed::Yes(PlanWithKeyRequirements::new( + requirements.plan, + ))) } } } else if let Some(CrossJoinExec { left, .. }) = @@ -337,14 +343,9 @@ fn adjust_input_keys_ordering( { let left_columns_len = left.schema().fields().len(); // Push down requirements to the right side - Some(PlanWithKeyRequirements { - plan: requirements.plan.clone(), - required_key_ordering: vec![], - request_key_ordering: vec![ - None, - shift_right_required(&parent_required, left_columns_len), - ], - }) + requirements.children[1].required_key_ordering = + shift_right_required(&parent_required, left_columns_len).unwrap_or_default(); + Ok(Transformed::Yes(requirements)) } else if let Some(SortMergeJoinExec { left, right, @@ -357,35 +358,40 @@ fn adjust_input_keys_ordering( { let join_constructor = |new_conditions: (Vec<(Column, Column)>, Vec)| { - Ok(Arc::new(SortMergeJoinExec::try_new( + SortMergeJoinExec::try_new( left.clone(), right.clone(), new_conditions.0, *join_type, new_conditions.1, *null_equals_null, - )?) as Arc) + ) + .map(|e| Arc::new(e) as _) }; - Some(reorder_partitioned_join_keys( + reorder_partitioned_join_keys( requirements.plan.clone(), &parent_required, on, sort_options.clone(), &join_constructor, - )?) + ) + .map(Transformed::Yes) } else if let Some(aggregate_exec) = plan_any.downcast_ref::() { if !parent_required.is_empty() { match aggregate_exec.mode() { - AggregateMode::FinalPartitioned => Some(reorder_aggregate_keys( + AggregateMode::FinalPartitioned => reorder_aggregate_keys( requirements.plan.clone(), &parent_required, aggregate_exec, - )?), - _ => Some(PlanWithKeyRequirements::new(requirements.plan.clone())), + ) + .map(Transformed::Yes), + _ => Ok(Transformed::Yes(PlanWithKeyRequirements::new( + requirements.plan, + ))), } } else { // Keep everything unchanged - None + Ok(Transformed::No(requirements)) } } else if let Some(proj) = plan_any.downcast_ref::() { let expr = proj.expr(); @@ -394,34 +400,28 @@ fn adjust_input_keys_ordering( // Construct a mapping from new name to the the orginal Column let new_required = map_columns_before_projection(&parent_required, expr); if new_required.len() == parent_required.len() { - Some(PlanWithKeyRequirements { - plan: requirements.plan.clone(), - required_key_ordering: vec![], - request_key_ordering: vec![Some(new_required.clone())], - }) + requirements.children[0].required_key_ordering = new_required; + Ok(Transformed::Yes(requirements)) } else { // Can not satisfy, clear the current requirements and generate new empty requirements - Some(PlanWithKeyRequirements::new(requirements.plan.clone())) + Ok(Transformed::Yes(PlanWithKeyRequirements::new( + requirements.plan, + ))) } } else if plan_any.downcast_ref::().is_some() || plan_any.downcast_ref::().is_some() || plan_any.downcast_ref::().is_some() { - Some(PlanWithKeyRequirements::new(requirements.plan.clone())) + Ok(Transformed::Yes(PlanWithKeyRequirements::new( + requirements.plan, + ))) } else { // By default, push down the parent requirements to children - let children_len = requirements.plan.children().len(); - Some(PlanWithKeyRequirements { - plan: requirements.plan.clone(), - required_key_ordering: vec![], - request_key_ordering: vec![Some(parent_required.clone()); children_len], - }) - }; - Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) - } else { - Transformed::No(requirements) - }) + requirements.children.iter_mut().for_each(|child| { + child.required_key_ordering = parent_required.clone(); + }); + Ok(Transformed::Yes(requirements)) + } } fn reorder_partitioned_join_keys( @@ -452,28 +452,24 @@ where for idx in 0..sort_options.len() { new_sort_options.push(sort_options[new_positions[idx]]) } - - Ok(PlanWithKeyRequirements { - plan: join_constructor((new_join_on, new_sort_options))?, - required_key_ordering: vec![], - request_key_ordering: vec![Some(left_keys), Some(right_keys)], - }) + let mut requirement_tree = PlanWithKeyRequirements::new(join_constructor(( + new_join_on, + new_sort_options, + ))?); + requirement_tree.children[0].required_key_ordering = left_keys; + requirement_tree.children[1].required_key_ordering = right_keys; + Ok(requirement_tree) } else { - Ok(PlanWithKeyRequirements { - plan: join_plan, - required_key_ordering: vec![], - request_key_ordering: vec![Some(left_keys), Some(right_keys)], - }) + let mut requirement_tree = PlanWithKeyRequirements::new(join_plan); + requirement_tree.children[0].required_key_ordering = left_keys; + requirement_tree.children[1].required_key_ordering = right_keys; + Ok(requirement_tree) } } else { - Ok(PlanWithKeyRequirements { - plan: join_plan, - required_key_ordering: vec![], - request_key_ordering: vec![ - Some(join_key_pairs.left_keys), - Some(join_key_pairs.right_keys), - ], - }) + let mut requirement_tree = PlanWithKeyRequirements::new(join_plan); + requirement_tree.children[0].required_key_ordering = join_key_pairs.left_keys; + requirement_tree.children[1].required_key_ordering = join_key_pairs.right_keys; + Ok(requirement_tree) } } @@ -521,7 +517,6 @@ fn reorder_aggregate_keys( new_partial_group_by, agg_exec.aggr_expr().to_vec(), agg_exec.filter_expr().to_vec(), - agg_exec.order_by_expr().to_vec(), agg_exec.input().clone(), agg_exec.input_schema.clone(), )?)) @@ -548,7 +543,6 @@ fn reorder_aggregate_keys( new_group_by, agg_exec.aggr_expr().to_vec(), agg_exec.filter_expr().to_vec(), - agg_exec.order_by_expr().to_vec(), partial_agg, agg_exec.input_schema(), )?); @@ -870,75 +864,41 @@ fn new_join_conditions( .collect() } -/// Updates `dist_onward` such that, to keep track of -/// `input` in the `exec_tree`. -/// -/// # Arguments -/// -/// * `input`: Current execution plan -/// * `dist_onward`: It keeps track of executors starting from a distribution -/// changing operator (e.g Repartition, SortPreservingMergeExec, etc.) -/// until child of `input` (`input` should have single child). -/// * `input_idx`: index of the `input`, for its parent. -/// -fn update_distribution_onward( - input: Arc, - dist_onward: &mut Option, - input_idx: usize, -) { - // Update the onward tree if there is an active branch - if let Some(exec_tree) = dist_onward { - // When we add a new operator to change distribution - // we add RepartitionExec, SortPreservingMergeExec, CoalescePartitionsExec - // in this case, we need to update exec tree idx such that exec tree is now child of these - // operators (change the 0, since all of the operators have single child). - exec_tree.idx = 0; - *exec_tree = ExecTree::new(input, input_idx, vec![exec_tree.clone()]); - } else { - *dist_onward = Some(ExecTree::new(input, input_idx, vec![])); - } -} - /// Adds RoundRobin repartition operator to the plan increase parallelism. /// /// # Arguments /// -/// * `input`: Current execution plan +/// * `input`: Current node. /// * `n_target`: desired target partition number, if partition number of the /// current executor is less than this value. Partition number will be increased. -/// * `dist_onward`: It keeps track of executors starting from a distribution -/// changing operator (e.g Repartition, SortPreservingMergeExec, etc.) -/// until `input` plan. -/// * `input_idx`: index of the `input`, for its parent. /// /// # Returns /// -/// A [Result] object that contains new execution plan, where desired partition number -/// is achieved by adding RoundRobin Repartition. +/// A [`Result`] object that contains new execution plan where the desired +/// partition number is achieved by adding a RoundRobin repartition. fn add_roundrobin_on_top( - input: Arc, + input: DistributionContext, n_target: usize, - dist_onward: &mut Option, - input_idx: usize, -) -> Result> { - // Adding repartition is helpful - if input.output_partitioning().partition_count() < n_target { +) -> Result { + // Adding repartition is helpful: + if input.plan.output_partitioning().partition_count() < n_target { // When there is an existing ordering, we preserve ordering // during repartition. This will be un-done in the future // If any of the following conditions is true // - Preserving ordering is not helpful in terms of satisfying ordering requirements // - Usage of order preserving variants is not desirable - // (determined by flag `config.optimizer.bounded_order_preserving_variants`) - let should_preserve_ordering = input.output_ordering().is_some(); - + // (determined by flag `config.optimizer.prefer_existing_sort`) let partitioning = Partitioning::RoundRobinBatch(n_target); - let repartition = RepartitionExec::try_new(input, partitioning)?; - let new_plan = Arc::new(repartition.with_preserve_order(should_preserve_ordering)) - as Arc; + let repartition = RepartitionExec::try_new(input.plan.clone(), partitioning)? + .with_preserve_order(); + + let new_plan = Arc::new(repartition) as _; - // update distribution onward with new operator - update_distribution_onward(new_plan.clone(), dist_onward, input_idx); - Ok(new_plan) + Ok(DistributionContext { + plan: new_plan, + distribution_connection: true, + children_nodes: vec![input], + }) } else { // Partition is not helpful, we already have desired number of partitions. Ok(input) @@ -952,122 +912,105 @@ fn add_roundrobin_on_top( /// /// # Arguments /// -/// * `input`: Current execution plan +/// * `input`: Current node. /// * `hash_exprs`: Stores Physical Exprs that are used during hashing. /// * `n_target`: desired target partition number, if partition number of the /// current executor is less than this value. Partition number will be increased. -/// * `dist_onward`: It keeps track of executors starting from a distribution -/// changing operator (e.g Repartition, SortPreservingMergeExec, etc.) -/// until `input` plan. -/// * `input_idx`: index of the `input`, for its parent. /// /// # Returns /// -/// A [`Result`] object that contains new execution plan, where desired distribution is -/// satisfied by adding Hash Repartition. +/// A [`Result`] object that contains new execution plan where the desired +/// distribution is satisfied by adding a Hash repartition. fn add_hash_on_top( - input: Arc, + mut input: DistributionContext, hash_exprs: Vec>, - // Repartition(Hash) will have `n_target` partitions at the output. n_target: usize, - // Stores executors starting from Repartition(RoundRobin) until - // current executor. When Repartition(Hash) is added, `dist_onward` - // is updated such that it stores connection from Repartition(RoundRobin) - // until Repartition(Hash). - dist_onward: &mut Option, - input_idx: usize, repartition_beneficial_stats: bool, -) -> Result> { - if n_target == input.output_partitioning().partition_count() && n_target == 1 { - // In this case adding a hash repartition is unnecessary as the hash - // requirement is implicitly satisfied. +) -> Result { + let partition_count = input.plan.output_partitioning().partition_count(); + // Early return if hash repartition is unnecessary + if n_target == partition_count && n_target == 1 { return Ok(input); } + let satisfied = input + .plan .output_partitioning() .satisfy(Distribution::HashPartitioned(hash_exprs.clone()), || { - input.equivalence_properties() + input.plan.equivalence_properties() }); + // Add hash repartitioning when: // - The hash distribution requirement is not satisfied, or // - We can increase parallelism by adding hash partitioning. - if !satisfied || n_target > input.output_partitioning().partition_count() { + if !satisfied || n_target > input.plan.output_partitioning().partition_count() { // When there is an existing ordering, we preserve ordering during // repartition. This will be rolled back in the future if any of the // following conditions is true: // - Preserving ordering is not helpful in terms of satisfying ordering // requirements. // - Usage of order preserving variants is not desirable (per the flag - // `config.optimizer.bounded_order_preserving_variants`). - let should_preserve_ordering = input.output_ordering().is_some(); - let mut new_plan = if repartition_beneficial_stats { + // `config.optimizer.prefer_existing_sort`). + if repartition_beneficial_stats { // Since hashing benefits from partitioning, add a round-robin repartition // before it: - add_roundrobin_on_top(input, n_target, dist_onward, 0)? - } else { - input - }; + input = add_roundrobin_on_top(input, n_target)?; + } + let partitioning = Partitioning::Hash(hash_exprs, n_target); - let repartition = RepartitionExec::try_new(new_plan, partitioning)?; - new_plan = - Arc::new(repartition.with_preserve_order(should_preserve_ordering)) as _; + let repartition = RepartitionExec::try_new(input.plan.clone(), partitioning)? + .with_preserve_order(); - // update distribution onward with new operator - update_distribution_onward(new_plan.clone(), dist_onward, input_idx); - Ok(new_plan) - } else { - Ok(input) + input.children_nodes = vec![input.clone()]; + input.distribution_connection = true; + input.plan = Arc::new(repartition) as _; } + + Ok(input) } -/// Adds a `SortPreservingMergeExec` operator on top of input executor: -/// - to satisfy single distribution requirement. +/// Adds a [`SortPreservingMergeExec`] operator on top of input executor +/// to satisfy single distribution requirement. /// /// # Arguments /// -/// * `input`: Current execution plan -/// * `dist_onward`: It keeps track of executors starting from a distribution -/// changing operator (e.g Repartition, SortPreservingMergeExec, etc.) -/// until `input` plan. -/// * `input_idx`: index of the `input`, for its parent. +/// * `input`: Current node. /// /// # Returns /// -/// New execution plan, where desired single -/// distribution is satisfied by adding `SortPreservingMergeExec`. -fn add_spm_on_top( - input: Arc, - dist_onward: &mut Option, - input_idx: usize, -) -> Arc { +/// Updated node with an execution plan, where desired single +/// distribution is satisfied by adding [`SortPreservingMergeExec`]. +fn add_spm_on_top(input: DistributionContext) -> DistributionContext { // Add SortPreservingMerge only when partition count is larger than 1. - if input.output_partitioning().partition_count() > 1 { + if input.plan.output_partitioning().partition_count() > 1 { // When there is an existing ordering, we preserve ordering - // during decreasıng partıtıons. This will be un-done in the future - // If any of the following conditions is true + // when decreasing partitions. This will be un-done in the future + // if any of the following conditions is true // - Preserving ordering is not helpful in terms of satisfying ordering requirements // - Usage of order preserving variants is not desirable // (determined by flag `config.optimizer.bounded_order_preserving_variants`) - let should_preserve_ordering = input.output_ordering().is_some(); - let new_plan: Arc = if should_preserve_ordering { - let existing_ordering = input.output_ordering().unwrap_or(&[]); + let should_preserve_ordering = input.plan.output_ordering().is_some(); + + let new_plan = if should_preserve_ordering { Arc::new(SortPreservingMergeExec::new( - existing_ordering.to_vec(), - input, + input.plan.output_ordering().unwrap_or(&[]).to_vec(), + input.plan.clone(), )) as _ } else { - Arc::new(CoalescePartitionsExec::new(input)) as _ + Arc::new(CoalescePartitionsExec::new(input.plan.clone())) as _ }; - // update repartition onward with new operator - update_distribution_onward(new_plan.clone(), dist_onward, input_idx); - new_plan + DistributionContext { + plan: new_plan, + distribution_connection: true, + children_nodes: vec![input], + } } else { input } } -/// Updates the physical plan inside `distribution_context` so that distribution +/// Updates the physical plan inside [`DistributionContext`] so that distribution /// changing operators are removed from the top. If they are necessary, they will /// be added in subsequent stages. /// @@ -1085,58 +1028,33 @@ fn add_spm_on_top( /// "ParquetExec: file_groups={2 groups: \[\[x], \[y]]}, projection=\[a, b, c, d, e], output_ordering=\[a@0 ASC]", /// ``` fn remove_dist_changing_operators( - distribution_context: DistributionContext, + mut distribution_context: DistributionContext, ) -> Result { - let DistributionContext { - mut plan, - mut distribution_onwards, - } = distribution_context; - - // Remove any distribution changing operators at the beginning: - // Note that they will be re-inserted later on if necessary or helpful. - while is_repartition(&plan) - || is_coalesce_partitions(&plan) - || is_sort_preserving_merge(&plan) + while is_repartition(&distribution_context.plan) + || is_coalesce_partitions(&distribution_context.plan) + || is_sort_preserving_merge(&distribution_context.plan) { - // All of above operators have a single child. When we remove the top - // operator, we take the first child. - plan = plan.children().swap_remove(0); - distribution_onwards = - get_children_exectrees(plan.children().len(), &distribution_onwards[0]); + // All of above operators have a single child. First child is only child. + let child = distribution_context.children_nodes.swap_remove(0); + // Remove any distribution changing operators at the beginning: + // Note that they will be re-inserted later on if necessary or helpful. + distribution_context = child; } - // Create a plan with the updated children: - Ok(DistributionContext { - plan, - distribution_onwards, - }) + Ok(distribution_context) } -/// Updates the physical plan `input` by using `dist_onward` replace order preserving operator variants -/// with their corresponding operators that do not preserve order. It is a wrapper for `replace_order_preserving_variants_helper` -fn replace_order_preserving_variants( - input: &mut Arc, - dist_onward: &mut Option, -) -> Result<()> { - if let Some(dist_onward) = dist_onward { - *input = replace_order_preserving_variants_helper(dist_onward)?; - } - *dist_onward = None; - Ok(()) -} - -/// Updates the physical plan inside `ExecTree` if preserving ordering while changing partitioning -/// is not helpful or desirable. +/// Updates the [`DistributionContext`] if preserving ordering while changing partitioning is not helpful or desirable. /// /// Assume that following plan is given: /// ```text /// "SortPreservingMergeExec: \[a@0 ASC]" -/// " SortPreservingRepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", -/// " SortPreservingRepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", +/// " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10, preserve_order=true", +/// " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true", /// " ParquetExec: file_groups={2 groups: \[\[x], \[y]]}, projection=\[a, b, c, d, e], output_ordering=\[a@0 ASC]", /// ``` /// -/// This function converts plan above (inside `ExecTree`) to the following: +/// This function converts plan above to the following: /// /// ```text /// "CoalescePartitionsExec" @@ -1144,30 +1062,75 @@ fn replace_order_preserving_variants( /// " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", /// " ParquetExec: file_groups={2 groups: \[\[x], \[y]]}, projection=\[a, b, c, d, e], output_ordering=\[a@0 ASC]", /// ``` -fn replace_order_preserving_variants_helper( - exec_tree: &ExecTree, -) -> Result> { - let mut updated_children = exec_tree.plan.children(); - for child in &exec_tree.children { - updated_children[child.idx] = replace_order_preserving_variants_helper(child)?; - } - if is_sort_preserving_merge(&exec_tree.plan) { - return Ok(Arc::new(CoalescePartitionsExec::new( - updated_children.swap_remove(0), - ))); - } - if let Some(repartition) = exec_tree.plan.as_any().downcast_ref::() { +fn replace_order_preserving_variants( + mut context: DistributionContext, +) -> Result { + let mut updated_children = context + .children_nodes + .iter() + .map(|child| { + if child.distribution_connection { + replace_order_preserving_variants(child.clone()) + } else { + Ok(child.clone()) + } + }) + .collect::>>()?; + + if is_sort_preserving_merge(&context.plan) { + let child = updated_children.swap_remove(0); + context.plan = Arc::new(CoalescePartitionsExec::new(child.plan.clone())); + context.children_nodes = vec![child]; + return Ok(context); + } else if let Some(repartition) = + context.plan.as_any().downcast_ref::() + { if repartition.preserve_order() { - return Ok(Arc::new( - RepartitionExec::try_new( - updated_children.swap_remove(0), - repartition.partitioning().clone(), - )? - .with_preserve_order(false), - )); + let child = updated_children.swap_remove(0); + context.plan = Arc::new(RepartitionExec::try_new( + child.plan.clone(), + repartition.partitioning().clone(), + )?); + context.children_nodes = vec![child]; + return Ok(context); + } + } + + context.plan = context + .plan + .clone() + .with_new_children(updated_children.into_iter().map(|c| c.plan).collect())?; + Ok(context) +} + +/// This utility function adds a [`SortExec`] above an operator according to the +/// given ordering requirements while preserving the original partitioning. +fn add_sort_preserving_partitions( + node: DistributionContext, + sort_requirement: LexRequirementRef, + fetch: Option, +) -> DistributionContext { + // If the ordering requirement is already satisfied, do not add a sort. + if !node + .plan + .equivalence_properties() + .ordering_satisfy_requirement(sort_requirement) + { + let sort_expr = PhysicalSortRequirement::to_sort_exprs(sort_requirement.to_vec()); + let new_sort = SortExec::new(sort_expr, node.plan.clone()).with_fetch(fetch); + + DistributionContext { + plan: Arc::new(if node.plan.output_partitioning().partition_count() > 1 { + new_sort.with_preserve_partitioning(true) + } else { + new_sort + }), + distribution_connection: false, + children_nodes: vec![node], } + } else { + node } - exec_tree.plan.clone().with_new_children(updated_children) } /// This function checks whether we need to add additional data exchange @@ -1178,6 +1141,12 @@ fn ensure_distribution( dist_context: DistributionContext, config: &ConfigOptions, ) -> Result> { + let dist_context = dist_context.update_children()?; + + if dist_context.plan.children().is_empty() { + return Ok(Transformed::No(dist_context)); + } + let target_partitions = config.execution.target_partitions; // When `false`, round robin repartition will not be added to increase parallelism let enable_round_robin = config.optimizer.enable_round_robin_repartition; @@ -1190,14 +1159,11 @@ fn ensure_distribution( let order_preserving_variants_desirable = is_unbounded || config.optimizer.prefer_existing_sort; - if dist_context.plan.children().is_empty() { - return Ok(Transformed::No(dist_context)); - } - // Remove unnecessary repartition from the physical plan if any let DistributionContext { mut plan, - mut distribution_onwards, + distribution_connection, + children_nodes, } = remove_dist_changing_operators(dist_context)?; if let Some(exec) = plan.as_any().downcast_ref::() { @@ -1217,33 +1183,23 @@ fn ensure_distribution( plan = updated_window; } }; - let n_children = plan.children().len(); + // This loop iterates over all the children to: // - Increase parallelism for every child if it is beneficial. // - Satisfy the distribution requirements of every child, if it is not // already satisfied. // We store the updated children in `new_children`. - let new_children = izip!( - plan.children().into_iter(), + let children_nodes = izip!( + children_nodes.into_iter(), plan.required_input_distribution().iter(), plan.required_input_ordering().iter(), - distribution_onwards.iter_mut(), plan.benefits_from_input_partitioning(), - plan.maintains_input_order(), - 0..n_children + plan.maintains_input_order() ) .map( - |( - mut child, - requirement, - required_input_ordering, - dist_onward, - would_benefit, - maintains, - child_idx, - )| { + |(mut child, requirement, required_input_ordering, would_benefit, maintains)| { // Don't need to apply when the returned row count is not greater than 1: - let num_rows = child.statistics()?.num_rows; + let num_rows = child.plan.statistics()?.num_rows; let repartition_beneficial_stats = if num_rows.is_exact().unwrap_or(false) { num_rows .get_value() @@ -1252,45 +1208,39 @@ fn ensure_distribution( } else { true }; + if enable_round_robin // Operator benefits from partitioning (e.g. filter): && (would_benefit && repartition_beneficial_stats) // Unless partitioning doesn't increase the partition count, it is not beneficial: - && child.output_partitioning().partition_count() < target_partitions + && child.plan.output_partitioning().partition_count() < target_partitions { // When `repartition_file_scans` is set, attempt to increase // parallelism at the source. if repartition_file_scans { if let Some(new_child) = - child.repartitioned(target_partitions, config)? + child.plan.repartitioned(target_partitions, config)? { - child = new_child; + child.plan = new_child; } } // Increase parallelism by adding round-robin repartitioning // on top of the operator. Note that we only do this if the // partition count is not already equal to the desired partition // count. - child = add_roundrobin_on_top( - child, - target_partitions, - dist_onward, - child_idx, - )?; + child = add_roundrobin_on_top(child, target_partitions)?; } // Satisfy the distribution requirement if it is unmet. match requirement { Distribution::SinglePartition => { - child = add_spm_on_top(child, dist_onward, child_idx); + child = add_spm_on_top(child); } Distribution::HashPartitioned(exprs) => { child = add_hash_on_top( child, exprs.to_vec(), target_partitions, - dist_onward, - child_idx, repartition_beneficial_stats, )?; } @@ -1303,31 +1253,38 @@ fn ensure_distribution( // - Ordering requirement cannot be satisfied by preserving ordering through repartitions, or // - using order preserving variant is not desirable. let ordering_satisfied = child + .plan .equivalence_properties() .ordering_satisfy_requirement(required_input_ordering); - if !ordering_satisfied || !order_preserving_variants_desirable { - replace_order_preserving_variants(&mut child, dist_onward)?; + if (!ordering_satisfied || !order_preserving_variants_desirable) + && child.distribution_connection + { + child = replace_order_preserving_variants(child)?; // If ordering requirements were satisfied before repartitioning, // make sure ordering requirements are still satisfied after. if ordering_satisfied { // Make sure to satisfy ordering requirement: - add_sort_above(&mut child, required_input_ordering, None); + child = add_sort_preserving_partitions( + child, + required_input_ordering, + None, + ); } } // Stop tracking distribution changing operators - *dist_onward = None; + child.distribution_connection = false; } else { // no ordering requirement match requirement { // Operator requires specific distribution. Distribution::SinglePartition | Distribution::HashPartitioned(_) => { // Since there is no ordering requirement, preserving ordering is pointless - replace_order_preserving_variants(&mut child, dist_onward)?; + child = replace_order_preserving_variants(child)?; } Distribution::UnspecifiedDistribution => { // Since ordering is lost, trying to preserve ordering is pointless - if !maintains { - replace_order_preserving_variants(&mut child, dist_onward)?; + if !maintains || plan.as_any().is::() { + child = replace_order_preserving_variants(child)?; } } } @@ -1338,7 +1295,9 @@ fn ensure_distribution( .collect::>>()?; let new_distribution_context = DistributionContext { - plan: if plan.as_any().is::() && can_interleave(&new_children) { + plan: if plan.as_any().is::() + && can_interleave(children_nodes.iter().map(|c| c.plan.clone())) + { // Add a special case for [`UnionExec`] since we want to "bubble up" // hash-partitioned data. So instead of // @@ -1362,152 +1321,116 @@ fn ensure_distribution( // - Agg: // Repartition (hash): // Data - Arc::new(InterleaveExec::try_new(new_children)?) + Arc::new(InterleaveExec::try_new( + children_nodes.iter().map(|c| c.plan.clone()).collect(), + )?) } else { - plan.with_new_children(new_children)? + plan.with_new_children( + children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? }, - distribution_onwards, + distribution_connection, + children_nodes, }; + Ok(Transformed::Yes(new_distribution_context)) } -/// A struct to keep track of distribution changing executors +/// A struct to keep track of distribution changing operators /// (`RepartitionExec`, `SortPreservingMergeExec`, `CoalescePartitionsExec`), /// and their associated parents inside `plan`. Using this information, /// we can optimize distribution of the plan if/when necessary. #[derive(Debug, Clone)] struct DistributionContext { plan: Arc, - /// Keep track of associations for each child of the plan. If `None`, - /// there is no distribution changing operator in its descendants. - distribution_onwards: Vec>, + /// Indicates whether this plan is connected to a distribution-changing + /// operator. + distribution_connection: bool, + children_nodes: Vec, } impl DistributionContext { - /// Creates an empty context. + /// Creates a tree according to the plan with empty states. fn new(plan: Arc) -> Self { - let length = plan.children().len(); - DistributionContext { + let children = plan.children(); + Self { plan, - distribution_onwards: vec![None; length], + distribution_connection: false, + children_nodes: children.into_iter().map(Self::new).collect(), } } - /// Constructs a new context from children contexts. - fn new_from_children_nodes( - children_nodes: Vec, - parent_plan: Arc, - ) -> Result { - let children_plans = children_nodes - .iter() - .map(|item| item.plan.clone()) - .collect(); - let distribution_onwards = children_nodes - .into_iter() - .enumerate() - .map(|(idx, context)| { - let DistributionContext { - plan, - // The `distribution_onwards` tree keeps track of operators - // that change distribution, or preserves the existing - // distribution (starting from an operator that change distribution). - distribution_onwards, - } = context; - if plan.children().is_empty() { - // Plan has no children, there is nothing to propagate. - None - } else if distribution_onwards[0].is_none() { - if let Some(repartition) = - plan.as_any().downcast_ref::() - { - match repartition.partitioning() { - Partitioning::RoundRobinBatch(_) - | Partitioning::Hash(_, _) => { - // Start tracking operators starting from this repartition (either roundrobin or hash): - return Some(ExecTree::new(plan, idx, vec![])); - } - _ => {} - } - } else if plan.as_any().is::() - || plan.as_any().is::() - { - // Start tracking operators starting from this sort preserving merge: - return Some(ExecTree::new(plan, idx, vec![])); - } - None - } else { - // Propagate children distribution tracking to the above - let new_distribution_onwards = izip!( - plan.required_input_distribution().iter(), - distribution_onwards.into_iter() - ) - .flat_map(|(required_dist, distribution_onwards)| { - if let Some(distribution_onwards) = distribution_onwards { - // Operator can safely propagate the distribution above. - // This is similar to maintaining order in the EnforceSorting rule. - if let Distribution::UnspecifiedDistribution = required_dist { - return Some(distribution_onwards); - } - } - None - }) - .collect::>(); - // Either: - // - None of the children has a connection to an operator that modifies distribution, or - // - The current operator requires distribution at its input so doesn't propagate it above. - if new_distribution_onwards.is_empty() { - None - } else { - Some(ExecTree::new(plan, idx, new_distribution_onwards)) - } + fn update_children(mut self) -> Result { + for child_context in self.children_nodes.iter_mut() { + child_context.distribution_connection = match child_context.plan.as_any() { + plan_any if plan_any.is::() => matches!( + plan_any + .downcast_ref::() + .unwrap() + .partitioning(), + Partitioning::RoundRobinBatch(_) | Partitioning::Hash(_, _) + ), + plan_any + if plan_any.is::() + || plan_any.is::() => + { + true } - }) - .collect(); - Ok(DistributionContext { - plan: with_new_children_if_necessary(parent_plan, children_plans)?.into(), - distribution_onwards, - }) - } + _ => { + child_context.plan.children().is_empty() + || child_context.children_nodes[0].distribution_connection + || child_context + .plan + .required_input_distribution() + .iter() + .zip(child_context.children_nodes.iter()) + .any(|(required_dist, child_context)| { + child_context.distribution_connection + && matches!( + required_dist, + Distribution::UnspecifiedDistribution + ) + }) + } + }; + } - /// Computes distribution tracking contexts for every child of the plan. - fn children(&self) -> Vec { - self.plan - .children() - .into_iter() - .map(DistributionContext::new) - .collect() + let children_plans = self + .children_nodes + .iter() + .map(|context| context.plan.clone()) + .collect::>(); + + Ok(Self { + plan: with_new_children_if_necessary(self.plan, children_plans)?.into(), + distribution_connection: false, + children_nodes: self.children_nodes, + }) } } impl TreeNode for DistributionContext { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in self.children() { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children_nodes.iter().map(Cow::Borrowed).collect() } - fn map_children(self, transform: F) -> Result + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - let children = self.children(); - if children.is_empty() { - Ok(self) - } else { - let children_nodes = children + if !self.children_nodes.is_empty() { + self.children_nodes = self + .children_nodes .into_iter() .map(transform) - .collect::>>()?; - DistributionContext::new_from_children_nodes(children_nodes, self.plan) + .collect::>()?; + self.plan = with_new_children_if_necessary( + self.plan, + self.children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); } + Ok(self) } } @@ -1516,11 +1439,11 @@ impl fmt::Display for DistributionContext { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let plan_string = get_plan_string(&self.plan); write!(f, "plan: {:?}", plan_string)?; - for (idx, child) in self.distribution_onwards.iter().enumerate() { - if let Some(child) = child { - write!(f, "idx:{:?}, exec_tree:{}", idx, child)?; - } - } + write!( + f, + "distribution_connection:{}", + self.distribution_connection, + )?; write!(f, "") } } @@ -1536,85 +1459,49 @@ struct PlanWithKeyRequirements { plan: Arc, /// Parent required key ordering required_key_ordering: Vec>, - /// The request key ordering to children - request_key_ordering: Vec>>>, + children: Vec, } impl PlanWithKeyRequirements { fn new(plan: Arc) -> Self { - let children_len = plan.children().len(); - PlanWithKeyRequirements { + let children = plan.children(); + Self { plan, required_key_ordering: vec![], - request_key_ordering: vec![None; children_len], + children: children.into_iter().map(Self::new).collect(), } } - - fn children(&self) -> Vec { - let plan_children = self.plan.children(); - assert_eq!(plan_children.len(), self.request_key_ordering.len()); - plan_children - .into_iter() - .zip(self.request_key_ordering.clone()) - .map(|(child, required)| { - let from_parent = required.unwrap_or_default(); - let length = child.children().len(); - PlanWithKeyRequirements { - plan: child, - required_key_ordering: from_parent, - request_key_ordering: vec![None; length], - } - }) - .collect() - } } impl TreeNode for PlanWithKeyRequirements { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - let children = self.children(); - for child in children { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children.iter().map(Cow::Borrowed).collect() } - fn map_children(self, transform: F) -> Result + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - let children = self.children(); - if !children.is_empty() { - let new_children: Result> = - children.into_iter().map(transform).collect(); - - let children_plans = new_children? + if !self.children.is_empty() { + self.children = self + .children .into_iter() - .map(|child| child.plan) - .collect::>(); - let new_plan = with_new_children_if_necessary(self.plan, children_plans)?; - Ok(PlanWithKeyRequirements { - plan: new_plan.into(), - required_key_ordering: self.required_key_ordering, - request_key_ordering: self.request_key_ordering, - }) - } else { - Ok(self) + .map(transform) + .collect::>()?; + self.plan = with_new_children_if_necessary( + self.plan, + self.children.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); } + Ok(self) } } /// Since almost all of these tests explicitly use `ParquetExec` they only run with the parquet feature flag on #[cfg(feature = "parquet")] #[cfg(test)] -mod tests { +pub(crate) mod tests { use std::ops::Deref; use super::*; @@ -1751,7 +1638,7 @@ mod tests { } } - fn schema() -> SchemaRef { + pub(crate) fn schema() -> SchemaRef { Arc::new(Schema::new(vec![ Field::new("a", DataType::Int64, true), Field::new("b", DataType::Int64, true), @@ -1765,7 +1652,8 @@ mod tests { parquet_exec_with_sort(vec![]) } - fn parquet_exec_with_sort( + /// create a single parquet file that is sorted + pub(crate) fn parquet_exec_with_sort( output_ordering: Vec>, ) -> Arc { Arc::new(ParquetExec::new( @@ -1778,7 +1666,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering, - infinite_source: false, }, None, None, @@ -1789,7 +1676,7 @@ mod tests { parquet_exec_multiple_sorted(vec![]) } - // Created a sorted parquet exec with multiple files + /// Created a sorted parquet exec with multiple files fn parquet_exec_multiple_sorted( output_ordering: Vec>, ) -> Arc { @@ -1806,7 +1693,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering, - infinite_source: false, }, None, None, @@ -1828,7 +1714,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering, - infinite_source: false, }, false, b',', @@ -1859,7 +1744,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering, - infinite_source: false, }, false, b',', @@ -1911,14 +1795,12 @@ mod tests { final_grouping, vec![], vec![], - vec![], Arc::new( AggregateExec::try_new( AggregateMode::Partial, group_by, vec![], vec![], - vec![], input, schema.clone(), ) @@ -2018,7 +1900,7 @@ mod tests { Arc::new(SortRequiredExec::new_with_requirement(input, sort_exprs)) } - fn trim_plan_display(plan: &str) -> Vec<&str> { + pub(crate) fn trim_plan_display(plan: &str) -> Vec<&str> { plan.split('\n') .map(|s| s.trim()) .filter(|s| !s.is_empty()) @@ -2028,7 +1910,7 @@ mod tests { fn ensure_distribution_helper( plan: Arc, target_partitions: usize, - bounded_order_preserving_variants: bool, + prefer_existing_sort: bool, ) -> Result> { let distribution_context = DistributionContext::new(plan); let mut config = ConfigOptions::new(); @@ -2036,7 +1918,7 @@ mod tests { config.optimizer.enable_round_robin_repartition = false; config.optimizer.repartition_file_scans = false; config.optimizer.repartition_file_min_size = 1024; - config.optimizer.prefer_existing_sort = bounded_order_preserving_variants; + config.optimizer.prefer_existing_sort = prefer_existing_sort; ensure_distribution(distribution_context, &config).map(|item| item.into().plan) } @@ -2058,23 +1940,33 @@ mod tests { } /// Runs the repartition optimizer and asserts the plan against the expected + /// Arguments + /// * `EXPECTED_LINES` - Expected output plan + /// * `PLAN` - Input plan + /// * `FIRST_ENFORCE_DIST` - + /// true: (EnforceDistribution, EnforceDistribution, EnforceSorting) + /// false: else runs (EnforceSorting, EnforceDistribution, EnforceDistribution) + /// * `PREFER_EXISTING_SORT` (optional) - if true, will not repartition / resort data if it is already sorted + /// * `TARGET_PARTITIONS` (optional) - number of partitions to repartition to + /// * `REPARTITION_FILE_SCANS` (optional) - if true, will repartition file scans + /// * `REPARTITION_FILE_MIN_SIZE` (optional) - minimum file size to repartition macro_rules! assert_optimized { ($EXPECTED_LINES: expr, $PLAN: expr, $FIRST_ENFORCE_DIST: expr) => { assert_optimized!($EXPECTED_LINES, $PLAN, $FIRST_ENFORCE_DIST, false, 10, false, 1024); }; - ($EXPECTED_LINES: expr, $PLAN: expr, $FIRST_ENFORCE_DIST: expr, $BOUNDED_ORDER_PRESERVING_VARIANTS: expr) => { - assert_optimized!($EXPECTED_LINES, $PLAN, $FIRST_ENFORCE_DIST, $BOUNDED_ORDER_PRESERVING_VARIANTS, 10, false, 1024); + ($EXPECTED_LINES: expr, $PLAN: expr, $FIRST_ENFORCE_DIST: expr, $PREFER_EXISTING_SORT: expr) => { + assert_optimized!($EXPECTED_LINES, $PLAN, $FIRST_ENFORCE_DIST, $PREFER_EXISTING_SORT, 10, false, 1024); }; - ($EXPECTED_LINES: expr, $PLAN: expr, $FIRST_ENFORCE_DIST: expr, $BOUNDED_ORDER_PRESERVING_VARIANTS: expr, $TARGET_PARTITIONS: expr, $REPARTITION_FILE_SCANS: expr, $REPARTITION_FILE_MIN_SIZE: expr) => { + ($EXPECTED_LINES: expr, $PLAN: expr, $FIRST_ENFORCE_DIST: expr, $PREFER_EXISTING_SORT: expr, $TARGET_PARTITIONS: expr, $REPARTITION_FILE_SCANS: expr, $REPARTITION_FILE_MIN_SIZE: expr) => { let expected_lines: Vec<&str> = $EXPECTED_LINES.iter().map(|s| *s).collect(); let mut config = ConfigOptions::new(); config.execution.target_partitions = $TARGET_PARTITIONS; config.optimizer.repartition_file_scans = $REPARTITION_FILE_SCANS; config.optimizer.repartition_file_min_size = $REPARTITION_FILE_MIN_SIZE; - config.optimizer.prefer_existing_sort = $BOUNDED_ORDER_PRESERVING_VARIANTS; + config.optimizer.prefer_existing_sort = $PREFER_EXISTING_SORT; // NOTE: These tests verify the joint `EnforceDistribution` + `EnforceSorting` cascade // because they were written prior to the separation of `BasicEnforcement` into @@ -3007,16 +2899,16 @@ mod tests { vec![ top_join_plan.as_str(), join_plan.as_str(), - "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, sort_exprs=a@0 ASC", + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[a@0 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, sort_exprs=b1@1 ASC", + "RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@1 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[b1@1 ASC]", "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, sort_exprs=c@2 ASC", + "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=c@2 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[c@2 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", @@ -3033,21 +2925,21 @@ mod tests { _ => vec![ top_join_plan.as_str(), // Below 4 operators are differences introduced, when join mode is changed - "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, sort_exprs=a@0 ASC", + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[a@0 ASC]", "CoalescePartitionsExec", join_plan.as_str(), - "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, sort_exprs=a@0 ASC", + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[a@0 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, sort_exprs=b1@1 ASC", + "RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@1 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[b1@1 ASC]", "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, sort_exprs=c@2 ASC", + "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=c@2 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[c@2 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", @@ -3121,16 +3013,16 @@ mod tests { JoinType::Inner | JoinType::Right => vec![ top_join_plan.as_str(), join_plan.as_str(), - "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, sort_exprs=a@0 ASC", + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[a@0 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, sort_exprs=b1@1 ASC", + "RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@1 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[b1@1 ASC]", "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, sort_exprs=c@2 ASC", + "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=c@2 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[c@2 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", @@ -3138,21 +3030,21 @@ mod tests { // Should include 8 RepartitionExecs (4 of them preserves order) and 4 SortExecs JoinType::Left | JoinType::Full => vec![ top_join_plan.as_str(), - "SortPreservingRepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=10, sort_exprs=b1@6 ASC", + "RepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@6 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[b1@6 ASC]", "CoalescePartitionsExec", join_plan.as_str(), - "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, sort_exprs=a@0 ASC", + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[a@0 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, sort_exprs=b1@1 ASC", + "RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@1 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[b1@1 ASC]", "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, sort_exprs=c@2 ASC", + "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=c@2 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[c@2 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", @@ -3243,7 +3135,7 @@ mod tests { let expected_first_sort_enforcement = &[ "SortMergeJoin: join_type=Inner, on=[(b3@1, b2@1), (a3@0, a2@0)]", - "SortPreservingRepartitionExec: partitioning=Hash([b3@1, a3@0], 10), input_partitions=10, sort_exprs=b3@1 ASC,a3@0 ASC", + "RepartitionExec: partitioning=Hash([b3@1, a3@0], 10), input_partitions=10, preserve_order=true, sort_exprs=b3@1 ASC,a3@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[b3@1 ASC,a3@0 ASC]", "CoalescePartitionsExec", @@ -3254,7 +3146,7 @@ mod tests { "AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[]", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([b2@1, a2@0], 10), input_partitions=10, sort_exprs=b2@1 ASC,a2@0 ASC", + "RepartitionExec: partitioning=Hash([b2@1, a2@0], 10), input_partitions=10, preserve_order=true, sort_exprs=b2@1 ASC,a2@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[b2@1 ASC,a2@0 ASC]", "CoalescePartitionsExec", @@ -3296,7 +3188,7 @@ mod tests { ]; assert_optimized!(expected, exec, true); // In this case preserving ordering through order preserving operators is not desirable - // (according to flag: bounded_order_preserving_variants) + // (according to flag: PREFER_EXISTING_SORT) // hence in this case ordering lost during CoalescePartitionsExec and re-introduced with // SortExec at the top. let expected = &[ @@ -3789,23 +3681,27 @@ mod tests { fn repartition_transitively_past_sort_with_projection_and_filter() -> Result<()> { let schema = schema(); let sort_key = vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + expr: col("a", &schema).unwrap(), options: SortOptions::default(), }]; let plan = sort_exec( sort_key, projection_exec_with_alias( filter_exec(parquet_exec()), - vec![("a".to_string(), "a".to_string())], + vec![ + ("a".to_string(), "a".to_string()), + ("b".to_string(), "b".to_string()), + ("c".to_string(), "c".to_string()), + ], ), false, ); let expected = &[ - "SortPreservingMergeExec: [c@2 ASC]", + "SortPreservingMergeExec: [a@0 ASC]", // Expect repartition on the input to the sort (as it can benefit from additional parallelism) - "SortExec: expr=[c@2 ASC]", - "ProjectionExec: expr=[a@0 as a]", + "SortExec: expr=[a@0 ASC]", + "ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c]", "FilterExec: c@2 = 0", // repartition is lowest down "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", @@ -3815,9 +3711,9 @@ mod tests { assert_optimized!(expected, plan.clone(), true); let expected_first_sort_enforcement = &[ - "SortExec: expr=[c@2 ASC]", + "SortExec: expr=[a@0 ASC]", "CoalescePartitionsExec", - "ProjectionExec: expr=[a@0 as a]", + "ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c]", "FilterExec: c@2 = 0", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", @@ -3850,6 +3746,56 @@ mod tests { Ok(()) } + #[test] + fn parallelization_multiple_files() -> Result<()> { + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("a", &schema).unwrap(), + options: SortOptions::default(), + }]; + + let plan = filter_exec(parquet_exec_multiple_sorted(vec![sort_key])); + let plan = sort_required_exec(plan); + + // The groups must have only contiguous ranges of rows from the same file + // if any group has rows from multiple files, the data is no longer sorted destroyed + // https://github.com/apache/arrow-datafusion/issues/8451 + let expected = [ + "SortRequiredExec: [a@0 ASC]", + "FilterExec: c@2 = 0", + "ParquetExec: file_groups={3 groups: [[x:0..50], [y:0..100], [x:50..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", ]; + let target_partitions = 3; + let repartition_size = 1; + assert_optimized!( + expected, + plan, + true, + true, + target_partitions, + true, + repartition_size + ); + + let expected = [ + "SortRequiredExec: [a@0 ASC]", + "FilterExec: c@2 = 0", + "ParquetExec: file_groups={8 groups: [[x:0..25], [y:0..25], [x:25..50], [y:25..50], [x:50..75], [y:50..75], [x:75..100], [y:75..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", + ]; + let target_partitions = 8; + let repartition_size = 1; + assert_optimized!( + expected, + plan, + true, + true, + target_partitions, + true, + repartition_size + ); + + Ok(()) + } + #[test] /// CsvExec on compressed csv file will not be partitioned /// (Not able to decompress chunked csv file) @@ -3898,7 +3844,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, false, b',', @@ -4335,11 +4280,11 @@ mod tests { let expected = &[ "SortPreservingMergeExec: [c@2 ASC]", "FilterExec: c@2 = 0", - "SortPreservingRepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, sort_exprs=c@2 ASC", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=c@2 ASC", "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", ]; - // last flag sets config.optimizer.bounded_order_preserving_variants + // last flag sets config.optimizer.PREFER_EXISTING_SORT assert_optimized!(expected, physical_plan.clone(), true, true); assert_optimized!(expected, physical_plan, false, true); @@ -4521,15 +4466,11 @@ mod tests { assert_plan_txt!(expected, physical_plan); let expected = &[ - "SortRequiredExec: [a@0 ASC]", // Since at the start of the rule ordering requirement is satisfied // EnforceDistribution rule satisfy this requirement also. - // ordering is re-satisfied by introduction of SortExec. - "SortExec: expr=[a@0 ASC]", + "SortRequiredExec: [a@0 ASC]", "FilterExec: c@2 = 0", - // ordering is lost here - "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", - "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", + "ParquetExec: file_groups={10 groups: [[x:0..20], [y:0..20], [x:20..40], [y:20..40], [x:40..60], [y:40..60], [x:60..80], [y:60..80], [x:80..100], [y:80..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", ]; let mut config = ConfigOptions::new(); diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 4779ced44f1ae..f609ddea66cff 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -34,6 +34,7 @@ //! in the physical plan. The first sort is unnecessary since its result is overwritten //! by another [`SortExec`]. Therefore, this rule removes it from the physical plan. +use std::borrow::Cow; use std::sync::Arc; use crate::config::ConfigOptions; @@ -44,7 +45,7 @@ use crate::physical_optimizer::replace_with_order_preserving_variants::{ use crate::physical_optimizer::sort_pushdown::{pushdown_sorts, SortPushDown}; use crate::physical_optimizer::utils::{ add_sort_above, is_coalesce_partitions, is_limit, is_repartition, is_sort, - is_sort_preserving_merge, is_union, is_window, ExecTree, + is_sort_preserving_merge, is_union, is_window, }; use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; @@ -53,14 +54,15 @@ use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::windows::{ get_best_fitting_window, BoundedWindowAggExec, WindowAggExec, }; -use crate::physical_plan::{with_new_children_if_necessary, Distribution, ExecutionPlan}; +use crate::physical_plan::{ + with_new_children_if_necessary, Distribution, ExecutionPlan, InputOrderMode, +}; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{plan_err, DataFusionError}; use datafusion_physical_expr::{PhysicalSortExpr, PhysicalSortRequirement}; - use datafusion_physical_plan::repartition::RepartitionExec; -use datafusion_physical_plan::windows::PartitionSearchMode; + use itertools::izip; /// This rule inspects [`SortExec`]'s in the given physical plan and removes the @@ -80,231 +82,172 @@ impl EnforceSorting { #[derive(Debug, Clone)] struct PlanWithCorrespondingSort { plan: Arc, - // For every child, keep a subtree of `ExecutionPlan`s starting from the - // child until the `SortExec`(s) -- could be multiple for n-ary plans like - // Union -- that determine the output ordering of the child. If the child - // has no connection to any sort, simply store None (and not a subtree). - sort_onwards: Vec>, + // For every child, track `ExecutionPlan`s starting from the child until + // the `SortExec`(s). If the child has no connection to any sort, it simply + // stores false. + sort_connection: bool, + children_nodes: Vec, } impl PlanWithCorrespondingSort { fn new(plan: Arc) -> Self { - let length = plan.children().len(); - PlanWithCorrespondingSort { + let children = plan.children(); + Self { plan, - sort_onwards: vec![None; length], + sort_connection: false, + children_nodes: children.into_iter().map(Self::new).collect(), } } - fn new_from_children_nodes( - children_nodes: Vec, + fn update_children( parent_plan: Arc, + mut children_nodes: Vec, ) -> Result { - let children_plans = children_nodes - .iter() - .map(|item| item.plan.clone()) - .collect::>(); - let sort_onwards = children_nodes - .into_iter() - .enumerate() - .map(|(idx, item)| { - let plan = &item.plan; - // Leaves of `sort_onwards` are `SortExec` operators, which impose - // an ordering. This tree collects all the intermediate executors - // that maintain this ordering. If we just saw a order imposing - // operator, we reset the tree and start accumulating. - if is_sort(plan) { - return Some(ExecTree::new(item.plan, idx, vec![])); - } else if is_limit(plan) { - // There is no sort linkage for this path, it starts at a limit. - return None; - } + for node in children_nodes.iter_mut() { + let plan = &node.plan; + // Leaves of `sort_onwards` are `SortExec` operators, which impose + // an ordering. This tree collects all the intermediate executors + // that maintain this ordering. If we just saw a order imposing + // operator, we reset the tree and start accumulating. + node.sort_connection = if is_sort(plan) { + // Initiate connection + true + } else if is_limit(plan) { + // There is no sort linkage for this path, it starts at a limit. + false + } else { let is_spm = is_sort_preserving_merge(plan); let required_orderings = plan.required_input_ordering(); let flags = plan.maintains_input_order(); - let children = izip!(flags, item.sort_onwards, required_orderings) - .filter_map(|(maintains, element, required_ordering)| { - if (required_ordering.is_none() && maintains) || is_spm { - element - } else { - None - } - }) - .collect::>(); - if !children.is_empty() { - // Add parent node to the tree if there is at least one - // child with a subtree: - Some(ExecTree::new(item.plan, idx, children)) - } else { - // There is no sort linkage for this child, do nothing. - None - } - }) - .collect(); + // Add parent node to the tree if there is at least one + // child with a sort connection: + izip!(flags, required_orderings).any(|(maintains, required_ordering)| { + let propagates_ordering = + (maintains && required_ordering.is_none()) || is_spm; + let connected_to_sort = + node.children_nodes.iter().any(|item| item.sort_connection); + propagates_ordering && connected_to_sort + }) + } + } + let children_plans = children_nodes + .iter() + .map(|item| item.plan.clone()) + .collect::>(); let plan = with_new_children_if_necessary(parent_plan, children_plans)?.into(); - Ok(PlanWithCorrespondingSort { plan, sort_onwards }) - } - fn children(&self) -> Vec { - self.plan - .children() - .into_iter() - .map(PlanWithCorrespondingSort::new) - .collect() + Ok(Self { + plan, + sort_connection: false, + children_nodes, + }) } } impl TreeNode for PlanWithCorrespondingSort { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - let children = self.children(); - for child in children { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children_nodes.iter().map(Cow::Borrowed).collect() } - fn map_children(self, transform: F) -> Result + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - let children = self.children(); - if children.is_empty() { - Ok(self) - } else { - let children_nodes = children + if !self.children_nodes.is_empty() { + self.children_nodes = self + .children_nodes .into_iter() .map(transform) - .collect::>>()?; - PlanWithCorrespondingSort::new_from_children_nodes(children_nodes, self.plan) + .collect::>()?; + self.plan = with_new_children_if_necessary( + self.plan, + self.children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); } + Ok(self) } } -/// This object is used within the [EnforceSorting] rule to track the closest +/// This object is used within the [`EnforceSorting`] rule to track the closest /// [`CoalescePartitionsExec`] descendant(s) for every child of a plan. #[derive(Debug, Clone)] struct PlanWithCorrespondingCoalescePartitions { plan: Arc, - // For every child, keep a subtree of `ExecutionPlan`s starting from the - // child until the `CoalescePartitionsExec`(s) -- could be multiple for - // n-ary plans like Union -- that affect the output partitioning of the - // child. If the child has no connection to any `CoalescePartitionsExec`, - // simply store None (and not a subtree). - coalesce_onwards: Vec>, + // Stores whether the plan is a `CoalescePartitionsExec` or it is connected to + // a `CoalescePartitionsExec` via its children. + coalesce_connection: bool, + children_nodes: Vec, } impl PlanWithCorrespondingCoalescePartitions { + /// Creates an empty tree with empty connections. fn new(plan: Arc) -> Self { - let length = plan.children().len(); - PlanWithCorrespondingCoalescePartitions { + let children = plan.children(); + Self { plan, - coalesce_onwards: vec![None; length], + coalesce_connection: false, + children_nodes: children.into_iter().map(Self::new).collect(), } } - fn new_from_children_nodes( - children_nodes: Vec, - parent_plan: Arc, - ) -> Result { - let children_plans = children_nodes + fn update_children(mut self) -> Result { + self.coalesce_connection = if self.plan.children().is_empty() { + // Plan has no children, it cannot be a `CoalescePartitionsExec`. + false + } else if is_coalesce_partitions(&self.plan) { + // Initiate a connection + true + } else { + self.children_nodes + .iter() + .enumerate() + .map(|(idx, node)| { + // Only consider operators that don't require a + // single partition, and connected to any coalesce + node.coalesce_connection + && !matches!( + self.plan.required_input_distribution()[idx], + Distribution::SinglePartition + ) + // If all children are None. There is nothing to track, set connection false. + }) + .any(|c| c) + }; + + let children_plans = self + .children_nodes .iter() .map(|item| item.plan.clone()) .collect(); - let coalesce_onwards = children_nodes - .into_iter() - .enumerate() - .map(|(idx, item)| { - // Leaves of the `coalesce_onwards` tree are `CoalescePartitionsExec` - // operators. This tree collects all the intermediate executors that - // maintain a single partition. If we just saw a `CoalescePartitionsExec` - // operator, we reset the tree and start accumulating. - let plan = item.plan; - if plan.children().is_empty() { - // Plan has no children, there is nothing to propagate. - None - } else if is_coalesce_partitions(&plan) { - Some(ExecTree::new(plan, idx, vec![])) - } else { - let children = item - .coalesce_onwards - .into_iter() - .flatten() - .filter(|item| { - // Only consider operators that don't require a - // single partition. - !matches!( - plan.required_input_distribution()[item.idx], - Distribution::SinglePartition - ) - }) - .collect::>(); - if children.is_empty() { - None - } else { - Some(ExecTree::new(plan, idx, children)) - } - } - }) - .collect(); - let plan = with_new_children_if_necessary(parent_plan, children_plans)?.into(); - Ok(PlanWithCorrespondingCoalescePartitions { - plan, - coalesce_onwards, - }) - } - - fn children(&self) -> Vec { - self.plan - .children() - .into_iter() - .map(PlanWithCorrespondingCoalescePartitions::new) - .collect() + self.plan = with_new_children_if_necessary(self.plan, children_plans)?.into(); + Ok(self) } } impl TreeNode for PlanWithCorrespondingCoalescePartitions { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - let children = self.children(); - for child in children { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children_nodes.iter().map(Cow::Borrowed).collect() } - fn map_children(self, transform: F) -> Result + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - let children = self.children(); - if children.is_empty() { - Ok(self) - } else { - let children_nodes = children + if !self.children_nodes.is_empty() { + self.children_nodes = self + .children_nodes .into_iter() .map(transform) - .collect::>>()?; - PlanWithCorrespondingCoalescePartitions::new_from_children_nodes( - children_nodes, + .collect::>()?; + self.plan = with_new_children_if_necessary( self.plan, - ) + self.children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); } + Ok(self) } } @@ -331,6 +274,7 @@ impl PhysicalOptimizerRule for EnforceSorting { } else { adjusted.plan }; + let plan_with_pipeline_fixer = OrderPreservationContext::new(new_plan); let updated_plan = plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| { @@ -344,7 +288,8 @@ impl PhysicalOptimizerRule for EnforceSorting { // Execute a top-down traversal to exploit sort push-down opportunities // missed by the bottom-up traversal: - let sort_pushdown = SortPushDown::init(updated_plan.plan); + let mut sort_pushdown = SortPushDown::new(updated_plan.plan); + sort_pushdown.assign_initial_requirements(); let adjusted = sort_pushdown.transform_down(&pushdown_sorts)?; Ok(adjusted.plan) } @@ -375,16 +320,21 @@ impl PhysicalOptimizerRule for EnforceSorting { fn parallelize_sorts( requirements: PlanWithCorrespondingCoalescePartitions, ) -> Result> { - let plan = requirements.plan; - let mut coalesce_onwards = requirements.coalesce_onwards; - if plan.children().is_empty() || coalesce_onwards[0].is_none() { + let PlanWithCorrespondingCoalescePartitions { + mut plan, + coalesce_connection, + mut children_nodes, + } = requirements.update_children()?; + + if plan.children().is_empty() || !children_nodes[0].coalesce_connection { // We only take an action when the plan is either a SortExec, a // SortPreservingMergeExec or a CoalescePartitionsExec, and they // all have a single child. Therefore, if the first child is `None`, // we can return immediately. return Ok(Transformed::No(PlanWithCorrespondingCoalescePartitions { plan, - coalesce_onwards, + coalesce_connection, + children_nodes, })); } else if (is_sort(&plan) || is_sort_preserving_merge(&plan)) && plan.output_partitioning().partition_count() <= 1 @@ -394,34 +344,30 @@ fn parallelize_sorts( // executors don't require single partition), then we can replace // the CoalescePartitionsExec + Sort cascade with a SortExec + // SortPreservingMergeExec cascade to parallelize sorting. - let mut prev_layer = plan.clone(); - update_child_to_remove_coalesce(&mut prev_layer, &mut coalesce_onwards[0])?; let (sort_exprs, fetch) = get_sort_exprs(&plan)?; - add_sort_above( - &mut prev_layer, - &PhysicalSortRequirement::from_sort_exprs(sort_exprs), - fetch, - ); - let spm = SortPreservingMergeExec::new(sort_exprs.to_vec(), prev_layer) - .with_fetch(fetch); - return Ok(Transformed::Yes(PlanWithCorrespondingCoalescePartitions { - plan: Arc::new(spm), - coalesce_onwards: vec![None], - })); + let sort_reqs = PhysicalSortRequirement::from_sort_exprs(sort_exprs); + let sort_exprs = sort_exprs.to_vec(); + update_child_to_remove_coalesce(&mut plan, &mut children_nodes[0])?; + add_sort_above(&mut plan, &sort_reqs, fetch); + let spm = SortPreservingMergeExec::new(sort_exprs, plan).with_fetch(fetch); + + return Ok(Transformed::Yes( + PlanWithCorrespondingCoalescePartitions::new(Arc::new(spm)), + )); } else if is_coalesce_partitions(&plan) { // There is an unnecessary `CoalescePartitionsExec` in the plan. - let mut prev_layer = plan.clone(); - update_child_to_remove_coalesce(&mut prev_layer, &mut coalesce_onwards[0])?; - let new_plan = plan.with_new_children(vec![prev_layer])?; - return Ok(Transformed::Yes(PlanWithCorrespondingCoalescePartitions { - plan: new_plan, - coalesce_onwards: vec![None], - })); + update_child_to_remove_coalesce(&mut plan, &mut children_nodes[0])?; + + let new_plan = Arc::new(CoalescePartitionsExec::new(plan)) as _; + return Ok(Transformed::Yes( + PlanWithCorrespondingCoalescePartitions::new(new_plan), + )); } Ok(Transformed::Yes(PlanWithCorrespondingCoalescePartitions { plan, - coalesce_onwards, + coalesce_connection, + children_nodes, })) } @@ -430,89 +376,102 @@ fn parallelize_sorts( fn ensure_sorting( requirements: PlanWithCorrespondingSort, ) -> Result> { + let requirements = PlanWithCorrespondingSort::update_children( + requirements.plan, + requirements.children_nodes, + )?; + // Perform naive analysis at the beginning -- remove already-satisfied sorts: if requirements.plan.children().is_empty() { return Ok(Transformed::No(requirements)); } - let plan = requirements.plan; - let mut children = plan.children(); - let mut sort_onwards = requirements.sort_onwards; - if let Some(result) = analyze_immediate_sort_removal(&plan, &sort_onwards) { + if let Some(result) = analyze_immediate_sort_removal(&requirements) { return Ok(Transformed::Yes(result)); } - for (idx, (child, sort_onwards, required_ordering)) in izip!( - children.iter_mut(), - sort_onwards.iter_mut(), - plan.required_input_ordering() - ) - .enumerate() + + let plan = requirements.plan; + let mut children_nodes = requirements.children_nodes; + + for (idx, (child_node, required_ordering)) in + izip!(children_nodes.iter_mut(), plan.required_input_ordering()).enumerate() { - let physical_ordering = child.output_ordering(); + let mut child_plan = child_node.plan.clone(); + let physical_ordering = child_plan.output_ordering(); match (required_ordering, physical_ordering) { (Some(required_ordering), Some(_)) => { - if !child + if !child_plan .equivalence_properties() .ordering_satisfy_requirement(&required_ordering) { // Make sure we preserve the ordering requirements: - update_child_to_remove_unnecessary_sort(child, sort_onwards, &plan)?; - add_sort_above(child, &required_ordering, None); - if is_sort(child) { - *sort_onwards = Some(ExecTree::new(child.clone(), idx, vec![])); - } else { - *sort_onwards = None; + update_child_to_remove_unnecessary_sort(idx, child_node, &plan)?; + add_sort_above(&mut child_plan, &required_ordering, None); + if is_sort(&child_plan) { + *child_node = PlanWithCorrespondingSort::update_children( + child_plan, + vec![child_node.clone()], + )?; + child_node.sort_connection = true; } } } (Some(required), None) => { // Ordering requirement is not met, we should add a `SortExec` to the plan. - add_sort_above(child, &required, None); - *sort_onwards = Some(ExecTree::new(child.clone(), idx, vec![])); + add_sort_above(&mut child_plan, &required, None); + *child_node = PlanWithCorrespondingSort::update_children( + child_plan, + vec![child_node.clone()], + )?; + child_node.sort_connection = true; } (None, Some(_)) => { // We have a `SortExec` whose effect may be neutralized by // another order-imposing operator. Remove this sort. if !plan.maintains_input_order()[idx] || is_union(&plan) { - update_child_to_remove_unnecessary_sort(child, sort_onwards, &plan)?; + update_child_to_remove_unnecessary_sort(idx, child_node, &plan)?; } } - (None, None) => {} + (None, None) => { + update_child_to_remove_unnecessary_sort(idx, child_node, &plan)?; + } } } // For window expressions, we can remove some sorts when we can // calculate the result in reverse: - if is_window(&plan) { - if let Some(tree) = &mut sort_onwards[0] { - if let Some(result) = analyze_window_sort_removal(tree, &plan)? { - return Ok(Transformed::Yes(result)); - } + if is_window(&plan) && children_nodes[0].sort_connection { + if let Some(result) = analyze_window_sort_removal(&mut children_nodes[0], &plan)? + { + return Ok(Transformed::Yes(result)); } } else if is_sort_preserving_merge(&plan) - && children[0].output_partitioning().partition_count() <= 1 + && children_nodes[0] + .plan + .output_partitioning() + .partition_count() + <= 1 { // This SortPreservingMergeExec is unnecessary, input already has a // single partition. - sort_onwards.truncate(1); - return Ok(Transformed::Yes(PlanWithCorrespondingSort { - plan: children.swap_remove(0), - sort_onwards, - })); + let child_node = children_nodes.swap_remove(0); + return Ok(Transformed::Yes(child_node)); } - Ok(Transformed::Yes(PlanWithCorrespondingSort { - plan: plan.with_new_children(children)?, - sort_onwards, - })) + Ok(Transformed::Yes( + PlanWithCorrespondingSort::update_children(plan, children_nodes)?, + )) } /// Analyzes a given [`SortExec`] (`plan`) to determine whether its input /// already has a finer ordering than it enforces. fn analyze_immediate_sort_removal( - plan: &Arc, - sort_onwards: &[Option], + node: &PlanWithCorrespondingSort, ) -> Option { + let PlanWithCorrespondingSort { + plan, + children_nodes, + .. + } = node; if let Some(sort_exec) = plan.as_any().downcast_ref::() { let sort_input = sort_exec.input().clone(); - // If this sort is unnecessary, we should remove it: if sort_input .equivalence_properties() @@ -530,20 +489,33 @@ fn analyze_immediate_sort_removal( sort_exec.expr().to_vec(), sort_input, )); - let new_tree = ExecTree::new( - new_plan.clone(), - 0, - sort_onwards.iter().flat_map(|e| e.clone()).collect(), - ); PlanWithCorrespondingSort { plan: new_plan, - sort_onwards: vec![Some(new_tree)], + // SortPreservingMergeExec has single child. + sort_connection: false, + children_nodes: children_nodes + .iter() + .cloned() + .map(|mut node| { + node.sort_connection = false; + node + }) + .collect(), } } else { // Remove the sort: PlanWithCorrespondingSort { plan: sort_input, - sort_onwards: sort_onwards.to_vec(), + sort_connection: false, + children_nodes: children_nodes[0] + .children_nodes + .iter() + .cloned() + .map(|mut node| { + node.sort_connection = false; + node + }) + .collect(), } }, ); @@ -555,15 +527,15 @@ fn analyze_immediate_sort_removal( /// Analyzes a [`WindowAggExec`] or a [`BoundedWindowAggExec`] to determine /// whether it may allow removing a sort. fn analyze_window_sort_removal( - sort_tree: &mut ExecTree, + sort_tree: &mut PlanWithCorrespondingSort, window_exec: &Arc, ) -> Result> { let requires_single_partition = matches!( - window_exec.required_input_distribution()[sort_tree.idx], + window_exec.required_input_distribution()[0], Distribution::SinglePartition ); - let mut window_child = - remove_corresponding_sort_from_sub_plan(sort_tree, requires_single_partition)?; + remove_corresponding_sort_from_sub_plan(sort_tree, requires_single_partition)?; + let mut window_child = sort_tree.plan.clone(); let (window_expr, new_window) = if let Some(exec) = window_exec.as_any().downcast_ref::() { ( @@ -609,7 +581,7 @@ fn analyze_window_sort_removal( window_expr.to_vec(), window_child, partitionby_exprs.to_vec(), - PartitionSearchMode::Sorted, + InputOrderMode::Sorted, )?) as _ } else { Arc::new(WindowAggExec::try_new( @@ -625,9 +597,9 @@ fn analyze_window_sort_removal( /// Updates child to remove the unnecessary [`CoalescePartitionsExec`] below it. fn update_child_to_remove_coalesce( child: &mut Arc, - coalesce_onwards: &mut Option, + coalesce_onwards: &mut PlanWithCorrespondingCoalescePartitions, ) -> Result<()> { - if let Some(coalesce_onwards) = coalesce_onwards { + if coalesce_onwards.coalesce_connection { *child = remove_corresponding_coalesce_in_sub_plan(coalesce_onwards, child)?; } Ok(()) @@ -635,10 +607,10 @@ fn update_child_to_remove_coalesce( /// Removes the [`CoalescePartitionsExec`] from the plan in `coalesce_onwards`. fn remove_corresponding_coalesce_in_sub_plan( - coalesce_onwards: &mut ExecTree, + coalesce_onwards: &mut PlanWithCorrespondingCoalescePartitions, parent: &Arc, ) -> Result> { - Ok(if is_coalesce_partitions(&coalesce_onwards.plan) { + if is_coalesce_partitions(&coalesce_onwards.plan) { // We can safely use the 0th index since we have a `CoalescePartitionsExec`. let mut new_plan = coalesce_onwards.plan.children()[0].clone(); while new_plan.output_partitioning() == parent.output_partitioning() @@ -647,89 +619,113 @@ fn remove_corresponding_coalesce_in_sub_plan( { new_plan = new_plan.children().swap_remove(0) } - new_plan + Ok(new_plan) } else { let plan = coalesce_onwards.plan.clone(); let mut children = plan.children(); - for item in &mut coalesce_onwards.children { - children[item.idx] = remove_corresponding_coalesce_in_sub_plan(item, &plan)?; + for (idx, node) in coalesce_onwards.children_nodes.iter_mut().enumerate() { + if node.coalesce_connection { + children[idx] = remove_corresponding_coalesce_in_sub_plan(node, &plan)?; + } } - plan.with_new_children(children)? - }) + plan.with_new_children(children) + } } /// Updates child to remove the unnecessary sort below it. fn update_child_to_remove_unnecessary_sort( - child: &mut Arc, - sort_onwards: &mut Option, + child_idx: usize, + sort_onwards: &mut PlanWithCorrespondingSort, parent: &Arc, ) -> Result<()> { - if let Some(sort_onwards) = sort_onwards { + if sort_onwards.sort_connection { let requires_single_partition = matches!( - parent.required_input_distribution()[sort_onwards.idx], + parent.required_input_distribution()[child_idx], Distribution::SinglePartition ); - *child = remove_corresponding_sort_from_sub_plan( - sort_onwards, - requires_single_partition, - )?; + remove_corresponding_sort_from_sub_plan(sort_onwards, requires_single_partition)?; } - *sort_onwards = None; + sort_onwards.sort_connection = false; Ok(()) } /// Removes the sort from the plan in `sort_onwards`. fn remove_corresponding_sort_from_sub_plan( - sort_onwards: &mut ExecTree, + sort_onwards: &mut PlanWithCorrespondingSort, requires_single_partition: bool, -) -> Result> { +) -> Result<()> { // A `SortExec` is always at the bottom of the tree. - let mut updated_plan = if is_sort(&sort_onwards.plan) { - sort_onwards.plan.children().swap_remove(0) + if is_sort(&sort_onwards.plan) { + *sort_onwards = sort_onwards.children_nodes.swap_remove(0); } else { - let plan = &sort_onwards.plan; - let mut children = plan.children(); - for item in &mut sort_onwards.children { - let requires_single_partition = matches!( - plan.required_input_distribution()[item.idx], - Distribution::SinglePartition - ); - children[item.idx] = - remove_corresponding_sort_from_sub_plan(item, requires_single_partition)?; + let PlanWithCorrespondingSort { + plan, + sort_connection: _, + children_nodes, + } = sort_onwards; + let mut any_connection = false; + for (child_idx, child_node) in children_nodes.iter_mut().enumerate() { + if child_node.sort_connection { + any_connection = true; + let requires_single_partition = matches!( + plan.required_input_distribution()[child_idx], + Distribution::SinglePartition + ); + remove_corresponding_sort_from_sub_plan( + child_node, + requires_single_partition, + )?; + } } + if any_connection || children_nodes.is_empty() { + *sort_onwards = PlanWithCorrespondingSort::update_children( + plan.clone(), + children_nodes.clone(), + )?; + } + let PlanWithCorrespondingSort { + plan, + children_nodes, + .. + } = sort_onwards; // Replace with variants that do not preserve order. if is_sort_preserving_merge(plan) { - children.swap_remove(0) + children_nodes.swap_remove(0); + *plan = plan.children().swap_remove(0); } else if let Some(repartition) = plan.as_any().downcast_ref::() { - Arc::new( - RepartitionExec::try_new( - children.swap_remove(0), - repartition.partitioning().clone(), - )? - .with_preserve_order(false), - ) - } else { - plan.clone().with_new_children(children)? + *plan = Arc::new(RepartitionExec::try_new( + children_nodes[0].plan.clone(), + repartition.output_partitioning(), + )?) as _; } }; // Deleting a merging sort may invalidate distribution requirements. // Ensure that we stay compliant with such requirements: if requires_single_partition - && updated_plan.output_partitioning().partition_count() > 1 + && sort_onwards.plan.output_partitioning().partition_count() > 1 { // If there is existing ordering, to preserve ordering use SortPreservingMergeExec // instead of CoalescePartitionsExec. - if let Some(ordering) = updated_plan.output_ordering() { - updated_plan = Arc::new(SortPreservingMergeExec::new( + if let Some(ordering) = sort_onwards.plan.output_ordering() { + let plan = Arc::new(SortPreservingMergeExec::new( ordering.to_vec(), - updated_plan, - )); + sort_onwards.plan.clone(), + )) as _; + *sort_onwards = PlanWithCorrespondingSort::update_children( + plan, + vec![sort_onwards.clone()], + )?; } else { - updated_plan = Arc::new(CoalescePartitionsExec::new(updated_plan)); + let plan = + Arc::new(CoalescePartitionsExec::new(sort_onwards.plan.clone())) as _; + *sort_onwards = PlanWithCorrespondingSort::update_children( + plan, + vec![sort_onwards.clone()], + )?; } } - Ok(updated_plan) + Ok(()) } /// Converts an [ExecutionPlan] trait object to a [PhysicalSortExpr] slice when possible. @@ -763,11 +759,10 @@ mod tests { repartition_exec, sort_exec, sort_expr, sort_expr_options, sort_merge_join_exec, sort_preserving_merge_exec, spr_repartition_exec, union_exec, }; - use crate::physical_optimizer::utils::get_plan_string; use crate::physical_plan::repartition::RepartitionExec; - use crate::physical_plan::{displayable, Partitioning}; + use crate::physical_plan::{displayable, get_plan_string, Partitioning}; use crate::prelude::{SessionConfig, SessionContext}; - use crate::test::csv_exec_sorted; + use crate::test::{csv_exec_ordered, csv_exec_sorted, stream_exec_ordered}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; @@ -775,6 +770,8 @@ mod tests { use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::{col, Column, NotExpr}; + use rstest::rstest; + fn create_test_schema() -> Result { let nullable_column = Field::new("nullable_col", DataType::Int32, true); let non_nullable_column = Field::new("non_nullable_col", DataType::Int32, false); @@ -2115,7 +2112,7 @@ mod tests { async fn test_with_lost_ordering_bounded() -> Result<()> { let schema = create_test_schema3()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, false); + let source = csv_exec_sorted(&schema, sort_exprs); let repartition_rr = repartition_exec(source); let repartition_hash = Arc::new(RepartitionExec::try_new( repartition_rr, @@ -2138,11 +2135,19 @@ mod tests { Ok(()) } + #[rstest] #[tokio::test] - async fn test_with_lost_ordering_unbounded() -> Result<()> { + async fn test_with_lost_ordering_unbounded_bounded( + #[values(false, true)] source_unbounded: bool, + ) -> Result<()> { let schema = create_test_schema3()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + // create either bounded or unbounded source + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_ordered(&schema, sort_exprs) + }; let repartition_rr = repartition_exec(source); let repartition_hash = Arc::new(RepartitionExec::try_new( repartition_rr, @@ -2151,49 +2156,71 @@ mod tests { let coalesce_partitions = coalesce_partitions_exec(repartition_hash); let physical_plan = sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = vec![ "SortExec: expr=[a@0 ASC]", " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false" + " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]", ]; - let expected_optimized = [ + let expected_input_bounded = vec![ + "SortExec: expr=[a@0 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = vec![ "SortPreservingMergeExec: [a@0 ASC]", - " SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, sort_exprs=a@0 ASC", + " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false", + " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); - Ok(()) - } - - #[tokio::test] - async fn test_with_lost_ordering_unbounded_parallelize_off() -> Result<()> { - let schema = create_test_schema3()?; - let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); - let repartition_rr = repartition_exec(source); - let repartition_hash = Arc::new(RepartitionExec::try_new( - repartition_rr, - Partitioning::Hash(vec![col("c", &schema).unwrap()], 10), - )?) as _; - let coalesce_partitions = coalesce_partitions_exec(repartition_hash); - let physical_plan = sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions); - let expected_input = ["SortExec: expr=[a@0 ASC]", + // Expected bounded results with and without flag + let expected_optimized_bounded = vec![ + "SortExec: expr=[a@0 ASC]", " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false" + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], has_header=true", ]; - let expected_optimized = [ + let expected_optimized_bounded_parallelize_sort = vec![ "SortPreservingMergeExec: [a@0 ASC]", - " SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, sort_exprs=a@0 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false", + " SortExec: expr=[a@0 ASC]", + " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], has_header=true", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, false); + let (expected_input, expected_optimized, expected_optimized_sort_parallelize) = + if source_unbounded { + ( + expected_input_unbounded, + expected_optimized_unbounded.clone(), + expected_optimized_unbounded, + ) + } else { + ( + expected_input_bounded, + expected_optimized_bounded, + expected_optimized_bounded_parallelize_sort, + ) + }; + assert_optimized!( + expected_input, + expected_optimized, + physical_plan.clone(), + false + ); + assert_optimized!( + expected_input, + expected_optimized_sort_parallelize, + physical_plan, + true + ); Ok(()) } @@ -2201,7 +2228,7 @@ mod tests { async fn test_do_not_pushdown_through_spm() -> Result<()> { let schema = create_test_schema3()?; let sort_exprs = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs.clone(), false); + let source = csv_exec_sorted(&schema, sort_exprs.clone()); let repartition_rr = repartition_exec(source); let spm = sort_preserving_merge_exec(sort_exprs, repartition_rr); let physical_plan = sort_exec(vec![sort_expr("b", &schema)], spm); @@ -2222,7 +2249,7 @@ mod tests { async fn test_pushdown_through_spm() -> Result<()> { let schema = create_test_schema3()?; let sort_exprs = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs.clone(), false); + let source = csv_exec_sorted(&schema, sort_exprs.clone()); let repartition_rr = repartition_exec(source); let spm = sort_preserving_merge_exec(sort_exprs, repartition_rr); let physical_plan = sort_exec( @@ -2250,7 +2277,7 @@ mod tests { async fn test_window_multi_layer_requirement() -> Result<()> { let schema = create_test_schema3()?; let sort_exprs = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; - let source = csv_exec_sorted(&schema, vec![], false); + let source = csv_exec_sorted(&schema, vec![]); let sort = sort_exec(sort_exprs.clone(), source); let repartition = repartition_exec(sort); let repartition = spr_repartition_exec(repartition); @@ -2261,7 +2288,7 @@ mod tests { let expected_input = [ "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", " SortPreservingMergeExec: [a@0 ASC,b@1 ASC]", - " SortPreservingRepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10, sort_exprs=a@0 ASC,b@1 ASC", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC,b@1 ASC", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " SortExec: expr=[a@0 ASC,b@1 ASC]", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index 876a464257cce..6b2fe24acf005 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -95,6 +95,7 @@ fn supports_collect_by_size( let Ok(stats) = plan.statistics() else { return false; }; + if let Some(size) = stats.total_byte_size.get_value() { *size != 0 && *size < collection_size_threshold } else if let Some(row_count) = stats.num_rows.get_value() { @@ -433,7 +434,7 @@ fn hash_join_convert_symmetric_subrule( config_options: &ConfigOptions, ) -> Option> { if let Some(hash_join) = input.plan.as_any().downcast_ref::() { - let ub_flags = &input.children_unbounded; + let ub_flags = input.children_unbounded(); let (left_unbounded, right_unbounded) = (ub_flags[0], ub_flags[1]); input.unbounded = left_unbounded || right_unbounded; let result = if left_unbounded && right_unbounded { @@ -510,7 +511,7 @@ fn hash_join_swap_subrule( _config_options: &ConfigOptions, ) -> Option> { if let Some(hash_join) = input.plan.as_any().downcast_ref::() { - let ub_flags = &input.children_unbounded; + let ub_flags = input.children_unbounded(); let (left_unbounded, right_unbounded) = (ub_flags[0], ub_flags[1]); input.unbounded = left_unbounded || right_unbounded; let result = if left_unbounded @@ -576,7 +577,7 @@ fn apply_subrules( } let is_unbounded = input .plan - .unbounded_output(&input.children_unbounded) + .unbounded_output(&input.children_unbounded()) // Treat the case where an operator can not run on unbounded data as // if it can and it outputs unbounded data. Do not raise an error yet. // Such operators may be fixed, adjusted or replaced later on during @@ -1252,6 +1253,7 @@ mod hash_join_tests { use arrow::record_batch::RecordBatch; use datafusion_common::utils::DataPtr; use datafusion_common::JoinType; + use datafusion_physical_plan::empty::EmptyExec; use std::sync::Arc; struct TestCase { @@ -1619,10 +1621,22 @@ mod hash_join_tests { false, )?; + let children = vec![ + PipelineStatePropagator { + plan: Arc::new(EmptyExec::new(Arc::new(Schema::empty()))), + unbounded: left_unbounded, + children: vec![], + }, + PipelineStatePropagator { + plan: Arc::new(EmptyExec::new(Arc::new(Schema::empty()))), + unbounded: right_unbounded, + children: vec![], + }, + ]; let initial_hash_join_state = PipelineStatePropagator { plan: Arc::new(join), unbounded: false, - children_unbounded: vec![left_unbounded, right_unbounded], + children, }; let optimized_hash_join = diff --git a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs new file mode 100644 index 0000000000000..540f9a6a132bf --- /dev/null +++ b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs @@ -0,0 +1,609 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A special-case optimizer rule that pushes limit into a grouped aggregation +//! which has no aggregate expressions or sorting requirements + +use crate::physical_optimizer::PhysicalOptimizerRule; +use crate::physical_plan::aggregates::AggregateExec; +use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use crate::physical_plan::ExecutionPlan; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::Result; +use itertools::Itertools; +use std::sync::Arc; + +/// An optimizer rule that passes a `limit` hint into grouped aggregations which don't require all +/// rows in the group to be processed for correctness. Example queries fitting this description are: +/// `SELECT distinct l_orderkey FROM lineitem LIMIT 10;` +/// `SELECT l_orderkey FROM lineitem GROUP BY l_orderkey LIMIT 10;` +pub struct LimitedDistinctAggregation {} + +impl LimitedDistinctAggregation { + /// Create a new `LimitedDistinctAggregation` + pub fn new() -> Self { + Self {} + } + + fn transform_agg( + aggr: &AggregateExec, + limit: usize, + ) -> Option> { + // rules for transforming this Aggregate are held in this method + if !aggr.is_unordered_unfiltered_group_by_distinct() { + return None; + } + + // We found what we want: clone, copy the limit down, and return modified node + let new_aggr = AggregateExec::try_new( + *aggr.mode(), + aggr.group_by().clone(), + aggr.aggr_expr().to_vec(), + aggr.filter_expr().to_vec(), + aggr.input().clone(), + aggr.input_schema(), + ) + .expect("Unable to copy Aggregate!") + .with_limit(Some(limit)); + Some(Arc::new(new_aggr)) + } + + /// transform_limit matches an `AggregateExec` as the child of a `LocalLimitExec` + /// or `GlobalLimitExec` and pushes the limit into the aggregation as a soft limit when + /// there is a group by, but no sorting, no aggregate expressions, and no filters in the + /// aggregation + fn transform_limit(plan: Arc) -> Option> { + let limit: usize; + let mut global_fetch: Option = None; + let mut global_skip: usize = 0; + let children: Vec>; + let mut is_global_limit = false; + if let Some(local_limit) = plan.as_any().downcast_ref::() { + limit = local_limit.fetch(); + children = local_limit.children(); + } else if let Some(global_limit) = plan.as_any().downcast_ref::() + { + global_fetch = global_limit.fetch(); + global_fetch?; + global_skip = global_limit.skip(); + // the aggregate must read at least fetch+skip number of rows + limit = global_fetch.unwrap() + global_skip; + children = global_limit.children(); + is_global_limit = true + } else { + return None; + } + let child = children.iter().exactly_one().ok()?; + // ensure there is no output ordering; can this rule be relaxed? + if plan.output_ordering().is_some() { + return None; + } + // ensure no ordering is required on the input + if plan.required_input_ordering()[0].is_some() { + return None; + } + + // if found_match_aggr is true, match_aggr holds a parent aggregation whose group_by + // must match that of a child aggregation in order to rewrite the child aggregation + let mut match_aggr: Arc = plan; + let mut found_match_aggr = false; + + let mut rewrite_applicable = true; + let mut closure = |plan: Arc| { + if !rewrite_applicable { + return Ok(Transformed::No(plan)); + } + if let Some(aggr) = plan.as_any().downcast_ref::() { + if found_match_aggr { + if let Some(parent_aggr) = + match_aggr.as_any().downcast_ref::() + { + if !parent_aggr.group_by().eq(aggr.group_by()) { + // a partial and final aggregation with different groupings disqualifies + // rewriting the child aggregation + rewrite_applicable = false; + return Ok(Transformed::No(plan)); + } + } + } + // either we run into an Aggregate and transform it, or disable the rewrite + // for subsequent children + match Self::transform_agg(aggr, limit) { + None => {} + Some(new_aggr) => { + match_aggr = plan; + found_match_aggr = true; + return Ok(Transformed::Yes(new_aggr)); + } + } + } + rewrite_applicable = false; + Ok(Transformed::No(plan)) + }; + let child = child.clone().transform_down_mut(&mut closure).ok()?; + if is_global_limit { + return Some(Arc::new(GlobalLimitExec::new( + child, + global_skip, + global_fetch, + ))); + } + Some(Arc::new(LocalLimitExec::new(child, limit))) + } +} + +impl Default for LimitedDistinctAggregation { + fn default() -> Self { + Self::new() + } +} + +impl PhysicalOptimizerRule for LimitedDistinctAggregation { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + let plan = if config.optimizer.enable_distinct_aggregation_soft_limit { + plan.transform_down(&|plan| { + Ok( + if let Some(plan) = + LimitedDistinctAggregation::transform_limit(plan.clone()) + { + Transformed::Yes(plan) + } else { + Transformed::No(plan) + }, + ) + })? + } else { + plan + }; + Ok(plan) + } + + fn name(&self) -> &str { + "LimitedDistinctAggregation" + } + + fn schema_check(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::error::Result; + use crate::physical_optimizer::aggregate_statistics::tests::TestAggregate; + use crate::physical_optimizer::enforce_distribution::tests::{ + parquet_exec_with_sort, schema, trim_plan_display, + }; + use crate::physical_plan::aggregates::{AggregateExec, PhysicalGroupBy}; + use crate::physical_plan::collect; + use crate::physical_plan::memory::MemoryExec; + use crate::prelude::SessionContext; + use arrow::array::Int32Array; + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use arrow::util::pretty::pretty_format_batches; + use arrow_schema::SchemaRef; + use datafusion_execution::config::SessionConfig; + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::cast; + use datafusion_physical_expr::expressions::col; + use datafusion_physical_expr::PhysicalSortExpr; + use datafusion_physical_expr::{expressions, PhysicalExpr}; + use datafusion_physical_plan::aggregates::AggregateMode; + use datafusion_physical_plan::displayable; + use std::sync::Arc; + + fn mock_data() -> Result> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + None, + Some(1), + Some(4), + Some(5), + ])), + Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(6), + Some(2), + Some(8), + Some(9), + ])), + ], + )?; + + Ok(Arc::new(MemoryExec::try_new( + &[vec![batch]], + Arc::clone(&schema), + None, + )?)) + } + + fn assert_plan_matches_expected( + plan: &Arc, + expected: &[&str], + ) -> Result<()> { + let expected_lines: Vec<&str> = expected.to_vec(); + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + + let optimized = LimitedDistinctAggregation::new() + .optimize(Arc::clone(plan), state.config_options())?; + + let optimized_result = displayable(optimized.as_ref()).indent(true).to_string(); + let actual_lines = trim_plan_display(&optimized_result); + + assert_eq!( + &expected_lines, &actual_lines, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected_lines, actual_lines + ); + + Ok(()) + } + + async fn assert_results_match_expected( + plan: Arc, + expected: &str, + ) -> Result<()> { + let cfg = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(cfg); + let batches = collect(plan, ctx.task_ctx()).await?; + let actual = format!("{}", pretty_format_batches(&batches)?); + assert_eq!(actual, expected); + Ok(()) + } + + pub fn build_group_by( + input_schema: &SchemaRef, + columns: Vec, + ) -> PhysicalGroupBy { + let mut group_by_expr: Vec<(Arc, String)> = vec![]; + for column in columns.iter() { + group_by_expr.push((col(column, input_schema).unwrap(), column.to_string())); + } + PhysicalGroupBy::new_single(group_by_expr.clone()) + } + + #[tokio::test] + async fn test_partial_final() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT a FROM MemoryExec GROUP BY a LIMIT 4;`, Partial/Final AggregateExec + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + Arc::new(partial_agg), /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(final_agg), + 4, // fetch + ); + // expected to push the limit to the Partial and Final AggregateExecs + let expected = [ + "LocalLimitExec: fetch=4", + "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[], lim=[4]", + "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[], lim=[4]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + let expected = r#" ++---+ +| a | ++---+ +| 1 | +| 2 | +| | +| 4 | ++---+ +"# + .trim(); + assert_results_match_expected(plan, expected).await?; + Ok(()) + } + + #[tokio::test] + async fn test_single_local() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT a FROM MemoryExec GROUP BY a LIMIT 4;`, Single AggregateExec + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(single_agg), + 4, // fetch + ); + // expected to push the limit to the AggregateExec + let expected = [ + "LocalLimitExec: fetch=4", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + let expected = r#" ++---+ +| a | ++---+ +| 1 | +| 2 | +| | +| 4 | ++---+ +"# + .trim(); + assert_results_match_expected(plan, expected).await?; + Ok(()) + } + + #[tokio::test] + async fn test_single_global() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT a FROM MemoryExec GROUP BY a LIMIT 4;`, Single AggregateExec + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = GlobalLimitExec::new( + Arc::new(single_agg), + 1, // skip + Some(3), // fetch + ); + // expected to push the skip+fetch limit to the AggregateExec + let expected = [ + "GlobalLimitExec: skip=1, fetch=3", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + let expected = r#" ++---+ +| a | ++---+ +| 2 | +| | +| 4 | ++---+ +"# + .trim(); + assert_results_match_expected(plan, expected).await?; + Ok(()) + } + + #[tokio::test] + async fn test_distinct_cols_different_than_group_by_cols() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT distinct a FROM MemoryExec GROUP BY a, b LIMIT 4;`, Single/Single AggregateExec + let group_by_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string(), "b".to_string()]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let distinct_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + Arc::new(group_by_agg), /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(distinct_agg), + 4, // fetch + ); + // expected to push the limit to the outer AggregateExec only + let expected = [ + "LocalLimitExec: fetch=4", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4]", + "AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + let expected = r#" ++---+ +| a | ++---+ +| 1 | +| 2 | +| | +| 4 | ++---+ +"# + .trim(); + assert_results_match_expected(plan, expected).await?; + Ok(()) + } + + #[test] + fn test_no_group_by() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT FROM MemoryExec LIMIT 10;`, Single AggregateExec + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec![]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(single_agg), + 10, // fetch + ); + // expected not to push the limit to the AggregateExec + let expected = [ + "LocalLimitExec: fetch=10", + "AggregateExec: mode=Single, gby=[], aggr=[]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + Ok(()) + } + + #[test] + fn test_has_aggregate_expression() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_star(); + + // `SELECT FROM MemoryExec LIMIT 10;`, Single AggregateExec + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![agg.count_expr()], /* aggr_expr */ + vec![None], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(single_agg), + 10, // fetch + ); + // expected not to push the limit to the AggregateExec + let expected = [ + "LocalLimitExec: fetch=10", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[COUNT(*)]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + Ok(()) + } + + #[test] + fn test_has_filter() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT a FROM MemoryExec WHERE a > 1 GROUP BY a LIMIT 10;`, Single AggregateExec + // the `a > 1` filter is applied in the AggregateExec + let filter_expr = Some(expressions::binary( + expressions::col("a", &schema)?, + Operator::Gt, + cast(expressions::lit(1u32), &schema, DataType::Int32)?, + &schema, + )?); + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![filter_expr], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(single_agg), + 10, // fetch + ); + // expected not to push the limit to the AggregateExec + // TODO(msirek): open an issue for `filter_expr` of `AggregateExec` not printing out + let expected = [ + "LocalLimitExec: fetch=10", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + Ok(()) + } + + #[test] + fn test_has_order_by() -> Result<()> { + let sort_key = vec![PhysicalSortExpr { + expr: expressions::col("a", &schema()).unwrap(), + options: SortOptions::default(), + }]; + let source = parquet_exec_with_sort(vec![sort_key]); + let schema = source.schema(); + + // `SELECT a FROM MemoryExec WHERE a > 1 GROUP BY a LIMIT 10;`, Single AggregateExec + // the `a > 1` filter is applied in the AggregateExec + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(single_agg), + 10, // fetch + ); + // expected not to push the limit to the AggregateExec + let expected = [ + "LocalLimitExec: fetch=10", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], ordering_mode=Sorted", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + Ok(()) + } +} diff --git a/datafusion/core/src/physical_optimizer/mod.rs b/datafusion/core/src/physical_optimizer/mod.rs index 9e22bff340c99..e990fead610d1 100644 --- a/datafusion/core/src/physical_optimizer/mod.rs +++ b/datafusion/core/src/physical_optimizer/mod.rs @@ -27,9 +27,11 @@ pub mod combine_partial_final_agg; pub mod enforce_distribution; pub mod enforce_sorting; pub mod join_selection; +pub mod limited_distinct_aggregation; pub mod optimizer; pub mod output_requirements; pub mod pipeline_checker; +mod projection_pushdown; pub mod pruning; pub mod replace_with_order_preserving_variants; mod sort_pushdown; diff --git a/datafusion/core/src/physical_optimizer/optimizer.rs b/datafusion/core/src/physical_optimizer/optimizer.rs index 95035e5f81a01..f8c82576e2546 100644 --- a/datafusion/core/src/physical_optimizer/optimizer.rs +++ b/datafusion/core/src/physical_optimizer/optimizer.rs @@ -19,6 +19,7 @@ use std::sync::Arc; +use super::projection_pushdown::ProjectionPushdown; use crate::config::ConfigOptions; use crate::physical_optimizer::aggregate_statistics::AggregateStatistics; use crate::physical_optimizer::coalesce_batches::CoalesceBatches; @@ -26,6 +27,7 @@ use crate::physical_optimizer::combine_partial_final_agg::CombinePartialFinalAgg use crate::physical_optimizer::enforce_distribution::EnforceDistribution; use crate::physical_optimizer::enforce_sorting::EnforceSorting; use crate::physical_optimizer::join_selection::JoinSelection; +use crate::physical_optimizer::limited_distinct_aggregation::LimitedDistinctAggregation; use crate::physical_optimizer::output_requirements::OutputRequirements; use crate::physical_optimizer::pipeline_checker::PipelineChecker; use crate::physical_optimizer::topk_aggregation::TopKAggregation; @@ -79,6 +81,10 @@ impl PhysicalOptimizer { // repartitioning and local sorting steps to meet distribution and ordering requirements. // Therefore, it should run before EnforceDistribution and EnforceSorting. Arc::new(JoinSelection::new()), + // The LimitedDistinctAggregation rule should be applied before the EnforceDistribution rule, + // as that rule may inject other operations in between the different AggregateExecs. + // Applying the rule early means only directly-connected AggregateExecs must be examined. + Arc::new(LimitedDistinctAggregation::new()), // The EnforceDistribution rule is for adding essential repartitioning to satisfy distribution // requirements. Please make sure that the whole plan tree is determined before this rule. // This rule increases parallelism if doing so is beneficial to the physical plan; i.e. at @@ -107,6 +113,13 @@ impl PhysicalOptimizer { // into an `order by max(x) limit y`. In this case it will copy the limit value down // to the aggregation, allowing it to use only y number of accumulators. Arc::new(TopKAggregation::new()), + // The ProjectionPushdown rule tries to push projections towards + // the sources in the execution plan. As a result of this process, + // a projection can disappear if it reaches the source providers, and + // sequential projections can merge into one. Even if these two cases + // are not present, the load of executors such as join or union will be + // reduced by narrowing their input tables. + Arc::new(ProjectionPushdown::new()), ]; Self::with_rules(rules) diff --git a/datafusion/core/src/physical_optimizer/output_requirements.rs b/datafusion/core/src/physical_optimizer/output_requirements.rs index d9cdc292dd562..4d03840d3dd31 100644 --- a/datafusion/core/src/physical_optimizer/output_requirements.rs +++ b/datafusion/core/src/physical_optimizer/output_requirements.rs @@ -88,14 +88,14 @@ enum RuleMode { /// /// See [`OutputRequirements`] for more details #[derive(Debug)] -struct OutputRequirementExec { +pub(crate) struct OutputRequirementExec { input: Arc, order_requirement: Option, dist_requirement: Distribution, } impl OutputRequirementExec { - fn new( + pub(crate) fn new( input: Arc, requirements: Option, dist_requirement: Distribution, @@ -107,7 +107,7 @@ impl OutputRequirementExec { } } - fn input(&self) -> Arc { + pub(crate) fn input(&self) -> Arc { self.input.clone() } } @@ -147,6 +147,10 @@ impl ExecutionPlan for OutputRequirementExec { self.input.output_ordering() } + fn maintains_input_order(&self) -> Vec { + vec![true] + } + fn children(&self) -> Vec> { vec![self.input.clone()] } diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index 43ae7dbfe7b60..e281d0e7c23eb 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -19,18 +19,19 @@ //! infinite sources, if there are any. It will reject non-runnable query plans //! that use pipeline-breaking operators on infinite input(s). +use std::borrow::Cow; use std::sync::Arc; use crate::config::ConfigOptions; use crate::error::Result; use crate::physical_optimizer::PhysicalOptimizerRule; -use crate::physical_plan::joins::SymmetricHashJoinExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; use datafusion_common::config::OptimizerOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{plan_err, DataFusionError}; use datafusion_physical_expr::intervals::utils::{check_support, is_datatype_supported}; +use datafusion_physical_plan::joins::SymmetricHashJoinExec; /// The PipelineChecker rule rejects non-runnable query plans that use /// pipeline-breaking operators on infinite input(s). @@ -70,65 +71,48 @@ impl PhysicalOptimizerRule for PipelineChecker { pub struct PipelineStatePropagator { pub(crate) plan: Arc, pub(crate) unbounded: bool, - pub(crate) children_unbounded: Vec, + pub(crate) children: Vec, } impl PipelineStatePropagator { /// Constructs a new, default pipelining state. pub fn new(plan: Arc) -> Self { - let length = plan.children().len(); - PipelineStatePropagator { + let children = plan.children(); + Self { plan, unbounded: false, - children_unbounded: vec![false; length], + children: children.into_iter().map(Self::new).collect(), } } + + /// Returns the children unboundedness information. + pub fn children_unbounded(&self) -> Vec { + self.children.iter().map(|c| c.unbounded).collect() + } } impl TreeNode for PipelineStatePropagator { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - let children = self.plan.children(); - for child in children { - match op(&PipelineStatePropagator::new(child))? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children.iter().map(Cow::Borrowed).collect() } - fn map_children(self, transform: F) -> Result + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - let children = self.plan.children(); - if !children.is_empty() { - let new_children = children + if !self.children.is_empty() { + self.children = self + .children .into_iter() - .map(PipelineStatePropagator::new) .map(transform) - .collect::>>()?; - let children_unbounded = new_children - .iter() - .map(|c| c.unbounded) - .collect::>(); - let children_plans = new_children - .into_iter() - .map(|child| child.plan) - .collect::>(); - Ok(PipelineStatePropagator { - plan: with_new_children_if_necessary(self.plan, children_plans)?.into(), - unbounded: self.unbounded, - children_unbounded, - }) - } else { - Ok(self) + .collect::>()?; + self.plan = with_new_children_if_necessary( + self.plan, + self.children.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); } + Ok(self) } } @@ -149,7 +133,7 @@ pub fn check_finiteness_requirements( } input .plan - .unbounded_output(&input.children_unbounded) + .unbounded_output(&input.children_unbounded()) .map(|value| { input.unbounded = value; Transformed::Yes(input) diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs new file mode 100644 index 0000000000000..d237a3e8607e7 --- /dev/null +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -0,0 +1,2313 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This file implements the `ProjectionPushdown` physical optimization rule. +//! The function [`remove_unnecessary_projections`] tries to push down all +//! projections one by one if the operator below is amenable to this. If a +//! projection reaches a source, it can even dissappear from the plan entirely. + +use std::collections::HashMap; +use std::sync::Arc; + +use super::output_requirements::OutputRequirementExec; +use super::PhysicalOptimizerRule; +use crate::datasource::physical_plan::CsvExec; +use crate::error::Result; +use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use crate::physical_plan::filter::FilterExec; +use crate::physical_plan::joins::utils::{ColumnIndex, JoinFilter}; +use crate::physical_plan::joins::{ + CrossJoinExec, HashJoinExec, NestedLoopJoinExec, SortMergeJoinExec, + SymmetricHashJoinExec, +}; +use crate::physical_plan::memory::MemoryExec; +use crate::physical_plan::projection::ProjectionExec; +use crate::physical_plan::repartition::RepartitionExec; +use crate::physical_plan::sorts::sort::SortExec; +use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use crate::physical_plan::{Distribution, ExecutionPlan}; + +use arrow_schema::SchemaRef; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::JoinSide; +use datafusion_physical_expr::expressions::{Column, Literal}; +use datafusion_physical_expr::{ + Partitioning, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, +}; +use datafusion_physical_plan::streaming::StreamingTableExec; +use datafusion_physical_plan::union::UnionExec; + +use itertools::Itertools; + +/// This rule inspects [`ProjectionExec`]'s in the given physical plan and tries to +/// remove or swap with its child. +#[derive(Default)] +pub struct ProjectionPushdown {} + +impl ProjectionPushdown { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl PhysicalOptimizerRule for ProjectionPushdown { + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> Result> { + plan.transform_down(&remove_unnecessary_projections) + } + + fn name(&self) -> &str { + "ProjectionPushdown" + } + + fn schema_check(&self) -> bool { + true + } +} + +/// This function checks if `plan` is a [`ProjectionExec`], and inspects its +/// input(s) to test whether it can push `plan` under its input(s). This function +/// will operate on the entire tree and may ultimately remove `plan` entirely +/// by leveraging source providers with built-in projection capabilities. +pub fn remove_unnecessary_projections( + plan: Arc, +) -> Result>> { + let maybe_modified = if let Some(projection) = + plan.as_any().downcast_ref::() + { + // If the projection does not cause any change on the input, we can + // safely remove it: + if is_projection_removable(projection) { + return Ok(Transformed::Yes(projection.input().clone())); + } + // If it does, check if we can push it under its child(ren): + let input = projection.input().as_any(); + if let Some(csv) = input.downcast_ref::() { + try_swapping_with_csv(projection, csv) + } else if let Some(memory) = input.downcast_ref::() { + try_swapping_with_memory(projection, memory)? + } else if let Some(child_projection) = input.downcast_ref::() { + let maybe_unified = try_unifying_projections(projection, child_projection)?; + return if let Some(new_plan) = maybe_unified { + // To unify 3 or more sequential projections: + remove_unnecessary_projections(new_plan) + } else { + Ok(Transformed::No(plan)) + }; + } else if let Some(output_req) = input.downcast_ref::() { + try_swapping_with_output_req(projection, output_req)? + } else if input.is::() { + try_swapping_with_coalesce_partitions(projection)? + } else if let Some(filter) = input.downcast_ref::() { + try_swapping_with_filter(projection, filter)? + } else if let Some(repartition) = input.downcast_ref::() { + try_swapping_with_repartition(projection, repartition)? + } else if let Some(sort) = input.downcast_ref::() { + try_swapping_with_sort(projection, sort)? + } else if let Some(spm) = input.downcast_ref::() { + try_swapping_with_sort_preserving_merge(projection, spm)? + } else if let Some(union) = input.downcast_ref::() { + try_pushdown_through_union(projection, union)? + } else if let Some(hash_join) = input.downcast_ref::() { + try_pushdown_through_hash_join(projection, hash_join)? + } else if let Some(cross_join) = input.downcast_ref::() { + try_swapping_with_cross_join(projection, cross_join)? + } else if let Some(nl_join) = input.downcast_ref::() { + try_swapping_with_nested_loop_join(projection, nl_join)? + } else if let Some(sm_join) = input.downcast_ref::() { + try_swapping_with_sort_merge_join(projection, sm_join)? + } else if let Some(sym_join) = input.downcast_ref::() { + try_swapping_with_sym_hash_join(projection, sym_join)? + } else if let Some(ste) = input.downcast_ref::() { + try_swapping_with_streaming_table(projection, ste)? + } else { + // If the input plan of the projection is not one of the above, we + // conservatively assume that pushing the projection down may hurt. + // When adding new operators, consider adding them here if you + // think pushing projections under them is beneficial. + None + } + } else { + return Ok(Transformed::No(plan)); + }; + + Ok(maybe_modified.map_or(Transformed::No(plan), Transformed::Yes)) +} + +/// Tries to embed `projection` to its input (`csv`). If possible, returns +/// [`CsvExec`] as the top plan. Otherwise, returns `None`. +fn try_swapping_with_csv( + projection: &ProjectionExec, + csv: &CsvExec, +) -> Option> { + // If there is any non-column or alias-carrier expression, Projection should not be removed. + // This process can be moved into CsvExec, but it would be an overlap of their responsibility. + all_alias_free_columns(projection.expr()).then(|| { + let mut file_scan = csv.base_config().clone(); + let new_projections = + new_projections_for_columns(projection, &file_scan.projection); + file_scan.projection = Some(new_projections); + + Arc::new(CsvExec::new( + file_scan, + csv.has_header(), + csv.delimiter(), + csv.quote(), + csv.escape(), + csv.file_compression_type, + )) as _ + }) +} + +/// Tries to embed `projection` to its input (`memory`). If possible, returns +/// [`MemoryExec`] as the top plan. Otherwise, returns `None`. +fn try_swapping_with_memory( + projection: &ProjectionExec, + memory: &MemoryExec, +) -> Result>> { + // If there is any non-column or alias-carrier expression, Projection should not be removed. + // This process can be moved into MemoryExec, but it would be an overlap of their responsibility. + all_alias_free_columns(projection.expr()) + .then(|| { + let new_projections = + new_projections_for_columns(projection, memory.projection()); + + MemoryExec::try_new( + memory.partitions(), + memory.original_schema(), + Some(new_projections), + ) + .map(|e| Arc::new(e) as _) + }) + .transpose() +} + +/// Tries to embed `projection` to its input (`streaming table`). +/// If possible, returns [`StreamingTableExec`] as the top plan. Otherwise, +/// returns `None`. +fn try_swapping_with_streaming_table( + projection: &ProjectionExec, + streaming_table: &StreamingTableExec, +) -> Result>> { + if !all_alias_free_columns(projection.expr()) { + return Ok(None); + } + + let streaming_table_projections = streaming_table + .projection() + .as_ref() + .map(|i| i.as_ref().to_vec()); + let new_projections = + new_projections_for_columns(projection, &streaming_table_projections); + + let mut lex_orderings = vec![]; + for lex_ordering in streaming_table.projected_output_ordering().into_iter() { + let mut orderings = vec![]; + for order in lex_ordering { + let Some(new_ordering) = update_expr(&order.expr, projection.expr(), false)? + else { + return Ok(None); + }; + orderings.push(PhysicalSortExpr { + expr: new_ordering, + options: order.options, + }); + } + lex_orderings.push(orderings); + } + + StreamingTableExec::try_new( + streaming_table.partition_schema().clone(), + streaming_table.partitions().clone(), + Some(&new_projections), + lex_orderings, + streaming_table.is_infinite(), + ) + .map(|e| Some(Arc::new(e) as _)) +} + +/// Unifies `projection` with its input (which is also a [`ProjectionExec`]). +fn try_unifying_projections( + projection: &ProjectionExec, + child: &ProjectionExec, +) -> Result>> { + let mut projected_exprs = vec![]; + let mut column_ref_map: HashMap = HashMap::new(); + + // Collect the column references usage in the outer projection. + projection.expr().iter().for_each(|(expr, _)| { + expr.apply(&mut |expr| { + Ok({ + if let Some(column) = expr.as_any().downcast_ref::() { + *column_ref_map.entry(column.clone()).or_default() += 1; + } + VisitRecursion::Continue + }) + }) + .unwrap(); + }); + + // Merging these projections is not beneficial, e.g + // If an expression is not trivial and it is referred more than 1, unifies projections will be + // beneficial as caching mechanism for non-trivial computations. + // See discussion in: https://github.com/apache/arrow-datafusion/issues/8296 + if column_ref_map.iter().any(|(column, count)| { + *count > 1 && !is_expr_trivial(&child.expr()[column.index()].0.clone()) + }) { + return Ok(None); + } + + for (expr, alias) in projection.expr() { + // If there is no match in the input projection, we cannot unify these + // projections. This case will arise if the projection expression contains + // a `PhysicalExpr` variant `update_expr` doesn't support. + let Some(expr) = update_expr(expr, child.expr(), true)? else { + return Ok(None); + }; + projected_exprs.push((expr, alias.clone())); + } + + ProjectionExec::try_new(projected_exprs, child.input().clone()) + .map(|e| Some(Arc::new(e) as _)) +} + +/// Checks if the given expression is trivial. +/// An expression is considered trivial if it is either a `Column` or a `Literal`. +fn is_expr_trivial(expr: &Arc) -> bool { + expr.as_any().downcast_ref::().is_some() + || expr.as_any().downcast_ref::().is_some() +} + +/// Tries to swap `projection` with its input (`output_req`). If possible, +/// performs the swap and returns [`OutputRequirementExec`] as the top plan. +/// Otherwise, returns `None`. +fn try_swapping_with_output_req( + projection: &ProjectionExec, + output_req: &OutputRequirementExec, +) -> Result>> { + // If the projection does not narrow the the schema, we should not try to push it down: + if projection.expr().len() >= projection.input().schema().fields().len() { + return Ok(None); + } + + let mut updated_sort_reqs = vec![]; + // None or empty_vec can be treated in the same way. + if let Some(reqs) = &output_req.required_input_ordering()[0] { + for req in reqs { + let Some(new_expr) = update_expr(&req.expr, projection.expr(), false)? else { + return Ok(None); + }; + updated_sort_reqs.push(PhysicalSortRequirement { + expr: new_expr, + options: req.options, + }); + } + } + + let dist_req = match &output_req.required_input_distribution()[0] { + Distribution::HashPartitioned(exprs) => { + let mut updated_exprs = vec![]; + for expr in exprs { + let Some(new_expr) = update_expr(expr, projection.expr(), false)? else { + return Ok(None); + }; + updated_exprs.push(new_expr); + } + Distribution::HashPartitioned(updated_exprs) + } + dist => dist.clone(), + }; + + make_with_child(projection, &output_req.input()) + .map(|input| { + OutputRequirementExec::new( + input, + (!updated_sort_reqs.is_empty()).then_some(updated_sort_reqs), + dist_req, + ) + }) + .map(|e| Some(Arc::new(e) as _)) +} + +/// Tries to swap `projection` with its input, which is known to be a +/// [`CoalescePartitionsExec`]. If possible, performs the swap and returns +/// [`CoalescePartitionsExec`] as the top plan. Otherwise, returns `None`. +fn try_swapping_with_coalesce_partitions( + projection: &ProjectionExec, +) -> Result>> { + // If the projection does not narrow the the schema, we should not try to push it down: + if projection.expr().len() >= projection.input().schema().fields().len() { + return Ok(None); + } + // CoalescePartitionsExec always has a single child, so zero indexing is safe. + make_with_child(projection, &projection.input().children()[0]) + .map(|e| Some(Arc::new(CoalescePartitionsExec::new(e)) as _)) +} + +/// Tries to swap `projection` with its input (`filter`). If possible, performs +/// the swap and returns [`FilterExec`] as the top plan. Otherwise, returns `None`. +fn try_swapping_with_filter( + projection: &ProjectionExec, + filter: &FilterExec, +) -> Result>> { + // If the projection does not narrow the the schema, we should not try to push it down: + if projection.expr().len() >= projection.input().schema().fields().len() { + return Ok(None); + } + // Each column in the predicate expression must exist after the projection. + let Some(new_predicate) = update_expr(filter.predicate(), projection.expr(), false)? + else { + return Ok(None); + }; + + FilterExec::try_new(new_predicate, make_with_child(projection, filter.input())?) + .and_then(|e| { + let selectivity = filter.default_selectivity(); + e.with_default_selectivity(selectivity) + }) + .map(|e| Some(Arc::new(e) as _)) +} + +/// Tries to swap the projection with its input [`RepartitionExec`]. If it can be done, +/// it returns the new swapped version having the [`RepartitionExec`] as the top plan. +/// Otherwise, it returns None. +fn try_swapping_with_repartition( + projection: &ProjectionExec, + repartition: &RepartitionExec, +) -> Result>> { + // If the projection does not narrow the the schema, we should not try to push it down. + if projection.expr().len() >= projection.input().schema().fields().len() { + return Ok(None); + } + + // If pushdown is not beneficial or applicable, break it. + if projection.benefits_from_input_partitioning()[0] || !all_columns(projection.expr()) + { + return Ok(None); + } + + let new_projection = make_with_child(projection, repartition.input())?; + + let new_partitioning = match repartition.partitioning() { + Partitioning::Hash(partitions, size) => { + let mut new_partitions = vec![]; + for partition in partitions { + let Some(new_partition) = + update_expr(partition, projection.expr(), false)? + else { + return Ok(None); + }; + new_partitions.push(new_partition); + } + Partitioning::Hash(new_partitions, *size) + } + others => others.clone(), + }; + + Ok(Some(Arc::new(RepartitionExec::try_new( + new_projection, + new_partitioning, + )?))) +} + +/// Tries to swap the projection with its input [`SortExec`]. If it can be done, +/// it returns the new swapped version having the [`SortExec`] as the top plan. +/// Otherwise, it returns None. +fn try_swapping_with_sort( + projection: &ProjectionExec, + sort: &SortExec, +) -> Result>> { + // If the projection does not narrow the the schema, we should not try to push it down. + if projection.expr().len() >= projection.input().schema().fields().len() { + return Ok(None); + } + + let mut updated_exprs = vec![]; + for sort in sort.expr() { + let Some(new_expr) = update_expr(&sort.expr, projection.expr(), false)? else { + return Ok(None); + }; + updated_exprs.push(PhysicalSortExpr { + expr: new_expr, + options: sort.options, + }); + } + + Ok(Some(Arc::new( + SortExec::new(updated_exprs, make_with_child(projection, sort.input())?) + .with_fetch(sort.fetch()) + .with_preserve_partitioning(sort.preserve_partitioning()), + ))) +} + +/// Tries to swap the projection with its input [`SortPreservingMergeExec`]. +/// If this is possible, it returns the new [`SortPreservingMergeExec`] whose +/// child is a projection. Otherwise, it returns None. +fn try_swapping_with_sort_preserving_merge( + projection: &ProjectionExec, + spm: &SortPreservingMergeExec, +) -> Result>> { + // If the projection does not narrow the the schema, we should not try to push it down. + if projection.expr().len() >= projection.input().schema().fields().len() { + return Ok(None); + } + + let mut updated_exprs = vec![]; + for sort in spm.expr() { + let Some(updated_expr) = update_expr(&sort.expr, projection.expr(), false)? + else { + return Ok(None); + }; + updated_exprs.push(PhysicalSortExpr { + expr: updated_expr, + options: sort.options, + }); + } + + Ok(Some(Arc::new( + SortPreservingMergeExec::new( + updated_exprs, + make_with_child(projection, spm.input())?, + ) + .with_fetch(spm.fetch()), + ))) +} + +/// Tries to push `projection` down through `union`. If possible, performs the +/// pushdown and returns a new [`UnionExec`] as the top plan which has projections +/// as its children. Otherwise, returns `None`. +fn try_pushdown_through_union( + projection: &ProjectionExec, + union: &UnionExec, +) -> Result>> { + // If the projection doesn't narrow the schema, we shouldn't try to push it down. + if projection.expr().len() >= projection.input().schema().fields().len() { + return Ok(None); + } + + let new_children = union + .children() + .into_iter() + .map(|child| make_with_child(projection, &child)) + .collect::>>()?; + + Ok(Some(Arc::new(UnionExec::new(new_children)))) +} + +/// Tries to push `projection` down through `hash_join`. If possible, performs the +/// pushdown and returns a new [`HashJoinExec`] as the top plan which has projections +/// as its children. Otherwise, returns `None`. +fn try_pushdown_through_hash_join( + projection: &ProjectionExec, + hash_join: &HashJoinExec, +) -> Result>> { + // Convert projected expressions to columns. We can not proceed if this is + // not possible. + let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else { + return Ok(None); + }; + + let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders( + hash_join.left().schema().fields().len(), + &projection_as_columns, + ); + + if !join_allows_pushdown( + &projection_as_columns, + hash_join.schema(), + far_right_left_col_ind, + far_left_right_col_ind, + ) { + return Ok(None); + } + + let Some(new_on) = update_join_on( + &projection_as_columns[0..=far_right_left_col_ind as _], + &projection_as_columns[far_left_right_col_ind as _..], + hash_join.on(), + ) else { + return Ok(None); + }; + + let new_filter = if let Some(filter) = hash_join.filter() { + match update_join_filter( + &projection_as_columns[0..=far_right_left_col_ind as _], + &projection_as_columns[far_left_right_col_ind as _..], + filter, + hash_join.left(), + hash_join.right(), + ) { + Some(updated_filter) => Some(updated_filter), + None => return Ok(None), + } + } else { + None + }; + + let (new_left, new_right) = new_join_children( + projection_as_columns, + far_right_left_col_ind, + far_left_right_col_ind, + hash_join.left(), + hash_join.right(), + )?; + + Ok(Some(Arc::new(HashJoinExec::try_new( + Arc::new(new_left), + Arc::new(new_right), + new_on, + new_filter, + hash_join.join_type(), + *hash_join.partition_mode(), + hash_join.null_equals_null, + )?))) +} + +/// Tries to swap the projection with its input [`CrossJoinExec`]. If it can be done, +/// it returns the new swapped version having the [`CrossJoinExec`] as the top plan. +/// Otherwise, it returns None. +fn try_swapping_with_cross_join( + projection: &ProjectionExec, + cross_join: &CrossJoinExec, +) -> Result>> { + // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed. + let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else { + return Ok(None); + }; + + let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders( + cross_join.left().schema().fields().len(), + &projection_as_columns, + ); + + if !join_allows_pushdown( + &projection_as_columns, + cross_join.schema(), + far_right_left_col_ind, + far_left_right_col_ind, + ) { + return Ok(None); + } + + let (new_left, new_right) = new_join_children( + projection_as_columns, + far_right_left_col_ind, + far_left_right_col_ind, + cross_join.left(), + cross_join.right(), + )?; + + Ok(Some(Arc::new(CrossJoinExec::new( + Arc::new(new_left), + Arc::new(new_right), + )))) +} + +/// Tries to swap the projection with its input [`NestedLoopJoinExec`]. If it can be done, +/// it returns the new swapped version having the [`NestedLoopJoinExec`] as the top plan. +/// Otherwise, it returns None. +fn try_swapping_with_nested_loop_join( + projection: &ProjectionExec, + nl_join: &NestedLoopJoinExec, +) -> Result>> { + // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed. + let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else { + return Ok(None); + }; + + let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders( + nl_join.left().schema().fields().len(), + &projection_as_columns, + ); + + if !join_allows_pushdown( + &projection_as_columns, + nl_join.schema(), + far_right_left_col_ind, + far_left_right_col_ind, + ) { + return Ok(None); + } + + let new_filter = if let Some(filter) = nl_join.filter() { + match update_join_filter( + &projection_as_columns[0..=far_right_left_col_ind as _], + &projection_as_columns[far_left_right_col_ind as _..], + filter, + nl_join.left(), + nl_join.right(), + ) { + Some(updated_filter) => Some(updated_filter), + None => return Ok(None), + } + } else { + None + }; + + let (new_left, new_right) = new_join_children( + projection_as_columns, + far_right_left_col_ind, + far_left_right_col_ind, + nl_join.left(), + nl_join.right(), + )?; + + Ok(Some(Arc::new(NestedLoopJoinExec::try_new( + Arc::new(new_left), + Arc::new(new_right), + new_filter, + nl_join.join_type(), + )?))) +} + +/// Tries to swap the projection with its input [`SortMergeJoinExec`]. If it can be done, +/// it returns the new swapped version having the [`SortMergeJoinExec`] as the top plan. +/// Otherwise, it returns None. +fn try_swapping_with_sort_merge_join( + projection: &ProjectionExec, + sm_join: &SortMergeJoinExec, +) -> Result>> { + // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed. + let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else { + return Ok(None); + }; + + let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders( + sm_join.left().schema().fields().len(), + &projection_as_columns, + ); + + if !join_allows_pushdown( + &projection_as_columns, + sm_join.schema(), + far_right_left_col_ind, + far_left_right_col_ind, + ) { + return Ok(None); + } + + let Some(new_on) = update_join_on( + &projection_as_columns[0..=far_right_left_col_ind as _], + &projection_as_columns[far_left_right_col_ind as _..], + sm_join.on(), + ) else { + return Ok(None); + }; + + let (new_left, new_right) = new_join_children( + projection_as_columns, + far_right_left_col_ind, + far_left_right_col_ind, + &sm_join.children()[0], + &sm_join.children()[1], + )?; + + Ok(Some(Arc::new(SortMergeJoinExec::try_new( + Arc::new(new_left), + Arc::new(new_right), + new_on, + sm_join.join_type, + sm_join.sort_options.clone(), + sm_join.null_equals_null, + )?))) +} + +/// Tries to swap the projection with its input [`SymmetricHashJoinExec`]. If it can be done, +/// it returns the new swapped version having the [`SymmetricHashJoinExec`] as the top plan. +/// Otherwise, it returns None. +fn try_swapping_with_sym_hash_join( + projection: &ProjectionExec, + sym_join: &SymmetricHashJoinExec, +) -> Result>> { + // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed. + let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else { + return Ok(None); + }; + + let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders( + sym_join.left().schema().fields().len(), + &projection_as_columns, + ); + + if !join_allows_pushdown( + &projection_as_columns, + sym_join.schema(), + far_right_left_col_ind, + far_left_right_col_ind, + ) { + return Ok(None); + } + + let Some(new_on) = update_join_on( + &projection_as_columns[0..=far_right_left_col_ind as _], + &projection_as_columns[far_left_right_col_ind as _..], + sym_join.on(), + ) else { + return Ok(None); + }; + + let new_filter = if let Some(filter) = sym_join.filter() { + match update_join_filter( + &projection_as_columns[0..=far_right_left_col_ind as _], + &projection_as_columns[far_left_right_col_ind as _..], + filter, + sym_join.left(), + sym_join.right(), + ) { + Some(updated_filter) => Some(updated_filter), + None => return Ok(None), + } + } else { + None + }; + + let (new_left, new_right) = new_join_children( + projection_as_columns, + far_right_left_col_ind, + far_left_right_col_ind, + sym_join.left(), + sym_join.right(), + )?; + + Ok(Some(Arc::new(SymmetricHashJoinExec::try_new( + Arc::new(new_left), + Arc::new(new_right), + new_on, + new_filter, + sym_join.join_type(), + sym_join.null_equals_null(), + sym_join.partition_mode(), + )?))) +} + +/// Compare the inputs and outputs of the projection. If the projection causes +/// any change in the fields, it returns `false`. +fn is_projection_removable(projection: &ProjectionExec) -> bool { + all_alias_free_columns(projection.expr()) && { + let schema = projection.schema(); + let input_schema = projection.input().schema(); + let fields = schema.fields(); + let input_fields = input_schema.fields(); + fields.len() == input_fields.len() + && fields + .iter() + .zip(input_fields.iter()) + .all(|(out, input)| out.eq(input)) + } +} + +/// Given the expression set of a projection, checks if the projection causes +/// any renaming or constructs a non-`Column` physical expression. +fn all_alias_free_columns(exprs: &[(Arc, String)]) -> bool { + exprs.iter().all(|(expr, alias)| { + expr.as_any() + .downcast_ref::() + .map(|column| column.name() == alias) + .unwrap_or(false) + }) +} + +/// Updates a source provider's projected columns according to the given +/// projection operator's expressions. To use this function safely, one must +/// ensure that all expressions are `Column` expressions without aliases. +fn new_projections_for_columns( + projection: &ProjectionExec, + source: &Option>, +) -> Vec { + projection + .expr() + .iter() + .filter_map(|(expr, _)| { + expr.as_any() + .downcast_ref::() + .and_then(|expr| source.as_ref().map(|proj| proj[expr.index()])) + }) + .collect() +} + +/// The function operates in two modes: +/// +/// 1) When `sync_with_child` is `true`: +/// +/// The function updates the indices of `expr` if the expression resides +/// in the input plan. For instance, given the expressions `a@1 + b@2` +/// and `c@0` with the input schema `c@2, a@0, b@1`, the expressions are +/// updated to `a@0 + b@1` and `c@2`. +/// +/// 2) When `sync_with_child` is `false`: +/// +/// The function determines how the expression would be updated if a projection +/// was placed before the plan associated with the expression. If the expression +/// cannot be rewritten after the projection, it returns `None`. For example, +/// given the expressions `c@0`, `a@1` and `b@2`, and the [`ProjectionExec`] with +/// an output schema of `a, c_new`, then `c@0` becomes `c_new@1`, `a@1` becomes +/// `a@0`, but `b@2` results in `None` since the projection does not include `b`. +fn update_expr( + expr: &Arc, + projected_exprs: &[(Arc, String)], + sync_with_child: bool, +) -> Result>> { + #[derive(Debug, PartialEq)] + enum RewriteState { + /// The expression is unchanged. + Unchanged, + /// Some part of the expression has been rewritten + RewrittenValid, + /// Some part of the expression has been rewritten, but some column + /// references could not be. + RewrittenInvalid, + } + + let mut state = RewriteState::Unchanged; + + let new_expr = expr + .clone() + .transform_up_mut(&mut |expr: Arc| { + if state == RewriteState::RewrittenInvalid { + return Ok(Transformed::No(expr)); + } + + let Some(column) = expr.as_any().downcast_ref::() else { + return Ok(Transformed::No(expr)); + }; + if sync_with_child { + state = RewriteState::RewrittenValid; + // Update the index of `column`: + Ok(Transformed::Yes(projected_exprs[column.index()].0.clone())) + } else { + // default to invalid, in case we can't find the relevant column + state = RewriteState::RewrittenInvalid; + // Determine how to update `column` to accommodate `projected_exprs` + projected_exprs + .iter() + .enumerate() + .find_map(|(index, (projected_expr, alias))| { + projected_expr.as_any().downcast_ref::().and_then( + |projected_column| { + column.name().eq(projected_column.name()).then(|| { + state = RewriteState::RewrittenValid; + Arc::new(Column::new(alias, index)) as _ + }) + }, + ) + }) + .map_or_else( + || Ok(Transformed::No(expr)), + |c| Ok(Transformed::Yes(c)), + ) + } + }); + + new_expr.map(|e| (state == RewriteState::RewrittenValid).then_some(e)) +} + +/// Creates a new [`ProjectionExec`] instance with the given child plan and +/// projected expressions. +fn make_with_child( + projection: &ProjectionExec, + child: &Arc, +) -> Result> { + ProjectionExec::try_new(projection.expr().to_vec(), child.clone()) + .map(|e| Arc::new(e) as _) +} + +/// Returns `true` if all the expressions in the argument are `Column`s. +fn all_columns(exprs: &[(Arc, String)]) -> bool { + exprs.iter().all(|(expr, _)| expr.as_any().is::()) +} + +/// Downcasts all the expressions in `exprs` to `Column`s. If any of the given +/// expressions is not a `Column`, returns `None`. +fn physical_to_column_exprs( + exprs: &[(Arc, String)], +) -> Option> { + exprs + .iter() + .map(|(expr, alias)| { + expr.as_any() + .downcast_ref::() + .map(|col| (col.clone(), alias.clone())) + }) + .collect() +} + +/// Returns the last index before encountering a column coming from the right table when traveling +/// through the projection from left to right, and the last index before encountering a column +/// coming from the left table when traveling through the projection from right to left. +/// If there is no column in the projection coming from the left side, it returns (-1, ...), +/// if there is no column in the projection coming from the right side, it returns (..., projection length). +fn join_table_borders( + left_table_column_count: usize, + projection_as_columns: &[(Column, String)], +) -> (i32, i32) { + let far_right_left_col_ind = projection_as_columns + .iter() + .enumerate() + .take_while(|(_, (projection_column, _))| { + projection_column.index() < left_table_column_count + }) + .last() + .map(|(index, _)| index as i32) + .unwrap_or(-1); + + let far_left_right_col_ind = projection_as_columns + .iter() + .enumerate() + .rev() + .take_while(|(_, (projection_column, _))| { + projection_column.index() >= left_table_column_count + }) + .last() + .map(|(index, _)| index as i32) + .unwrap_or(projection_as_columns.len() as i32); + + (far_right_left_col_ind, far_left_right_col_ind) +} + +/// Tries to update the equi-join `Column`'s of a join as if the the input of +/// the join was replaced by a projection. +fn update_join_on( + proj_left_exprs: &[(Column, String)], + proj_right_exprs: &[(Column, String)], + hash_join_on: &[(Column, Column)], +) -> Option> { + // TODO: Clippy wants the "map" call removed, but doing so generates + // a compilation error. Remove the clippy directive once this + // issue is fixed. + #[allow(clippy::map_identity)] + let (left_idx, right_idx): (Vec<_>, Vec<_>) = hash_join_on + .iter() + .map(|(left, right)| (left, right)) + .unzip(); + + let new_left_columns = new_columns_for_join_on(&left_idx, proj_left_exprs); + let new_right_columns = new_columns_for_join_on(&right_idx, proj_right_exprs); + + match (new_left_columns, new_right_columns) { + (Some(left), Some(right)) => Some(left.into_iter().zip(right).collect()), + _ => None, + } +} + +/// This function generates a new set of columns to be used in a hash join +/// operation based on a set of equi-join conditions (`hash_join_on`) and a +/// list of projection expressions (`projection_exprs`). +fn new_columns_for_join_on( + hash_join_on: &[&Column], + projection_exprs: &[(Column, String)], +) -> Option> { + let new_columns = hash_join_on + .iter() + .filter_map(|on| { + projection_exprs + .iter() + .enumerate() + .find(|(_, (proj_column, _))| on.name() == proj_column.name()) + .map(|(index, (_, alias))| Column::new(alias, index)) + }) + .collect::>(); + (new_columns.len() == hash_join_on.len()).then_some(new_columns) +} + +/// Tries to update the column indices of a [`JoinFilter`] as if the the input of +/// the join was replaced by a projection. +fn update_join_filter( + projection_left_exprs: &[(Column, String)], + projection_right_exprs: &[(Column, String)], + join_filter: &JoinFilter, + join_left: &Arc, + join_right: &Arc, +) -> Option { + let mut new_left_indices = new_indices_for_join_filter( + join_filter, + JoinSide::Left, + projection_left_exprs, + join_left.schema(), + ) + .into_iter(); + let mut new_right_indices = new_indices_for_join_filter( + join_filter, + JoinSide::Right, + projection_right_exprs, + join_right.schema(), + ) + .into_iter(); + + // Check if all columns match: + (new_right_indices.len() + new_left_indices.len() + == join_filter.column_indices().len()) + .then(|| { + JoinFilter::new( + join_filter.expression().clone(), + join_filter + .column_indices() + .iter() + .map(|col_idx| ColumnIndex { + index: if col_idx.side == JoinSide::Left { + new_left_indices.next().unwrap() + } else { + new_right_indices.next().unwrap() + }, + side: col_idx.side, + }) + .collect(), + join_filter.schema().clone(), + ) + }) +} + +/// This function determines and returns a vector of indices representing the +/// positions of columns in `projection_exprs` that are involved in `join_filter`, +/// and correspond to a particular side (`join_side`) of the join operation. +fn new_indices_for_join_filter( + join_filter: &JoinFilter, + join_side: JoinSide, + projection_exprs: &[(Column, String)], + join_child_schema: SchemaRef, +) -> Vec { + join_filter + .column_indices() + .iter() + .filter(|col_idx| col_idx.side == join_side) + .filter_map(|col_idx| { + projection_exprs.iter().position(|(col, _)| { + col.name() == join_child_schema.fields()[col_idx.index].name() + }) + }) + .collect() +} + +/// Checks three conditions for pushing a projection down through a join: +/// - Projection must narrow the join output schema. +/// - Columns coming from left/right tables must be collected at the left/right +/// sides of the output table. +/// - Left or right table is not lost after the projection. +fn join_allows_pushdown( + projection_as_columns: &[(Column, String)], + join_schema: SchemaRef, + far_right_left_col_ind: i32, + far_left_right_col_ind: i32, +) -> bool { + // Projection must narrow the join output: + projection_as_columns.len() < join_schema.fields().len() + // Are the columns from different tables mixed? + && (far_right_left_col_ind + 1 == far_left_right_col_ind) + // Left or right table is not lost after the projection. + && far_right_left_col_ind >= 0 + && far_left_right_col_ind < projection_as_columns.len() as i32 +} + +/// If pushing down the projection over this join's children seems possible, +/// this function constructs the new [`ProjectionExec`]s that will come on top +/// of the original children of the join. +fn new_join_children( + projection_as_columns: Vec<(Column, String)>, + far_right_left_col_ind: i32, + far_left_right_col_ind: i32, + left_child: &Arc, + right_child: &Arc, +) -> Result<(ProjectionExec, ProjectionExec)> { + let new_left = ProjectionExec::try_new( + projection_as_columns[0..=far_right_left_col_ind as _] + .iter() + .map(|(col, alias)| { + ( + Arc::new(Column::new(col.name(), col.index())) as _, + alias.clone(), + ) + }) + .collect_vec(), + left_child.clone(), + )?; + let left_size = left_child.schema().fields().len() as i32; + let new_right = ProjectionExec::try_new( + projection_as_columns[far_left_right_col_ind as _..] + .iter() + .map(|(col, alias)| { + ( + Arc::new(Column::new( + col.name(), + // Align projected expressions coming from the right + // table with the new right child projection: + (col.index() as i32 - left_size) as _, + )) as _, + alias.clone(), + ) + }) + .collect_vec(), + right_child.clone(), + )?; + + Ok((new_left, new_right)) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::datasource::file_format::file_compression_type::FileCompressionType; + use crate::datasource::listing::PartitionedFile; + use crate::datasource::physical_plan::{CsvExec, FileScanConfig}; + use crate::physical_optimizer::output_requirements::OutputRequirementExec; + use crate::physical_optimizer::projection_pushdown::{ + join_table_borders, update_expr, ProjectionPushdown, + }; + use crate::physical_optimizer::PhysicalOptimizerRule; + use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; + use crate::physical_plan::filter::FilterExec; + use crate::physical_plan::joins::utils::{ColumnIndex, JoinFilter}; + use crate::physical_plan::joins::StreamJoinPartitionMode; + use crate::physical_plan::memory::MemoryExec; + use crate::physical_plan::projection::ProjectionExec; + use crate::physical_plan::repartition::RepartitionExec; + use crate::physical_plan::sorts::sort::SortExec; + use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; + use crate::physical_plan::{get_plan_string, ExecutionPlan}; + + use arrow_schema::{DataType, Field, Schema, SchemaRef, SortOptions}; + use datafusion_common::config::ConfigOptions; + use datafusion_common::{JoinSide, JoinType, Result, ScalarValue, Statistics}; + use datafusion_execution::object_store::ObjectStoreUrl; + use datafusion_execution::{SendableRecordBatchStream, TaskContext}; + use datafusion_expr::{ColumnarValue, Operator}; + use datafusion_physical_expr::expressions::{ + BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr, + }; + use datafusion_physical_expr::{ + Distribution, Partitioning, PhysicalExpr, PhysicalSortExpr, + PhysicalSortRequirement, ScalarFunctionExpr, + }; + use datafusion_physical_plan::joins::SymmetricHashJoinExec; + use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; + use datafusion_physical_plan::union::UnionExec; + + use itertools::Itertools; + + #[test] + fn test_update_matching_exprs() -> Result<()> { + let exprs: Vec> = vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 3)), + Operator::Divide, + Arc::new(Column::new("e", 5)), + )), + Arc::new(CastExpr::new( + Arc::new(Column::new("a", 3)), + DataType::Float32, + None, + )), + Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))), + Arc::new(ScalarFunctionExpr::new( + "scalar_expr", + Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Divide, + Arc::new(Column::new("c", 0)), + )), + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 0)), + Operator::Divide, + Arc::new(Column::new("b", 1)), + )), + ], + DataType::Int32, + None, + )), + Arc::new(CaseExpr::try_new( + Some(Arc::new(Column::new("d", 2))), + vec![ + ( + Arc::new(Column::new("a", 3)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("d", 2)), + Operator::Plus, + Arc::new(Column::new("e", 5)), + )) as Arc, + ), + ( + Arc::new(Column::new("a", 3)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("e", 5)), + Operator::Plus, + Arc::new(Column::new("d", 2)), + )) as Arc, + ), + ], + Some(Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 3)), + Operator::Modulo, + Arc::new(Column::new("e", 5)), + ))), + )?), + ]; + let child: Vec<(Arc, String)> = vec![ + (Arc::new(Column::new("c", 2)), "c".to_owned()), + (Arc::new(Column::new("b", 1)), "b".to_owned()), + (Arc::new(Column::new("d", 3)), "d".to_owned()), + (Arc::new(Column::new("a", 0)), "a".to_owned()), + (Arc::new(Column::new("f", 5)), "f".to_owned()), + (Arc::new(Column::new("e", 4)), "e".to_owned()), + ]; + + let expected_exprs: Vec> = vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Divide, + Arc::new(Column::new("e", 4)), + )), + Arc::new(CastExpr::new( + Arc::new(Column::new("a", 0)), + DataType::Float32, + None, + )), + Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 5)))), + Arc::new(ScalarFunctionExpr::new( + "scalar_expr", + Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Divide, + Arc::new(Column::new("c", 2)), + )), + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Divide, + Arc::new(Column::new("b", 1)), + )), + ], + DataType::Int32, + None, + )), + Arc::new(CaseExpr::try_new( + Some(Arc::new(Column::new("d", 3))), + vec![ + ( + Arc::new(Column::new("a", 0)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("d", 3)), + Operator::Plus, + Arc::new(Column::new("e", 4)), + )) as Arc, + ), + ( + Arc::new(Column::new("a", 0)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("e", 4)), + Operator::Plus, + Arc::new(Column::new("d", 3)), + )) as Arc, + ), + ], + Some(Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Modulo, + Arc::new(Column::new("e", 4)), + ))), + )?), + ]; + + for (expr, expected_expr) in exprs.into_iter().zip(expected_exprs.into_iter()) { + assert!(update_expr(&expr, &child, true)? + .unwrap() + .eq(&expected_expr)); + } + + Ok(()) + } + + #[test] + fn test_update_projected_exprs() -> Result<()> { + let exprs: Vec> = vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 3)), + Operator::Divide, + Arc::new(Column::new("e", 5)), + )), + Arc::new(CastExpr::new( + Arc::new(Column::new("a", 3)), + DataType::Float32, + None, + )), + Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))), + Arc::new(ScalarFunctionExpr::new( + "scalar_expr", + Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Divide, + Arc::new(Column::new("c", 0)), + )), + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 0)), + Operator::Divide, + Arc::new(Column::new("b", 1)), + )), + ], + DataType::Int32, + None, + )), + Arc::new(CaseExpr::try_new( + Some(Arc::new(Column::new("d", 2))), + vec![ + ( + Arc::new(Column::new("a", 3)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("d", 2)), + Operator::Plus, + Arc::new(Column::new("e", 5)), + )) as Arc, + ), + ( + Arc::new(Column::new("a", 3)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("e", 5)), + Operator::Plus, + Arc::new(Column::new("d", 2)), + )) as Arc, + ), + ], + Some(Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 3)), + Operator::Modulo, + Arc::new(Column::new("e", 5)), + ))), + )?), + ]; + let projected_exprs: Vec<(Arc, String)> = vec![ + (Arc::new(Column::new("a", 0)), "a".to_owned()), + (Arc::new(Column::new("b", 1)), "b_new".to_owned()), + (Arc::new(Column::new("c", 2)), "c".to_owned()), + (Arc::new(Column::new("d", 3)), "d_new".to_owned()), + (Arc::new(Column::new("e", 4)), "e".to_owned()), + (Arc::new(Column::new("f", 5)), "f_new".to_owned()), + ]; + + let expected_exprs: Vec> = vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Divide, + Arc::new(Column::new("e", 4)), + )), + Arc::new(CastExpr::new( + Arc::new(Column::new("a", 0)), + DataType::Float32, + None, + )), + Arc::new(NegativeExpr::new(Arc::new(Column::new("f_new", 5)))), + Arc::new(ScalarFunctionExpr::new( + "scalar_expr", + Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b_new", 1)), + Operator::Divide, + Arc::new(Column::new("c", 2)), + )), + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Divide, + Arc::new(Column::new("b_new", 1)), + )), + ], + DataType::Int32, + None, + )), + Arc::new(CaseExpr::try_new( + Some(Arc::new(Column::new("d_new", 3))), + vec![ + ( + Arc::new(Column::new("a", 0)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("d_new", 3)), + Operator::Plus, + Arc::new(Column::new("e", 4)), + )) as Arc, + ), + ( + Arc::new(Column::new("a", 0)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("e", 4)), + Operator::Plus, + Arc::new(Column::new("d_new", 3)), + )) as Arc, + ), + ], + Some(Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Modulo, + Arc::new(Column::new("e", 4)), + ))), + )?), + ]; + + for (expr, expected_expr) in exprs.into_iter().zip(expected_exprs.into_iter()) { + assert!(update_expr(&expr, &projected_exprs, false)? + .unwrap() + .eq(&expected_expr)); + } + + Ok(()) + } + + #[test] + fn test_join_table_borders() -> Result<()> { + let projections = vec![ + (Column::new("b", 1), "b".to_owned()), + (Column::new("c", 2), "c".to_owned()), + (Column::new("e", 4), "e".to_owned()), + (Column::new("d", 3), "d".to_owned()), + (Column::new("c", 2), "c".to_owned()), + (Column::new("f", 5), "f".to_owned()), + (Column::new("h", 7), "h".to_owned()), + (Column::new("g", 6), "g".to_owned()), + ]; + let left_table_column_count = 5; + assert_eq!( + join_table_borders(left_table_column_count, &projections), + (4, 5) + ); + + let left_table_column_count = 8; + assert_eq!( + join_table_borders(left_table_column_count, &projections), + (7, 8) + ); + + let left_table_column_count = 1; + assert_eq!( + join_table_borders(left_table_column_count, &projections), + (-1, 0) + ); + + let projections = vec![ + (Column::new("a", 0), "a".to_owned()), + (Column::new("b", 1), "b".to_owned()), + (Column::new("d", 3), "d".to_owned()), + (Column::new("g", 6), "g".to_owned()), + (Column::new("e", 4), "e".to_owned()), + (Column::new("f", 5), "f".to_owned()), + (Column::new("e", 4), "e".to_owned()), + (Column::new("h", 7), "h".to_owned()), + ]; + let left_table_column_count = 5; + assert_eq!( + join_table_borders(left_table_column_count, &projections), + (2, 7) + ); + + let left_table_column_count = 7; + assert_eq!( + join_table_borders(left_table_column_count, &projections), + (6, 7) + ); + + Ok(()) + } + + fn create_simple_csv_exec() -> Arc { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + ])); + Arc::new(CsvExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), + file_schema: schema.clone(), + file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], + statistics: Statistics::new_unknown(&schema), + projection: Some(vec![0, 1, 2, 3, 4]), + limit: None, + table_partition_cols: vec![], + output_ordering: vec![vec![]], + }, + false, + 0, + 0, + None, + FileCompressionType::UNCOMPRESSED, + )) + } + + fn create_projecting_csv_exec() -> Arc { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + ])); + Arc::new(CsvExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), + file_schema: schema.clone(), + file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], + statistics: Statistics::new_unknown(&schema), + projection: Some(vec![3, 2, 1]), + limit: None, + table_partition_cols: vec![], + output_ordering: vec![vec![]], + }, + false, + 0, + 0, + None, + FileCompressionType::UNCOMPRESSED, + )) + } + + fn create_projecting_memory_exec() -> Arc { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + ])); + + Arc::new(MemoryExec::try_new(&[], schema, Some(vec![2, 0, 3, 4])).unwrap()) + } + + #[test] + fn test_csv_after_projection() -> Result<()> { + let csv = create_projecting_csv_exec(); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("b", 2)), "b".to_string()), + (Arc::new(Column::new("d", 0)), "d".to_string()), + ], + csv.clone(), + )?); + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[b@2 as b, d@0 as d]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[d, c, b], has_header=false", + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "CsvExec: file_groups={1 group: [[x]]}, projection=[b, d], has_header=false", + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn test_memory_after_projection() -> Result<()> { + let memory = create_projecting_memory_exec(); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("d", 2)), "d".to_string()), + (Arc::new(Column::new("e", 3)), "e".to_string()), + (Arc::new(Column::new("a", 1)), "a".to_string()), + ], + memory.clone(), + )?); + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[d@2 as d, e@3 as e, a@1 as a]", + " MemoryExec: partitions=0, partition_sizes=[]", + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = ["MemoryExec: partitions=0, partition_sizes=[]"]; + assert_eq!(get_plan_string(&after_optimize), expected); + assert_eq!( + after_optimize + .clone() + .as_any() + .downcast_ref::() + .unwrap() + .projection() + .clone() + .unwrap(), + vec![3, 4, 0] + ); + + Ok(()) + } + + #[test] + fn test_streaming_table_after_projection() -> Result<()> { + struct DummyStreamPartition { + schema: SchemaRef, + } + impl PartitionStream for DummyStreamPartition { + fn schema(&self) -> &SchemaRef { + &self.schema + } + fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { + unreachable!() + } + } + + let streaming_table = StreamingTableExec::try_new( + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + ])), + vec![Arc::new(DummyStreamPartition { + schema: Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + ])), + }) as _], + Some(&vec![0_usize, 2, 4, 3]), + vec![ + vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("e", 2)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions::default(), + }, + ], + vec![PhysicalSortExpr { + expr: Arc::new(Column::new("d", 3)), + options: SortOptions::default(), + }], + ] + .into_iter(), + true, + )?; + let projection = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("d", 3)), "d".to_string()), + (Arc::new(Column::new("e", 2)), "e".to_string()), + (Arc::new(Column::new("a", 0)), "a".to_string()), + ], + Arc::new(streaming_table) as _, + )?) as _; + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let result = after_optimize + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + result.partition_schema(), + &Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + ])) + ); + assert_eq!( + result.projection().clone().unwrap().to_vec(), + vec![3_usize, 4, 0] + ); + assert_eq!( + result.projected_schema(), + &Schema::new(vec![ + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + Field::new("a", DataType::Int32, true), + ]) + ); + assert_eq!( + result.projected_output_ordering().into_iter().collect_vec(), + vec![ + vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("e", 1)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 2)), + options: SortOptions::default(), + }, + ], + vec![PhysicalSortExpr { + expr: Arc::new(Column::new("d", 0)), + options: SortOptions::default(), + }], + ] + ); + assert!(result.is_infinite()); + + Ok(()) + } + + #[test] + fn test_projection_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let child_projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("c", 2)), "c".to_string()), + (Arc::new(Column::new("e", 4)), "new_e".to_string()), + (Arc::new(Column::new("a", 0)), "a".to_string()), + (Arc::new(Column::new("b", 1)), "new_b".to_string()), + ], + csv.clone(), + )?); + let top_projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("new_b", 3)), "new_b".to_string()), + ( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 0)), + Operator::Plus, + Arc::new(Column::new("new_e", 1)), + )), + "binary".to_string(), + ), + (Arc::new(Column::new("new_b", 3)), "newest_b".to_string()), + ], + child_projection.clone(), + )?); + + let initial = get_plan_string(&top_projection); + let expected_initial = [ + "ProjectionExec: expr=[new_b@3 as new_b, c@0 + new_e@1 as binary, new_b@3 as newest_b]", + " ProjectionExec: expr=[c@2 as c, e@4 as new_e, a@0 as a, b@1 as new_b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(top_projection, &ConfigOptions::new())?; + + let expected = [ + "ProjectionExec: expr=[b@1 as new_b, c@2 + e@4 as binary, b@1 as newest_b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn test_output_req_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let sort_req: Arc = Arc::new(OutputRequirementExec::new( + csv.clone(), + Some(vec![ + PhysicalSortRequirement { + expr: Arc::new(Column::new("b", 1)), + options: Some(SortOptions::default()), + }, + PhysicalSortRequirement { + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Plus, + Arc::new(Column::new("a", 0)), + )), + options: Some(SortOptions::default()), + }, + ]), + Distribution::HashPartitioned(vec![ + Arc::new(Column::new("a", 0)), + Arc::new(Column::new("b", 1)), + ]), + )); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("c", 2)), "c".to_string()), + (Arc::new(Column::new("a", 0)), "new_a".to_string()), + (Arc::new(Column::new("b", 1)), "b".to_string()), + ], + sort_req.clone(), + )?); + + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " OutputRequirementExec", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected: [&str; 3] = [ + "OutputRequirementExec", + " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + + assert_eq!(get_plan_string(&after_optimize), expected); + let expected_reqs = vec![ + PhysicalSortRequirement { + expr: Arc::new(Column::new("b", 2)), + options: Some(SortOptions::default()), + }, + PhysicalSortRequirement { + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 0)), + Operator::Plus, + Arc::new(Column::new("new_a", 1)), + )), + options: Some(SortOptions::default()), + }, + ]; + assert_eq!( + after_optimize + .as_any() + .downcast_ref::() + .unwrap() + .required_input_ordering()[0] + .clone() + .unwrap(), + expected_reqs + ); + let expected_distribution: Vec> = vec![ + Arc::new(Column::new("new_a", 1)), + Arc::new(Column::new("b", 2)), + ]; + if let Distribution::HashPartitioned(vec) = after_optimize + .as_any() + .downcast_ref::() + .unwrap() + .required_input_distribution()[0] + .clone() + { + assert!(vec + .iter() + .zip(expected_distribution) + .all(|(actual, expected)| actual.eq(&expected))); + } else { + panic!("Expected HashPartitioned distribution!"); + }; + + Ok(()) + } + + #[test] + fn test_coalesce_partitions_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let coalesce_partitions: Arc = + Arc::new(CoalescePartitionsExec::new(csv)); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("b", 1)), "b".to_string()), + (Arc::new(Column::new("a", 0)), "a_new".to_string()), + (Arc::new(Column::new("d", 3)), "d".to_string()), + ], + coalesce_partitions, + )?); + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[b@1 as b, a@0 as a_new, d@3 as d]", + " CoalescePartitionsExec", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "CoalescePartitionsExec", + " ProjectionExec: expr=[b@1 as b, a@0 as a_new, d@3 as d]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn test_filter_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let predicate = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Minus, + Arc::new(Column::new("a", 0)), + )), + Operator::Gt, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("d", 3)), + Operator::Minus, + Arc::new(Column::new("a", 0)), + )), + )); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, csv)?); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("a", 0)), "a_new".to_string()), + (Arc::new(Column::new("b", 1)), "b".to_string()), + (Arc::new(Column::new("d", 3)), "d".to_string()), + ], + filter.clone(), + )?); + + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[a@0 as a_new, b@1 as b, d@3 as d]", + " FilterExec: b@1 - a@0 > d@3 - a@0", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "FilterExec: b@1 - a_new@0 > d@2 - a_new@0", + " ProjectionExec: expr=[a@0 as a_new, b@1 as b, d@3 as d]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn test_join_after_projection() -> Result<()> { + let left_csv = create_simple_csv_exec(); + let right_csv = create_simple_csv_exec(); + + let join: Arc = Arc::new(SymmetricHashJoinExec::try_new( + left_csv, + right_csv, + vec![(Column::new("b", 1), Column::new("c", 2))], + // b_left-(1+a_right)<=a_right+c_left + Some(JoinFilter::new( + Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b_left_inter", 0)), + Operator::Minus, + Arc::new(BinaryExpr::new( + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + Operator::Plus, + Arc::new(Column::new("a_right_inter", 1)), + )), + )), + Operator::LtEq, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a_right_inter", 1)), + Operator::Plus, + Arc::new(Column::new("c_left_inter", 2)), + )), + )), + vec![ + ColumnIndex { + index: 1, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ColumnIndex { + index: 2, + side: JoinSide::Left, + }, + ], + Schema::new(vec![ + Field::new("b_left_inter", DataType::Int32, true), + Field::new("a_right_inter", DataType::Int32, true), + Field::new("c_left_inter", DataType::Int32, true), + ]), + )), + &JoinType::Inner, + true, + StreamJoinPartitionMode::SinglePartition, + )?); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("c", 2)), "c_from_left".to_string()), + (Arc::new(Column::new("b", 1)), "b_from_left".to_string()), + (Arc::new(Column::new("a", 0)), "a_from_left".to_string()), + (Arc::new(Column::new("a", 5)), "a_from_right".to_string()), + (Arc::new(Column::new("c", 7)), "c_from_right".to_string()), + ], + join, + )?); + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, a@5 as a_from_right, c@7 as c_from_right]", + " SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b_from_left@1, c_from_right@1)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", + " ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " ProjectionExec: expr=[a@0 as a_from_right, c@2 as c_from_right]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + let expected_filter_col_ind = vec![ + ColumnIndex { + index: 1, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ]; + + assert_eq!( + expected_filter_col_ind, + after_optimize + .as_any() + .downcast_ref::() + .unwrap() + .filter() + .unwrap() + .column_indices() + ); + + Ok(()) + } + + #[test] + fn test_repartition_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let repartition: Arc = Arc::new(RepartitionExec::try_new( + csv, + Partitioning::Hash( + vec![ + Arc::new(Column::new("a", 0)), + Arc::new(Column::new("b", 1)), + Arc::new(Column::new("d", 3)), + ], + 6, + ), + )?); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("b", 1)), "b_new".to_string()), + (Arc::new(Column::new("a", 0)), "a".to_string()), + (Arc::new(Column::new("d", 3)), "d_new".to_string()), + ], + repartition, + )?); + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[b@1 as b_new, a@0 as a, d@3 as d_new]", + " RepartitionExec: partitioning=Hash([a@0, b@1, d@3], 6), input_partitions=1", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "RepartitionExec: partitioning=Hash([a@1, b_new@0, d_new@2], 6), input_partitions=1", + " ProjectionExec: expr=[b@1 as b_new, a@0 as a, d@3 as d_new]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + assert_eq!( + after_optimize + .as_any() + .downcast_ref::() + .unwrap() + .partitioning() + .clone(), + Partitioning::Hash( + vec![ + Arc::new(Column::new("a", 1)), + Arc::new(Column::new("b_new", 0)), + Arc::new(Column::new("d_new", 2)), + ], + 6, + ), + ); + + Ok(()) + } + + #[test] + fn test_sort_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let sort_req: Arc = Arc::new(SortExec::new( + vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Plus, + Arc::new(Column::new("a", 0)), + )), + options: SortOptions::default(), + }, + ], + csv.clone(), + )); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("c", 2)), "c".to_string()), + (Arc::new(Column::new("a", 0)), "new_a".to_string()), + (Arc::new(Column::new("b", 1)), "b".to_string()), + ], + sort_req.clone(), + )?); + + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " SortExec: expr=[b@1 ASC,c@2 + a@0 ASC]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "SortExec: expr=[b@2 ASC,c@0 + new_a@1 ASC]", + " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn test_sort_preserving_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let sort_req: Arc = Arc::new(SortPreservingMergeExec::new( + vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Plus, + Arc::new(Column::new("a", 0)), + )), + options: SortOptions::default(), + }, + ], + csv.clone(), + )); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("c", 2)), "c".to_string()), + (Arc::new(Column::new("a", 0)), "new_a".to_string()), + (Arc::new(Column::new("b", 1)), "b".to_string()), + ], + sort_req.clone(), + )?); + + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " SortPreservingMergeExec: [b@1 ASC,c@2 + a@0 ASC]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "SortPreservingMergeExec: [b@2 ASC,c@0 + new_a@1 ASC]", + " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn test_union_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let union: Arc = + Arc::new(UnionExec::new(vec![csv.clone(), csv.clone(), csv])); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("c", 2)), "c".to_string()), + (Arc::new(Column::new("a", 0)), "new_a".to_string()), + (Arc::new(Column::new("b", 1)), "b".to_string()), + ], + union.clone(), + )?); + + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " UnionExec", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "UnionExec", + " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } +} diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index de508327fade7..06cfc72824688 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -35,12 +35,13 @@ use arrow::{ datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::RecordBatch, }; -use datafusion_common::{downcast_value, plan_datafusion_err, ScalarValue}; +use arrow_array::cast::AsArray; use datafusion_common::{ internal_err, plan_err, tree_node::{Transformed, TreeNode}, }; -use datafusion_physical_expr::utils::collect_columns; +use datafusion_common::{plan_datafusion_err, ScalarValue}; +use datafusion_physical_expr::utils::{collect_columns, Guarantee, LiteralGuarantee}; use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; use log::trace; @@ -66,43 +67,81 @@ use log::trace; /// min_values("X") -> None /// ``` pub trait PruningStatistics { - /// return the minimum values for the named column, if known. - /// Note: the returned array must contain `num_containers()` rows + /// Return the minimum values for the named column, if known. + /// + /// If the minimum value for a particular container is not known, the + /// returned array should have `null` in that row. If the minimum value is + /// not known for any row, return `None`. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows fn min_values(&self, column: &Column) -> Option; - /// return the maximum values for the named column, if known. - /// Note: the returned array must contain `num_containers()` rows. + /// Return the maximum values for the named column, if known. + /// + /// See [`Self::min_values`] for when to return `None` and null values. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows fn max_values(&self, column: &Column) -> Option; - /// return the number of containers (e.g. row groups) being - /// pruned with these statistics + /// Return the number of containers (e.g. row groups) being + /// pruned with these statistics (the number of rows in each returned array) fn num_containers(&self) -> usize; - /// return the number of null values for the named column as an + /// Return the number of null values for the named column as an /// `Option`. /// - /// Note: the returned array must contain `num_containers()` rows. + /// See [`Self::min_values`] for when to return `None` and null values. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows fn null_counts(&self, column: &Column) -> Option; + + /// Returns an array where each row represents information known about + /// the `values` contained in a column. + /// + /// This API is designed to be used along with [`LiteralGuarantee`] to prove + /// that predicates can not possibly evaluate to `true` and thus prune + /// containers. For example, Parquet Bloom Filters can prove that values are + /// not present. + /// + /// The returned array has one row for each container, with the following + /// meanings: + /// * `true` if the values in `column` ONLY contain values from `values` + /// * `false` if the values in `column` are NOT ANY of `values` + /// * `null` if the neither of the above holds or is unknown. + /// + /// If these statistics can not determine column membership for any + /// container, return `None` (the default). + /// + /// Note: the returned array must contain [`Self::num_containers`] rows + fn contained( + &self, + column: &Column, + values: &HashSet, + ) -> Option; } -/// Evaluates filter expressions on statistics, rather than the actual data. If -/// no rows could possibly pass the filter entire containers can be "pruned" -/// (skipped), without reading any actual data, leading to significant +/// Evaluates filter expressions on statistics such as min/max values and null +/// counts, attempting to prove a "container" (e.g. Parquet Row Group) can be +/// skipped without reading the actual data, potentially leading to significant /// performance improvements. /// -/// [`PruningPredicate`]s are used to prune (avoid scanning) Parquet Row Groups +/// For example, [`PruningPredicate`]s are used to prune Parquet Row Groups /// based on the min/max values found in the Parquet metadata. If the /// `PruningPredicate` can guarantee that no rows in the Row Group match the /// filter, the entire Row Group is skipped during query execution. /// -/// Note that this API is designed to be general, as it works: +/// The `PruningPredicate` API is general, allowing it to be used for pruning +/// other types of containers (e.g. files) based on statistics that may be +/// known from external catalogs (e.g. Delta Lake) or other sources. Thus it +/// supports: /// /// 1. Arbitrary expressions expressions (including user defined functions) /// -/// 2. Anything that implements the [`PruningStatistics`] trait, not just -/// Parquet metadata, allowing it to be used by other systems to prune entities -/// (e.g. entire files) if the statistics are known via some other source, such -/// as a catalog. +/// 2. Vectorized evaluation (provide more than one set of statistics at a time) +/// so it is suitable for pruning 1000s of containers. +/// +/// 3. Anything that implements the [`PruningStatistics`] trait, not just +/// Parquet metadata. /// /// # Example /// @@ -122,17 +161,23 @@ pub trait PruningStatistics { /// B: true (rows might match x = 5) /// C: true (rows might match x = 5) /// ``` +/// /// See [`PruningPredicate::try_new`] and [`PruningPredicate::prune`] for more information. #[derive(Debug, Clone)] pub struct PruningPredicate { /// The input schema against which the predicate will be evaluated schema: SchemaRef, - /// Actual pruning predicate (rewritten in terms of column min/max statistics) + /// A min/max pruning predicate (rewritten in terms of column min/max + /// values, which are supplied by statistics) predicate_expr: Arc, - /// The statistics required to evaluate this predicate - required_columns: RequiredStatColumns, - /// Original physical predicate from which this predicate expr is derived (required for serialization) + /// Description of which statistics are required to evaluate `predicate_expr` + required_columns: RequiredColumns, + /// Original physical predicate from which this predicate expr is derived + /// (required for serialization) orig_expr: Arc, + /// [`LiteralGuarantee`]s that are used to try and prove a predicate can not + /// possibly evaluate to `true`. + literal_guarantees: Vec, } impl PruningPredicate { @@ -157,14 +202,18 @@ impl PruningPredicate { /// `(column_min / 2) <= 4 && 4 <= (column_max / 2))` pub fn try_new(expr: Arc, schema: SchemaRef) -> Result { // build predicate expression once - let mut required_columns = RequiredStatColumns::new(); + let mut required_columns = RequiredColumns::new(); let predicate_expr = build_predicate_expression(&expr, schema.as_ref(), &mut required_columns); + + let literal_guarantees = LiteralGuarantee::analyze(&expr); + Ok(Self { schema, predicate_expr, required_columns, orig_expr: expr, + literal_guarantees, }) } @@ -183,40 +232,52 @@ impl PruningPredicate { /// /// [`ExprSimplifier`]: crate::optimizer::simplify_expressions::ExprSimplifier pub fn prune(&self, statistics: &S) -> Result> { + let mut builder = BoolVecBuilder::new(statistics.num_containers()); + + // Try to prove the predicate can't be true for the containers based on + // literal guarantees + for literal_guarantee in &self.literal_guarantees { + let LiteralGuarantee { + column, + guarantee, + literals, + } = literal_guarantee; + if let Some(results) = statistics.contained(column, literals) { + match guarantee { + // `In` means the values in the column must be one of the + // values in the set for the predicate to evaluate to true. + // If `contained` returns false, that means the column is + // not any of the values so we can prune the container + Guarantee::In => builder.combine_array(&results), + // `NotIn` means the values in the column must must not be + // any of the values in the set for the predicate to + // evaluate to true. If contained returns true, it means the + // column is only in the set of values so we can prune the + // container + Guarantee::NotIn => { + builder.combine_array(&arrow::compute::not(&results)?) + } + } + // if all containers are pruned (has rows that DEFINITELY DO NOT pass the predicate) + // can return early without evaluating the rest of predicates. + if builder.check_all_pruned() { + return Ok(builder.build()); + } + } + } + + // Next, try to prove the predicate can't be true for the containers based + // on min/max values + // build a RecordBatch that contains the min/max values in the - // appropriate statistics columns + // appropriate statistics columns for the min/max predicate let statistics_batch = build_statistics_record_batch(statistics, &self.required_columns)?; - // Evaluate the pruning predicate on that record batch. - // - // Use true when the result of evaluating a predicate - // expression on a row group is null (aka `None`). Null can - // arise when the statistics are unknown or some calculation - // in the predicate means we don't know for sure if the row - // group can be filtered out or not. To maintain correctness - // the row group must be kept and thus `true` is returned. - match self.predicate_expr.evaluate(&statistics_batch)? { - ColumnarValue::Array(array) => { - let predicate_array = downcast_value!(array, BooleanArray); + // Evaluate the pruning predicate on that record batch and append any results to the builder + builder.combine_value(self.predicate_expr.evaluate(&statistics_batch)?); - Ok(predicate_array - .into_iter() - .map(|x| x.unwrap_or(true)) // None -> true per comments above - .collect::>()) - } - // result was a column - ColumnarValue::Scalar(ScalarValue::Boolean(v)) => { - let v = v.unwrap_or(true); // None -> true per comments above - Ok(vec![v; statistics.num_containers()]) - } - other => { - internal_err!( - "Unexpected result of pruning predicate evaluation. Expected Boolean array \ - or scalar but got {other:?}" - ) - } - } + Ok(builder.build()) } /// Return a reference to the input schema @@ -234,14 +295,104 @@ impl PruningPredicate { &self.predicate_expr } - /// Returns true if this pruning predicate is "always true" (aka will not prune anything) + /// Returns true if this pruning predicate can not prune anything. + /// + /// This happens if the predicate is a literal `true` and + /// literal_guarantees is empty. pub fn allways_true(&self) -> bool { - is_always_true(&self.predicate_expr) + is_always_true(&self.predicate_expr) && self.literal_guarantees.is_empty() } - pub(crate) fn required_columns(&self) -> &RequiredStatColumns { + pub(crate) fn required_columns(&self) -> &RequiredColumns { &self.required_columns } + + /// Names of the columns that are known to be / not be in a set + /// of literals (constants). These are the columns the that may be passed to + /// [`PruningStatistics::contained`] during pruning. + /// + /// This is useful to avoid fetching statistics for columns that will not be + /// used in the predicate. For example, it can be used to avoid reading + /// uneeded bloom filters (a non trivial operation). + pub fn literal_columns(&self) -> Vec { + let mut seen = HashSet::new(); + self.literal_guarantees + .iter() + .map(|e| &e.column.name) + // avoid duplicates + .filter(|name| seen.insert(*name)) + .map(|s| s.to_string()) + .collect() + } +} + +/// Builds the return `Vec` for [`PruningPredicate::prune`]. +#[derive(Debug)] +struct BoolVecBuilder { + /// One element per container. Each element is + /// * `true`: if the container has row that may pass the predicate + /// * `false`: if the container has rows that DEFINITELY DO NOT pass the predicate + inner: Vec, +} + +impl BoolVecBuilder { + /// Create a new `BoolVecBuilder` with `num_containers` elements + fn new(num_containers: usize) -> Self { + Self { + // assume by default all containers may pass the predicate + inner: vec![true; num_containers], + } + } + + /// Combines result `array` for a conjunct (e.g. `AND` clause) of a + /// predicate into the currently in progress array. + /// + /// Each `array` element is: + /// * `true`: container has row that may pass the predicate + /// * `false`: all container rows DEFINITELY DO NOT pass the predicate + /// * `null`: container may or may not have rows that pass the predicate + fn combine_array(&mut self, array: &BooleanArray) { + assert_eq!(array.len(), self.inner.len()); + for (cur, new) in self.inner.iter_mut().zip(array.iter()) { + // `false` for this conjunct means we know for sure no rows could + // pass the predicate and thus we set the corresponding container + // location to false. + if let Some(false) = new { + *cur = false; + } + } + } + + /// Combines the results in the [`ColumnarValue`] to the currently in + /// progress array, following the same rules as [`Self::combine_array`]. + /// + /// # Panics + /// If `value` is not boolean + fn combine_value(&mut self, value: ColumnarValue) { + match value { + ColumnarValue::Array(array) => { + self.combine_array(array.as_boolean()); + } + ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))) => { + // False means all containers can not pass the predicate + self.inner = vec![false; self.inner.len()]; + } + _ => { + // Null or true means the rows in container may pass this + // conjunct so we can't prune any containers based on that + } + } + } + + /// Convert this builder into a Vec of bools + fn build(self) -> Vec { + self.inner + } + + /// Check all containers has rows that DEFINITELY DO NOT pass the predicate + fn check_all_pruned(&self) -> bool { + self.inner.iter().all(|&x| !x) + } } fn is_always_true(expr: &Arc) -> bool { @@ -251,27 +402,31 @@ fn is_always_true(expr: &Arc) -> bool { .unwrap_or_default() } -/// Records for which columns statistics are necessary to evaluate a -/// pruning predicate. +/// Describes which columns statistics are necessary to evaluate a +/// [`PruningPredicate`]. +/// +/// This structure permits reading and creating the minimum number statistics, +/// which is important since statistics may be non trivial to read (e.g. large +/// strings or when there are 1000s of columns). /// /// Handles creating references to the min/max statistics /// for columns as well as recording which statistics are needed #[derive(Debug, Default, Clone)] -pub(crate) struct RequiredStatColumns { +pub(crate) struct RequiredColumns { /// The statistics required to evaluate this predicate: /// * The unqualified column in the input schema /// * Statistics type (e.g. Min or Max or Null_Count) /// * The field the statistics value should be placed in for - /// pruning predicate evaluation + /// pruning predicate evaluation (e.g. `min_value` or `max_value`) columns: Vec<(phys_expr::Column, StatisticsType, Field)>, } -impl RequiredStatColumns { +impl RequiredColumns { fn new() -> Self { Self::default() } - /// Returns number of unique columns. + /// Returns number of unique columns pub(crate) fn n_columns(&self) -> usize { self.iter() .map(|(c, _s, _f)| c) @@ -325,11 +480,10 @@ impl RequiredStatColumns { // only add statistics column if not previously added if need_to_insert { - let stat_field = Field::new( - stat_column.name(), - field.data_type().clone(), - field.is_nullable(), - ); + // may be null if statistics are not present + let nullable = true; + let stat_field = + Field::new(stat_column.name(), field.data_type().clone(), nullable); self.columns.push((column.clone(), stat_type, stat_field)); } rewrite_column_expr(column_expr.clone(), column, &stat_column) @@ -372,7 +526,7 @@ impl RequiredStatColumns { } } -impl From> for RequiredStatColumns { +impl From> for RequiredColumns { fn from(columns: Vec<(phys_expr::Column, StatisticsType, Field)>) -> Self { Self { columns } } @@ -405,7 +559,7 @@ impl From> for RequiredStatColum /// ``` fn build_statistics_record_batch( statistics: &S, - required_columns: &RequiredStatColumns, + required_columns: &RequiredColumns, ) -> Result { let mut fields = Vec::::new(); let mut arrays = Vec::::new(); @@ -461,7 +615,7 @@ struct PruningExpressionBuilder<'a> { op: Operator, scalar_expr: Arc, field: &'a Field, - required_columns: &'a mut RequiredStatColumns, + required_columns: &'a mut RequiredColumns, } impl<'a> PruningExpressionBuilder<'a> { @@ -470,7 +624,7 @@ impl<'a> PruningExpressionBuilder<'a> { right: &'a Arc, op: Operator, schema: &'a Schema, - required_columns: &'a mut RequiredStatColumns, + required_columns: &'a mut RequiredColumns, ) -> Result { // find column name; input could be a more complicated expression let left_columns = collect_columns(left); @@ -685,7 +839,7 @@ fn reverse_operator(op: Operator) -> Result { fn build_single_column_expr( column: &phys_expr::Column, schema: &Schema, - required_columns: &mut RequiredStatColumns, + required_columns: &mut RequiredColumns, is_not: bool, // if true, treat as !col ) -> Option> { let field = schema.field_with_name(column.name()).ok()?; @@ -726,7 +880,7 @@ fn build_single_column_expr( fn build_is_null_column_expr( expr: &Arc, schema: &Schema, - required_columns: &mut RequiredStatColumns, + required_columns: &mut RequiredColumns, ) -> Option> { if let Some(col) = expr.as_any().downcast_ref::() { let field = schema.field_with_name(col.name()).ok()?; @@ -756,7 +910,7 @@ fn build_is_null_column_expr( fn build_predicate_expression( expr: &Arc, schema: &Schema, - required_columns: &mut RequiredStatColumns, + required_columns: &mut RequiredColumns, ) -> Arc { // Returned for unsupported expressions. Such expressions are // converted to TRUE. @@ -965,7 +1119,7 @@ mod tests { use std::collections::HashMap; use std::ops::{Not, Rem}; - #[derive(Debug)] + #[derive(Debug, Default)] /// Mock statistic provider for tests /// /// Each row represents the statistics for a "container" (which @@ -974,95 +1128,142 @@ mod tests { /// /// Note All `ArrayRefs` must be the same size. struct ContainerStats { - min: ArrayRef, - max: ArrayRef, + min: Option, + max: Option, /// Optional values null_counts: Option, + /// Optional known values (e.g. mimic a bloom filter) + /// (value, contained) + /// If present, all BooleanArrays must be the same size as min/max + contained: Vec<(HashSet, BooleanArray)>, } impl ContainerStats { + fn new() -> Self { + Default::default() + } fn new_decimal128( min: impl IntoIterator>, max: impl IntoIterator>, precision: u8, scale: i8, ) -> Self { - Self { - min: Arc::new( + Self::new() + .with_min(Arc::new( min.into_iter() .collect::() .with_precision_and_scale(precision, scale) .unwrap(), - ), - max: Arc::new( + )) + .with_max(Arc::new( max.into_iter() .collect::() .with_precision_and_scale(precision, scale) .unwrap(), - ), - null_counts: None, - } + )) } fn new_i64( min: impl IntoIterator>, max: impl IntoIterator>, ) -> Self { - Self { - min: Arc::new(min.into_iter().collect::()), - max: Arc::new(max.into_iter().collect::()), - null_counts: None, - } + Self::new() + .with_min(Arc::new(min.into_iter().collect::())) + .with_max(Arc::new(max.into_iter().collect::())) } fn new_i32( min: impl IntoIterator>, max: impl IntoIterator>, ) -> Self { - Self { - min: Arc::new(min.into_iter().collect::()), - max: Arc::new(max.into_iter().collect::()), - null_counts: None, - } + Self::new() + .with_min(Arc::new(min.into_iter().collect::())) + .with_max(Arc::new(max.into_iter().collect::())) } fn new_utf8<'a>( min: impl IntoIterator>, max: impl IntoIterator>, ) -> Self { - Self { - min: Arc::new(min.into_iter().collect::()), - max: Arc::new(max.into_iter().collect::()), - null_counts: None, - } + Self::new() + .with_min(Arc::new(min.into_iter().collect::())) + .with_max(Arc::new(max.into_iter().collect::())) } fn new_bool( min: impl IntoIterator>, max: impl IntoIterator>, ) -> Self { - Self { - min: Arc::new(min.into_iter().collect::()), - max: Arc::new(max.into_iter().collect::()), - null_counts: None, - } + Self::new() + .with_min(Arc::new(min.into_iter().collect::())) + .with_max(Arc::new(max.into_iter().collect::())) } fn min(&self) -> Option { - Some(self.min.clone()) + self.min.clone() } fn max(&self) -> Option { - Some(self.max.clone()) + self.max.clone() } fn null_counts(&self) -> Option { self.null_counts.clone() } + /// return an iterator over all arrays in this statistics + fn arrays(&self) -> Vec { + let contained_arrays = self + .contained + .iter() + .map(|(_values, contained)| Arc::new(contained.clone()) as ArrayRef); + + [ + self.min.as_ref().cloned(), + self.max.as_ref().cloned(), + self.null_counts.as_ref().cloned(), + ] + .into_iter() + .flatten() + .chain(contained_arrays) + .collect() + } + + /// Returns the number of containers represented by this statistics This + /// picks the length of the first array as all arrays must have the same + /// length (which is verified by `assert_invariants`). fn len(&self) -> usize { - assert_eq!(self.min.len(), self.max.len()); - self.min.len() + // pick the first non zero length + self.arrays().iter().map(|a| a.len()).next().unwrap_or(0) + } + + /// Ensure that the lengths of all arrays are consistent + fn assert_invariants(&self) { + let mut prev_len = None; + + for len in self.arrays().iter().map(|a| a.len()) { + // Get a length, if we don't already have one + match prev_len { + None => { + prev_len = Some(len); + } + Some(prev_len) => { + assert_eq!(prev_len, len); + } + } + } + } + + /// Add min values + fn with_min(mut self, min: ArrayRef) -> Self { + self.min = Some(min); + self + } + + /// Add max values + fn with_max(mut self, max: ArrayRef) -> Self { + self.max = Some(max); + self } /// Add null counts. There must be the same number of null counts as @@ -1071,14 +1272,36 @@ mod tests { mut self, counts: impl IntoIterator>, ) -> Self { - // take stats out and update them let null_counts: ArrayRef = Arc::new(counts.into_iter().collect::()); - assert_eq!(null_counts.len(), self.len()); + self.assert_invariants(); self.null_counts = Some(null_counts); self } + + /// Add contained information. + pub fn with_contained( + mut self, + values: impl IntoIterator, + contained: impl IntoIterator>, + ) -> Self { + let contained: BooleanArray = contained.into_iter().collect(); + let values: HashSet<_> = values.into_iter().collect(); + + self.contained.push((values, contained)); + self.assert_invariants(); + self + } + + /// get any contained information for the specified values + fn contained(&self, find_values: &HashSet) -> Option { + // find the one with the matching values + self.contained + .iter() + .find(|(values, _contained)| values == find_values) + .map(|(_values, contained)| contained.clone()) + } } #[derive(Debug, Default)] @@ -1116,13 +1339,34 @@ mod tests { let container_stats = self .stats .remove(&col) - .expect("Can not find stats for column") + .unwrap_or_default() .with_null_counts(counts); // put stats back in self.stats.insert(col, container_stats); self } + + /// Add contained information for the specified columm. + fn with_contained( + mut self, + name: impl Into, + values: impl IntoIterator, + contained: impl IntoIterator>, + ) -> Self { + let col = Column::from_name(name.into()); + + // take stats out and update them + let container_stats = self + .stats + .remove(&col) + .unwrap_or_default() + .with_contained(values, contained); + + // put stats back in + self.stats.insert(col, container_stats); + self + } } impl PruningStatistics for TestStatistics { @@ -1154,6 +1398,16 @@ mod tests { .map(|container_stats| container_stats.null_counts()) .unwrap_or(None) } + + fn contained( + &self, + column: &Column, + values: &HashSet, + ) -> Option { + self.stats + .get(column) + .and_then(|container_stats| container_stats.contained(values)) + } } /// Returns the specified min/max container values @@ -1179,12 +1433,20 @@ mod tests { fn null_counts(&self, _column: &Column) -> Option { None } + + fn contained( + &self, + _column: &Column, + _values: &HashSet, + ) -> Option { + None + } } #[test] fn test_build_statistics_record_batch() { // Request a record batch with of s1_min, s2_max, s3_max, s3_min - let required_columns = RequiredStatColumns::from(vec![ + let required_columns = RequiredColumns::from(vec![ // min of original column s1, named s1_min ( phys_expr::Column::new("s1", 1), @@ -1256,7 +1518,7 @@ mod tests { // which is what Parquet does // Request a record batch with of s1_min as a timestamp - let required_columns = RequiredStatColumns::from(vec![( + let required_columns = RequiredColumns::from(vec![( phys_expr::Column::new("s3", 3), StatisticsType::Min, Field::new( @@ -1288,7 +1550,7 @@ mod tests { #[test] fn test_build_statistics_no_required_stats() { - let required_columns = RequiredStatColumns::new(); + let required_columns = RequiredColumns::new(); let statistics = OneContainerStats { min_values: Some(Arc::new(Int64Array::from(vec![Some(10)]))), @@ -1306,7 +1568,7 @@ mod tests { // Test requesting a Utf8 column when the stats return some other type // Request a record batch with of s1_min as a timestamp - let required_columns = RequiredStatColumns::from(vec![( + let required_columns = RequiredColumns::from(vec![( phys_expr::Column::new("s3", 3), StatisticsType::Min, Field::new("s1_min", DataType::Utf8, true), @@ -1335,7 +1597,7 @@ mod tests { #[test] fn test_build_statistics_inconsistent_length() { // return an inconsistent length to the actual statistics arrays - let required_columns = RequiredStatColumns::from(vec![( + let required_columns = RequiredColumns::from(vec![( phys_expr::Column::new("s1", 3), StatisticsType::Min, Field::new("s1_min", DataType::Int64, true), @@ -1366,20 +1628,14 @@ mod tests { // test column on the left let expr = col("c1").eq(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).eq(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1392,20 +1648,14 @@ mod tests { // test column on the left let expr = col("c1").not_eq(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).not_eq(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1418,20 +1668,14 @@ mod tests { // test column on the left let expr = col("c1").gt(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).lt(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1444,19 +1688,13 @@ mod tests { // test column on the left let expr = col("c1").gt_eq(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).lt_eq(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1469,20 +1707,14 @@ mod tests { // test column on the left let expr = col("c1").lt(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).gt(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1495,19 +1727,13 @@ mod tests { // test column on the left let expr = col("c1").lt_eq(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).gt_eq(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1523,11 +1749,8 @@ mod tests { // test AND operator joining supported c1 < 1 expression and unsupported c2 > c3 expression let expr = col("c1").lt(lit(1)).and(col("c2").lt(col("c3"))); let expected_expr = "c1_min@0 < 1"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1542,11 +1765,8 @@ mod tests { // test OR operator joining supported c1 < 1 expression and unsupported c2 % 2 = 0 expression let expr = col("c1").lt(lit(1)).or(col("c2").rem(lit(2)).eq(lit(0))); let expected_expr = "true"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1558,11 +1778,8 @@ mod tests { let expected_expr = "true"; let expr = col("c1").not(); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1574,11 +1791,8 @@ mod tests { let expected_expr = "NOT c1_min@0 AND c1_max@1"; let expr = col("c1").not(); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1590,11 +1804,8 @@ mod tests { let expected_expr = "c1_min@0 OR c1_max@1"; let expr = col("c1"); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1608,11 +1819,8 @@ mod tests { // DF doesn't support arithmetic on boolean columns so // this predicate will error when evaluated let expr = col("c1").lt(lit(true)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1624,7 +1832,7 @@ mod tests { Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Int32, false), ]); - let mut required_columns = RequiredStatColumns::new(); + let mut required_columns = RequiredColumns::new(); // c1 < 1 and (c2 = 2 or c2 = 3) let expr = col("c1") .lt(lit(1)) @@ -1640,7 +1848,7 @@ mod tests { ( phys_expr::Column::new("c1", 0), StatisticsType::Min, - c1_min_field + c1_min_field.with_nullable(true) // could be nullable if stats are not present ) ); // c2 = 2 should add c2_min and c2_max @@ -1650,7 +1858,7 @@ mod tests { ( phys_expr::Column::new("c2", 1), StatisticsType::Min, - c2_min_field + c2_min_field.with_nullable(true) // could be nullable if stats are not present ) ); let c2_max_field = Field::new("c2_max", DataType::Int32, false); @@ -1659,7 +1867,7 @@ mod tests { ( phys_expr::Column::new("c2", 1), StatisticsType::Max, - c2_max_field + c2_max_field.with_nullable(true) // could be nullable if stats are not present ) ); // c2 = 3 shouldn't add any new statistics fields @@ -1681,11 +1889,8 @@ mod tests { false, )); let expected_expr = "c1_min@0 <= 1 AND 1 <= c1_max@1 OR c1_min@0 <= 2 AND 2 <= c1_max@1 OR c1_min@0 <= 3 AND 3 <= c1_max@1"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1700,11 +1905,8 @@ mod tests { // test c1 in() let expr = Expr::InList(InList::new(Box::new(col("c1")), vec![], false)); let expected_expr = "true"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1725,11 +1927,8 @@ mod tests { let expected_expr = "(c1_min@0 != 1 OR 1 != c1_max@1) \ AND (c1_min@0 != 2 OR 2 != c1_max@1) \ AND (c1_min@0 != 3 OR 3 != c1_max@1)"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1743,20 +1942,14 @@ mod tests { // test column on the left let expr = cast(col("c1"), DataType::Int64).eq(lit(ScalarValue::Int64(Some(1)))); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(ScalarValue::Int64(Some(1))).eq(cast(col("c1"), DataType::Int64)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); let expected_expr = "TRY_CAST(c1_max@0 AS Int64) > 1"; @@ -1764,21 +1957,15 @@ mod tests { // test column on the left let expr = try_cast(col("c1"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(1)))); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(ScalarValue::Int64(Some(1))).lt(try_cast(col("c1"), DataType::Int64)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1798,11 +1985,8 @@ mod tests { false, )); let expected_expr = "CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64) OR CAST(c1_min@0 AS Int64) <= 2 AND 2 <= CAST(c1_max@1 AS Int64) OR CAST(c1_min@0 AS Int64) <= 3 AND 3 <= CAST(c1_max@1 AS Int64)"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); let expr = Expr::InList(InList::new( @@ -1818,11 +2002,8 @@ mod tests { "(CAST(c1_min@0 AS Int64) != 1 OR 1 != CAST(c1_max@1 AS Int64)) \ AND (CAST(c1_min@0 AS Int64) != 2 OR 2 != CAST(c1_max@1 AS Int64)) \ AND (CAST(c1_min@0 AS Int64) != 3 OR 3 != CAST(c1_max@1 AS Int64))"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -2465,10 +2646,464 @@ mod tests { // TODO: add other negative test for other case and op } + #[test] + fn prune_with_contained_one_column() { + let schema = Arc::new(Schema::new(vec![Field::new("s1", DataType::Utf8, true)])); + + // Model having information like a bloom filter for s1 + let statistics = TestStatistics::new() + .with_contained( + "s1", + [ScalarValue::from("foo")], + [ + // container 0 known to only contain "foo"", + Some(true), + // container 1 known to not contain "foo" + Some(false), + // container 2 unknown about "foo" + None, + // container 3 known to only contain "foo" + Some(true), + // container 4 known to not contain "foo" + Some(false), + // container 5 unknown about "foo" + None, + // container 6 known to only contain "foo" + Some(true), + // container 7 known to not contain "foo" + Some(false), + // container 8 unknown about "foo" + None, + ], + ) + .with_contained( + "s1", + [ScalarValue::from("bar")], + [ + // containers 0,1,2 known to only contain "bar" + Some(true), + Some(true), + Some(true), + // container 3,4,5 known to not contain "bar" + Some(false), + Some(false), + Some(false), + // container 6,7,8 unknown about "bar" + None, + None, + None, + ], + ) + .with_contained( + // the way the tests are setup, this data is + // consulted if the "foo" and "bar" are being checked at the same time + "s1", + [ScalarValue::from("foo"), ScalarValue::from("bar")], + [ + // container 0,1,2 unknown about ("foo, "bar") + None, + None, + None, + // container 3,4,5 known to contain only either "foo" and "bar" + Some(true), + Some(true), + Some(true), + // container 6,7,8 known to contain neither "foo" and "bar" + Some(false), + Some(false), + Some(false), + ], + ); + + // s1 = 'foo' + prune_with_expr( + col("s1").eq(lit("foo")), + &schema, + &statistics, + // rule out containers ('false) where we know foo is not present + vec![true, false, true, true, false, true, true, false, true], + ); + + // s1 = 'bar' + prune_with_expr( + col("s1").eq(lit("bar")), + &schema, + &statistics, + // rule out containers where we know bar is not present + vec![true, true, true, false, false, false, true, true, true], + ); + + // s1 = 'baz' (unknown value) + prune_with_expr( + col("s1").eq(lit("baz")), + &schema, + &statistics, + // can't rule out anything + vec![true, true, true, true, true, true, true, true, true], + ); + + // s1 = 'foo' AND s1 = 'bar' + prune_with_expr( + col("s1").eq(lit("foo")).and(col("s1").eq(lit("bar"))), + &schema, + &statistics, + // logically this predicate can't possibly be true (the column can't + // take on both values) but we could rule it out if the stats tell + // us that both values are not present + vec![true, true, true, true, true, true, true, true, true], + ); + + // s1 = 'foo' OR s1 = 'bar' + prune_with_expr( + col("s1").eq(lit("foo")).or(col("s1").eq(lit("bar"))), + &schema, + &statistics, + // can rule out containers that we know contain neither foo nor bar + vec![true, true, true, true, true, true, false, false, false], + ); + + // s1 = 'foo' OR s1 = 'baz' + prune_with_expr( + col("s1").eq(lit("foo")).or(col("s1").eq(lit("baz"))), + &schema, + &statistics, + // can't rule out anything container + vec![true, true, true, true, true, true, true, true, true], + ); + + // s1 = 'foo' OR s1 = 'bar' OR s1 = 'baz' + prune_with_expr( + col("s1") + .eq(lit("foo")) + .or(col("s1").eq(lit("bar"))) + .or(col("s1").eq(lit("baz"))), + &schema, + &statistics, + // can rule out any containers based on knowledge of s1 and `foo`, + // `bar` and (`foo`, `bar`) + vec![true, true, true, true, true, true, true, true, true], + ); + + // s1 != foo + prune_with_expr( + col("s1").not_eq(lit("foo")), + &schema, + &statistics, + // rule out containers we know for sure only contain foo + vec![false, true, true, false, true, true, false, true, true], + ); + + // s1 != bar + prune_with_expr( + col("s1").not_eq(lit("bar")), + &schema, + &statistics, + // rule out when we know for sure s1 has the value bar + vec![false, false, false, true, true, true, true, true, true], + ); + + // s1 != foo AND s1 != bar + prune_with_expr( + col("s1") + .not_eq(lit("foo")) + .and(col("s1").not_eq(lit("bar"))), + &schema, + &statistics, + // can rule out any container where we know s1 does not have either 'foo' or 'bar' + vec![true, true, true, false, false, false, true, true, true], + ); + + // s1 != foo AND s1 != bar AND s1 != baz + prune_with_expr( + col("s1") + .not_eq(lit("foo")) + .and(col("s1").not_eq(lit("bar"))) + .and(col("s1").not_eq(lit("baz"))), + &schema, + &statistics, + // can't rule out any container based on knowledge of s1,s2 + vec![true, true, true, true, true, true, true, true, true], + ); + + // s1 != foo OR s1 != bar + prune_with_expr( + col("s1") + .not_eq(lit("foo")) + .or(col("s1").not_eq(lit("bar"))), + &schema, + &statistics, + // cant' rule out anything based on contains information + vec![true, true, true, true, true, true, true, true, true], + ); + + // s1 != foo OR s1 != bar OR s1 != baz + prune_with_expr( + col("s1") + .not_eq(lit("foo")) + .or(col("s1").not_eq(lit("bar"))) + .or(col("s1").not_eq(lit("baz"))), + &schema, + &statistics, + // cant' rule out anything based on contains information + vec![true, true, true, true, true, true, true, true, true], + ); + } + + #[test] + fn prune_with_contained_two_columns() { + let schema = Arc::new(Schema::new(vec![ + Field::new("s1", DataType::Utf8, true), + Field::new("s2", DataType::Utf8, true), + ])); + + // Model having information like bloom filters for s1 and s2 + let statistics = TestStatistics::new() + .with_contained( + "s1", + [ScalarValue::from("foo")], + [ + // container 0, s1 known to only contain "foo"", + Some(true), + // container 1, s1 known to not contain "foo" + Some(false), + // container 2, s1 unknown about "foo" + None, + // container 3, s1 known to only contain "foo" + Some(true), + // container 4, s1 known to not contain "foo" + Some(false), + // container 5, s1 unknown about "foo" + None, + // container 6, s1 known to only contain "foo" + Some(true), + // container 7, s1 known to not contain "foo" + Some(false), + // container 8, s1 unknown about "foo" + None, + ], + ) + .with_contained( + "s2", // for column s2 + [ScalarValue::from("bar")], + [ + // containers 0,1,2 s2 known to only contain "bar" + Some(true), + Some(true), + Some(true), + // container 3,4,5 s2 known to not contain "bar" + Some(false), + Some(false), + Some(false), + // container 6,7,8 s2 unknown about "bar" + None, + None, + None, + ], + ); + + // s1 = 'foo' + prune_with_expr( + col("s1").eq(lit("foo")), + &schema, + &statistics, + // rule out containers where we know s1 is not present + vec![true, false, true, true, false, true, true, false, true], + ); + + // s1 = 'foo' OR s2 = 'bar' + let expr = col("s1").eq(lit("foo")).or(col("s2").eq(lit("bar"))); + prune_with_expr( + expr, + &schema, + &statistics, + // can't rule out any container (would need to prove that s1 != foo AND s2 != bar) + vec![true, true, true, true, true, true, true, true, true], + ); + + // s1 = 'foo' AND s2 != 'bar' + prune_with_expr( + col("s1").eq(lit("foo")).and(col("s2").not_eq(lit("bar"))), + &schema, + &statistics, + // can only rule out container where we know either: + // 1. s1 doesn't have the value 'foo` or + // 2. s2 has only the value of 'bar' + vec![false, false, false, true, false, true, true, false, true], + ); + + // s1 != 'foo' AND s2 != 'bar' + prune_with_expr( + col("s1") + .not_eq(lit("foo")) + .and(col("s2").not_eq(lit("bar"))), + &schema, + &statistics, + // Can rule out any container where we know either + // 1. s1 has only the value 'foo' + // 2. s2 has only the value 'bar' + vec![false, false, false, false, true, true, false, true, true], + ); + + // s1 != 'foo' AND (s2 = 'bar' OR s2 = 'baz') + prune_with_expr( + col("s1") + .not_eq(lit("foo")) + .and(col("s2").eq(lit("bar")).or(col("s2").eq(lit("baz")))), + &schema, + &statistics, + // Can rule out any container where we know s1 has only the value + // 'foo'. Can't use knowledge of s2 and bar to rule out anything + vec![false, true, true, false, true, true, false, true, true], + ); + + // s1 like '%foo%bar%' + prune_with_expr( + col("s1").like(lit("foo%bar%")), + &schema, + &statistics, + // cant rule out anything with information we know + vec![true, true, true, true, true, true, true, true, true], + ); + + // s1 like '%foo%bar%' AND s2 = 'bar' + prune_with_expr( + col("s1") + .like(lit("foo%bar%")) + .and(col("s2").eq(lit("bar"))), + &schema, + &statistics, + // can rule out any container where we know s2 does not have the value 'bar' + vec![true, true, true, false, false, false, true, true, true], + ); + + // s1 like '%foo%bar%' OR s2 = 'bar' + prune_with_expr( + col("s1").like(lit("foo%bar%")).or(col("s2").eq(lit("bar"))), + &schema, + &statistics, + // can't rule out anything (we would have to prove that both the + // like and the equality must be false) + vec![true, true, true, true, true, true, true, true, true], + ); + } + + #[test] + fn prune_with_range_and_contained() { + // Setup mimics range information for i, a bloom filter for s + let schema = Arc::new(Schema::new(vec![ + Field::new("i", DataType::Int32, true), + Field::new("s", DataType::Utf8, true), + ])); + + let statistics = TestStatistics::new() + .with( + "i", + ContainerStats::new_i32( + // Container 0, 3, 6: [-5 to 5] + // Container 1, 4, 7: [10 to 20] + // Container 2, 5, 9: unknown + vec![ + Some(-5), + Some(10), + None, + Some(-5), + Some(10), + None, + Some(-5), + Some(10), + None, + ], // min + vec![ + Some(5), + Some(20), + None, + Some(5), + Some(20), + None, + Some(5), + Some(20), + None, + ], // max + ), + ) + // Add contained information about the s and "foo" + .with_contained( + "s", + [ScalarValue::from("foo")], + [ + // container 0,1,2 known to only contain "foo" + Some(true), + Some(true), + Some(true), + // container 3,4,5 known to not contain "foo" + Some(false), + Some(false), + Some(false), + // container 6,7,8 unknown about "foo" + None, + None, + None, + ], + ); + + // i = 0 and s = 'foo' + prune_with_expr( + col("i").eq(lit(0)).and(col("s").eq(lit("foo"))), + &schema, + &statistics, + // Can rule out container where we know that either: + // 1. 0 is outside the min/max range of i + // 1. s does not contain foo + // (range is false, and contained is false) + vec![true, false, true, false, false, false, true, false, true], + ); + + // i = 0 and s != 'foo' + prune_with_expr( + col("i").eq(lit(0)).and(col("s").not_eq(lit("foo"))), + &schema, + &statistics, + // Can rule out containers where either: + // 1. 0 is outside the min/max range of i + // 2. s only contains foo + vec![false, false, false, true, false, true, true, false, true], + ); + + // i = 0 OR s = 'foo' + prune_with_expr( + col("i").eq(lit(0)).or(col("s").eq(lit("foo"))), + &schema, + &statistics, + // in theory could rule out containers if we had min/max values for + // s as well. But in this case we don't so we can't rule out anything + vec![true, true, true, true, true, true, true, true, true], + ); + } + + /// prunes the specified expr with the specified schema and statistics, and + /// ensures it returns expected. + /// + /// `expected` is a vector of bools, where true means the row group should + /// be kept, and false means it should be pruned. + /// + // TODO refactor other tests to use this to reduce boiler plate + fn prune_with_expr( + expr: Expr, + schema: &SchemaRef, + statistics: &TestStatistics, + expected: Vec, + ) { + println!("Pruning with expr: {}", expr); + let expr = logical2physical(&expr, schema); + let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); + let result = p.prune(statistics).unwrap(); + assert_eq!(result, expected); + } + fn test_build_predicate_expression( expr: &Expr, schema: &Schema, - required_columns: &mut RequiredStatColumns, + required_columns: &mut RequiredColumns, ) -> Arc { let expr = logical2physical(expr, schema); build_predicate_expression(&expr, schema, required_columns) diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index 0c2f21d11acdd..e49b358608aab 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -19,18 +19,18 @@ //! order-preserving variants when it is helpful; either in terms of //! performance or to accommodate unbounded streams by fixing the pipeline. +use std::borrow::Cow; use std::sync::Arc; +use super::utils::is_repartition; use crate::error::Result; -use crate::physical_optimizer::utils::{is_coalesce_partitions, is_sort, ExecTree}; +use crate::physical_optimizer::utils::{is_coalesce_partitions, is_sort}; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; -use super::utils::is_repartition; - use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_physical_plan::unbounded_output; /// For a given `plan`, this object carries the information one needs from its @@ -40,159 +40,157 @@ use datafusion_physical_plan::unbounded_output; #[derive(Debug, Clone)] pub(crate) struct OrderPreservationContext { pub(crate) plan: Arc, - ordering_onwards: Vec>, + ordering_connection: bool, + children_nodes: Vec, } impl OrderPreservationContext { - /// Creates a "default" order-preservation context. + /// Creates an empty context tree. Each node has `false` connections. pub fn new(plan: Arc) -> Self { - let length = plan.children().len(); - OrderPreservationContext { + let children = plan.children(); + Self { plan, - ordering_onwards: vec![None; length], + ordering_connection: false, + children_nodes: children.into_iter().map(Self::new).collect(), } } /// Creates a new order-preservation context from those of children nodes. - pub fn new_from_children_nodes( - children_nodes: Vec, - parent_plan: Arc, - ) -> Result { - let children_plans = children_nodes - .iter() - .map(|item| item.plan.clone()) - .collect(); - let ordering_onwards = children_nodes - .into_iter() - .enumerate() - .map(|(idx, item)| { - // `ordering_onwards` tree keeps track of executors that maintain - // ordering, (or that can maintain ordering with the replacement of - // its variant) - let plan = item.plan; - let children = plan.children(); - let ordering_onwards = item.ordering_onwards; - if children.is_empty() { - // Plan has no children, there is nothing to propagate. - None - } else if ordering_onwards[0].is_none() - && ((is_repartition(&plan) && !plan.maintains_input_order()[0]) - || (is_coalesce_partitions(&plan) - && children[0].output_ordering().is_some())) - { - Some(ExecTree::new(plan, idx, vec![])) - } else { - let children = ordering_onwards - .into_iter() - .flatten() - .filter(|item| { - // Only consider operators that maintains ordering - plan.maintains_input_order()[item.idx] - || is_coalesce_partitions(&plan) - || is_repartition(&plan) - }) - .collect::>(); - if children.is_empty() { - None - } else { - Some(ExecTree::new(plan, idx, children)) - } - } - }) - .collect(); - let plan = with_new_children_if_necessary(parent_plan, children_plans)?.into(); - Ok(OrderPreservationContext { - plan, - ordering_onwards, - }) - } + pub fn update_children(mut self) -> Result { + for node in self.children_nodes.iter_mut() { + let plan = node.plan.clone(); + let children = plan.children(); + let maintains_input_order = plan.maintains_input_order(); + let inspect_child = |idx| { + maintains_input_order[idx] + || is_coalesce_partitions(&plan) + || is_repartition(&plan) + }; + + // We cut the path towards nodes that do not maintain ordering. + for (idx, c) in node.children_nodes.iter_mut().enumerate() { + c.ordering_connection &= inspect_child(idx); + } - /// Computes order-preservation contexts for every child of the plan. - pub fn children(&self) -> Vec { - self.plan - .children() - .into_iter() - .map(OrderPreservationContext::new) - .collect() + node.ordering_connection = if children.is_empty() { + false + } else if !node.children_nodes[0].ordering_connection + && ((is_repartition(&plan) && !maintains_input_order[0]) + || (is_coalesce_partitions(&plan) + && children[0].output_ordering().is_some())) + { + // We either have a RepartitionExec or a CoalescePartitionsExec + // and they lose their input ordering, so initiate connection: + true + } else { + // Maintain connection if there is a child with a connection, + // and operator can possibly maintain that connection (either + // in its current form or when we replace it with the corresponding + // order preserving operator). + node.children_nodes + .iter() + .enumerate() + .any(|(idx, c)| c.ordering_connection && inspect_child(idx)) + } + } + + self.plan = with_new_children_if_necessary( + self.plan, + self.children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); + self.ordering_connection = false; + Ok(self) } } impl TreeNode for OrderPreservationContext { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in self.children() { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children_nodes.iter().map(Cow::Borrowed).collect() } - fn map_children(self, transform: F) -> Result + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - let children = self.children(); - if children.is_empty() { - Ok(self) - } else { - let children_nodes = children + if !self.children_nodes.is_empty() { + self.children_nodes = self + .children_nodes .into_iter() .map(transform) - .collect::>>()?; - OrderPreservationContext::new_from_children_nodes(children_nodes, self.plan) + .collect::>()?; + self.plan = with_new_children_if_necessary( + self.plan, + self.children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); } + Ok(self) } } -/// Calculates the updated plan by replacing executors that lose ordering -/// inside the `ExecTree` with their order-preserving variants. This will +/// Calculates the updated plan by replacing operators that lose ordering +/// inside `sort_input` with their order-preserving variants. This will /// generate an alternative plan, which will be accepted or rejected later on /// depending on whether it helps us remove a `SortExec`. fn get_updated_plan( - exec_tree: &ExecTree, + mut sort_input: OrderPreservationContext, // Flag indicating that it is desirable to replace `RepartitionExec`s with // `SortPreservingRepartitionExec`s: is_spr_better: bool, // Flag indicating that it is desirable to replace `CoalescePartitionsExec`s // with `SortPreservingMergeExec`s: is_spm_better: bool, -) -> Result> { - let plan = exec_tree.plan.clone(); +) -> Result { + let updated_children = sort_input + .children_nodes + .clone() + .into_iter() + .map(|item| { + // Update children and their descendants in the given tree if the connection is open: + if item.ordering_connection { + get_updated_plan(item, is_spr_better, is_spm_better) + } else { + Ok(item) + } + }) + .collect::>>()?; - let mut children = plan.children(); - // Update children and their descendants in the given tree: - for item in &exec_tree.children { - children[item.idx] = get_updated_plan(item, is_spr_better, is_spm_better)?; - } - // Construct the plan with updated children: - let mut plan = plan.with_new_children(children)?; + sort_input.plan = sort_input + .plan + .with_new_children(updated_children.iter().map(|c| c.plan.clone()).collect())?; + sort_input.ordering_connection = false; + sort_input.children_nodes = updated_children; // When a `RepartitionExec` doesn't preserve ordering, replace it with - // a `SortPreservingRepartitionExec` if appropriate: - if is_repartition(&plan) && !plan.maintains_input_order()[0] && is_spr_better { - let child = plan.children().swap_remove(0); - let repartition = RepartitionExec::try_new(child, plan.output_partitioning())?; - plan = Arc::new(repartition.with_preserve_order(true)) as _ - } - // When the input of a `CoalescePartitionsExec` has an ordering, replace it - // with a `SortPreservingMergeExec` if appropriate: - let mut children = plan.children(); - if is_coalesce_partitions(&plan) - && children[0].output_ordering().is_some() - && is_spm_better + // a sort-preserving variant if appropriate: + if is_repartition(&sort_input.plan) + && !sort_input.plan.maintains_input_order()[0] + && is_spr_better { - let child = children.swap_remove(0); - plan = Arc::new(SortPreservingMergeExec::new( - child.output_ordering().unwrap_or(&[]).to_vec(), - child, - )) as _ + let child = sort_input.plan.children().swap_remove(0); + let repartition = + RepartitionExec::try_new(child, sort_input.plan.output_partitioning())? + .with_preserve_order(); + sort_input.plan = Arc::new(repartition) as _; + sort_input.children_nodes[0].ordering_connection = true; + } else if is_coalesce_partitions(&sort_input.plan) && is_spm_better { + // When the input of a `CoalescePartitionsExec` has an ordering, replace it + // with a `SortPreservingMergeExec` if appropriate: + if let Some(ordering) = sort_input.children_nodes[0] + .plan + .output_ordering() + .map(|o| o.to_vec()) + { + // Now we can mutate `new_node.children_nodes` safely + let child = sort_input.children_nodes.clone().swap_remove(0); + sort_input.plan = + Arc::new(SortPreservingMergeExec::new(ordering, child.plan)) as _; + sort_input.children_nodes[0].ordering_connection = true; + } } - Ok(plan) + + Ok(sort_input) } /// The `replace_with_order_preserving_variants` optimizer sub-rule tries to @@ -210,11 +208,11 @@ fn get_updated_plan( /// /// The algorithm flow is simply like this: /// 1. Visit nodes of the physical plan bottom-up and look for `SortExec` nodes. -/// 1_1. During the traversal, build an `ExecTree` to keep track of operators -/// that maintain ordering (or can maintain ordering when replaced by an -/// order-preserving variant) until a `SortExec` is found. +/// 1_1. During the traversal, keep track of operators that maintain ordering +/// (or can maintain ordering when replaced by an order-preserving variant) until +/// a `SortExec` is found. /// 2. When a `SortExec` is found, update the child of the `SortExec` by replacing -/// operators that do not preserve ordering in the `ExecTree` with their order +/// operators that do not preserve ordering in the tree with their order /// preserving variants. /// 3. Check if the `SortExec` is still necessary in the updated plan by comparing /// its input ordering with the output ordering it imposes. We do this because @@ -238,37 +236,41 @@ pub(crate) fn replace_with_order_preserving_variants( is_spm_better: bool, config: &ConfigOptions, ) -> Result> { - let plan = &requirements.plan; - let ordering_onwards = &requirements.ordering_onwards; - if is_sort(plan) { - let exec_tree = if let Some(exec_tree) = &ordering_onwards[0] { - exec_tree - } else { - return Ok(Transformed::No(requirements)); - }; - // For unbounded cases, replace with the order-preserving variant in - // any case, as doing so helps fix the pipeline. - // Also do the replacement if opted-in via config options. - let use_order_preserving_variant = - config.optimizer.prefer_existing_sort || unbounded_output(plan); - let updated_sort_input = get_updated_plan( - exec_tree, - is_spr_better || use_order_preserving_variant, - is_spm_better || use_order_preserving_variant, - )?; - // If this sort is unnecessary, we should remove it and update the plan: - if updated_sort_input - .equivalence_properties() - .ordering_satisfy(plan.output_ordering().unwrap_or(&[])) - { - return Ok(Transformed::Yes(OrderPreservationContext { - plan: updated_sort_input, - ordering_onwards: vec![None], - })); - } + let mut requirements = requirements.update_children()?; + if !(is_sort(&requirements.plan) + && requirements.children_nodes[0].ordering_connection) + { + return Ok(Transformed::No(requirements)); } - Ok(Transformed::No(requirements)) + // For unbounded cases, replace with the order-preserving variant in + // any case, as doing so helps fix the pipeline. + // Also do the replacement if opted-in via config options. + let use_order_preserving_variant = + config.optimizer.prefer_existing_sort || unbounded_output(&requirements.plan); + + let mut updated_sort_input = get_updated_plan( + requirements.children_nodes.clone().swap_remove(0), + is_spr_better || use_order_preserving_variant, + is_spm_better || use_order_preserving_variant, + )?; + + // If this sort is unnecessary, we should remove it and update the plan: + if updated_sort_input + .plan + .equivalence_properties() + .ordering_satisfy(requirements.plan.output_ordering().unwrap_or(&[])) + { + for child in updated_sort_input.children_nodes.iter_mut() { + child.ordering_connection = false; + } + Ok(Transformed::Yes(updated_sort_input)) + } else { + for child in requirements.children_nodes.iter_mut() { + child.ordering_connection = false; + } + Ok(Transformed::Yes(requirements)) + } } #[cfg(test)] @@ -285,8 +287,9 @@ mod tests { use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; - use crate::physical_plan::{displayable, Partitioning}; + use crate::physical_plan::{displayable, get_plan_string, Partitioning}; use crate::prelude::SessionConfig; + use crate::test::TestStreamPartition; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; @@ -296,24 +299,85 @@ mod tests { use datafusion_expr::{JoinType, Operator}; use datafusion_physical_expr::expressions::{self, col, Column}; use datafusion_physical_expr::PhysicalSortExpr; + use datafusion_physical_plan::streaming::StreamingTableExec; + + use rstest::rstest; - /// Runs the `replace_with_order_preserving_variants` sub-rule and asserts the plan - /// against the original and expected plans. + /// Runs the `replace_with_order_preserving_variants` sub-rule and asserts + /// the plan against the original and expected plans for both bounded and + /// unbounded cases. /// - /// `$EXPECTED_PLAN_LINES`: input plan - /// `$EXPECTED_OPTIMIZED_PLAN_LINES`: optimized plan - /// `$PLAN`: the plan to optimized - /// `$ALLOW_BOUNDED`: whether to allow the plan to be optimized for bounded cases - macro_rules! assert_optimized { - ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr) => { + /// # Parameters + /// + /// * `EXPECTED_UNBOUNDED_PLAN_LINES`: Expected input unbounded plan. + /// * `EXPECTED_BOUNDED_PLAN_LINES`: Expected input bounded plan. + /// * `EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES`: Optimized plan, which is + /// the same regardless of the value of the `prefer_existing_sort` flag. + /// * `EXPECTED_BOUNDED_OPTIMIZED_PLAN_LINES`: Optimized plan when the flag + /// `prefer_existing_sort` is `false` for bounded cases. + /// * `EXPECTED_BOUNDED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES`: Optimized plan + /// when the flag `prefer_existing_sort` is `true` for bounded cases. + /// * `$PLAN`: The plan to optimize. + /// * `$SOURCE_UNBOUNDED`: Whether the given plan contains an unbounded source. + macro_rules! assert_optimized_in_all_boundedness_situations { + ($EXPECTED_UNBOUNDED_PLAN_LINES: expr, $EXPECTED_BOUNDED_PLAN_LINES: expr, $EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_BOUNDED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_BOUNDED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $SOURCE_UNBOUNDED: expr) => { + if $SOURCE_UNBOUNDED { + assert_optimized_prefer_sort_on_off!( + $EXPECTED_UNBOUNDED_PLAN_LINES, + $EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES, + $EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES, + $PLAN + ); + } else { + assert_optimized_prefer_sort_on_off!( + $EXPECTED_BOUNDED_PLAN_LINES, + $EXPECTED_BOUNDED_OPTIMIZED_PLAN_LINES, + $EXPECTED_BOUNDED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES, + $PLAN + ); + } + }; + } + + /// Runs the `replace_with_order_preserving_variants` sub-rule and asserts + /// the plan against the original and expected plans. + /// + /// # Parameters + /// + /// * `$EXPECTED_PLAN_LINES`: Expected input plan. + /// * `EXPECTED_OPTIMIZED_PLAN_LINES`: Optimized plan when the flag + /// `prefer_existing_sort` is `false`. + /// * `EXPECTED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES`: Optimized plan when + /// the flag `prefer_existing_sort` is `true`. + /// * `$PLAN`: The plan to optimize. + macro_rules! assert_optimized_prefer_sort_on_off { + ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr) => { assert_optimized!( $EXPECTED_PLAN_LINES, $EXPECTED_OPTIMIZED_PLAN_LINES, - $PLAN, + $PLAN.clone(), false ); + assert_optimized!( + $EXPECTED_PLAN_LINES, + $EXPECTED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES, + $PLAN, + true + ); }; - ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $ALLOW_BOUNDED: expr) => { + } + + /// Runs the `replace_with_order_preserving_variants` sub-rule and asserts + /// the plan against the original and expected plans. + /// + /// # Parameters + /// + /// * `$EXPECTED_PLAN_LINES`: Expected input plan. + /// * `$EXPECTED_OPTIMIZED_PLAN_LINES`: Expected optimized plan. + /// * `$PLAN`: The plan to optimize. + /// * `$PREFER_EXISTING_SORT`: Value of the `prefer_existing_sort` flag. + macro_rules! assert_optimized { + ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $PREFER_EXISTING_SORT: expr) => { let physical_plan = $PLAN; let formatted = displayable(physical_plan.as_ref()).indent(true).to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -329,8 +393,7 @@ mod tests { let expected_optimized_lines: Vec<&str> = $EXPECTED_OPTIMIZED_PLAN_LINES.iter().map(|s| *s).collect(); // Run the rule top-down - // let optimized_physical_plan = physical_plan.transform_down(&replace_repartition_execs)?; - let config = SessionConfig::new().with_prefer_existing_sort($ALLOW_BOUNDED); + let config = SessionConfig::new().with_prefer_existing_sort($PREFER_EXISTING_SORT); let plan_with_pipeline_fixer = OrderPreservationContext::new(physical_plan); let parallel = plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| replace_with_order_preserving_variants(plan_with_pipeline_fixer, false, false, config.options()))?; let optimized_physical_plan = parallel.plan; @@ -344,150 +407,351 @@ mod tests { }; } + #[rstest] #[tokio::test] // Searches for a simple sort and a repartition just after it, the second repartition with 1 input partition should not be affected - async fn test_replace_multiple_input_repartition_1() -> Result<()> { + async fn test_replace_multiple_input_repartition_1( + #[values(false, true)] source_unbounded: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition = repartition_exec_hash(repartition_exec_round_robin(source)); let sort = sort_exec(vec![sort_expr("a", &schema)], repartition, true); let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - let expected_input = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + let expected_input_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, + physical_plan, + source_unbounded + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_with_inter_children_change_only() -> Result<()> { + async fn test_with_inter_children_change_only( + #[values(false, true)] source_unbounded: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr_default("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); let sort = sort_exec( - vec![sort_expr_default("a", &schema)], + vec![sort_expr_default("a", &coalesce_partitions.schema())], coalesce_partitions, false, ); let repartition_rr2 = repartition_exec_round_robin(sort); let repartition_hash2 = repartition_exec_hash(repartition_rr2); - let filter = filter_exec(repartition_hash2, &schema); - let sort2 = sort_exec(vec![sort_expr_default("a", &schema)], filter, true); + let filter = filter_exec(repartition_hash2); + let sort2 = + sort_exec(vec![sort_expr_default("a", &filter.schema())], filter, true); - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr_default("a", &schema)], sort2); + let physical_plan = sort_preserving_merge_exec( + vec![sort_expr_default("a", &sort2.schema())], + sort2, + ); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ "SortPreservingMergeExec: [a@0 ASC]", " SortExec: expr=[a@0 ASC]", - " FilterExec: c@2 > 3", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " SortExec: expr=[a@0 ASC]", " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC]", + ]; + let expected_input_bounded = [ + "SortPreservingMergeExec: [a@0 ASC]", + " SortExec: expr=[a@0 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " SortExec: expr=[a@0 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC], has_header=true", ]; - let expected_optimized = [ + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ + "SortPreservingMergeExec: [a@0 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " SortPreservingMergeExec: [a@0 ASC]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC]", + ]; + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [a@0 ASC]", + " SortExec: expr=[a@0 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " SortExec: expr=[a@0 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ "SortPreservingMergeExec: [a@0 ASC]", - " FilterExec: c@2 > 3", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " SortPreservingMergeExec: [a@0 ASC]", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC], has_header=true", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, + physical_plan, + source_unbounded + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_replace_multiple_input_repartition_2() -> Result<()> { + async fn test_replace_multiple_input_repartition_2( + #[values(false, true)] source_unbounded: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); - let filter = filter_exec(repartition_rr, &schema); + let filter = filter_exec(repartition_rr); let repartition_hash = repartition_exec_hash(filter); let sort = sort_exec(vec![sort_expr("a", &schema)], repartition_hash, true); let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - let expected_input = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", - " FilterExec: c@2 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " FilterExec: c@1 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - let expected_optimized = [ + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + let expected_input_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", - " FilterExec: c@2 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " FilterExec: c@1 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, + physical_plan, + source_unbounded + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_replace_multiple_input_repartition_with_extra_steps() -> Result<()> { + async fn test_replace_multiple_input_repartition_with_extra_steps( + #[values(false, true)] source_unbounded: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); - let filter = filter_exec(repartition_hash, &schema); + let filter = filter_exec(repartition_hash); let coalesce_batches_exec: Arc = coalesce_batches_exec(filter); let sort = sort_exec(vec![sort_expr("a", &schema)], coalesce_batches_exec, true); let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - let expected_input = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + let expected_input_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@2 > 3", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@2 > 3", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, + physical_plan, + source_unbounded + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_replace_multiple_input_repartition_with_extra_steps_2() -> Result<()> { + async fn test_replace_multiple_input_repartition_with_extra_steps_2( + #[values(false, true)] source_unbounded: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let coalesce_batches_exec_1 = coalesce_batches_exec(repartition_rr); let repartition_hash = repartition_exec_hash(coalesce_batches_exec_1); - let filter = filter_exec(repartition_hash, &schema); + let filter = filter_exec(repartition_hash); let coalesce_batches_exec_2 = coalesce_batches_exec(filter); let sort = sort_exec(vec![sort_expr("a", &schema)], coalesce_batches_exec_2, true); @@ -495,62 +759,157 @@ mod tests { let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - let expected_input = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " CoalesceBatchesExec: target_batch_size=8192", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + let expected_input_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " CoalesceBatchesExec: target_batch_size=8192", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " CoalesceBatchesExec: target_batch_size=8192", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@2 > 3", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " CoalesceBatchesExec: target_batch_size=8192", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@2 > 3", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " CoalesceBatchesExec: target_batch_size=8192", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, + physical_plan, + source_unbounded + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_not_replacing_when_no_need_to_preserve_sorting() -> Result<()> { + async fn test_not_replacing_when_no_need_to_preserve_sorting( + #[values(false, true)] source_unbounded: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); - let filter = filter_exec(repartition_hash, &schema); + let filter = filter_exec(repartition_hash); let coalesce_batches_exec: Arc = coalesce_batches_exec(filter); let physical_plan: Arc = coalesce_partitions_exec(coalesce_batches_exec); - let expected_input = ["CoalescePartitionsExec", + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ + "CoalescePartitionsExec", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + let expected_input_bounded = [ + "CoalescePartitionsExec", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ + "CoalescePartitionsExec", " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@2 > 3", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - let expected_optimized = ["CoalescePartitionsExec", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + + // Expected bounded results same with and without flag, because there is no executor with ordering requirement + let expected_optimized_bounded = [ + "CoalescePartitionsExec", " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@2 > 3", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = expected_optimized_bounded; + + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, + physical_plan, + source_unbounded + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_with_multiple_replacable_repartitions() -> Result<()> { + async fn test_with_multiple_replacable_repartitions( + #[values(false, true)] source_unbounded: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); - let filter = filter_exec(repartition_hash, &schema); + let filter = filter_exec(repartition_hash); let coalesce_batches = coalesce_batches_exec(filter); let repartition_hash_2 = repartition_exec_hash(coalesce_batches); let sort = sort_exec(vec![sort_expr("a", &schema)], repartition_hash_2, true); @@ -558,145 +917,341 @@ mod tests { let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@2 > 3", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true" + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - let expected_optimized = [ + let expected_input_bounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@2 > 3", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, + physical_plan, + source_unbounded + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_not_replace_with_different_orderings() -> Result<()> { + async fn test_not_replace_with_different_orderings( + #[values(false, true)] source_unbounded: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let sort = sort_exec( - vec![sort_expr_default("c", &schema)], + vec![sort_expr_default("c", &repartition_hash.schema())], repartition_hash, true, ); - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr_default("c", &schema)], sort); + let physical_plan = sort_preserving_merge_exec( + vec![sort_expr_default("c", &sort.schema())], + sort, + ); - let expected_input = ["SortPreservingMergeExec: [c@2 ASC]", - " SortExec: expr=[c@2 ASC]", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ + "SortPreservingMergeExec: [c@1 ASC]", + " SortExec: expr=[c@1 ASC]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - let expected_optimized = ["SortPreservingMergeExec: [c@2 ASC]", - " SortExec: expr=[c@2 ASC]", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + let expected_input_bounded = [ + "SortPreservingMergeExec: [c@1 ASC]", + " SortExec: expr=[c@1 ASC]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ + "SortPreservingMergeExec: [c@1 ASC]", + " SortExec: expr=[c@1 ASC]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + + // Expected bounded results same with and without flag, because ordering requirement of the executor is different than the existing ordering. + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [c@1 ASC]", + " SortExec: expr=[c@1 ASC]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = expected_optimized_bounded; + + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, + physical_plan, + source_unbounded + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_with_lost_ordering() -> Result<()> { + async fn test_with_lost_ordering( + #[values(false, true)] source_unbounded: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); let physical_plan = sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions, false); - let expected_input = ["SortExec: expr=[a@0 ASC NULLS LAST]", + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ + "SortExec: expr=[a@0 ASC NULLS LAST]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + let expected_input_bounded = [ + "SortExec: expr=[a@0 ASC NULLS LAST]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortExec: expr=[a@0 ASC NULLS LAST]", " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, + physical_plan, + source_unbounded + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_with_lost_and_kept_ordering() -> Result<()> { + async fn test_with_lost_and_kept_ordering( + #[values(false, true)] source_unbounded: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); let sort = sort_exec( - vec![sort_expr_default("c", &schema)], + vec![sort_expr_default("c", &coalesce_partitions.schema())], coalesce_partitions, false, ); let repartition_rr2 = repartition_exec_round_robin(sort); let repartition_hash2 = repartition_exec_hash(repartition_rr2); - let filter = filter_exec(repartition_hash2, &schema); - let sort2 = sort_exec(vec![sort_expr_default("c", &schema)], filter, true); + let filter = filter_exec(repartition_hash2); + let sort2 = + sort_exec(vec![sort_expr_default("c", &filter.schema())], filter, true); - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr_default("c", &schema)], sort2); + let physical_plan = sort_preserving_merge_exec( + vec![sort_expr_default("c", &sort2.schema())], + sort2, + ); - let expected_input = [ - "SortPreservingMergeExec: [c@2 ASC]", - " SortExec: expr=[c@2 ASC]", - " FilterExec: c@2 > 3", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ + "SortPreservingMergeExec: [c@1 ASC]", + " SortExec: expr=[c@1 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " SortExec: expr=[c@1 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + let expected_input_bounded = [ + "SortPreservingMergeExec: [c@1 ASC]", + " SortExec: expr=[c@1 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[c@2 ASC]", + " SortExec: expr=[c@1 ASC]", " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; - let expected_optimized = [ - "SortPreservingMergeExec: [c@2 ASC]", - " FilterExec: c@2 > 3", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=c@2 ASC", + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ + "SortPreservingMergeExec: [c@1 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=c@1 ASC", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[c@2 ASC]", + " SortExec: expr=[c@1 ASC]", " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [c@1 ASC]", + " SortExec: expr=[c@1 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " SortExec: expr=[c@1 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [c@1 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=c@1 ASC", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " SortExec: expr=[c@1 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, + physical_plan, + source_unbounded + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_with_multiple_child_trees() -> Result<()> { + async fn test_with_multiple_child_trees( + #[values(false, true)] source_unbounded: bool, + ) -> Result<()> { let schema = create_test_schema()?; let left_sort_exprs = vec![sort_expr("a", &schema)]; - let left_source = csv_exec_sorted(&schema, left_sort_exprs, true); + let left_source = if source_unbounded { + stream_exec_ordered(&schema, left_sort_exprs) + } else { + csv_exec_sorted(&schema, left_sort_exprs) + }; let left_repartition_rr = repartition_exec_round_robin(left_source); let left_repartition_hash = repartition_exec_hash(left_repartition_rr); let left_coalesce_partitions = Arc::new(CoalesceBatchesExec::new(left_repartition_hash, 4096)); let right_sort_exprs = vec![sort_expr("a", &schema)]; - let right_source = csv_exec_sorted(&schema, right_sort_exprs, true); + let right_source = if source_unbounded { + stream_exec_ordered(&schema, right_sort_exprs) + } else { + csv_exec_sorted(&schema, right_sort_exprs) + }; let right_repartition_rr = repartition_exec_round_robin(right_source); let right_repartition_hash = repartition_exec_hash(right_repartition_rr); let right_coalesce_partitions = @@ -704,63 +1259,86 @@ mod tests { let hash_join_exec = hash_join_exec(left_coalesce_partitions, right_coalesce_partitions); - let sort = sort_exec(vec![sort_expr_default("a", &schema)], hash_join_exec, true); + let sort = sort_exec( + vec![sort_expr_default("a", &hash_join_exec.schema())], + hash_join_exec, + true, + ); - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr_default("a", &schema)], sort); + let physical_plan = sort_preserving_merge_exec( + vec![sort_expr_default("a", &sort.schema())], + sort, + ); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ "SortPreservingMergeExec: [a@0 ASC]", " SortExec: expr=[a@0 ASC]", " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - - let expected_optimized = [ + let expected_input_bounded = [ "SortPreservingMergeExec: [a@0 ASC]", " SortExec: expr=[a@0 ASC]", " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); - Ok(()) - } - #[tokio::test] - async fn test_with_bounded_input() -> Result<()> { - let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, false); - let repartition = repartition_exec_hash(repartition_exec_round_robin(source)); - let sort = sort_exec(vec![sort_expr("a", &schema)], repartition, true); - - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ + "SortPreservingMergeExec: [a@0 ASC]", + " SortExec: expr=[a@0 ASC]", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; - let expected_input = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + // Expected bounded results same with and without flag, because ordering get lost during intermediate executor anyway. Hence no need to preserve + // existing ordering. + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [a@0 ASC]", + " SortExec: expr=[a@0 ASC]", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = expected_optimized_bounded; + + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, + physical_plan, + source_unbounded + ); Ok(()) } @@ -820,24 +1398,23 @@ mod tests { } fn repartition_exec_hash(input: Arc) -> Arc { + let input_schema = input.schema(); Arc::new( RepartitionExec::try_new( input, - Partitioning::Hash(vec![Arc::new(Column::new("c1", 0))], 8), + Partitioning::Hash(vec![col("c", &input_schema).unwrap()], 8), ) .unwrap(), ) } - fn filter_exec( - input: Arc, - schema: &SchemaRef, - ) -> Arc { + fn filter_exec(input: Arc) -> Arc { + let input_schema = input.schema(); let predicate = expressions::binary( - col("c", schema).unwrap(), + col("c", &input_schema).unwrap(), Operator::Gt, expressions::lit(3i32), - schema, + &input_schema, ) .unwrap(); Arc::new(FilterExec::try_new(predicate, input).unwrap()) @@ -855,11 +1432,15 @@ mod tests { left: Arc, right: Arc, ) -> Arc { + let left_on = col("c", &left.schema()).unwrap(); + let right_on = col("c", &right.schema()).unwrap(); + let left_col = left_on.as_any().downcast_ref::().unwrap(); + let right_col = right_on.as_any().downcast_ref::().unwrap(); Arc::new( HashJoinExec::try_new( left, right, - vec![(Column::new("c", 1), Column::new("c", 1))], + vec![(left_col.clone(), right_col.clone())], None, &JoinType::Inner, PartitionMode::Partitioned, @@ -879,12 +1460,33 @@ mod tests { Ok(schema) } + // creates a stream exec source for the test purposes + fn stream_exec_ordered( + schema: &SchemaRef, + sort_exprs: impl IntoIterator, + ) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + let projection: Vec = vec![0, 2, 3]; + + Arc::new( + StreamingTableExec::try_new( + schema.clone(), + vec![Arc::new(TestStreamPartition { + schema: schema.clone(), + }) as _], + Some(&projection), + vec![sort_exprs], + true, + ) + .unwrap(), + ) + } + // creates a csv exec source for the test purposes // projection and has_header parameters are given static due to testing needs fn csv_exec_sorted( schema: &SchemaRef, sort_exprs: impl IntoIterator, - infinite_source: bool, ) -> Arc { let sort_exprs = sort_exprs.into_iter().collect(); let projection: Vec = vec![0, 2, 3]; @@ -902,7 +1504,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![sort_exprs], - infinite_source, }, true, 0, @@ -911,11 +1512,4 @@ mod tests { FileCompressionType::UNCOMPRESSED, )) } - - // Util function to get string representation of a physical plan - fn get_plan_string(plan: &Arc) -> Vec { - let formatted = displayable(plan.as_ref()).indent(true).to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - actual.iter().map(|elem| elem.to_string()).collect() - } } diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index b9502d92ac12f..f0a8c8cfd3cba 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::borrow::Cow; use std::sync::Arc; use crate::physical_optimizer::utils::{ @@ -28,7 +29,7 @@ use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{plan_err, DataFusionError, JoinSide, Result}; use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::Column; @@ -36,8 +37,6 @@ use datafusion_physical_expr::{ LexRequirementRef, PhysicalSortExpr, PhysicalSortRequirement, }; -use itertools::izip; - /// This is a "data class" we use within the [`EnforceSorting`] rule to push /// down [`SortExec`] in the plan. In some cases, we can reduce the total /// computational cost by pushing down `SortExec`s through some executors. @@ -49,112 +48,92 @@ pub(crate) struct SortPushDown { pub plan: Arc, /// Parent required sort ordering required_ordering: Option>, - /// The adjusted request sort ordering to children. - /// By default they are the same as the plan's required input ordering, but can be adjusted based on parent required sort ordering properties. - adjusted_request_ordering: Vec>>, + children_nodes: Vec, } impl SortPushDown { - pub fn init(plan: Arc) -> Self { - let request_ordering = plan.required_input_ordering(); - SortPushDown { + /// Creates an empty tree with empty `required_ordering`'s. + pub fn new(plan: Arc) -> Self { + let children = plan.children(); + Self { plan, required_ordering: None, - adjusted_request_ordering: request_ordering, + children_nodes: children.into_iter().map(Self::new).collect(), } } - pub fn children(&self) -> Vec { - izip!( - self.plan.children().into_iter(), - self.adjusted_request_ordering.clone().into_iter(), - ) - .map(|(child, from_parent)| { - let child_request_ordering = child.required_input_ordering(); - SortPushDown { - plan: child, - required_ordering: from_parent, - adjusted_request_ordering: child_request_ordering, - } - }) - .collect() + /// Assigns the ordering requirement of the root node to the its children. + pub fn assign_initial_requirements(&mut self) { + let reqs = self.plan.required_input_ordering(); + for (child, requirement) in self.children_nodes.iter_mut().zip(reqs) { + child.required_ordering = requirement; + } } } impl TreeNode for SortPushDown { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - let children = self.children(); - for child in children { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children_nodes.iter().map(Cow::Borrowed).collect() } fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - let children = self.children(); - if !children.is_empty() { - let children_plans = children + if !self.children_nodes.is_empty() { + self.children_nodes = self + .children_nodes .into_iter() .map(transform) - .map(|r| r.map(|s| s.plan)) - .collect::>>()?; - - match with_new_children_if_necessary(self.plan, children_plans)? { - Transformed::Yes(plan) | Transformed::No(plan) => { - self.plan = plan; - } - } - }; + .collect::>()?; + self.plan = with_new_children_if_necessary( + self.plan, + self.children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); + } Ok(self) } } pub(crate) fn pushdown_sorts( - requirements: SortPushDown, + mut requirements: SortPushDown, ) -> Result> { let plan = &requirements.plan; let parent_required = requirements.required_ordering.as_deref().unwrap_or(&[]); + if let Some(sort_exec) = plan.as_any().downcast_ref::() { - let new_plan = if !plan + if !plan .equivalence_properties() .ordering_satisfy_requirement(parent_required) { // If the current plan is a SortExec, modify it to satisfy parent requirements: let mut new_plan = sort_exec.input().clone(); add_sort_above(&mut new_plan, parent_required, sort_exec.fetch()); - new_plan - } else { - requirements.plan + requirements.plan = new_plan; }; - let required_ordering = new_plan + + let required_ordering = requirements + .plan .output_ordering() .map(PhysicalSortRequirement::from_sort_exprs) .unwrap_or_default(); // Since new_plan is a SortExec, we can safely get the 0th index. - let child = new_plan.children().swap_remove(0); + let mut child = requirements.children_nodes.swap_remove(0); if let Some(adjusted) = - pushdown_requirement_to_children(&child, &required_ordering)? + pushdown_requirement_to_children(&child.plan, &required_ordering)? { + for (c, o) in child.children_nodes.iter_mut().zip(adjusted) { + c.required_ordering = o; + } // Can push down requirements - Ok(Transformed::Yes(SortPushDown { - plan: child, - required_ordering: None, - adjusted_request_ordering: adjusted, - })) + child.required_ordering = None; + Ok(Transformed::Yes(child)) } else { // Can not push down requirements - Ok(Transformed::Yes(SortPushDown::init(new_plan))) + let mut empty_node = SortPushDown::new(requirements.plan); + empty_node.assign_initial_requirements(); + Ok(Transformed::Yes(empty_node)) } } else { // Executors other than SortExec @@ -163,23 +142,27 @@ pub(crate) fn pushdown_sorts( .ordering_satisfy_requirement(parent_required) { // Satisfies parent requirements, immediately return. - return Ok(Transformed::Yes(SortPushDown { - required_ordering: None, - ..requirements - })); + let reqs = requirements.plan.required_input_ordering(); + for (child, order) in requirements.children_nodes.iter_mut().zip(reqs) { + child.required_ordering = order; + } + return Ok(Transformed::Yes(requirements)); } // Can not satisfy the parent requirements, check whether the requirements can be pushed down: if let Some(adjusted) = pushdown_requirement_to_children(plan, parent_required)? { - Ok(Transformed::Yes(SortPushDown { - plan: requirements.plan, - required_ordering: None, - adjusted_request_ordering: adjusted, - })) + for (c, o) in requirements.children_nodes.iter_mut().zip(adjusted) { + c.required_ordering = o; + } + requirements.required_ordering = None; + Ok(Transformed::Yes(requirements)) } else { // Can not push down requirements, add new SortExec: let mut new_plan = requirements.plan; add_sort_above(&mut new_plan, parent_required, None); - Ok(Transformed::Yes(SortPushDown::init(new_plan))) + let mut new_empty = SortPushDown::new(new_plan); + new_empty.assign_initial_requirements(); + // Can not push down requirements + Ok(Transformed::Yes(new_empty)) } } } @@ -297,10 +280,11 @@ fn pushdown_requirement_to_children( // TODO: Add support for Projection push down } -/// Determine the children requirements -/// If the children requirements are more specific, do not push down the parent requirements -/// If the the parent requirements are more specific, push down the parent requirements -/// If they are not compatible, need to add Sort. +/// Determine children requirements: +/// - If children requirements are more specific, do not push down parent +/// requirements. +/// - If parent requirements are more specific, push down parent requirements. +/// - If they are not compatible, need to add a sort. fn determine_children_requirement( parent_required: LexRequirementRef, request_child: LexRequirementRef, @@ -310,18 +294,15 @@ fn determine_children_requirement( .equivalence_properties() .requirements_compatible(request_child, parent_required) { - // request child requirements are more specific, no need to push down the parent requirements + // Child requirements are more specific, no need to push down. RequirementsCompatibility::Satisfy } else if child_plan .equivalence_properties() .requirements_compatible(parent_required, request_child) { - // parent requirements are more specific, adjust the request child requirements and push down the new requirements - let adjusted = if parent_required.is_empty() { - None - } else { - Some(parent_required.to_vec()) - }; + // Parent requirements are more specific, adjust child's requirements + // and push down the new requirements: + let adjusted = (!parent_required.is_empty()).then(|| parent_required.to_vec()); RequirementsCompatibility::Compatible(adjusted) } else { RequirementsCompatibility::NonCompatible @@ -424,9 +405,7 @@ fn shift_right_required( let new_right_required: Vec = parent_required .iter() .filter_map(|r| { - let Some(col) = r.expr.as_any().downcast_ref::() else { - return None; - }; + let col = r.expr.as_any().downcast_ref::()?; if col.index() < left_columns_len { return None; diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 159ee50890752..debafefe39ab9 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -35,17 +35,17 @@ use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::union::UnionExec; use crate::physical_plan::windows::create_window_expr; -use crate::physical_plan::{ExecutionPlan, Partitioning}; +use crate::physical_plan::{ExecutionPlan, InputOrderMode, Partitioning}; use crate::prelude::{CsvReadOptions, SessionContext}; use arrow_schema::{Schema, SchemaRef, SortOptions}; use datafusion_common::{JoinType, Statistics}; use datafusion_execution::object_store::ObjectStoreUrl; -use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunction}; +use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunctionDefinition}; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; -use datafusion_physical_plan::windows::PartitionSearchMode; +use crate::datasource::stream::{StreamConfig, StreamTable}; use async_trait::async_trait; async fn register_current_csv( @@ -55,14 +55,19 @@ async fn register_current_csv( ) -> Result<()> { let testdata = crate::test_util::arrow_test_data(); let schema = crate::test_util::aggr_test_schema(); - ctx.register_csv( - table_name, - &format!("{testdata}/csv/aggregate_test_100.csv"), - CsvReadOptions::new() - .schema(&schema) - .mark_infinite(infinite), - ) - .await?; + let path = format!("{testdata}/csv/aggregate_test_100.csv"); + + match infinite { + true => { + let config = StreamConfig::new_file(schema, path.into()); + ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config))))?; + } + false => { + ctx.register_csv(table_name, &path, CsvReadOptions::new().schema(&schema)) + .await?; + } + } + Ok(()) } @@ -229,7 +234,7 @@ pub fn bounded_window_exec( Arc::new( crate::physical_plan::windows::BoundedWindowAggExec::try_new( vec![create_window_expr( - &WindowFunction::AggregateFunction(AggregateFunction::Count), + &WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), "count".to_owned(), &[col(col_name, &schema).unwrap()], &[], @@ -240,7 +245,7 @@ pub fn bounded_window_exec( .unwrap()], input.clone(), vec![], - PartitionSearchMode::Sorted, + InputOrderMode::Sorted, ) .unwrap(), ) @@ -273,7 +278,6 @@ pub fn parquet_exec(schema: &SchemaRef) -> Arc { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, None, @@ -297,7 +301,6 @@ pub fn parquet_exec_sorted( limit: None, table_partition_cols: vec![], output_ordering: vec![sort_exprs], - infinite_source: false, }, None, None, @@ -328,7 +331,7 @@ pub fn spr_repartition_exec(input: Arc) -> Arc) -> Arc { PhysicalGroupBy::default(), vec![], vec![], - vec![], input, schema, ) diff --git a/datafusion/core/src/physical_optimizer/topk_aggregation.rs b/datafusion/core/src/physical_optimizer/topk_aggregation.rs index e0a8da82e35fc..dd02614203043 100644 --- a/datafusion/core/src/physical_optimizer/topk_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/topk_aggregation.rs @@ -73,7 +73,6 @@ impl TopKAggregation { aggr.group_by().clone(), aggr.aggr_expr().to_vec(), aggr.filter_expr().to_vec(), - aggr.order_by_expr().to_vec(), aggr.input().clone(), aggr.input_schema(), ) @@ -118,7 +117,7 @@ impl TopKAggregation { } Ok(Transformed::No(plan)) }; - let child = transform_down_mut(child.clone(), &mut closure).ok()?; + let child = child.clone().transform_down_mut(&mut closure).ok()?; let sort = SortExec::new(sort.expr().to_vec(), child) .with_fetch(sort.fetch()) .with_preserve_partitioning(sort.preserve_partitioning()); @@ -126,17 +125,6 @@ impl TopKAggregation { } } -fn transform_down_mut( - me: Arc, - op: &mut F, -) -> Result> -where - F: FnMut(Arc) -> Result>>, -{ - let after_op = op(me)?.into(); - after_op.map_children(|node| transform_down_mut(node, op)) -} - impl Default for TopKAggregation { fn default() -> Self { Self::new() diff --git a/datafusion/core/src/physical_optimizer/utils.rs b/datafusion/core/src/physical_optimizer/utils.rs index 530df374ca7c0..f8063e9694223 100644 --- a/datafusion/core/src/physical_optimizer/utils.rs +++ b/datafusion/core/src/physical_optimizer/utils.rs @@ -17,83 +17,18 @@ //! Collection of utility functions that are leveraged by the query optimizer rules -use std::fmt; -use std::fmt::Formatter; use std::sync::Arc; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; -use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::union::UnionExec; use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; -use crate::physical_plan::{displayable, ExecutionPlan}; +use crate::physical_plan::ExecutionPlan; use datafusion_physical_expr::{LexRequirementRef, PhysicalSortRequirement}; - -/// This object implements a tree that we use while keeping track of paths -/// leading to [`SortExec`]s. -#[derive(Debug, Clone)] -pub(crate) struct ExecTree { - /// The `ExecutionPlan` associated with this node - pub plan: Arc, - /// Child index of the plan in its parent - pub idx: usize, - /// Children of the plan that would need updating if we remove leaf executors - pub children: Vec, -} - -impl fmt::Display for ExecTree { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - let plan_string = get_plan_string(&self.plan); - write!(f, "\nidx: {:?}", self.idx)?; - write!(f, "\nplan: {:?}", plan_string)?; - for child in self.children.iter() { - write!(f, "\nexec_tree:{}", child)?; - } - writeln!(f) - } -} - -impl ExecTree { - /// Create new Exec tree - pub fn new( - plan: Arc, - idx: usize, - children: Vec, - ) -> Self { - ExecTree { - plan, - idx, - children, - } - } -} - -/// Get `ExecTree` for each child of the plan if they are tracked. -/// # Arguments -/// -/// * `n_children` - Children count of the plan of interest -/// * `onward` - Contains `Some(ExecTree)` of the plan tracked. -/// - Contains `None` is plan is not tracked. -/// -/// # Returns -/// -/// A `Vec>` that contains tracking information of each child. -/// If a child is `None`, it is not tracked. If `Some(ExecTree)` child is tracked also. -pub(crate) fn get_children_exectrees( - n_children: usize, - onward: &Option, -) -> Vec> { - let mut children_onward = vec![None; n_children]; - if let Some(exec_tree) = &onward { - for child in &exec_tree.children { - children_onward[child.idx] = Some(child.clone()); - } - } - children_onward -} +use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; /// This utility function adds a `SortExec` above an operator according to the /// given ordering requirements while preserving the original partitioning. @@ -154,10 +89,3 @@ pub fn is_union(plan: &Arc) -> bool { pub fn is_repartition(plan: &Arc) -> bool { plan.as_any().is::() } - -/// Utility function yielding a string representation of the given [`ExecutionPlan`]. -pub fn get_plan_string(plan: &Arc) -> Vec { - let formatted = displayable(plan.as_ref()).indent(true).to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - actual.iter().map(|elem| elem.to_string()).collect() -} diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index f941e88f3a36d..d696c55a8c131 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -27,7 +27,6 @@ use crate::datasource::file_format::csv::CsvFormat; use crate::datasource::file_format::json::JsonFormat; #[cfg(feature = "parquet")] use crate::datasource::file_format::parquet::ParquetFormat; -use crate::datasource::file_format::write::FileWriterMode; use crate::datasource::file_format::FileFormat; use crate::datasource::listing::ListingTableUrl; use crate::datasource::physical_plan::FileSinkConfig; @@ -64,12 +63,10 @@ use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::union::UnionExec; use crate::physical_plan::unnest::UnnestExec; use crate::physical_plan::values::ValuesExec; -use crate::physical_plan::windows::{ - BoundedWindowAggExec, PartitionSearchMode, WindowAggExec, -}; +use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use crate::physical_plan::{ - aggregates, displayable, udaf, windows, AggregateExpr, ExecutionPlan, Partitioning, - PhysicalExpr, WindowExpr, + aggregates, displayable, udaf, windows, AggregateExpr, ExecutionPlan, InputOrderMode, + Partitioning, PhysicalExpr, WindowExpr, }; use arrow::compute::SortOptions; @@ -83,16 +80,18 @@ use datafusion_common::{ }; use datafusion_expr::dml::{CopyOptions, CopyTo}; use datafusion_expr::expr::{ - self, AggregateFunction, AggregateUDF, Alias, Between, BinaryExpr, Cast, - GetFieldAccess, GetIndexedField, GroupingSet, InList, Like, ScalarUDF, TryCast, + self, AggregateFunction, AggregateFunctionDefinition, Alias, Between, BinaryExpr, + Cast, GetFieldAccess, GetIndexedField, GroupingSet, InList, Like, TryCast, WindowFunction, }; -use datafusion_expr::expr_rewriter::{unalias, unnormalize_cols}; +use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::{ - DescribeTable, DmlStatement, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, + DescribeTable, DmlStatement, ScalarFunctionDefinition, StringifiedPlan, WindowFrame, + WindowFrameBound, WriteOp, }; use datafusion_physical_expr::expressions::Literal; +use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_sql::utils::window_expr_common_partition_keys; use async_trait::async_trait; @@ -218,40 +217,49 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { Ok(name) } - Expr::ScalarFunction(func) => { - create_function_physical_name(&func.fun.to_string(), false, &func.args) - } - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - create_function_physical_name(&fun.name, false, args) + Expr::ScalarFunction(fun) => { + // function should be resolved during `AnalyzerRule`s + if let ScalarFunctionDefinition::Name(_) = fun.func_def { + return internal_err!("Function `Expr` with name should be resolved."); + } + + create_function_physical_name(fun.name(), false, &fun.args) } Expr::WindowFunction(WindowFunction { fun, args, .. }) => { create_function_physical_name(&fun.to_string(), false, args) } Expr::AggregateFunction(AggregateFunction { - fun, + func_def, distinct, args, - .. - }) => create_function_physical_name(&fun.to_string(), *distinct, args), - Expr::AggregateUDF(AggregateUDF { - fun, - args, filter, order_by, - }) => { - // TODO: Add support for filter and order by in AggregateUDF - if filter.is_some() { - return exec_err!("aggregate expression with filter is not supported"); + }) => match func_def { + AggregateFunctionDefinition::BuiltIn(..) => { + create_function_physical_name(func_def.name(), *distinct, args) } - if order_by.is_some() { - return exec_err!("aggregate expression with order_by is not supported"); + AggregateFunctionDefinition::UDF(fun) => { + // TODO: Add support for filter and order by in AggregateUDF + if filter.is_some() { + return exec_err!( + "aggregate expression with filter is not supported" + ); + } + if order_by.is_some() { + return exec_err!( + "aggregate expression with order_by is not supported" + ); + } + let names = args + .iter() + .map(|e| create_physical_name(e, false)) + .collect::>>()?; + Ok(format!("{}({})", fun.name(), names.join(","))) } - let mut names = Vec::with_capacity(args.len()); - for e in args { - names.push(create_physical_name(e, false)?); + AggregateFunctionDefinition::Name(_) => { + internal_err!("Aggregate function `Expr` with name should be resolved.") } - Ok(format!("{}({})", fun.name, names.join(","))) - } + }, Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => Ok(format!( "ROLLUP ({})", @@ -364,9 +372,8 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { Expr::Sort { .. } => { internal_err!("Create physical name does not support sort expression") } - Expr::Wildcard => internal_err!("Create physical name does not support wildcard"), - Expr::QualifiedWildcard { .. } => { - internal_err!("Create physical name does not support qualified wildcard") + Expr::Wildcard { .. } => { + internal_err!("Create physical name does not support wildcard") } Expr::Placeholder(_) => { internal_err!("Create physical name does not support placeholder") @@ -554,8 +561,7 @@ impl DefaultPhysicalPlanner { // doesn't know (nor should care) how the relation was // referred to in the query let filters = unnormalize_cols(filters.iter().cloned()); - let unaliased: Vec = filters.into_iter().map(unalias).collect(); - source.scan(session_state, projection.as_ref(), &unaliased, *fetch).await + source.scan(session_state, projection.as_ref(), &filters, *fetch).await } LogicalPlan::Copy(CopyTo{ input, @@ -565,11 +571,7 @@ impl DefaultPhysicalPlanner { copy_options, }) => { let input_exec = self.create_initial_plan(input, session_state).await?; - - // TODO: make this behavior configurable via options (should copy to create path/file as needed?) - // TODO: add additional configurable options for if existing files should be overwritten or - // appended to - let parsed_url = ListingTableUrl::parse_create_local_if_not_exists(output_url, !*single_file_output)?; + let parsed_url = ListingTableUrl::parse(output_url)?; let object_store_url = parsed_url.object_store(); let schema: Schema = (**input.schema()).clone().into(); @@ -591,8 +593,6 @@ impl DefaultPhysicalPlanner { file_groups: vec![], output_schema: Arc::new(schema), table_partition_cols: vec![], - unbounded_input: false, - writer_mode: FileWriterMode::PutMultipart, single_file_output: *single_file_output, overwrite: false, file_type_writer_options @@ -755,7 +755,7 @@ impl DefaultPhysicalPlanner { window_expr, input_exec, physical_partition_keys, - PartitionSearchMode::Sorted, + InputOrderMode::Sorted, )?) } else { Arc::new(WindowAggExec::try_new( @@ -794,14 +794,13 @@ impl DefaultPhysicalPlanner { }) .collect::>>()?; - let (aggregates, filters, order_bys) : (Vec<_>, Vec<_>, Vec<_>) = multiunzip(agg_filter); + let (aggregates, filters, _order_bys) : (Vec<_>, Vec<_>, Vec<_>) = multiunzip(agg_filter); let initial_aggr = Arc::new(AggregateExec::try_new( AggregateMode::Partial, groups.clone(), aggregates.clone(), filters.clone(), - order_bys, input_exec, physical_input_schema.clone(), )?); @@ -819,18 +818,14 @@ impl DefaultPhysicalPlanner { // To reflect such changes to subsequent stages, use the updated // `AggregateExpr`/`PhysicalSortExpr` objects. let updated_aggregates = initial_aggr.aggr_expr().to_vec(); - let updated_order_bys = initial_aggr.order_by_expr().to_vec(); - let (initial_aggr, next_partition_mode): ( - Arc, - AggregateMode, - ) = if can_repartition { + let next_partition_mode = if can_repartition { // construct a second aggregation with 'AggregateMode::FinalPartitioned' - (initial_aggr, AggregateMode::FinalPartitioned) + AggregateMode::FinalPartitioned } else { // construct a second aggregation, keeping the final column name equal to the // first aggregation and the expressions corresponding to the respective aggregate - (initial_aggr, AggregateMode::Final) + AggregateMode::Final }; let final_grouping_set = PhysicalGroupBy::new_single( @@ -846,7 +841,6 @@ impl DefaultPhysicalPlanner { final_grouping_set, updated_aggregates, filters, - updated_order_bys, initial_aggr, physical_input_schema.clone(), )?)) @@ -914,19 +908,14 @@ impl DefaultPhysicalPlanner { &input_schema, session_state, )?; - Ok(Arc::new(FilterExec::try_new(runtime_expr, physical_input)?)) + let selectivity = session_state.config().options().optimizer.default_filter_selectivity; + let filter = FilterExec::try_new(runtime_expr, physical_input)?; + Ok(Arc::new(filter.with_default_selectivity(selectivity)?)) } - LogicalPlan::Union(Union { inputs, schema }) => { + LogicalPlan::Union(Union { inputs, .. }) => { let physical_plans = self.create_initial_plan_multi(inputs.iter().map(|lp| lp.as_ref()), session_state).await?; - if schema.fields().len() < physical_plans[0].schema().fields().len() { - // `schema` could be a subset of the child schema. For example - // for query "select count(*) from (select a from t union all select a from t)" - // `schema` is empty but child schema contains one field `a`. - Ok(Arc::new(UnionExec::try_new_with_schema(physical_plans, schema.clone())?)) - } else { - Ok(Arc::new(UnionExec::new(physical_plans))) - } + Ok(Arc::new(UnionExec::new(physical_plans))) } LogicalPlan::Repartition(Repartition { input, @@ -1200,10 +1189,15 @@ impl DefaultPhysicalPlanner { } LogicalPlan::Subquery(_) => todo!(), LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row, + produce_one_row: false, schema, }) => Ok(Arc::new(EmptyExec::new( - *produce_one_row, + SchemaRef::new(schema.as_ref().to_owned().into()), + ))), + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: true, + schema, + }) => Ok(Arc::new(PlaceholderRowExec::new( SchemaRef::new(schema.as_ref().to_owned().into()), ))), LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => { @@ -1709,7 +1703,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( ) -> Result { match e { Expr::AggregateFunction(AggregateFunction { - fun, + func_def, distinct, args, filter, @@ -1750,63 +1744,35 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( ), None => None, }; - let ordering_reqs = order_by.clone().unwrap_or(vec![]); - let agg_expr = aggregates::create_aggregate_expr( - fun, - *distinct, - &args, - &ordering_reqs, - physical_input_schema, - name, - )?; - Ok((agg_expr, filter, order_by)) - } - Expr::AggregateUDF(AggregateUDF { - fun, - args, - filter, - order_by, - }) => { - let args = args - .iter() - .map(|e| { - create_physical_expr( - e, - logical_input_schema, + let (agg_expr, filter, order_by) = match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + let ordering_reqs = order_by.clone().unwrap_or(vec![]); + let agg_expr = aggregates::create_aggregate_expr( + fun, + *distinct, + &args, + &ordering_reqs, physical_input_schema, - execution_props, + name, + )?; + (agg_expr, filter, order_by) + } + AggregateFunctionDefinition::UDF(fun) => { + let agg_expr = udaf::create_aggregate_expr( + fun, + &args, + physical_input_schema, + name, + ); + (agg_expr?, filter, order_by) + } + AggregateFunctionDefinition::Name(_) => { + return internal_err!( + "Aggregate function name should have been resolved" ) - }) - .collect::>>()?; - - let filter = match filter { - Some(e) => Some(create_physical_expr( - e, - logical_input_schema, - physical_input_schema, - execution_props, - )?), - None => None, - }; - let order_by = match order_by { - Some(e) => Some( - e.iter() - .map(|expr| { - create_physical_sort_expr( - expr, - logical_input_schema, - physical_input_schema, - execution_props, - ) - }) - .collect::>>()?, - ), - None => None, + } }; - - let agg_expr = - udaf::create_aggregate_expr(fun, &args, physical_input_schema, name); - Ok((agg_expr?, filter, order_by)) + Ok((agg_expr, filter, order_by)) } other => internal_err!("Invalid aggregate expression '{other:?}'"), } @@ -1894,13 +1860,26 @@ impl DefaultPhysicalPlanner { .await { Ok(input) => { + // This plan will includes statistics if show_statistics is on stringified_plans.push( displayable(input.as_ref()) .set_show_statistics(config.show_statistics) .to_stringified(e.verbose, InitialPhysicalPlan), ); - match self.optimize_internal( + // If the show_statisitcs is off, add another line to show statsitics in the case of explain verbose + if e.verbose && !config.show_statistics { + stringified_plans.push( + displayable(input.as_ref()) + .set_show_statistics(true) + .to_stringified( + e.verbose, + InitialPhysicalPlanWithStats, + ), + ); + } + + let optimized_plan = self.optimize_internal( input, session_state, |plan, optimizer| { @@ -1912,12 +1891,28 @@ impl DefaultPhysicalPlanner { .to_stringified(e.verbose, plan_type), ); }, - ) { - Ok(input) => stringified_plans.push( - displayable(input.as_ref()) - .set_show_statistics(config.show_statistics) - .to_stringified(e.verbose, FinalPhysicalPlan), - ), + ); + match optimized_plan { + Ok(input) => { + // This plan will includes statistics if show_statistics is on + stringified_plans.push( + displayable(input.as_ref()) + .set_show_statistics(config.show_statistics) + .to_stringified(e.verbose, FinalPhysicalPlan), + ); + + // If the show_statisitcs is off, add another line to show statsitics in the case of explain verbose + if e.verbose && !config.show_statistics { + stringified_plans.push( + displayable(input.as_ref()) + .set_show_statistics(true) + .to_stringified( + e.verbose, + FinalPhysicalPlanWithStats, + ), + ); + } + } Err(DataFusionError::Context(optimizer_name, e)) => { let plan_type = OptimizedPhysicalPlan { optimizer_name }; stringified_plans @@ -2016,7 +2011,7 @@ impl DefaultPhysicalPlanner { let mut column_names = StringBuilder::new(); let mut data_types = StringBuilder::new(); let mut is_nullables = StringBuilder::new(); - for (_, field) in table_schema.fields().iter().enumerate() { + for field in table_schema.fields() { column_names.append_value(field.name()); // "System supplied type" --> Use debug format of the datatype @@ -2516,6 +2511,27 @@ mod tests { Ok(()) } + #[tokio::test] + async fn aggregate_with_alias() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::UInt32, false), + ])); + + let logical_plan = scan_empty(None, schema.as_ref(), None)? + .aggregate(vec![col("c1")], vec![sum(col("c2"))])? + .project(vec![col("c1"), sum(col("c2")).alias("total_salary")])? + .build()?; + + let physical_plan = plan(&logical_plan).await?; + assert_eq!("c1", physical_plan.schema().field(0).name().as_str()); + assert_eq!( + "total_salary", + physical_plan.schema().field(1).name().as_str() + ); + Ok(()) + } + #[tokio::test] async fn test_explain() { let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); @@ -2750,7 +2766,7 @@ mod tests { digraph { 1[shape=box label="ProjectionExec: expr=[id@0 + 2 as employee.id + Int32(2)]", tooltip=""] - 2[shape=box label="EmptyExec: produce_one_row=false", tooltip=""] + 2[shape=box label="EmptyExec", tooltip=""] 1 -> 2 [arrowhead=none, arrowtail=normal, dir=back] } // End DataFusion GraphViz Plan diff --git a/datafusion/core/src/test/mod.rs b/datafusion/core/src/test/mod.rs index aad5c19044ea9..ed5aa15e291b5 100644 --- a/datafusion/core/src/test/mod.rs +++ b/datafusion/core/src/test/mod.rs @@ -43,6 +43,7 @@ use arrow::record_batch::RecordBatch; use datafusion_common::{DataFusionError, FileType, Statistics}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::{Partitioning, PhysicalSortExpr}; +use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; use datafusion_physical_plan::{DisplayAs, DisplayFormatType}; #[cfg(feature = "compression")] @@ -203,7 +204,6 @@ pub fn partitioned_csv_config( limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }) } @@ -277,7 +277,6 @@ fn make_decimal() -> RecordBatch { pub fn csv_exec_sorted( schema: &SchemaRef, sort_exprs: impl IntoIterator, - infinite_source: bool, ) -> Arc { let sort_exprs = sort_exprs.into_iter().collect(); @@ -291,7 +290,6 @@ pub fn csv_exec_sorted( limit: None, table_partition_cols: vec![], output_ordering: vec![sort_exprs], - infinite_source, }, false, 0, @@ -301,6 +299,67 @@ pub fn csv_exec_sorted( )) } +// construct a stream partition for test purposes +pub(crate) struct TestStreamPartition { + pub schema: SchemaRef, +} + +impl PartitionStream for TestStreamPartition { + fn schema(&self) -> &SchemaRef { + &self.schema + } + fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { + unreachable!() + } +} + +/// Create an unbounded stream exec +pub fn stream_exec_ordered( + schema: &SchemaRef, + sort_exprs: impl IntoIterator, +) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + + Arc::new( + StreamingTableExec::try_new( + schema.clone(), + vec![Arc::new(TestStreamPartition { + schema: schema.clone(), + }) as _], + None, + vec![sort_exprs], + true, + ) + .unwrap(), + ) +} + +/// Create a csv exec for tests +pub fn csv_exec_ordered( + schema: &SchemaRef, + sort_exprs: impl IntoIterator, +) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + + Arc::new(CsvExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), + file_schema: schema.clone(), + file_groups: vec![vec![PartitionedFile::new("file_path".to_string(), 100)]], + statistics: Statistics::new_unknown(schema), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![sort_exprs], + }, + true, + 0, + b'"', + None, + FileCompressionType::UNCOMPRESSED, + )) +} + /// A mock execution plan that simply returns the provided statistics #[derive(Debug, Clone)] pub struct StatisticsExec { diff --git a/datafusion/core/src/test/object_store.rs b/datafusion/core/src/test/object_store.rs index 08cebb56cc772..d6f324a7f1f95 100644 --- a/datafusion/core/src/test/object_store.rs +++ b/datafusion/core/src/test/object_store.rs @@ -61,5 +61,6 @@ pub fn local_unpartitioned_file(path: impl AsRef) -> ObjectMeta last_modified: metadata.modified().map(chrono::DateTime::from).unwrap(), size: metadata.len() as usize, e_tag: None, + version: None, } } diff --git a/datafusion/core/src/test/variable.rs b/datafusion/core/src/test/variable.rs index a55513841561f..38207b42cb7b8 100644 --- a/datafusion/core/src/test/variable.rs +++ b/datafusion/core/src/test/variable.rs @@ -37,7 +37,7 @@ impl VarProvider for SystemVar { /// get system variable value fn get_value(&self, var_names: Vec) -> Result { let s = format!("{}-{}", "system-var", var_names.concat()); - Ok(ScalarValue::Utf8(Some(s))) + Ok(ScalarValue::from(s)) } fn get_type(&self, _: &[String]) -> Option { @@ -61,7 +61,7 @@ impl VarProvider for UserDefinedVar { fn get_value(&self, var_names: Vec) -> Result { if var_names[0] != "@integer" { let s = format!("{}-{}", "user-defined-var", var_names.concat()); - Ok(ScalarValue::Utf8(Some(s))) + Ok(ScalarValue::from(s)) } else { Ok(ScalarValue::Int32(Some(41))) } diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index c6b43de0c18d5..282b0f7079ee2 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -36,7 +36,6 @@ use crate::datasource::provider::TableProviderFactory; use crate::datasource::{empty::EmptyTable, provider_as_source, TableProvider}; use crate::error::Result; use crate::execution::context::{SessionState, TaskContext}; -use crate::execution::options::ReadOptions; use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE}; use crate::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, @@ -58,6 +57,7 @@ use futures::Stream; pub use datafusion_common::test_util::parquet_test_data; pub use datafusion_common::test_util::{arrow_test_data, get_data_dir}; +use crate::datasource::stream::{StreamConfig, StreamTable}; pub use datafusion_common::{assert_batches_eq, assert_batches_sorted_eq}; /// Scan an empty data source, mainly used in tests @@ -342,30 +342,17 @@ impl RecordBatchStream for UnboundedStream { } /// This function creates an unbounded sorted file for testing purposes. -pub async fn register_unbounded_file_with_ordering( +pub fn register_unbounded_file_with_ordering( ctx: &SessionContext, schema: SchemaRef, file_path: &Path, table_name: &str, file_sort_order: Vec>, - with_unbounded_execution: bool, ) -> Result<()> { - // Mark infinite and provide schema: - let fifo_options = CsvReadOptions::new() - .schema(schema.as_ref()) - .mark_infinite(with_unbounded_execution); - // Get listing options: - let options_sort = fifo_options - .to_listing_options(&ctx.copied_config()) - .with_file_sort_order(file_sort_order); + let config = + StreamConfig::new_file(schema, file_path.into()).with_order(file_sort_order); + // Register table: - ctx.register_listing_table( - table_name, - file_path.as_os_str().to_str().unwrap(), - options_sort, - Some(schema), - None, - ) - .await?; + ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config))))?; Ok(()) } diff --git a/datafusion/core/src/test_util/parquet.rs b/datafusion/core/src/test_util/parquet.rs index 0d11526703b46..336a6804637ae 100644 --- a/datafusion/core/src/test_util/parquet.rs +++ b/datafusion/core/src/test_util/parquet.rs @@ -113,6 +113,7 @@ impl TestParquetFile { last_modified: Default::default(), size, e_tag: None, + version: None, }; Ok(Self { @@ -155,7 +156,6 @@ impl TestParquetFile { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }; let df_schema = self.schema.clone().to_dfschema_ref()?; diff --git a/datafusion/core/tests/custom_sources.rs b/datafusion/core/tests/custom_sources.rs index daf1ef41a297a..a9ea5cc2a35c8 100644 --- a/datafusion/core/tests/custom_sources.rs +++ b/datafusion/core/tests/custom_sources.rs @@ -30,7 +30,6 @@ use datafusion::execution::context::{SessionContext, SessionState, TaskContext}; use datafusion::logical_expr::{ col, Expr, LogicalPlan, LogicalPlanBuilder, TableScan, UNNAMED_TABLE, }; -use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::{ collect, ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, @@ -42,6 +41,7 @@ use datafusion_common::project_schema; use datafusion_common::stats::Precision; use async_trait::async_trait; +use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use futures::stream::Stream; /// Also run all tests that are found in the `custom_sources_cases` directory @@ -256,9 +256,9 @@ async fn optimizers_catch_all_statistics() { let physical_plan = df.create_physical_plan().await.unwrap(); - // when the optimization kicks in, the source is replaced by an EmptyExec + // when the optimization kicks in, the source is replaced by an PlaceholderRowExec assert!( - contains_empty_exec(Arc::clone(&physical_plan)), + contains_place_holder_exec(Arc::clone(&physical_plan)), "Expected aggregate_statistics optimizations missing: {physical_plan:?}" ); @@ -283,12 +283,12 @@ async fn optimizers_catch_all_statistics() { assert_eq!(format!("{:?}", actual[0]), format!("{expected:?}")); } -fn contains_empty_exec(plan: Arc) -> bool { - if plan.as_any().is::() { +fn contains_place_holder_exec(plan: Arc) -> bool { + if plan.as_any().is::() { true } else if plan.children().len() != 1 { false } else { - contains_empty_exec(Arc::clone(&plan.children()[0])) + contains_place_holder_exec(Arc::clone(&plan.children()[0])) } } diff --git a/datafusion/core/tests/data/aggregate_agg_multi_order.csv b/datafusion/core/tests/data/aggregate_agg_multi_order.csv new file mode 100644 index 0000000000000..e9a65ceee4aab --- /dev/null +++ b/datafusion/core/tests/data/aggregate_agg_multi_order.csv @@ -0,0 +1,11 @@ +c1,c2,c3 +1,20,0 +2,20,1 +3,10,2 +4,10,3 +5,30,4 +6,30,5 +7,30,6 +8,30,7 +9,30,8 +10,10,9 \ No newline at end of file diff --git a/datafusion/core/tests/data/empty.json b/datafusion/core/tests/data/empty.json new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/datafusion/core/tests/data/escape.csv b/datafusion/core/tests/data/escape.csv new file mode 100644 index 0000000000000..331a1e697329f --- /dev/null +++ b/datafusion/core/tests/data/escape.csv @@ -0,0 +1,11 @@ +c1,c2 +"id0","value\"0" +"id1","value\"1" +"id2","value\"2" +"id3","value\"3" +"id4","value\"4" +"id5","value\"5" +"id6","value\"6" +"id7","value\"7" +"id8","value\"8" +"id9","value\"9" diff --git a/datafusion/core/tests/data/quote.csv b/datafusion/core/tests/data/quote.csv new file mode 100644 index 0000000000000..d814884364095 --- /dev/null +++ b/datafusion/core/tests/data/quote.csv @@ -0,0 +1,11 @@ +c1,c2 +~id0~,~value0~ +~id1~,~value1~ +~id2~,~value2~ +~id3~,~value3~ +~id4~,~value4~ +~id5~,~value5~ +~id6~,~value6~ +~id7~,~value7~ +~id8~,~value8~ +~id9~,~value9~ diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index 9677003ec226f..fe56fc22ea8cc 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -31,6 +31,7 @@ use datafusion::prelude::*; use datafusion::execution::context::SessionContext; use datafusion::assert_batches_eq; +use datafusion_expr::expr::Alias; use datafusion_expr::{approx_median, cast}; async fn create_test_table() -> Result { @@ -186,6 +187,25 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { assert_batches_eq!(expected, &batches); + // the arg2 parameter is a complex expr, but it can be evaluated to the literal value + let alias_expr = Expr::Alias(Alias::new( + cast(lit(0.5), DataType::Float32), + None::<&str>, + "arg_2".to_string(), + )); + let expr = approx_percentile_cont(col("b"), alias_expr); + let df = create_test_table().await?; + let expected = [ + "+--------------------------------------+", + "| APPROX_PERCENTILE_CONT(test.b,arg_2) |", + "+--------------------------------------+", + "| 10 |", + "+--------------------------------------+", + ]; + let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; + + assert_batches_eq!(expected, &batches); + Ok(()) } diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 845d77581b59c..cca23ac6847c6 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -39,14 +39,13 @@ use datafusion::prelude::JoinType; use datafusion::prelude::{CsvReadOptions, ParquetReadOptions}; use datafusion::test_util::parquet_test_data; use datafusion::{assert_batches_eq, assert_batches_sorted_eq}; -use datafusion_common::{DataFusionError, ScalarValue, UnnestOptions}; +use datafusion_common::{assert_contains, DataFusionError, ScalarValue, UnnestOptions}; use datafusion_execution::config::SessionConfig; use datafusion_expr::expr::{GroupingSet, Sort}; -use datafusion_expr::Expr::Wildcard; use datafusion_expr::{ array_agg, avg, col, count, exists, expr, in_subquery, lit, max, out_ref_col, - scalar_subquery, sum, AggregateFunction, Expr, ExprSchemable, WindowFrame, - WindowFrameBound, WindowFrameUnits, WindowFunction, + scalar_subquery, sum, wildcard, AggregateFunction, Expr, ExprSchemable, WindowFrame, + WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; use datafusion_physical_expr::var_provider::{VarProvider, VarType}; @@ -64,8 +63,8 @@ async fn test_count_wildcard_on_sort() -> Result<()> { let df_results = ctx .table("t1") .await? - .aggregate(vec![col("b")], vec![count(Wildcard)])? - .sort(vec![count(Wildcard).sort(true, false)])? + .aggregate(vec![col("b")], vec![count(wildcard())])? + .sort(vec![count(wildcard()).sort(true, false)])? .explain(false, false)? .collect() .await?; @@ -99,8 +98,8 @@ async fn test_count_wildcard_on_where_in() -> Result<()> { Arc::new( ctx.table("t2") .await? - .aggregate(vec![], vec![count(Expr::Wildcard)])? - .select(vec![count(Expr::Wildcard)])? + .aggregate(vec![], vec![count(wildcard())])? + .select(vec![count(wildcard())])? .into_unoptimized_plan(), // Usually, into_optimized_plan() should be used here, but due to // https://github.com/apache/arrow-datafusion/issues/5771, @@ -136,8 +135,8 @@ async fn test_count_wildcard_on_where_exist() -> Result<()> { .filter(exists(Arc::new( ctx.table("t2") .await? - .aggregate(vec![], vec![count(Expr::Wildcard)])? - .select(vec![count(Expr::Wildcard)])? + .aggregate(vec![], vec![count(wildcard())])? + .select(vec![count(wildcard())])? .into_unoptimized_plan(), // Usually, into_optimized_plan() should be used here, but due to // https://github.com/apache/arrow-datafusion/issues/5771, @@ -171,8 +170,8 @@ async fn test_count_wildcard_on_window() -> Result<()> { .table("t1") .await? .select(vec![Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Count), - vec![Expr::Wildcard], + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + vec![wildcard()], vec![], vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], WindowFrame { @@ -202,17 +201,17 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { let sql_results = ctx .sql("select count(*) from t1") .await? - .select(vec![count(Expr::Wildcard)])? + .select(vec![count(wildcard())])? .explain(false, false)? .collect() .await?; - // add `.select(vec![count(Expr::Wildcard)])?` to make sure we can analyze all node instead of just top node. + // add `.select(vec![count(wildcard())])?` to make sure we can analyze all node instead of just top node. let df_results = ctx .table("t1") .await? - .aggregate(vec![], vec![count(Expr::Wildcard)])? - .select(vec![count(Expr::Wildcard)])? + .aggregate(vec![], vec![count(wildcard())])? + .select(vec![count(wildcard())])? .explain(false, false)? .collect() .await?; @@ -248,8 +247,8 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { ctx.table("t2") .await? .filter(out_ref_col(DataType::UInt32, "t1.a").eq(col("t2.a")))? - .aggregate(vec![], vec![count(Wildcard)])? - .select(vec![col(count(Wildcard).to_string())])? + .aggregate(vec![], vec![count(wildcard())])? + .select(vec![col(count(wildcard()).to_string())])? .into_unoptimized_plan(), )) .gt(lit(ScalarValue::UInt8(Some(0)))), @@ -1324,6 +1323,113 @@ async fn unnest_array_agg() -> Result<()> { Ok(()) } +#[tokio::test] +async fn unnest_with_redundant_columns() -> Result<()> { + let mut shape_id_builder = UInt32Builder::new(); + let mut tag_id_builder = UInt32Builder::new(); + + for shape_id in 1..=3 { + for tag_id in 1..=3 { + shape_id_builder.append_value(shape_id as u32); + tag_id_builder.append_value((shape_id * 10 + tag_id) as u32); + } + } + + let batch = RecordBatch::try_from_iter(vec![ + ("shape_id", Arc::new(shape_id_builder.finish()) as ArrayRef), + ("tag_id", Arc::new(tag_id_builder.finish()) as ArrayRef), + ])?; + + let ctx = SessionContext::new(); + ctx.register_batch("shapes", batch)?; + let df = ctx.table("shapes").await?; + + let results = df.clone().collect().await?; + let expected = vec![ + "+----------+--------+", + "| shape_id | tag_id |", + "+----------+--------+", + "| 1 | 11 |", + "| 1 | 12 |", + "| 1 | 13 |", + "| 2 | 21 |", + "| 2 | 22 |", + "| 2 | 23 |", + "| 3 | 31 |", + "| 3 | 32 |", + "| 3 | 33 |", + "+----------+--------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + // Doing an `array_agg` by `shape_id` produces: + let df = df + .clone() + .aggregate( + vec![col("shape_id")], + vec![array_agg(col("shape_id")).alias("shape_id2")], + )? + .unnest_column("shape_id2")? + .select(vec![col("shape_id")])?; + + let optimized_plan = df.clone().into_optimized_plan()?; + let expected = vec![ + "Projection: shapes.shape_id [shape_id:UInt32]", + " Unnest: shape_id2 [shape_id:UInt32, shape_id2:UInt32;N]", + " Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]", + " TableScan: shapes projection=[shape_id] [shape_id:UInt32]", + ]; + + let formatted = optimized_plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let results = df.collect().await?; + let expected = [ + "+----------+", + "| shape_id |", + "+----------+", + "| 1 |", + "| 1 |", + "| 1 |", + "| 2 |", + "| 2 |", + "| 2 |", + "| 3 |", + "| 3 |", + "| 3 |", + "+----------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn unnest_analyze_metrics() -> Result<()> { + const NUM_ROWS: usize = 5; + + let df = table_with_nested_types(NUM_ROWS).await?; + let results = df + .unnest_column("tags")? + .explain(false, true)? + .collect() + .await?; + let formatted = arrow::util::pretty::pretty_format_batches(&results) + .unwrap() + .to_string(); + assert_contains!(&formatted, "elapsed_compute="); + assert_contains!(&formatted, "input_batches=1"); + assert_contains!(&formatted, "input_rows=5"); + assert_contains!(&formatted, "output_rows=10"); + assert_contains!(&formatted, "output_batches=1"); + + Ok(()) +} + async fn create_test_table(name: &str) -> Result { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, false), diff --git a/datafusion/core/tests/fifo.rs b/datafusion/core/tests/fifo.rs index 7d9ea97f7b5bc..93c7f7368065c 100644 --- a/datafusion/core/tests/fifo.rs +++ b/datafusion/core/tests/fifo.rs @@ -17,42 +17,48 @@ //! This test demonstrates the DataFusion FIFO capabilities. //! -#[cfg(not(target_os = "windows"))] +#[cfg(target_family = "unix")] #[cfg(test)] mod unix_test { - use arrow::array::Array; - use arrow::csv::ReaderBuilder; - use arrow::datatypes::{DataType, Field, Schema}; - use datafusion::test_util::register_unbounded_file_with_ordering; - use datafusion::{ - prelude::{CsvReadOptions, SessionConfig, SessionContext}, - test_util::{aggr_test_schema, arrow_test_data}, - }; - use datafusion_common::{exec_err, DataFusionError, Result}; - use futures::StreamExt; - use itertools::enumerate; - use nix::sys::stat; - use nix::unistd; - use rstest::*; use std::fs::{File, OpenOptions}; use std::io::Write; use std::path::PathBuf; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::thread; - use std::thread::JoinHandle; use std::time::{Duration, Instant}; + + use arrow::array::Array; + use arrow::csv::ReaderBuilder; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::SchemaRef; + use futures::StreamExt; + use nix::sys::stat; + use nix::unistd; use tempfile::TempDir; + use tokio::task::{spawn_blocking, JoinHandle}; - // ! For the sake of the test, do not alter the numbers. ! - // Session batch size - const TEST_BATCH_SIZE: usize = 20; - // Number of lines written to FIFO - const TEST_DATA_SIZE: usize = 20_000; - // Number of lines what can be joined. Each joinable key produced 20 lines with - // aggregate_test_100 dataset. We will use these joinable keys for understanding - // incremental execution. - const TEST_JOIN_RATIO: f64 = 0.01; + use datafusion::datasource::stream::{StreamConfig, StreamTable}; + use datafusion::datasource::TableProvider; + use datafusion::{ + prelude::{CsvReadOptions, SessionConfig, SessionContext}, + test_util::{aggr_test_schema, arrow_test_data}, + }; + use datafusion_common::{exec_err, DataFusionError, Result}; + use datafusion_expr::Expr; + + /// Makes a TableProvider for a fifo file + fn fifo_table( + schema: SchemaRef, + path: impl Into, + sort: Vec>, + ) -> Arc { + let config = StreamConfig::new_file(schema, path.into()) + .with_order(sort) + .with_batch_size(TEST_BATCH_SIZE) + .with_header(true); + Arc::new(StreamTable::new(Arc::new(config))) + } fn create_fifo_file(tmp_dir: &TempDir, file_name: &str) -> Result { let file_path = tmp_dir.path().join(file_name); @@ -86,14 +92,46 @@ mod unix_test { Ok(()) } + fn create_writing_thread( + file_path: PathBuf, + header: String, + lines: Vec, + waiting_lock: Arc, + wait_until: usize, + ) -> JoinHandle<()> { + // Timeout for a long period of BrokenPipe error + let broken_pipe_timeout = Duration::from_secs(10); + let sa = file_path.clone(); + // Spawn a new thread to write to the FIFO file + spawn_blocking(move || { + let file = OpenOptions::new().write(true).open(sa).unwrap(); + // Reference time to use when deciding to fail the test + let execution_start = Instant::now(); + write_to_fifo(&file, &header, execution_start, broken_pipe_timeout).unwrap(); + for (cnt, line) in lines.iter().enumerate() { + while waiting_lock.load(Ordering::SeqCst) && cnt > wait_until { + thread::sleep(Duration::from_millis(50)); + } + write_to_fifo(&file, line, execution_start, broken_pipe_timeout).unwrap(); + } + drop(file); + }) + } + + // ! For the sake of the test, do not alter the numbers. ! + // Session batch size + const TEST_BATCH_SIZE: usize = 20; + // Number of lines written to FIFO + const TEST_DATA_SIZE: usize = 20_000; + // Number of lines what can be joined. Each joinable key produced 20 lines with + // aggregate_test_100 dataset. We will use these joinable keys for understanding + // incremental execution. + const TEST_JOIN_RATIO: f64 = 0.01; + // This test provides a relatively realistic end-to-end scenario where // we swap join sides to accommodate a FIFO source. - #[rstest] - #[timeout(std::time::Duration::from_secs(30))] #[tokio::test(flavor = "multi_thread", worker_threads = 8)] - async fn unbounded_file_with_swapped_join( - #[values(true, false)] unbounded_file: bool, - ) -> Result<()> { + async fn unbounded_file_with_swapped_join() -> Result<()> { // Create session context let config = SessionConfig::new() .with_batch_size(TEST_BATCH_SIZE) @@ -101,11 +139,10 @@ mod unix_test { .with_target_partitions(1); let ctx = SessionContext::new_with_config(config); // To make unbounded deterministic - let waiting = Arc::new(AtomicBool::new(unbounded_file)); + let waiting = Arc::new(AtomicBool::new(true)); // Create a new temporary FIFO file let tmp_dir = TempDir::new()?; - let fifo_path = - create_fifo_file(&tmp_dir, &format!("fifo_{unbounded_file:?}.csv"))?; + let fifo_path = create_fifo_file(&tmp_dir, "fifo_unbounded.csv")?; // Execution can calculated at least one RecordBatch after the number of // "joinable_lines_length" lines are read. let joinable_lines_length = @@ -129,7 +166,7 @@ mod unix_test { "a1,a2\n".to_owned(), lines, waiting.clone(), - joinable_lines_length, + joinable_lines_length * 2, ); // Data Schema @@ -137,15 +174,10 @@ mod unix_test { Field::new("a1", DataType::Utf8, false), Field::new("a2", DataType::UInt32, false), ])); - // Create a file with bounded or unbounded flag. - ctx.register_csv( - "left", - fifo_path.as_os_str().to_str().unwrap(), - CsvReadOptions::new() - .schema(schema.as_ref()) - .mark_infinite(unbounded_file), - ) - .await?; + + let provider = fifo_table(schema, fifo_path, vec![]); + ctx.register_table("left", provider).unwrap(); + // Register right table let schema = aggr_test_schema(); let test_data = arrow_test_data(); @@ -161,7 +193,7 @@ mod unix_test { while (stream.next().await).is_some() { waiting.store(false, Ordering::SeqCst); } - task.join().unwrap(); + task.await.unwrap(); Ok(()) } @@ -172,39 +204,10 @@ mod unix_test { Equal, } - fn create_writing_thread( - file_path: PathBuf, - header: String, - lines: Vec, - waiting_lock: Arc, - wait_until: usize, - ) -> JoinHandle<()> { - // Timeout for a long period of BrokenPipe error - let broken_pipe_timeout = Duration::from_secs(10); - // Spawn a new thread to write to the FIFO file - thread::spawn(move || { - let file = OpenOptions::new().write(true).open(file_path).unwrap(); - // Reference time to use when deciding to fail the test - let execution_start = Instant::now(); - write_to_fifo(&file, &header, execution_start, broken_pipe_timeout).unwrap(); - for (cnt, line) in enumerate(lines) { - while waiting_lock.load(Ordering::SeqCst) && cnt > wait_until { - thread::sleep(Duration::from_millis(50)); - } - write_to_fifo(&file, &line, execution_start, broken_pipe_timeout) - .unwrap(); - } - drop(file); - }) - } - // This test provides a relatively realistic end-to-end scenario where // we change the join into a [SymmetricHashJoin] to accommodate two // unbounded (FIFO) sources. - #[rstest] - #[timeout(std::time::Duration::from_secs(30))] - #[tokio::test(flavor = "multi_thread")] - #[ignore] + #[tokio::test] async fn unbounded_file_with_symmetric_join() -> Result<()> { // Create session context let config = SessionConfig::new() @@ -254,47 +257,30 @@ mod unix_test { Field::new("a1", DataType::UInt32, false), Field::new("a2", DataType::UInt32, false), ])); + // Specify the ordering: - let file_sort_order = vec![[datafusion_expr::col("a1")] - .into_iter() - .map(|e| { - let ascending = true; - let nulls_first = false; - e.sort(ascending, nulls_first) - }) - .collect::>()]; + let order = vec![vec![datafusion_expr::col("a1").sort(true, false)]]; + // Set unbounded sorted files read configuration - register_unbounded_file_with_ordering( - &ctx, - schema.clone(), - &left_fifo, - "left", - file_sort_order.clone(), - true, - ) - .await?; - register_unbounded_file_with_ordering( - &ctx, - schema, - &right_fifo, - "right", - file_sort_order, - true, - ) - .await?; + let provider = fifo_table(schema.clone(), left_fifo, order.clone()); + ctx.register_table("left", provider)?; + + let provider = fifo_table(schema.clone(), right_fifo, order); + ctx.register_table("right", provider)?; + // Execute the query, with no matching rows. (since key is modulus 10) let df = ctx .sql( "SELECT - t1.a1, - t1.a2, - t2.a1, - t2.a2 - FROM - left as t1 FULL - JOIN right as t2 ON t1.a2 = t2.a2 - AND t1.a1 > t2.a1 + 4 - AND t1.a1 < t2.a1 + 9", + t1.a1, + t1.a2, + t2.a1, + t2.a2 + FROM + left as t1 FULL + JOIN right as t2 ON t1.a2 = t2.a2 + AND t1.a1 > t2.a1 + 4 + AND t1.a1 < t2.a1 + 9", ) .await?; let mut stream = df.execute_stream().await?; @@ -313,7 +299,8 @@ mod unix_test { }; operations.push(op); } - tasks.into_iter().for_each(|jh| jh.join().unwrap()); + futures::future::try_join_all(tasks).await.unwrap(); + // The SymmetricHashJoin executor produces FULL join results at every // pruning, which happens before it reaches the end of input and more // than once. In this test, we feed partially joinable data to both @@ -368,8 +355,9 @@ mod unix_test { // Prevent move let (sink_fifo_path_thread, sink_display_fifo_path) = (sink_fifo_path.clone(), sink_fifo_path.display()); + // Spawn a new thread to read sink EXTERNAL TABLE. - tasks.push(thread::spawn(move || { + tasks.push(spawn_blocking(move || { let file = File::open(sink_fifo_path_thread).unwrap(); let schema = Arc::new(Schema::new(vec![ Field::new("a1", DataType::Utf8, false), @@ -377,7 +365,6 @@ mod unix_test { ])); let mut reader = ReaderBuilder::new(schema) - .with_header(true) .with_batch_size(TEST_BATCH_SIZE) .build(file) .map_err(|e| DataFusionError::Internal(e.to_string())) @@ -389,38 +376,35 @@ mod unix_test { })); // register second csv file with the SQL (create an empty file if not found) ctx.sql(&format!( - "CREATE EXTERNAL TABLE source_table ( + "CREATE UNBOUNDED EXTERNAL TABLE source_table ( a1 VARCHAR NOT NULL, a2 INT NOT NULL ) STORED AS CSV WITH HEADER ROW - OPTIONS ('UNBOUNDED' 'TRUE') LOCATION '{source_display_fifo_path}'" )) .await?; // register csv file with the SQL ctx.sql(&format!( - "CREATE EXTERNAL TABLE sink_table ( + "CREATE UNBOUNDED EXTERNAL TABLE sink_table ( a1 VARCHAR NOT NULL, a2 INT NOT NULL ) STORED AS CSV WITH HEADER ROW - OPTIONS ('UNBOUNDED' 'TRUE') LOCATION '{sink_display_fifo_path}'" )) .await?; let df = ctx - .sql( - "INSERT INTO sink_table - SELECT a1, a2 FROM source_table", - ) + .sql("INSERT INTO sink_table SELECT a1, a2 FROM source_table") .await?; + + // Start execution df.collect().await?; - tasks.into_iter().for_each(|jh| jh.join().unwrap()); + futures::future::try_join_all(tasks).await.unwrap(); Ok(()) } } diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 821f236af87b5..9069dbbd5850e 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -109,7 +109,6 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str group_by.clone(), aggregate_expr.clone(), vec![None], - vec![None], running_source, schema.clone(), ) @@ -122,7 +121,6 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str group_by.clone(), aggregate_expr.clone(), vec![None], - vec![None], usual_source, schema.clone(), ) diff --git a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs index 4d3a2a15c5e93..df6499e9b1e47 100644 --- a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs @@ -17,22 +17,273 @@ #[cfg(test)] mod sp_repartition_fuzz_tests { - use arrow::compute::concat_batches; - use arrow_array::{ArrayRef, Int64Array, RecordBatch}; - use arrow_schema::SortOptions; - use datafusion::physical_plan::memory::MemoryExec; - use datafusion::physical_plan::repartition::RepartitionExec; - use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; - use datafusion::physical_plan::{collect, ExecutionPlan, Partitioning}; - use datafusion::prelude::SessionContext; - use datafusion_execution::config::SessionConfig; - use datafusion_physical_expr::expressions::col; - use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; - use rand::rngs::StdRng; - use rand::{Rng, SeedableRng}; use std::sync::Arc; + + use arrow::compute::{concat_batches, lexsort, SortColumn}; + use arrow_array::{ArrayRef, Int64Array, RecordBatch, UInt64Array}; + use arrow_schema::{DataType, Field, Schema, SchemaRef, SortOptions}; + + use datafusion::physical_plan::{ + collect, + memory::MemoryExec, + metrics::{BaselineMetrics, ExecutionPlanMetricsSet}, + repartition::RepartitionExec, + sorts::sort_preserving_merge::SortPreservingMergeExec, + sorts::streaming_merge::streaming_merge, + stream::RecordBatchStreamAdapter, + ExecutionPlan, Partitioning, + }; + use datafusion::prelude::SessionContext; + use datafusion_common::Result; + use datafusion_execution::{ + config::SessionConfig, memory_pool::MemoryConsumer, SendableRecordBatchStream, + }; + use datafusion_physical_expr::{ + expressions::{col, Column}, + EquivalenceProperties, PhysicalExpr, PhysicalSortExpr, + }; use test_utils::add_empty_batches; + use datafusion_physical_expr::equivalence::EquivalenceClass; + use itertools::izip; + use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng}; + + // Generate a schema which consists of 6 columns (a, b, c, d, e, f) + fn create_test_schema() -> Result { + let a = Field::new("a", DataType::Int32, true); + let b = Field::new("b", DataType::Int32, true); + let c = Field::new("c", DataType::Int32, true); + let d = Field::new("d", DataType::Int32, true); + let e = Field::new("e", DataType::Int32, true); + let f = Field::new("f", DataType::Int32, true); + let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f])); + + Ok(schema) + } + + /// Construct a schema with random ordering + /// among column a, b, c, d + /// where + /// Column [a=f] (e.g they are aliases). + /// Column e is constant. + fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperties)> { + let test_schema = create_test_schema()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_exprs = [col_a, col_b, col_c, col_d, col_e, col_f]; + + let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); + // Define a and f are aliases + eq_properties.add_equal_conditions(col_a, col_f); + // Column e has constant value. + eq_properties = eq_properties.add_constants([col_e.clone()]); + + // Randomly order columns for sorting + let mut rng = StdRng::seed_from_u64(seed); + let mut remaining_exprs = col_exprs[0..4].to_vec(); // only a, b, c, d are sorted + + let options_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + while !remaining_exprs.is_empty() { + let n_sort_expr = rng.gen_range(0..remaining_exprs.len() + 1); + remaining_exprs.shuffle(&mut rng); + + let ordering = remaining_exprs + .drain(0..n_sort_expr) + .map(|expr| PhysicalSortExpr { + expr: expr.clone(), + options: options_asc, + }) + .collect(); + + eq_properties.add_new_orderings([ordering]); + } + + Ok((test_schema, eq_properties)) + } + + // If we already generated a random result for one of the + // expressions in the equivalence classes. For other expressions in the same + // equivalence class use same result. This util gets already calculated result, when available. + fn get_representative_arr( + eq_group: &EquivalenceClass, + existing_vec: &[Option], + schema: SchemaRef, + ) -> Option { + for expr in eq_group.iter() { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + if let Some(res) = &existing_vec[idx] { + return Some(res.clone()); + } + } + None + } + + // Generate a table that satisfies the given equivalence properties; i.e. + // equivalences, ordering equivalences, and constants. + fn generate_table_for_eq_properties( + eq_properties: &EquivalenceProperties, + n_elem: usize, + n_distinct: usize, + ) -> Result { + let mut rng = StdRng::seed_from_u64(23); + + let schema = eq_properties.schema(); + let mut schema_vec = vec![None; schema.fields.len()]; + + // Utility closure to generate random array + let mut generate_random_array = |num_elems: usize, max_val: usize| -> ArrayRef { + let values: Vec = (0..num_elems) + .map(|_| rng.gen_range(0..max_val) as u64) + .collect(); + Arc::new(UInt64Array::from_iter_values(values)) + }; + + // Fill constant columns + for constant in eq_properties.constants() { + let col = constant.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + let arr = + Arc::new(UInt64Array::from_iter_values(vec![0; n_elem])) as ArrayRef; + schema_vec[idx] = Some(arr); + } + + // Fill columns based on ordering equivalences + for ordering in eq_properties.oeq_class().iter() { + let (sort_columns, indices): (Vec<_>, Vec<_>) = ordering + .iter() + .map(|PhysicalSortExpr { expr, options }| { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + let arr = generate_random_array(n_elem, n_distinct); + ( + SortColumn { + values: arr, + options: Some(*options), + }, + idx, + ) + }) + .unzip(); + + let sort_arrs = arrow::compute::lexsort(&sort_columns, None)?; + for (idx, arr) in izip!(indices, sort_arrs) { + schema_vec[idx] = Some(arr); + } + } + + // Fill columns based on equivalence groups + for eq_group in eq_properties.eq_group().iter() { + let representative_array = + get_representative_arr(eq_group, &schema_vec, schema.clone()) + .unwrap_or_else(|| generate_random_array(n_elem, n_distinct)); + + for expr in eq_group.iter() { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + schema_vec[idx] = Some(representative_array.clone()); + } + } + + let res: Vec<_> = schema_vec + .into_iter() + .zip(schema.fields.iter()) + .map(|(elem, field)| { + ( + field.name(), + // Generate random values for columns that do not occur in any of the groups (equivalence, ordering equivalence, constants) + elem.unwrap_or_else(|| generate_random_array(n_elem, n_distinct)), + ) + }) + .collect(); + + Ok(RecordBatch::try_from_iter(res)?) + } + + // This test checks for whether during sort preserving merge we can preserve all of the valid orderings + // successfully. If at the input we have orderings [a ASC, b ASC], [c ASC, d ASC] + // After sort preserving merge orderings [a ASC, b ASC], [c ASC, d ASC] should still be valid. + #[tokio::test] + async fn stream_merge_multi_order_preserve() -> Result<()> { + const N_PARTITION: usize = 8; + const N_ELEM: usize = 25; + const N_DISTINCT: usize = 5; + const N_DIFF_SCHEMA: usize = 20; + + use datafusion::physical_plan::common::collect; + for seed in 0..N_DIFF_SCHEMA { + // Create a schema with random equivalence properties + let (_test_schema, eq_properties) = create_random_schema(seed as u64)?; + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEM, N_DISTINCT)?; + let schema = table_data_with_properties.schema(); + let streams: Vec = (0..N_PARTITION) + .map(|_idx| { + let batch = table_data_with_properties.clone(); + Box::pin(RecordBatchStreamAdapter::new( + schema.clone(), + futures::stream::once(async { Ok(batch) }), + )) as SendableRecordBatchStream + }) + .collect::>(); + + // Returns concatenated version of the all available orderings + let exprs = eq_properties + .oeq_class() + .output_ordering() + .unwrap_or_default(); + + let context = SessionContext::new().task_ctx(); + let mem_reservation = + MemoryConsumer::new("test".to_string()).register(context.memory_pool()); + + // Internally SortPreservingMergeExec uses this function for merging. + let res = streaming_merge( + streams, + schema, + &exprs, + BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0), + 1, + None, + mem_reservation, + )?; + let res = collect(res).await?; + // Contains the merged result. + let res = concat_batches(&res[0].schema(), &res)?; + + for ordering in eq_properties.oeq_class().iter() { + let err_msg = format!("error in eq properties: {:?}", eq_properties); + let sort_solumns = ordering + .iter() + .map(|sort_expr| sort_expr.evaluate_to_sort_column(&res)) + .collect::>>()?; + let orig_columns = sort_solumns + .iter() + .map(|sort_column| sort_column.values.clone()) + .collect::>(); + let sorted_columns = lexsort(&sort_solumns, None)?; + + // Make sure after merging ordering is still valid. + assert_eq!(orig_columns.len(), sorted_columns.len(), "{}", err_msg); + assert!( + izip!(orig_columns.into_iter(), sorted_columns.into_iter()) + .all(|(lhs, rhs)| { lhs == rhs }), + "{}", + err_msg + ) + } + } + Ok(()) + } + #[tokio::test(flavor = "multi_thread", worker_threads = 8)] async fn sort_preserving_repartition_test() { let seed_start = 0; @@ -140,7 +391,7 @@ mod sp_repartition_fuzz_tests { Arc::new( RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(2)) .unwrap() - .with_preserve_order(true), + .with_preserve_order(), ) } @@ -159,7 +410,7 @@ mod sp_repartition_fuzz_tests { Arc::new( RepartitionExec::try_new(input, Partitioning::Hash(hash_expr, 2)) .unwrap() - .with_preserve_order(true), + .with_preserve_order(), ) } diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index af96063ffb5fb..3037b4857a3b3 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -25,15 +25,15 @@ use arrow::util::pretty::pretty_format_batches; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::windows::{ - create_window_expr, BoundedWindowAggExec, PartitionSearchMode, WindowAggExec, + create_window_expr, BoundedWindowAggExec, WindowAggExec, }; -use datafusion::physical_plan::{collect, ExecutionPlan}; +use datafusion::physical_plan::{collect, ExecutionPlan, InputOrderMode}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::type_coercion::aggregates::coerce_types; use datafusion_expr::{ AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound, - WindowFrameUnits, WindowFunction, + WindowFrameUnits, WindowFunctionDefinition, }; use datafusion_physical_expr::expressions::{cast, col, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; @@ -43,9 +43,7 @@ use hashbrown::HashMap; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; -use datafusion_physical_plan::windows::PartitionSearchMode::{ - Linear, PartiallySorted, Sorted, -}; +use datafusion_physical_plan::InputOrderMode::{Linear, PartiallySorted, Sorted}; #[tokio::test(flavor = "multi_thread", worker_threads = 16)] async fn window_bounded_window_random_comparison() -> Result<()> { @@ -145,7 +143,7 @@ fn get_random_function( schema: &SchemaRef, rng: &mut StdRng, is_linear: bool, -) -> (WindowFunction, Vec>, String) { +) -> (WindowFunctionDefinition, Vec>, String) { let mut args = if is_linear { // In linear test for the test version with WindowAggExec we use insert SortExecs to the plan to be able to generate // same result with BoundedWindowAggExec which doesn't use any SortExec. To make result @@ -161,28 +159,28 @@ fn get_random_function( window_fn_map.insert( "sum", ( - WindowFunction::AggregateFunction(AggregateFunction::Sum), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), vec![], ), ); window_fn_map.insert( "count", ( - WindowFunction::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), vec![], ), ); window_fn_map.insert( "min", ( - WindowFunction::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), vec![], ), ); window_fn_map.insert( "max", ( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![], ), ); @@ -193,28 +191,36 @@ fn get_random_function( window_fn_map.insert( "row_number", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::RowNumber), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::RowNumber, + ), vec![], ), ); window_fn_map.insert( "rank", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Rank), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::Rank, + ), vec![], ), ); window_fn_map.insert( "dense_rank", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::DenseRank), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::DenseRank, + ), vec![], ), ); window_fn_map.insert( "lead", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::Lead, + ), vec![ lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))), @@ -224,7 +230,9 @@ fn get_random_function( window_fn_map.insert( "lag", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lag), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::Lag, + ), vec![ lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))), @@ -235,21 +243,27 @@ fn get_random_function( window_fn_map.insert( "first_value", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::FirstValue, + ), vec![], ), ); window_fn_map.insert( "last_value", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::LastValue), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::LastValue, + ), vec![], ), ); window_fn_map.insert( "nth_value", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::NthValue), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::NthValue, + ), vec![lit(ScalarValue::Int64(Some(rng.gen_range(1..10))))], ), ); @@ -257,7 +271,7 @@ fn get_random_function( let rand_fn_idx = rng.gen_range(0..window_fn_map.len()); let fn_name = window_fn_map.keys().collect::>()[rand_fn_idx]; let (window_fn, new_args) = window_fn_map.values().collect::>()[rand_fn_idx]; - if let WindowFunction::AggregateFunction(f) = window_fn { + if let WindowFunctionDefinition::AggregateFunction(f) = window_fn { let a = args[0].clone(); let dt = a.data_type(schema.as_ref()).unwrap(); let sig = f.signature(); @@ -385,9 +399,9 @@ async fn run_window_test( random_seed: u64, partition_by_columns: Vec<&str>, orderby_columns: Vec<&str>, - search_mode: PartitionSearchMode, + search_mode: InputOrderMode, ) -> Result<()> { - let is_linear = !matches!(search_mode, PartitionSearchMode::Sorted); + let is_linear = !matches!(search_mode, InputOrderMode::Sorted); let mut rng = StdRng::seed_from_u64(random_seed); let schema = input1[0].schema(); let session_config = SessionConfig::new().with_batch_size(50); diff --git a/datafusion/core/tests/parquet/custom_reader.rs b/datafusion/core/tests/parquet/custom_reader.rs index 37481b936d24a..e76b201e0222e 100644 --- a/datafusion/core/tests/parquet/custom_reader.rs +++ b/datafusion/core/tests/parquet/custom_reader.rs @@ -85,7 +85,6 @@ async fn route_data_access_ops_to_parquet_file_reader_factory() { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, None, @@ -188,6 +187,7 @@ async fn store_parquet_in_memory( last_modified: chrono::DateTime::from(SystemTime::now()), size: buf.len(), e_tag: None, + version: None, }; (meta, Bytes::from(buf)) diff --git a/datafusion/core/tests/parquet/file_statistics.rs b/datafusion/core/tests/parquet/file_statistics.rs index 1ea154303d697..9f94a59a3e598 100644 --- a/datafusion/core/tests/parquet/file_statistics.rs +++ b/datafusion/core/tests/parquet/file_statistics.rs @@ -133,7 +133,7 @@ async fn list_files_with_session_level_cache() { assert_eq!(get_list_file_cache_size(&state1), 1); let fg = &parquet1.base_config().file_groups; assert_eq!(fg.len(), 1); - assert_eq!(fg.get(0).unwrap().len(), 1); + assert_eq!(fg.first().unwrap().len(), 1); //Session 2 first time list files //check session 1 cache result not show in session 2 @@ -144,7 +144,7 @@ async fn list_files_with_session_level_cache() { assert_eq!(get_list_file_cache_size(&state2), 1); let fg2 = &parquet2.base_config().file_groups; assert_eq!(fg2.len(), 1); - assert_eq!(fg2.get(0).unwrap().len(), 1); + assert_eq!(fg2.first().unwrap().len(), 1); //Session 1 second time list files //check session 1 cache result not show in session 2 @@ -155,7 +155,7 @@ async fn list_files_with_session_level_cache() { assert_eq!(get_list_file_cache_size(&state1), 1); let fg = &parquet3.base_config().file_groups; assert_eq!(fg.len(), 1); - assert_eq!(fg.get(0).unwrap().len(), 1); + assert_eq!(fg.first().unwrap().len(), 1); // List same file no increase assert_eq!(get_list_file_cache_size(&state1), 1); } diff --git a/datafusion/core/tests/parquet/filter_pushdown.rs b/datafusion/core/tests/parquet/filter_pushdown.rs index 61a8f87b9ea58..f214e8903a4f8 100644 --- a/datafusion/core/tests/parquet/filter_pushdown.rs +++ b/datafusion/core/tests/parquet/filter_pushdown.rs @@ -34,7 +34,7 @@ use datafusion::physical_plan::collect; use datafusion::physical_plan::metrics::MetricsSet; use datafusion::prelude::{col, lit, lit_timestamp_nano, Expr, SessionContext}; use datafusion::test_util::parquet::{ParquetScanOptions, TestParquetFile}; -use datafusion_optimizer::utils::{conjunction, disjunction, split_conjunction}; +use datafusion_expr::utils::{conjunction, disjunction, split_conjunction}; use itertools::Itertools; use parquet::file::properties::WriterProperties; use tempfile::TempDir; diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 3f003c077d6a0..943f7fdbf4ac5 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -44,6 +44,7 @@ mod file_statistics; mod filter_pushdown; mod page_pruning; mod row_group_pruning; +mod schema; mod schema_coercion; #[cfg(test)] diff --git a/datafusion/core/tests/parquet/page_pruning.rs b/datafusion/core/tests/parquet/page_pruning.rs index b77643c35e84d..23a56bc821d44 100644 --- a/datafusion/core/tests/parquet/page_pruning.rs +++ b/datafusion/core/tests/parquet/page_pruning.rs @@ -50,6 +50,7 @@ async fn get_parquet_exec(state: &SessionState, filter: Expr) -> ParquetExec { last_modified: metadata.modified().map(chrono::DateTime::from).unwrap(), size: metadata.len() as usize, e_tag: None, + version: None, }; let schema = ParquetFormat::default() @@ -80,7 +81,6 @@ async fn get_parquet_exec(state: &SessionState, filter: Expr) -> ParquetExec { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, Some(predicate), None, diff --git a/datafusion/core/tests/sql/parquet_schema.rs b/datafusion/core/tests/parquet/schema.rs similarity index 95% rename from datafusion/core/tests/sql/parquet_schema.rs rename to datafusion/core/tests/parquet/schema.rs index bc1578da2c58a..30d4e1193022a 100644 --- a/datafusion/core/tests/sql/parquet_schema.rs +++ b/datafusion/core/tests/parquet/schema.rs @@ -22,6 +22,7 @@ use ::parquet::arrow::ArrowWriter; use tempfile::TempDir; use super::*; +use datafusion_common::assert_batches_sorted_eq; #[tokio::test] async fn schema_merge_ignores_metadata_by_default() { @@ -90,7 +91,13 @@ async fn schema_merge_ignores_metadata_by_default() { .await .unwrap(); - let actual = execute_to_batches(&ctx, "SELECT * from t").await; + let actual = ctx + .sql("SELECT * from t") + .await + .unwrap() + .collect() + .await + .unwrap(); assert_batches_sorted_eq!(expected, &actual); assert_no_metadata(&actual); } @@ -151,7 +158,13 @@ async fn schema_merge_can_preserve_metadata() { .await .unwrap(); - let actual = execute_to_batches(&ctx, "SELECT * from t").await; + let actual = ctx + .sql("SELECT * from t") + .await + .unwrap() + .collect() + .await + .unwrap(); assert_batches_sorted_eq!(expected, &actual); assert_metadata(&actual, &expected_metadata); } diff --git a/datafusion/core/tests/parquet/schema_coercion.rs b/datafusion/core/tests/parquet/schema_coercion.rs index b3134d470b56f..00f3eada496ee 100644 --- a/datafusion/core/tests/parquet/schema_coercion.rs +++ b/datafusion/core/tests/parquet/schema_coercion.rs @@ -69,7 +69,6 @@ async fn multi_parquet_coercion() { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, None, @@ -133,7 +132,6 @@ async fn multi_parquet_coercion_projection() { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, None, @@ -194,5 +192,6 @@ pub fn local_unpartitioned_file(path: impl AsRef) -> ObjectMeta last_modified: metadata.modified().map(chrono::DateTime::from).unwrap(), size: metadata.len() as usize, e_tag: None, + version: None, } } diff --git a/datafusion/core/tests/path_partition.rs b/datafusion/core/tests/path_partition.rs index 27d146de798d8..abe6ab283aff4 100644 --- a/datafusion/core/tests/path_partition.rs +++ b/datafusion/core/tests/path_partition.rs @@ -46,7 +46,7 @@ use futures::stream; use futures::stream::BoxStream; use object_store::{ path::Path, GetOptions, GetResult, GetResultPayload, ListResult, MultipartId, - ObjectMeta, ObjectStore, + ObjectMeta, ObjectStore, PutOptions, PutResult, }; use tokio::io::AsyncWrite; use url::Url; @@ -168,9 +168,9 @@ async fn parquet_distinct_partition_col() -> Result<()> { assert_eq!(min_limit, resulting_limit); let s = ScalarValue::try_from_array(results[0].column(1), 0)?; - let month = match extract_as_utf(&s) { - Some(month) => month, - s => panic!("Expected month as Dict(_, Utf8) found {s:?}"), + let month = match s { + ScalarValue::Utf8(Some(month)) => month, + s => panic!("Expected month as Utf8 found {s:?}"), }; let sql_on_partition_boundary = format!( @@ -191,15 +191,6 @@ async fn parquet_distinct_partition_col() -> Result<()> { Ok(()) } -fn extract_as_utf(v: &ScalarValue) -> Option { - if let ScalarValue::Dictionary(_, v) = v { - if let ScalarValue::Utf8(v) = v.as_ref() { - return v.clone(); - } - } - None -} - #[tokio::test] async fn csv_filter_with_file_col() -> Result<()> { let ctx = SessionContext::new(); @@ -620,7 +611,12 @@ impl MirroringObjectStore { #[async_trait] impl ObjectStore for MirroringObjectStore { - async fn put(&self, _location: &Path, _bytes: Bytes) -> object_store::Result<()> { + async fn put_opts( + &self, + _location: &Path, + _bytes: Bytes, + _opts: PutOptions, + ) -> object_store::Result { unimplemented!() } @@ -653,6 +649,7 @@ impl ObjectStore for MirroringObjectStore { last_modified: metadata.modified().map(chrono::DateTime::from).unwrap(), size: metadata.len() as usize, e_tag: None, + version: None, }; Ok(GetResult { @@ -680,26 +677,16 @@ impl ObjectStore for MirroringObjectStore { Ok(data.into()) } - async fn head(&self, location: &Path) -> object_store::Result { - self.files.iter().find(|x| *x == location).unwrap(); - Ok(ObjectMeta { - location: location.clone(), - last_modified: Utc.timestamp_nanos(0), - size: self.file_size as usize, - e_tag: None, - }) - } - async fn delete(&self, _location: &Path) -> object_store::Result<()> { unimplemented!() } - async fn list( + fn list( &self, prefix: Option<&Path>, - ) -> object_store::Result>> { + ) -> BoxStream<'_, object_store::Result> { let prefix = prefix.cloned().unwrap_or_default(); - Ok(Box::pin(stream::iter(self.files.iter().filter_map( + Box::pin(stream::iter(self.files.iter().filter_map( move |location| { // Don't return for exact prefix match let filter = location @@ -713,10 +700,11 @@ impl ObjectStore for MirroringObjectStore { last_modified: Utc.timestamp_nanos(0), size: self.file_size as usize, e_tag: None, + version: None, }) }) }, - )))) + ))) } async fn list_with_delimiter( @@ -750,6 +738,7 @@ impl ObjectStore for MirroringObjectStore { last_modified: Utc.timestamp_nanos(0), size: self.file_size as usize, e_tag: None, + version: None, }; objects.push(object); } diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 03864e9efef80..af6d0d5f4e245 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -17,8 +17,6 @@ use super::*; use datafusion::scalar::ScalarValue; -use datafusion::test_util::scan_empty; -use datafusion_common::cast::as_float64_array; #[tokio::test] async fn csv_query_array_agg_distinct() -> Result<()> { @@ -68,324 +66,6 @@ async fn csv_query_array_agg_distinct() -> Result<()> { Ok(()) } -#[tokio::test] -async fn aggregate() -> Result<()> { - let results = execute_with_partition("SELECT SUM(c1), SUM(c2) FROM test", 4).await?; - assert_eq!(results.len(), 1); - - let expected = [ - "+--------------+--------------+", - "| SUM(test.c1) | SUM(test.c2) |", - "+--------------+--------------+", - "| 60 | 220 |", - "+--------------+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_empty() -> Result<()> { - // The predicate on this query purposely generates no results - let results = - execute_with_partition("SELECT SUM(c1), SUM(c2) FROM test where c1 > 100000", 4) - .await - .unwrap(); - - assert_eq!(results.len(), 1); - - let expected = [ - "+--------------+--------------+", - "| SUM(test.c1) | SUM(test.c2) |", - "+--------------+--------------+", - "| | |", - "+--------------+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_avg() -> Result<()> { - let results = execute_with_partition("SELECT AVG(c1), AVG(c2) FROM test", 4).await?; - assert_eq!(results.len(), 1); - - let expected = [ - "+--------------+--------------+", - "| AVG(test.c1) | AVG(test.c2) |", - "+--------------+--------------+", - "| 1.5 | 5.5 |", - "+--------------+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_max() -> Result<()> { - let results = execute_with_partition("SELECT MAX(c1), MAX(c2) FROM test", 4).await?; - assert_eq!(results.len(), 1); - - let expected = [ - "+--------------+--------------+", - "| MAX(test.c1) | MAX(test.c2) |", - "+--------------+--------------+", - "| 3 | 10 |", - "+--------------+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_min() -> Result<()> { - let results = execute_with_partition("SELECT MIN(c1), MIN(c2) FROM test", 4).await?; - assert_eq!(results.len(), 1); - - let expected = [ - "+--------------+--------------+", - "| MIN(test.c1) | MIN(test.c2) |", - "+--------------+--------------+", - "| 0 | 1 |", - "+--------------+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_grouped() -> Result<()> { - let results = - execute_with_partition("SELECT c1, SUM(c2) FROM test GROUP BY c1", 4).await?; - - let expected = [ - "+----+--------------+", - "| c1 | SUM(test.c2) |", - "+----+--------------+", - "| 0 | 55 |", - "| 1 | 55 |", - "| 2 | 55 |", - "| 3 | 55 |", - "+----+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_grouped_avg() -> Result<()> { - let results = - execute_with_partition("SELECT c1, AVG(c2) FROM test GROUP BY c1", 4).await?; - - let expected = [ - "+----+--------------+", - "| c1 | AVG(test.c2) |", - "+----+--------------+", - "| 0 | 5.5 |", - "| 1 | 5.5 |", - "| 2 | 5.5 |", - "| 3 | 5.5 |", - "+----+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_grouped_empty() -> Result<()> { - let results = execute_with_partition( - "SELECT c1, AVG(c2) FROM test WHERE c1 = 123 GROUP BY c1", - 4, - ) - .await?; - - let expected = [ - "+----+--------------+", - "| c1 | AVG(test.c2) |", - "+----+--------------+", - "+----+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_grouped_max() -> Result<()> { - let results = - execute_with_partition("SELECT c1, MAX(c2) FROM test GROUP BY c1", 4).await?; - - let expected = [ - "+----+--------------+", - "| c1 | MAX(test.c2) |", - "+----+--------------+", - "| 0 | 10 |", - "| 1 | 10 |", - "| 2 | 10 |", - "| 3 | 10 |", - "+----+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_grouped_min() -> Result<()> { - let results = - execute_with_partition("SELECT c1, MIN(c2) FROM test GROUP BY c1", 4).await?; - - let expected = [ - "+----+--------------+", - "| c1 | MIN(test.c2) |", - "+----+--------------+", - "| 0 | 1 |", - "| 1 | 1 |", - "| 2 | 1 |", - "| 3 | 1 |", - "+----+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_min_max_w_custom_window_frames() -> Result<()> { - let ctx = SessionContext::new(); - register_aggregate_csv(&ctx).await?; - let sql = - "SELECT - MIN(c12) OVER (ORDER BY C12 RANGE BETWEEN 0.3 PRECEDING AND 0.2 FOLLOWING) as min1, - MAX(c12) OVER (ORDER BY C11 RANGE BETWEEN 0.1 PRECEDING AND 0.2 FOLLOWING) as max1 - FROM aggregate_test_100 - ORDER BY C9 - LIMIT 5"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+---------------------+--------------------+", - "| min1 | max1 |", - "+---------------------+--------------------+", - "| 0.01479305307777301 | 0.9965400387585364 |", - "| 0.01479305307777301 | 0.9800193410444061 |", - "| 0.01479305307777301 | 0.9706712283358269 |", - "| 0.2667177795079635 | 0.9965400387585364 |", - "| 0.3600766362333053 | 0.9706712283358269 |", - "+---------------------+--------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn aggregate_min_max_w_custom_window_frames_unbounded_start() -> Result<()> { - let ctx = SessionContext::new(); - register_aggregate_csv(&ctx).await?; - let sql = - "SELECT - MIN(c12) OVER (ORDER BY C12 RANGE BETWEEN UNBOUNDED PRECEDING AND 0.2 FOLLOWING) as min1, - MAX(c12) OVER (ORDER BY C11 RANGE BETWEEN UNBOUNDED PRECEDING AND 0.2 FOLLOWING) as max1 - FROM aggregate_test_100 - ORDER BY C9 - LIMIT 5"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+---------------------+--------------------+", - "| min1 | max1 |", - "+---------------------+--------------------+", - "| 0.01479305307777301 | 0.9965400387585364 |", - "| 0.01479305307777301 | 0.9800193410444061 |", - "| 0.01479305307777301 | 0.9800193410444061 |", - "| 0.01479305307777301 | 0.9965400387585364 |", - "| 0.01479305307777301 | 0.9800193410444061 |", - "+---------------------+--------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn aggregate_avg_add() -> Result<()> { - let results = execute_with_partition( - "SELECT AVG(c1), AVG(c1) + 1, AVG(c1) + 2, 1 + AVG(c1) FROM test", - 4, - ) - .await?; - assert_eq!(results.len(), 1); - - let expected = ["+--------------+-------------------------+-------------------------+-------------------------+", - "| AVG(test.c1) | AVG(test.c1) + Int64(1) | AVG(test.c1) + Int64(2) | Int64(1) + AVG(test.c1) |", - "+--------------+-------------------------+-------------------------+-------------------------+", - "| 1.5 | 2.5 | 3.5 | 2.5 |", - "+--------------+-------------------------+-------------------------+-------------------------+"]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn case_sensitive_identifiers_aggregates() { - let ctx = SessionContext::new(); - ctx.register_table("t", table_with_sequence(1, 1).unwrap()) - .unwrap(); - - let expected = [ - "+----------+", - "| MAX(t.i) |", - "+----------+", - "| 1 |", - "+----------+", - ]; - - let results = plan_and_collect(&ctx, "SELECT max(i) FROM t") - .await - .unwrap(); - - assert_batches_sorted_eq!(expected, &results); - - let results = plan_and_collect(&ctx, "SELECT MAX(i) FROM t") - .await - .unwrap(); - assert_batches_sorted_eq!(expected, &results); - - // Using double quotes allows specifying the function name with capitalization - let err = plan_and_collect(&ctx, "SELECT \"MAX\"(i) FROM t") - .await - .unwrap_err(); - assert!(err - .to_string() - .contains("Error during planning: Invalid function 'MAX'")); - - let results = plan_and_collect(&ctx, "SELECT \"max\"(i) FROM t") - .await - .unwrap(); - assert_batches_sorted_eq!(expected, &results); -} - -#[tokio::test] -async fn count_basic() -> Result<()> { - let results = - execute_with_partition("SELECT COUNT(c1), COUNT(c2) FROM test", 1).await?; - assert_eq!(results.len(), 1); - - let expected = [ - "+----------------+----------------+", - "| COUNT(test.c1) | COUNT(test.c2) |", - "+----------------+----------------+", - "| 10 | 10 |", - "+----------------+----------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - Ok(()) -} - #[tokio::test] async fn count_partitioned() -> Result<()> { let results = @@ -495,162 +175,6 @@ async fn count_aggregated_cube() -> Result<()> { Ok(()) } -#[tokio::test] -async fn count_multi_expr() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Int32, true), - Field::new("c2", DataType::Int32, true), - ])); - - let data = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![ - Some(0), - None, - Some(1), - Some(2), - None, - ])), - Arc::new(Int32Array::from(vec![ - Some(1), - Some(1), - Some(0), - None, - None, - ])), - ], - )?; - - let ctx = SessionContext::new(); - ctx.register_batch("test", data)?; - let sql = "SELECT count(c1, c2) FROM test"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = [ - "+------------------------+", - "| COUNT(test.c1,test.c2) |", - "+------------------------+", - "| 2 |", - "+------------------------+", - ]; - assert_batches_sorted_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn count_multi_expr_group_by() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Int32, true), - Field::new("c2", DataType::Int32, true), - Field::new("c3", DataType::Int32, true), - ])); - - let data = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![ - Some(0), - None, - Some(1), - Some(2), - None, - ])), - Arc::new(Int32Array::from(vec![ - Some(1), - Some(1), - Some(0), - None, - None, - ])), - Arc::new(Int32Array::from(vec![ - Some(10), - Some(10), - Some(10), - Some(10), - Some(10), - ])), - ], - )?; - - let ctx = SessionContext::new(); - ctx.register_batch("test", data)?; - let sql = "SELECT c3, count(c1, c2) FROM test group by c3"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = [ - "+----+------------------------+", - "| c3 | COUNT(test.c1,test.c2) |", - "+----+------------------------+", - "| 10 | 2 |", - "+----+------------------------+", - ]; - assert_batches_sorted_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn simple_avg() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - - let batch1 = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], - )?; - let batch2 = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(Int32Array::from(vec![4, 5]))], - )?; - - let ctx = SessionContext::new(); - - let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Arc::new(provider))?; - - let result = plan_and_collect(&ctx, "SELECT AVG(a) FROM t").await?; - - let batch = &result[0]; - assert_eq!(1, batch.num_columns()); - assert_eq!(1, batch.num_rows()); - - let values = as_float64_array(batch.column(0)).expect("failed to cast version"); - assert_eq!(values.len(), 1); - // avg(1,2,3,4,5) = 3.0 - assert_eq!(values.value(0), 3.0_f64); - Ok(()) -} - -#[tokio::test] -async fn simple_mean() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - - let batch1 = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], - )?; - let batch2 = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(Int32Array::from(vec![4, 5]))], - )?; - - let ctx = SessionContext::new(); - - let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Arc::new(provider))?; - - let result = plan_and_collect(&ctx, "SELECT MEAN(a) FROM t").await?; - - let batch = &result[0]; - assert_eq!(1, batch.num_columns()); - assert_eq!(1, batch.num_rows()); - - let values = as_float64_array(batch.column(0)).expect("failed to cast version"); - assert_eq!(values.len(), 1); - // mean(1,2,3,4,5) = 3.0 - assert_eq!(values.value(0), 3.0_f64); - Ok(()) -} - async fn run_count_distinct_integers_aggregated_scenario( partitions: Vec>, ) -> Result> { @@ -771,31 +295,6 @@ async fn count_distinct_integers_aggregated_multiple_partitions() -> Result<()> Ok(()) } -#[tokio::test] -async fn aggregate_with_alias() -> Result<()> { - let ctx = SessionContext::new(); - let state = ctx.state(); - - let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Utf8, false), - Field::new("c2", DataType::UInt32, false), - ])); - - let plan = scan_empty(None, schema.as_ref(), None)? - .aggregate(vec![col("c1")], vec![sum(col("c2"))])? - .project(vec![col("c1"), sum(col("c2")).alias("total_salary")])? - .build()?; - - let plan = state.optimize(&plan)?; - let physical_plan = state.create_physical_plan(&Arc::new(plan)).await?; - assert_eq!("c1", physical_plan.schema().field(0).name().as_str()); - assert_eq!( - "total_salary", - physical_plan.schema().field(1).name().as_str() - ); - Ok(()) -} - #[tokio::test] async fn test_accumulator_row_accumulator() -> Result<()> { let config = SessionConfig::new(); diff --git a/datafusion/core/tests/sql/arrow_files.rs b/datafusion/core/tests/sql/arrow_files.rs deleted file mode 100644 index fc90fe3c34640..0000000000000 --- a/datafusion/core/tests/sql/arrow_files.rs +++ /dev/null @@ -1,70 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. -use datafusion::execution::options::ArrowReadOptions; - -use super::*; - -async fn register_arrow(ctx: &mut SessionContext) { - ctx.register_arrow( - "arrow_simple", - "tests/data/example.arrow", - ArrowReadOptions::default(), - ) - .await - .unwrap(); -} - -#[tokio::test] -async fn arrow_query() { - let mut ctx = SessionContext::new(); - register_arrow(&mut ctx).await; - let sql = "SELECT * FROM arrow_simple"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+----+-----+-------+", - "| f0 | f1 | f2 |", - "+----+-----+-------+", - "| 1 | foo | true |", - "| 2 | bar | |", - "| 3 | baz | false |", - "| 4 | | true |", - "+----+-----+-------+", - ]; - - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -async fn arrow_explain() { - let mut ctx = SessionContext::new(); - register_arrow(&mut ctx).await; - let sql = "EXPLAIN SELECT * FROM arrow_simple"; - let actual = execute(&ctx, sql).await; - let actual = normalize_vec_for_explain(actual); - let expected = vec![ - vec![ - "logical_plan", - "TableScan: arrow_simple projection=[f0, f1, f2]", - ], - vec![ - "physical_plan", - "ArrowExec: file_groups={1 group: [[WORKING_DIR/tests/data/example.arrow]]}, projection=[f0, f1, f2]\n", - ], - ]; - - assert_eq!(expected, actual); -} diff --git a/datafusion/core/tests/sql/describe.rs b/datafusion/core/tests/sql/describe.rs deleted file mode 100644 index cd8e79b2c93b1..0000000000000 --- a/datafusion/core/tests/sql/describe.rs +++ /dev/null @@ -1,72 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::assert_batches_eq; -use datafusion::prelude::*; -use datafusion_common::test_util::parquet_test_data; - -#[tokio::test] -async fn describe_plan() { - let ctx = parquet_context().await; - - let query = "describe alltypes_tiny_pages"; - let results = ctx.sql(query).await.unwrap().collect().await.unwrap(); - - let expected = vec![ - "+-----------------+-----------------------------+-------------+", - "| column_name | data_type | is_nullable |", - "+-----------------+-----------------------------+-------------+", - "| id | Int32 | YES |", - "| bool_col | Boolean | YES |", - "| tinyint_col | Int8 | YES |", - "| smallint_col | Int16 | YES |", - "| int_col | Int32 | YES |", - "| bigint_col | Int64 | YES |", - "| float_col | Float32 | YES |", - "| double_col | Float64 | YES |", - "| date_string_col | Utf8 | YES |", - "| string_col | Utf8 | YES |", - "| timestamp_col | Timestamp(Nanosecond, None) | YES |", - "| year | Int32 | YES |", - "| month | Int32 | YES |", - "+-----------------+-----------------------------+-------------+", - ]; - - assert_batches_eq!(expected, &results); - - // also ensure we plan Describe via SessionState - let state = ctx.state(); - let plan = state.create_logical_plan(query).await.unwrap(); - let df = DataFrame::new(state, plan); - let results = df.collect().await.unwrap(); - - assert_batches_eq!(expected, &results); -} - -/// Return a SessionContext with parquet file registered -async fn parquet_context() -> SessionContext { - let ctx = SessionContext::new(); - let testdata = parquet_test_data(); - ctx.register_parquet( - "alltypes_tiny_pages", - &format!("{testdata}/alltypes_tiny_pages.parquet"), - ParquetReadOptions::default(), - ) - .await - .unwrap(); - ctx -} diff --git a/datafusion/core/tests/sql/displayable.rs b/datafusion/core/tests/sql/displayable.rs deleted file mode 100644 index 3255d514c5e4a..0000000000000 --- a/datafusion/core/tests/sql/displayable.rs +++ /dev/null @@ -1,57 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use object_store::path::Path; - -use datafusion::prelude::*; -use datafusion_physical_plan::displayable; - -#[tokio::test] -async fn teset_displayable() { - // Hard code target_partitions as it appears in the RepartitionExec output - let config = SessionConfig::new().with_target_partitions(3); - let ctx = SessionContext::new_with_config(config); - - // register the a table - ctx.register_csv("example", "tests/data/example.csv", CsvReadOptions::new()) - .await - .unwrap(); - - // create a plan to run a SQL query - let dataframe = ctx.sql("SELECT a FROM example WHERE a < 5").await.unwrap(); - let physical_plan = dataframe.create_physical_plan().await.unwrap(); - - // Format using display string in verbose mode - let displayable_plan = displayable(physical_plan.as_ref()); - let plan_string = format!("{}", displayable_plan.indent(true)); - - let working_directory = std::env::current_dir().unwrap(); - let normalized = Path::from_filesystem_path(working_directory).unwrap(); - let plan_string = plan_string.replace(normalized.as_ref(), "WORKING_DIR"); - - assert_eq!("CoalesceBatchesExec: target_batch_size=8192\ - \n FilterExec: a@0 < 5\ - \n RepartitionExec: partitioning=RoundRobinBatch(3), input_partitions=1\ - \n CsvExec: file_groups={1 group: [[WORKING_DIR/tests/data/example.csv]]}, projection=[a], has_header=true", - plan_string.trim()); - - let one_line = format!("{}", displayable_plan.one_line()); - assert_eq!( - "CoalesceBatchesExec: target_batch_size=8192", - one_line.trim() - ); -} diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 2436e82f3ce98..37f8cefc90809 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -560,7 +560,7 @@ async fn csv_explain_verbose_plans() { // Since the plan contains path that are environmentally // dependant(e.g. full path of the test file), only verify // important content - assert_contains!(&actual, "logical_plan after push_down_projection"); + assert_contains!(&actual, "logical_plan after optimize_projections"); assert_contains!(&actual, "physical_plan"); assert_contains!(&actual, "FilterExec: c2@1 > 10"); assert_contains!(actual, "ProjectionExec: expr=[c1@0 as c1]"); @@ -575,7 +575,7 @@ async fn explain_analyze_runs_optimizers() { // This happens as an optimization pass where count(*) can be // answered using statistics only. - let expected = "EmptyExec: produce_one_row=true"; + let expected = "PlaceholderRowExec"; let sql = "EXPLAIN SELECT count(*) from alltypes_plain"; let actual = execute_to_batches(&ctx, sql).await; @@ -806,7 +806,7 @@ async fn explain_physical_plan_only() { let expected = vec![vec![ "physical_plan", "ProjectionExec: expr=[2 as COUNT(*)]\ - \n EmptyExec: produce_one_row=true\ + \n PlaceholderRowExec\ \n", ]]; assert_eq!(expected, actual); @@ -827,5 +827,8 @@ async fn csv_explain_analyze_with_statistics() { .to_string(); // should contain scan statistics - assert_contains!(&formatted, ", statistics=[Rows=Absent, Bytes=Absent]"); + assert_contains!( + &formatted, + ", statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:)]]" + ); } diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs index 7d41ad4a881c5..8ac0e3e5ef190 100644 --- a/datafusion/core/tests/sql/expr.rs +++ b/datafusion/core/tests/sql/expr.rs @@ -741,6 +741,7 @@ async fn test_extract_date_part() -> Result<()> { #[tokio::test] async fn test_extract_epoch() -> Result<()> { + // timestamp test_expression!( "extract(epoch from '1870-01-01T07:29:10.256'::timestamp)", "-3155646649.744" @@ -754,6 +755,39 @@ async fn test_extract_epoch() -> Result<()> { "946684800.0" ); test_expression!("extract(epoch from NULL::timestamp)", "NULL"); + // date + test_expression!( + "extract(epoch from arrow_cast('1970-01-01', 'Date32'))", + "0.0" + ); + test_expression!( + "extract(epoch from arrow_cast('1970-01-02', 'Date32'))", + "86400.0" + ); + test_expression!( + "extract(epoch from arrow_cast('1970-01-11', 'Date32'))", + "864000.0" + ); + test_expression!( + "extract(epoch from arrow_cast('1969-12-31', 'Date32'))", + "-86400.0" + ); + test_expression!( + "extract(epoch from arrow_cast('1970-01-01', 'Date64'))", + "0.0" + ); + test_expression!( + "extract(epoch from arrow_cast('1970-01-02', 'Date64'))", + "86400.0" + ); + test_expression!( + "extract(epoch from arrow_cast('1970-01-11', 'Date64'))", + "864000.0" + ); + test_expression!( + "extract(epoch from arrow_cast('1969-12-31', 'Date64'))", + "-86400.0" + ); Ok(()) } diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 528bde632355b..d1f270b540b55 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use datafusion::datasource::stream::{StreamConfig, StreamTable}; use datafusion::test_util::register_unbounded_file_with_ordering; use super::*; @@ -105,9 +106,7 @@ async fn join_change_in_planner() -> Result<()> { &left_file_path, "left", file_sort_order.clone(), - true, - ) - .await?; + )?; let right_file_path = tmp_dir.path().join("right.csv"); File::create(right_file_path.clone()).unwrap(); register_unbounded_file_with_ordering( @@ -116,9 +115,7 @@ async fn join_change_in_planner() -> Result<()> { &right_file_path, "right", file_sort_order, - true, - ) - .await?; + )?; let sql = "SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10"; let dataframe = ctx.sql(sql).await?; let physical_plan = dataframe.create_physical_plan().await?; @@ -160,20 +157,13 @@ async fn join_change_in_planner_without_sort() -> Result<()> { Field::new("a1", DataType::UInt32, false), Field::new("a2", DataType::UInt32, false), ])); - ctx.register_csv( - "left", - left_file_path.as_os_str().to_str().unwrap(), - CsvReadOptions::new().schema(&schema).mark_infinite(true), - ) - .await?; + let left = StreamConfig::new_file(schema.clone(), left_file_path); + ctx.register_table("left", Arc::new(StreamTable::new(Arc::new(left))))?; + let right_file_path = tmp_dir.path().join("right.csv"); File::create(right_file_path.clone())?; - ctx.register_csv( - "right", - right_file_path.as_os_str().to_str().unwrap(), - CsvReadOptions::new().schema(&schema).mark_infinite(true), - ) - .await?; + let right = StreamConfig::new_file(schema, right_file_path); + ctx.register_table("right", Arc::new(StreamTable::new(Arc::new(right))))?; let sql = "SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10"; let dataframe = ctx.sql(sql).await?; let physical_plan = dataframe.create_physical_plan().await?; @@ -217,20 +207,12 @@ async fn join_change_in_planner_without_sort_not_allowed() -> Result<()> { Field::new("a1", DataType::UInt32, false), Field::new("a2", DataType::UInt32, false), ])); - ctx.register_csv( - "left", - left_file_path.as_os_str().to_str().unwrap(), - CsvReadOptions::new().schema(&schema).mark_infinite(true), - ) - .await?; + let left = StreamConfig::new_file(schema.clone(), left_file_path); + ctx.register_table("left", Arc::new(StreamTable::new(Arc::new(left))))?; let right_file_path = tmp_dir.path().join("right.csv"); File::create(right_file_path.clone())?; - ctx.register_csv( - "right", - right_file_path.as_os_str().to_str().unwrap(), - CsvReadOptions::new().schema(&schema).mark_infinite(true), - ) - .await?; + let right = StreamConfig::new_file(schema.clone(), right_file_path); + ctx.register_table("right", Arc::new(StreamTable::new(Arc::new(right))))?; let df = ctx.sql("SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10").await?; match df.create_physical_plan().await { Ok(_) => panic!("Expecting error."), diff --git a/datafusion/core/tests/sql/limit.rs b/datafusion/core/tests/sql/limit.rs deleted file mode 100644 index 1c8ea4fd3468c..0000000000000 --- a/datafusion/core/tests/sql/limit.rs +++ /dev/null @@ -1,101 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use super::*; - -#[tokio::test] -async fn limit() -> Result<()> { - let tmp_dir = TempDir::new()?; - let ctx = create_ctx_with_partition(&tmp_dir, 1).await?; - ctx.register_table("t", table_with_sequence(1, 1000).unwrap()) - .unwrap(); - - let results = plan_and_collect(&ctx, "SELECT i FROM t ORDER BY i DESC limit 3") - .await - .unwrap(); - - #[rustfmt::skip] - let expected = ["+------+", - "| i |", - "+------+", - "| 1000 |", - "| 999 |", - "| 998 |", - "+------+"]; - - assert_batches_eq!(expected, &results); - - let results = plan_and_collect(&ctx, "SELECT i FROM t ORDER BY i limit 3") - .await - .unwrap(); - - #[rustfmt::skip] - let expected = ["+---+", - "| i |", - "+---+", - "| 1 |", - "| 2 |", - "| 3 |", - "+---+"]; - - assert_batches_eq!(expected, &results); - - let results = plan_and_collect(&ctx, "SELECT i FROM t limit 3") - .await - .unwrap(); - - // the actual rows are not guaranteed, so only check the count (should be 3) - let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum(); - assert_eq!(num_rows, 3); - - Ok(()) -} - -#[tokio::test] -async fn limit_multi_partitions() -> Result<()> { - let tmp_dir = TempDir::new()?; - let ctx = create_ctx_with_partition(&tmp_dir, 1).await?; - - let partitions = vec![ - vec![make_partition(0)], - vec![make_partition(1)], - vec![make_partition(2)], - vec![make_partition(3)], - vec![make_partition(4)], - vec![make_partition(5)], - ]; - let schema = partitions[0][0].schema(); - let provider = Arc::new(MemTable::try_new(schema, partitions).unwrap()); - - ctx.register_table("t", provider).unwrap(); - - // select all rows - let results = plan_and_collect(&ctx, "SELECT i FROM t").await.unwrap(); - - let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum(); - assert_eq!(num_rows, 15); - - for limit in 1..10 { - let query = format!("SELECT i FROM t limit {limit}"); - let results = plan_and_collect(&ctx, &query).await.unwrap(); - - let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum(); - assert_eq!(num_rows, limit, "mismatch with query {query}"); - } - - Ok(()) -} diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index d44513e69a9f9..849d85dec6bf1 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::convert::TryFrom; use std::sync::Arc; use arrow::{ @@ -73,27 +72,19 @@ macro_rules! test_expression { } pub mod aggregates; -pub mod arrow_files; pub mod create_drop; pub mod csv_files; -pub mod describe; -pub mod displayable; pub mod explain_analyze; pub mod expr; pub mod group_by; pub mod joins; -pub mod limit; pub mod order; -pub mod parquet; -pub mod parquet_schema; pub mod partitioned_csv; pub mod predicates; -pub mod projection; pub mod references; pub mod repartition; pub mod select; mod sql_api; -pub mod subqueries; pub mod timestamp; fn create_join_context( @@ -458,23 +449,6 @@ async fn register_aggregate_csv_by_sql(ctx: &SessionContext) { ); } -async fn register_aggregate_simple_csv(ctx: &SessionContext) -> Result<()> { - // It's not possible to use aggregate_test_100 as it doesn't have enough similar values to test grouping on floats. - let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Float32, false), - Field::new("c2", DataType::Float64, false), - Field::new("c3", DataType::Boolean, false), - ])); - - ctx.register_csv( - "aggregate_simple", - "tests/data/aggregate_simple.csv", - CsvReadOptions::new().schema(&schema), - ) - .await?; - Ok(()) -} - async fn register_aggregate_csv(ctx: &SessionContext) -> Result<()> { let testdata = datafusion::test_util::arrow_test_data(); let schema = test_util::aggr_test_schema(); @@ -568,18 +542,6 @@ fn populate_csv_partitions( Ok(schema) } -/// Return a RecordBatch with a single Int32 array with values (0..sz) -pub fn make_partition(sz: i32) -> RecordBatch { - let seq_start = 0; - let seq_end = sz; - let values = (seq_start..seq_end).collect::>(); - let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); - let arr = Arc::new(Int32Array::from(values)); - let arr = arr as ArrayRef; - - RecordBatch::try_new(schema, vec![arr]).unwrap() -} - /// Specialised String representation fn col_str(column: &ArrayRef, row_index: usize) -> String { // NullArray::is_null() does not work on NullArray. diff --git a/datafusion/core/tests/sql/parquet.rs b/datafusion/core/tests/sql/parquet.rs deleted file mode 100644 index c2844a2b762af..0000000000000 --- a/datafusion/core/tests/sql/parquet.rs +++ /dev/null @@ -1,383 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::{fs, path::Path}; - -use ::parquet::arrow::ArrowWriter; -use datafusion::{datasource::listing::ListingOptions, execution::options::ReadOptions}; -use datafusion_common::cast::{as_list_array, as_primitive_array, as_string_array}; -use tempfile::TempDir; - -use super::*; - -#[tokio::test] -async fn parquet_query() { - let ctx = SessionContext::new(); - register_alltypes_parquet(&ctx).await; - // NOTE that string_col is actually a binary column and does not have the UTF8 logical type - // so we need an explicit cast - let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+----+---------------------------+", - "| id | alltypes_plain.string_col |", - "+----+---------------------------+", - "| 4 | 0 |", - "| 5 | 1 |", - "| 6 | 0 |", - "| 7 | 1 |", - "| 2 | 0 |", - "| 3 | 1 |", - "| 0 | 0 |", - "| 1 | 1 |", - "+----+---------------------------+", - ]; - - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -/// Test that if sort order is specified in ListingOptions, the sort -/// expressions make it all the way down to the ParquetExec -async fn parquet_with_sort_order_specified() { - let parquet_read_options = ParquetReadOptions::default(); - let session_config = SessionConfig::new().with_target_partitions(2); - - // The sort order is not specified - let options_no_sort = parquet_read_options.to_listing_options(&session_config); - - // The sort order is specified (not actually correct in this case) - let file_sort_order = [col("string_col"), col("int_col")] - .into_iter() - .map(|e| { - let ascending = true; - let nulls_first = false; - e.sort(ascending, nulls_first) - }) - .collect::>(); - - let options_sort = parquet_read_options - .to_listing_options(&session_config) - .with_file_sort_order(vec![file_sort_order]); - - // This string appears in ParquetExec if the output ordering is - // specified - let expected_output_ordering = - "output_ordering=[string_col@1 ASC NULLS LAST, int_col@0 ASC NULLS LAST]"; - - // when sort not specified, should not appear in the explain plan - let num_files = 1; - assert_not_contains!( - run_query_with_options(options_no_sort, num_files).await, - expected_output_ordering - ); - - // when sort IS specified, SHOULD appear in the explain plan - let num_files = 1; - assert_contains!( - run_query_with_options(options_sort.clone(), num_files).await, - expected_output_ordering - ); - - // when sort IS specified, but there are too many files (greater - // than the number of partitions) sort should not appear - let num_files = 3; - assert_not_contains!( - run_query_with_options(options_sort, num_files).await, - expected_output_ordering - ); -} - -/// Runs a limit query against a parquet file that was registered from -/// options on num_files copies of all_types_plain.parquet -async fn run_query_with_options(options: ListingOptions, num_files: usize) -> String { - let ctx = SessionContext::new(); - - let testdata = datafusion::test_util::parquet_test_data(); - let file_path = format!("{testdata}/alltypes_plain.parquet"); - - // Create a directory of parquet files with names - // 0.parquet - // 1.parquet - let tmpdir = TempDir::new().unwrap(); - for i in 0..num_files { - let target_file = tmpdir.path().join(format!("{i}.parquet")); - println!("Copying {file_path} to {target_file:?}"); - std::fs::copy(&file_path, target_file).unwrap(); - } - - let provided_schema = None; - let sql_definition = None; - ctx.register_listing_table( - "t", - tmpdir.path().to_string_lossy(), - options.clone(), - provided_schema, - sql_definition, - ) - .await - .unwrap(); - - let batches = ctx.sql("explain select int_col, string_col from t order by string_col, int_col limit 10") - .await - .expect("planing worked") - .collect() - .await - .expect("execution worked"); - - arrow::util::pretty::pretty_format_batches(&batches) - .unwrap() - .to_string() -} - -#[tokio::test] -async fn fixed_size_binary_columns() { - let ctx = SessionContext::new(); - ctx.register_parquet( - "t0", - "tests/data/test_binary.parquet", - ParquetReadOptions::default(), - ) - .await - .unwrap(); - let sql = "SELECT ids FROM t0 ORDER BY ids"; - let dataframe = ctx.sql(sql).await.unwrap(); - let results = dataframe.collect().await.unwrap(); - for batch in results { - assert_eq!(466, batch.num_rows()); - assert_eq!(1, batch.num_columns()); - } -} - -#[tokio::test] -async fn window_fn_timestamp_tz() { - let ctx = SessionContext::new(); - ctx.register_parquet( - "t0", - "tests/data/timestamp_with_tz.parquet", - ParquetReadOptions::default(), - ) - .await - .unwrap(); - - let sql = "SELECT count, LAG(timestamp, 1) OVER (ORDER BY timestamp) FROM t0"; - let dataframe = ctx.sql(sql).await.unwrap(); - let results = dataframe.collect().await.unwrap(); - - let mut num_rows = 0; - for batch in results { - num_rows += batch.num_rows(); - assert_eq!(2, batch.num_columns()); - - let ty = batch.column(0).data_type().clone(); - assert_eq!(DataType::Int64, ty); - - let ty = batch.column(1).data_type().clone(); - assert_eq!( - DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())), - ty - ); - } - - assert_eq!(131072, num_rows); -} - -#[tokio::test] -async fn parquet_single_nan_schema() { - let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); - ctx.register_parquet( - "single_nan", - &format!("{testdata}/single_nan.parquet"), - ParquetReadOptions::default(), - ) - .await - .unwrap(); - let sql = "SELECT mycol FROM single_nan"; - let dataframe = ctx.sql(sql).await.unwrap(); - let results = dataframe.collect().await.unwrap(); - for batch in results { - assert_eq!(1, batch.num_rows()); - assert_eq!(1, batch.num_columns()); - } -} - -#[tokio::test] -#[ignore = "Test ignored, will be enabled as part of the nested Parquet reader"] -async fn parquet_list_columns() { - let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); - ctx.register_parquet( - "list_columns", - &format!("{testdata}/list_columns.parquet"), - ParquetReadOptions::default(), - ) - .await - .unwrap(); - - let schema = Arc::new(Schema::new(vec![ - Field::new_list( - "int64_list", - Field::new("item", DataType::Int64, true), - true, - ), - Field::new_list("utf8_list", Field::new("item", DataType::Utf8, true), true), - ])); - - let sql = "SELECT int64_list, utf8_list FROM list_columns"; - let dataframe = ctx.sql(sql).await.unwrap(); - let results = dataframe.collect().await.unwrap(); - - // int64_list utf8_list - // 0 [1, 2, 3] [abc, efg, hij] - // 1 [None, 1] None - // 2 [4] [efg, None, hij, xyz] - - assert_eq!(1, results.len()); - let batch = &results[0]; - assert_eq!(3, batch.num_rows()); - assert_eq!(2, batch.num_columns()); - assert_eq!(schema, batch.schema()); - - let int_list_array = as_list_array(batch.column(0)).unwrap(); - let utf8_list_array = as_list_array(batch.column(1)).unwrap(); - - assert_eq!( - as_primitive_array::(&int_list_array.value(0)).unwrap(), - &PrimitiveArray::::from(vec![Some(1), Some(2), Some(3),]) - ); - - assert_eq!( - as_string_array(&utf8_list_array.value(0)).unwrap(), - &StringArray::try_from(vec![Some("abc"), Some("efg"), Some("hij"),]).unwrap() - ); - - assert_eq!( - as_primitive_array::(&int_list_array.value(1)).unwrap(), - &PrimitiveArray::::from(vec![None, Some(1),]) - ); - - assert!(utf8_list_array.is_null(1)); - - assert_eq!( - as_primitive_array::(&int_list_array.value(2)).unwrap(), - &PrimitiveArray::::from(vec![Some(4),]) - ); - - let result = utf8_list_array.value(2); - let result = as_string_array(&result).unwrap(); - - assert_eq!(result.value(0), "efg"); - assert!(result.is_null(1)); - assert_eq!(result.value(2), "hij"); - assert_eq!(result.value(3), "xyz"); -} - -#[tokio::test] -async fn parquet_query_with_max_min() { - let tmp_dir = TempDir::new().unwrap(); - let table_dir = tmp_dir.path().join("parquet_test"); - let table_path = Path::new(&table_dir); - - let fields = vec![ - Field::new("c1", DataType::Int32, true), - Field::new("c2", DataType::Utf8, true), - Field::new("c3", DataType::Int64, true), - Field::new("c4", DataType::Date32, true), - ]; - - let schema = Arc::new(Schema::new(fields.clone())); - - if let Ok(()) = fs::create_dir(table_path) { - let filename = "foo.parquet"; - let path = table_path.join(filename); - let file = fs::File::create(path).unwrap(); - let mut writer = - ArrowWriter::try_new(file.try_clone().unwrap(), schema.clone(), None) - .unwrap(); - - // create mock record batch - let c1s = Arc::new(Int32Array::from(vec![1, 2, 3])); - let c2s = Arc::new(StringArray::from(vec!["aaa", "bbb", "ccc"])); - let c3s = Arc::new(Int64Array::from(vec![100, 200, 300])); - let c4s = Arc::new(Date32Array::from(vec![Some(1), Some(2), Some(3)])); - let rec_batch = - RecordBatch::try_new(schema.clone(), vec![c1s, c2s, c3s, c4s]).unwrap(); - - writer.write(&rec_batch).unwrap(); - writer.close().unwrap(); - } - - // query parquet - let ctx = SessionContext::new(); - - ctx.register_parquet( - "foo", - &format!("{}/foo.parquet", table_dir.to_str().unwrap()), - ParquetReadOptions::default(), - ) - .await - .unwrap(); - - let sql = "SELECT max(c1) FROM foo"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+-------------+", - "| MAX(foo.c1) |", - "+-------------+", - "| 3 |", - "+-------------+", - ]; - - assert_batches_eq!(expected, &actual); - - let sql = "SELECT min(c2) FROM foo"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+-------------+", - "| MIN(foo.c2) |", - "+-------------+", - "| aaa |", - "+-------------+", - ]; - - assert_batches_eq!(expected, &actual); - - let sql = "SELECT max(c3) FROM foo"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+-------------+", - "| MAX(foo.c3) |", - "+-------------+", - "| 300 |", - "+-------------+", - ]; - - assert_batches_eq!(expected, &actual); - - let sql = "SELECT min(c4) FROM foo"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+-------------+", - "| MIN(foo.c4) |", - "+-------------+", - "| 1970-01-02 |", - "+-------------+", - ]; - - assert_batches_eq!(expected, &actual); -} diff --git a/datafusion/core/tests/sql/partitioned_csv.rs b/datafusion/core/tests/sql/partitioned_csv.rs index d5a1c2f0b4f84..b77557a66cd89 100644 --- a/datafusion/core/tests/sql/partitioned_csv.rs +++ b/datafusion/core/tests/sql/partitioned_csv.rs @@ -19,31 +19,13 @@ use std::{io::Write, sync::Arc}; -use arrow::{ - datatypes::{DataType, Field, Schema, SchemaRef}, - record_batch::RecordBatch, -}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::{ error::Result, prelude::{CsvReadOptions, SessionConfig, SessionContext}, }; use tempfile::TempDir; -/// Execute SQL and return results -async fn plan_and_collect( - ctx: &mut SessionContext, - sql: &str, -) -> Result> { - ctx.sql(sql).await?.collect().await -} - -/// Execute SQL and return results -pub async fn execute(sql: &str, partition_count: usize) -> Result> { - let tmp_dir = TempDir::new()?; - let mut ctx = create_ctx(&tmp_dir, partition_count).await?; - plan_and_collect(&mut ctx, sql).await -} - /// Generate CSV partitions within the supplied directory fn populate_csv_partitions( tmp_dir: &TempDir, diff --git a/datafusion/core/tests/sql/projection.rs b/datafusion/core/tests/sql/projection.rs deleted file mode 100644 index b31cb34f52108..0000000000000 --- a/datafusion/core/tests/sql/projection.rs +++ /dev/null @@ -1,373 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::datasource::provider_as_source; -use datafusion::test_util::scan_empty; -use datafusion_expr::{when, LogicalPlanBuilder, UNNAMED_TABLE}; -use tempfile::TempDir; - -use super::*; - -#[tokio::test] -async fn projection_same_fields() -> Result<()> { - let ctx = SessionContext::new(); - - let sql = "select (1+1) as a from (select 1 as a) as b;"; - let actual = execute_to_batches(&ctx, sql).await; - - #[rustfmt::skip] - let expected = ["+---+", - "| a |", - "+---+", - "| 2 |", - "+---+"]; - assert_batches_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn projection_type_alias() -> Result<()> { - let ctx = SessionContext::new(); - register_aggregate_simple_csv(&ctx).await?; - - // Query that aliases one column to the name of a different column - // that also has a different type (c1 == float32, c3 == boolean) - let sql = "SELECT c1 as c3 FROM aggregate_simple ORDER BY c3 LIMIT 2"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = [ - "+---------+", - "| c3 |", - "+---------+", - "| 0.00001 |", - "| 0.00002 |", - "+---------+", - ]; - assert_batches_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn csv_query_group_by_avg_with_projection() -> Result<()> { - let ctx = SessionContext::new(); - register_aggregate_csv(&ctx).await?; - let sql = "SELECT avg(c12), c1 FROM aggregate_test_100 GROUP BY c1"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+-----------------------------+----+", - "| AVG(aggregate_test_100.c12) | c1 |", - "+-----------------------------+----+", - "| 0.41040709263815384 | b |", - "| 0.48600669271341534 | e |", - "| 0.48754517466109415 | a |", - "| 0.48855379387549824 | d |", - "| 0.6600456536439784 | c |", - "+-----------------------------+----+", - ]; - assert_batches_sorted_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn parallel_projection() -> Result<()> { - let partition_count = 4; - let results = - partitioned_csv::execute("SELECT c1, c2 FROM test", partition_count).await?; - - let expected = vec![ - "+----+----+", - "| c1 | c2 |", - "+----+----+", - "| 3 | 1 |", - "| 3 | 2 |", - "| 3 | 3 |", - "| 3 | 4 |", - "| 3 | 5 |", - "| 3 | 6 |", - "| 3 | 7 |", - "| 3 | 8 |", - "| 3 | 9 |", - "| 3 | 10 |", - "| 2 | 1 |", - "| 2 | 2 |", - "| 2 | 3 |", - "| 2 | 4 |", - "| 2 | 5 |", - "| 2 | 6 |", - "| 2 | 7 |", - "| 2 | 8 |", - "| 2 | 9 |", - "| 2 | 10 |", - "| 1 | 1 |", - "| 1 | 2 |", - "| 1 | 3 |", - "| 1 | 4 |", - "| 1 | 5 |", - "| 1 | 6 |", - "| 1 | 7 |", - "| 1 | 8 |", - "| 1 | 9 |", - "| 1 | 10 |", - "| 0 | 1 |", - "| 0 | 2 |", - "| 0 | 3 |", - "| 0 | 4 |", - "| 0 | 5 |", - "| 0 | 6 |", - "| 0 | 7 |", - "| 0 | 8 |", - "| 0 | 9 |", - "| 0 | 10 |", - "+----+----+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn subquery_alias_case_insensitive() -> Result<()> { - let partition_count = 1; - let results = - partitioned_csv::execute("SELECT V1.c1, v1.C2 FROM (SELECT test.C1, TEST.c2 FROM test) V1 ORDER BY v1.c1, V1.C2 LIMIT 1", partition_count).await?; - - let expected = [ - "+----+----+", - "| c1 | c2 |", - "+----+----+", - "| 0 | 1 |", - "+----+----+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn projection_on_table_scan() -> Result<()> { - let tmp_dir = TempDir::new()?; - let partition_count = 4; - let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?; - - let table = ctx.table("test").await?; - let logical_plan = LogicalPlanBuilder::from(table.into_optimized_plan()?) - .project(vec![col("c2")])? - .build()?; - - let state = ctx.state(); - let optimized_plan = state.optimize(&logical_plan)?; - match &optimized_plan { - LogicalPlan::TableScan(TableScan { - source, - projected_schema, - .. - }) => { - assert_eq!(source.schema().fields().len(), 3); - assert_eq!(projected_schema.fields().len(), 1); - } - _ => panic!("input to projection should be TableScan"), - } - - let expected = "TableScan: test projection=[c2]"; - assert_eq!(format!("{optimized_plan:?}"), expected); - - let physical_plan = state.create_physical_plan(&optimized_plan).await?; - - assert_eq!(1, physical_plan.schema().fields().len()); - assert_eq!("c2", physical_plan.schema().field(0).name().as_str()); - let batches = collect(physical_plan, state.task_ctx()).await?; - assert_eq!(40, batches.iter().map(|x| x.num_rows()).sum::()); - - Ok(()) -} - -#[tokio::test] -async fn preserve_nullability_on_projection() -> Result<()> { - let tmp_dir = TempDir::new()?; - let ctx = partitioned_csv::create_ctx(&tmp_dir, 1).await?; - - let schema: Schema = ctx.table("test").await.unwrap().schema().clone().into(); - assert!(!schema.field_with_name("c1")?.is_nullable()); - - let plan = scan_empty(None, &schema, None)? - .project(vec![col("c1")])? - .build()?; - - let dataframe = DataFrame::new(ctx.state(), plan); - let physical_plan = dataframe.create_physical_plan().await?; - assert!(!physical_plan.schema().field_with_name("c1")?.is_nullable()); - Ok(()) -} - -#[tokio::test] -async fn project_cast_dictionary() { - let ctx = SessionContext::new(); - - let host: DictionaryArray = vec![Some("host1"), None, Some("host2")] - .into_iter() - .collect(); - - let batch = RecordBatch::try_from_iter(vec![("host", Arc::new(host) as _)]).unwrap(); - - let t = MemTable::try_new(batch.schema(), vec![vec![batch]]).unwrap(); - - // Note that `host` is a dictionary array but `lit("")` is a DataType::Utf8 that needs to be cast - let expr = when(col("host").is_null(), lit("")) - .otherwise(col("host")) - .unwrap(); - - let projection = None; - let builder = LogicalPlanBuilder::scan( - "cpu_load_short", - provider_as_source(Arc::new(t)), - projection, - ) - .unwrap(); - - let logical_plan = builder.project(vec![expr]).unwrap().build().unwrap(); - let df = DataFrame::new(ctx.state(), logical_plan); - let actual = df.collect().await.unwrap(); - - let expected = ["+----------------------------------------------------------------------------------+", - "| CASE WHEN cpu_load_short.host IS NULL THEN Utf8(\"\") ELSE cpu_load_short.host END |", - "+----------------------------------------------------------------------------------+", - "| host1 |", - "| |", - "| host2 |", - "+----------------------------------------------------------------------------------+"]; - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -async fn projection_on_memory_scan() -> Result<()> { - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - Field::new("c", DataType::Int32, false), - ]); - let schema = SchemaRef::new(schema); - - let partitions = vec![vec![RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![1, 10, 10, 100])), - Arc::new(Int32Array::from(vec![2, 12, 12, 120])), - Arc::new(Int32Array::from(vec![3, 12, 12, 120])), - ], - )?]]; - - let provider = Arc::new(MemTable::try_new(schema, partitions)?); - let plan = - LogicalPlanBuilder::scan(UNNAMED_TABLE, provider_as_source(provider), None)? - .project(vec![col("b")])? - .build()?; - assert_fields_eq(&plan, vec!["b"]); - - let ctx = SessionContext::new(); - let state = ctx.state(); - let optimized_plan = state.optimize(&plan)?; - match &optimized_plan { - LogicalPlan::TableScan(TableScan { - source, - projected_schema, - .. - }) => { - assert_eq!(source.schema().fields().len(), 3); - assert_eq!(projected_schema.fields().len(), 1); - } - _ => panic!("input to projection should be InMemoryScan"), - } - - let expected = format!("TableScan: {UNNAMED_TABLE} projection=[b]"); - assert_eq!(format!("{optimized_plan:?}"), expected); - - let physical_plan = state.create_physical_plan(&optimized_plan).await?; - - assert_eq!(1, physical_plan.schema().fields().len()); - assert_eq!("b", physical_plan.schema().field(0).name().as_str()); - - let batches = collect(physical_plan, state.task_ctx()).await?; - assert_eq!(1, batches.len()); - assert_eq!(1, batches[0].num_columns()); - assert_eq!(4, batches[0].num_rows()); - - Ok(()) -} - -fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { - let actual: Vec = plan - .schema() - .fields() - .iter() - .map(|f| f.name().clone()) - .collect(); - assert_eq!(actual, expected); -} - -#[tokio::test] -async fn project_column_with_same_name_as_relation() -> Result<()> { - let ctx = SessionContext::new(); - - let sql = "select a.a from (select 1 as a) as a;"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = ["+---+", "| a |", "+---+", "| 1 |", "+---+"]; - assert_batches_sorted_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn project_column_with_filters_that_cant_pushed_down_always_false() -> Result<()> { - let ctx = SessionContext::new(); - - let sql = "select * from (select 1 as a) f where f.a=2;"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = ["++", "++"]; - assert_batches_sorted_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn project_column_with_filters_that_cant_pushed_down_always_true() -> Result<()> { - let ctx = SessionContext::new(); - - let sql = "select * from (select 1 as a) f where f.a=1;"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = ["+---+", "| a |", "+---+", "| 1 |", "+---+"]; - assert_batches_sorted_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn project_columns_in_memory_without_propagation() -> Result<()> { - let ctx = SessionContext::new(); - - let sql = "select column1 as a from (values (1), (2)) f where f.column1 = 2;"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = ["+---+", "| a |", "+---+", "| 2 |", "+---+"]; - assert_batches_sorted_eq!(expected, &actual); - - Ok(()) -} diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index 63f3e979305ab..cbdea9d729487 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -525,6 +525,53 @@ async fn test_prepare_statement() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_named_query_parameters() -> Result<()> { + let tmp_dir = TempDir::new()?; + let partition_count = 4; + let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?; + + // sql to statement then to logical plan with parameters + // c1 defined as UINT32, c2 defined as UInt64 + let results = ctx + .sql("SELECT c1, c2 FROM test WHERE c1 > $coo AND c1 < $foo") + .await? + .with_param_values(vec![ + ("foo", ScalarValue::UInt32(Some(3))), + ("coo", ScalarValue::UInt32(Some(0))), + ])? + .collect() + .await?; + let expected = vec![ + "+----+----+", + "| c1 | c2 |", + "+----+----+", + "| 1 | 1 |", + "| 1 | 2 |", + "| 1 | 3 |", + "| 1 | 4 |", + "| 1 | 5 |", + "| 1 | 6 |", + "| 1 | 7 |", + "| 1 | 8 |", + "| 1 | 9 |", + "| 1 | 10 |", + "| 2 | 1 |", + "| 2 | 2 |", + "| 2 | 3 |", + "| 2 | 4 |", + "| 2 | 5 |", + "| 2 | 6 |", + "| 2 | 7 |", + "| 2 | 8 |", + "| 2 | 9 |", + "| 2 | 10 |", + "+----+----+", + ]; + assert_batches_sorted_eq!(expected, &results); + Ok(()) +} + #[tokio::test] async fn parallel_query_with_filter() -> Result<()> { let tmp_dir = TempDir::new()?; diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs deleted file mode 100644 index 01f8dd684b23c..0000000000000 --- a/datafusion/core/tests/sql/subqueries.rs +++ /dev/null @@ -1,63 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use super::*; -use crate::sql::execute_to_batches; - -#[tokio::test] -#[ignore] -async fn correlated_scalar_subquery_sum_agg_bug() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", true)?; - - let sql = "select t1.t1_int from t1 where (select sum(t2_int) is null from t2 where t1.t1_id = t2.t2_id)"; - - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - let expected = vec![ - "Projection: t1.t1_int [t1_int:UInt32;N]", - " Inner Join: t1.t1_id = __scalar_sq_1.t2_id [t1_id:UInt32;N, t1_int:UInt32;N, t2_id:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_int] [t1_id:UInt32;N, t1_int:UInt32;N]", - " SubqueryAlias: __scalar_sq_1 [t2_id:UInt32;N]", - " Projection: t2.t2_id [t2_id:UInt32;N]", - " Filter: SUM(t2.t2_int) IS NULL [t2_id:UInt32;N, SUM(t2.t2_int):UInt64;N]", - " Aggregate: groupBy=[[t2.t2_id]], aggr=[[SUM(t2.t2_int)]] [t2_id:UInt32;N, SUM(t2.t2_int):UInt64;N]", - " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - // assert data - let results = execute_to_batches(&ctx, sql).await; - let expected = [ - "+--------+", - "| t1_int |", - "+--------+", - "| 2 |", - "| 4 |", - "| 3 |", - "+--------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} diff --git a/datafusion/core/tests/sql/timestamp.rs b/datafusion/core/tests/sql/timestamp.rs index a18e6831b6157..ada66503a1816 100644 --- a/datafusion/core/tests/sql/timestamp.rs +++ b/datafusion/core/tests/sql/timestamp.rs @@ -742,7 +742,7 @@ async fn test_arrow_typeof() -> Result<()> { "+-----------------------------------------------------------------------+", "| arrow_typeof(date_trunc(Utf8(\"microsecond\"),to_timestamp(Int64(61)))) |", "+-----------------------------------------------------------------------+", - "| Timestamp(Second, None) |", + "| Timestamp(Nanosecond, None) |", "+-----------------------------------------------------------------------+", ]; assert_batches_eq!(expected, &actual); diff --git a/datafusion/core/tests/user_defined/mod.rs b/datafusion/core/tests/user_defined/mod.rs index 09c7c3d3266bc..6c6d966cc3aab 100644 --- a/datafusion/core/tests/user_defined/mod.rs +++ b/datafusion/core/tests/user_defined/mod.rs @@ -26,3 +26,6 @@ mod user_defined_plan; /// Tests for User Defined Window Functions mod user_defined_window_functions; + +/// Tests for User Defined Table Functions +mod user_defined_table_functions; diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index d4a8842c0a7ad..29708c4422cac 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -91,6 +91,7 @@ use datafusion::{ }; use async_trait::async_trait; +use datafusion_common::arrow_datafusion_err; use futures::{Stream, StreamExt}; /// Execute the specified sql and return the resulting record batches @@ -99,7 +100,7 @@ async fn exec_sql(ctx: &mut SessionContext, sql: &str) -> Result { let df = ctx.sql(sql).await?; let batches = df.collect().await?; pretty_format_batches(&batches) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) .map(|d| d.to_string()) } diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 1c7e7137290fc..985b0bd5bc767 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -341,6 +341,43 @@ async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_user_defined_functions_with_alias() -> Result<()> { + let ctx = SessionContext::new(); + let arr = Int32Array::from(vec![1]); + let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?; + ctx.register_batch("t", batch).unwrap(); + + let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0])); + let myfunc = make_scalar_function(myfunc); + + let udf = create_udf( + "dummy", + vec![DataType::Int32], + Arc::new(DataType::Int32), + Volatility::Immutable, + myfunc, + ) + .with_aliases(vec!["dummy_alias"]); + + ctx.register_udf(udf); + + let expected = [ + "+------------+", + "| dummy(t.i) |", + "+------------+", + "| 1 |", + "+------------+", + ]; + let result = plan_and_collect(&ctx, "SELECT dummy(i) FROM t").await?; + assert_batches_eq!(expected, &result); + + let alias_result = plan_and_collect(&ctx, "SELECT dummy_alias(i) FROM t").await?; + assert_batches_eq!(expected, &alias_result); + + Ok(()) +} + fn create_udf_context() -> SessionContext { let ctx = SessionContext::new(); // register a custom UDF diff --git a/datafusion/core/tests/user_defined/user_defined_table_functions.rs b/datafusion/core/tests/user_defined/user_defined_table_functions.rs new file mode 100644 index 0000000000000..b5d10b1c5b9ba --- /dev/null +++ b/datafusion/core/tests/user_defined/user_defined_table_functions.rs @@ -0,0 +1,219 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::Int64Array; +use arrow::csv::reader::Format; +use arrow::csv::ReaderBuilder; +use async_trait::async_trait; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::datasource::function::TableFunctionImpl; +use datafusion::datasource::TableProvider; +use datafusion::error::Result; +use datafusion::execution::context::SessionState; +use datafusion::execution::TaskContext; +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::{collect, ExecutionPlan}; +use datafusion::prelude::SessionContext; +use datafusion_common::{assert_batches_eq, DFSchema, ScalarValue}; +use datafusion_expr::{EmptyRelation, Expr, LogicalPlan, Projection, TableType}; +use std::fs::File; +use std::io::Seek; +use std::path::Path; +use std::sync::Arc; + +/// test simple udtf with define read_csv with parameters +#[tokio::test] +async fn test_simple_read_csv_udtf() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_udtf("read_csv", Arc::new(SimpleCsvTableFunc {})); + + let csv_file = "tests/tpch-csv/nation.csv"; + // read csv with at most 5 rows + let rbs = ctx + .sql(format!("SELECT * FROM read_csv('{csv_file}', 5);").as_str()) + .await? + .collect() + .await?; + + let excepted = [ + "+-------------+-----------+-------------+-------------------------------------------------------------------------------------------------------------+", + "| n_nationkey | n_name | n_regionkey | n_comment |", + "+-------------+-----------+-------------+-------------------------------------------------------------------------------------------------------------+", + "| 1 | ARGENTINA | 1 | al foxes promise slyly according to the regular accounts. bold requests alon |", + "| 2 | BRAZIL | 1 | y alongside of the pending deposits. carefully special packages are about the ironic forges. slyly special |", + "| 3 | CANADA | 1 | eas hang ironic, silent packages. slyly regular packages are furiously over the tithes. fluffily bold |", + "| 4 | EGYPT | 4 | y above the carefully unusual theodolites. final dugouts are quickly across the furiously regular d |", + "| 5 | ETHIOPIA | 0 | ven packages wake quickly. regu |", + "+-------------+-----------+-------------+-------------------------------------------------------------------------------------------------------------+", ]; + assert_batches_eq!(excepted, &rbs); + + // just run, return all rows + let rbs = ctx + .sql(format!("SELECT * FROM read_csv('{csv_file}');").as_str()) + .await? + .collect() + .await?; + let excepted = [ + "+-------------+-----------+-------------+--------------------------------------------------------------------------------------------------------------------+", + "| n_nationkey | n_name | n_regionkey | n_comment |", + "+-------------+-----------+-------------+--------------------------------------------------------------------------------------------------------------------+", + "| 1 | ARGENTINA | 1 | al foxes promise slyly according to the regular accounts. bold requests alon |", + "| 2 | BRAZIL | 1 | y alongside of the pending deposits. carefully special packages are about the ironic forges. slyly special |", + "| 3 | CANADA | 1 | eas hang ironic, silent packages. slyly regular packages are furiously over the tithes. fluffily bold |", + "| 4 | EGYPT | 4 | y above the carefully unusual theodolites. final dugouts are quickly across the furiously regular d |", + "| 5 | ETHIOPIA | 0 | ven packages wake quickly. regu |", + "| 6 | FRANCE | 3 | refully final requests. regular, ironi |", + "| 7 | GERMANY | 3 | l platelets. regular accounts x-ray: unusual, regular acco |", + "| 8 | INDIA | 2 | ss excuses cajole slyly across the packages. deposits print aroun |", + "| 9 | INDONESIA | 2 | slyly express asymptotes. regular deposits haggle slyly. carefully ironic hockey players sleep blithely. carefull |", + "| 10 | IRAN | 4 | efully alongside of the slyly final dependencies. |", + "+-------------+-----------+-------------+--------------------------------------------------------------------------------------------------------------------+" + ]; + assert_batches_eq!(excepted, &rbs); + + Ok(()) +} + +struct SimpleCsvTable { + schema: SchemaRef, + exprs: Vec, + batches: Vec, +} + +#[async_trait] +impl TableProvider for SimpleCsvTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + state: &SessionState, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + let batches = if !self.exprs.is_empty() { + let max_return_lines = self.interpreter_expr(state).await?; + // get max return rows from self.batches + let mut batches = vec![]; + let mut lines = 0; + for batch in &self.batches { + let batch_lines = batch.num_rows(); + if lines + batch_lines > max_return_lines as usize { + let batch_lines = max_return_lines as usize - lines; + batches.push(batch.slice(0, batch_lines)); + break; + } else { + batches.push(batch.clone()); + lines += batch_lines; + } + } + batches + } else { + self.batches.clone() + }; + Ok(Arc::new(MemoryExec::try_new( + &[batches], + TableProvider::schema(self), + projection.cloned(), + )?)) + } +} + +impl SimpleCsvTable { + async fn interpreter_expr(&self, state: &SessionState) -> Result { + use datafusion::logical_expr::expr_rewriter::normalize_col; + use datafusion::logical_expr::utils::columnize_expr; + let plan = LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: true, + schema: Arc::new(DFSchema::empty()), + }); + let logical_plan = Projection::try_new( + vec![columnize_expr( + normalize_col(self.exprs[0].clone(), &plan)?, + plan.schema(), + )], + Arc::new(plan), + ) + .map(LogicalPlan::Projection)?; + let rbs = collect( + state.create_physical_plan(&logical_plan).await?, + Arc::new(TaskContext::from(state)), + ) + .await?; + let limit = rbs[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0); + Ok(limit) + } +} + +struct SimpleCsvTableFunc {} + +impl TableFunctionImpl for SimpleCsvTableFunc { + fn call(&self, exprs: &[Expr]) -> Result> { + let mut new_exprs = vec![]; + let mut filepath = String::new(); + for expr in exprs { + match expr { + Expr::Literal(ScalarValue::Utf8(Some(ref path))) => { + filepath = path.clone() + } + expr => new_exprs.push(expr.clone()), + } + } + let (schema, batches) = read_csv_batches(filepath)?; + let table = SimpleCsvTable { + schema, + exprs: new_exprs.clone(), + batches, + }; + Ok(Arc::new(table)) + } +} + +fn read_csv_batches(csv_path: impl AsRef) -> Result<(SchemaRef, Vec)> { + let mut file = File::open(csv_path)?; + let (schema, _) = Format::default() + .with_header(true) + .infer_schema(&mut file, None)?; + file.rewind()?; + + let reader = ReaderBuilder::new(Arc::new(schema.clone())) + .with_header(true) + .build(file)?; + let mut batches = vec![]; + for bacth in reader { + batches.push(bacth?); + } + let schema = Arc::new(schema); + Ok((schema, batches)) +} diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index 5f99391572174..3040fbafe81af 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -19,6 +19,7 @@ //! user defined window functions use std::{ + any::Any, ops::Range, sync::{ atomic::{AtomicUsize, Ordering}, @@ -32,8 +33,7 @@ use arrow_schema::DataType; use datafusion::{assert_batches_eq, prelude::SessionContext}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ - function::PartitionEvaluatorFactory, PartitionEvaluator, ReturnTypeFunction, - Signature, Volatility, WindowUDF, + PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl, }; /// A query with a window function evaluated over the entire partition @@ -471,24 +471,48 @@ impl OddCounter { } fn register(ctx: &mut SessionContext, test_state: Arc) { - let name = "odd_counter"; - let volatility = Volatility::Immutable; - - let signature = Signature::exact(vec![DataType::Int64], volatility); - - let return_type = Arc::new(DataType::Int64); - let return_type: ReturnTypeFunction = - Arc::new(move |_| Ok(Arc::clone(&return_type))); - - let partition_evaluator_factory: PartitionEvaluatorFactory = - Arc::new(move || Ok(Box::new(OddCounter::new(Arc::clone(&test_state))))); - - ctx.register_udwf(WindowUDF::new( - name, - &signature, - &return_type, - &partition_evaluator_factory, - )) + struct SimpleWindowUDF { + signature: Signature, + return_type: DataType, + test_state: Arc, + } + + impl SimpleWindowUDF { + fn new(test_state: Arc) -> Self { + let signature = + Signature::exact(vec![DataType::Float64], Volatility::Immutable); + let return_type = DataType::Int64; + Self { + signature, + return_type, + test_state, + } + } + } + + impl WindowUDFImpl for SimpleWindowUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "odd_counter" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn partition_evaluator(&self) -> Result> { + Ok(Box::new(OddCounter::new(Arc::clone(&self.test_state)))) + } + } + + ctx.register_udwf(WindowUDF::from(SimpleWindowUDF::new(test_state))) } } diff --git a/datafusion/execution/src/cache/cache_unit.rs b/datafusion/execution/src/cache/cache_unit.rs index 4a21dc02bd137..25f9b9fa4d687 100644 --- a/datafusion/execution/src/cache/cache_unit.rs +++ b/datafusion/execution/src/cache/cache_unit.rs @@ -176,6 +176,7 @@ mod tests { .into(), size: 1024, e_tag: None, + version: None, }; let cache = DefaultFileStatisticsCache::default(); assert!(cache.get_with_extra(&meta.location, &meta).is_none()); @@ -219,6 +220,7 @@ mod tests { .into(), size: 1024, e_tag: None, + version: None, }; let cache = DefaultListFilesCache::default(); @@ -226,7 +228,7 @@ mod tests { cache.put(&meta.location, vec![meta.clone()].into()); assert_eq!( - cache.get(&meta.location).unwrap().get(0).unwrap().clone(), + cache.get(&meta.location).unwrap().first().unwrap().clone(), meta.clone() ); } diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index cfcc205b56252..8556335b395a9 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -86,7 +86,7 @@ impl SessionConfig { /// Set a generic `str` configuration option pub fn set_str(self, key: &str, value: &str) -> Self { - self.set(key, ScalarValue::Utf8(Some(value.to_string()))) + self.set(key, ScalarValue::from(value)) } /// Customize batch size diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 5b1b421538772..3e05dae61954a 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -35,10 +35,13 @@ path = "src/lib.rs" [features] [dependencies] -ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } +ahash = { version = "0.8", default-features = false, features = [ + "runtime-rng", +] } arrow = { workspace = true } arrow-array = { workspace = true } datafusion-common = { workspace = true } +paste = "^1.0" sqlparser = { workspace = true } strum = { version = "0.25.0", features = ["derive"] } strum_macros = "0.25.0" diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index eaf4ff5ad806b..cea72c3cb5e6b 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -100,10 +100,12 @@ pub enum AggregateFunction { BoolAnd, /// Bool Or BoolOr, + /// string_agg + StringAgg, } impl AggregateFunction { - fn name(&self) -> &str { + pub fn name(&self) -> &str { use AggregateFunction::*; match self { Count => "COUNT", @@ -116,13 +118,13 @@ impl AggregateFunction { ArrayAgg => "ARRAY_AGG", FirstValue => "FIRST_VALUE", LastValue => "LAST_VALUE", - Variance => "VARIANCE", - VariancePop => "VARIANCE_POP", + Variance => "VAR", + VariancePop => "VAR_POP", Stddev => "STDDEV", StddevPop => "STDDEV_POP", - Covariance => "COVARIANCE", - CovariancePop => "COVARIANCE_POP", - Correlation => "CORRELATION", + Covariance => "COVAR", + CovariancePop => "COVAR_POP", + Correlation => "CORR", RegrSlope => "REGR_SLOPE", RegrIntercept => "REGR_INTERCEPT", RegrCount => "REGR_COUNT", @@ -141,6 +143,7 @@ impl AggregateFunction { BitXor => "BIT_XOR", BoolAnd => "BOOL_AND", BoolOr => "BOOL_OR", + StringAgg => "STRING_AGG", } } } @@ -171,6 +174,7 @@ impl FromStr for AggregateFunction { "array_agg" => AggregateFunction::ArrayAgg, "first_value" => AggregateFunction::FirstValue, "last_value" => AggregateFunction::LastValue, + "string_agg" => AggregateFunction::StringAgg, // statistical "corr" => AggregateFunction::Correlation, "covar" => AggregateFunction::Covariance, @@ -299,6 +303,7 @@ impl AggregateFunction { AggregateFunction::FirstValue | AggregateFunction::LastValue => { Ok(coerced_data_types[0].clone()) } + AggregateFunction::StringAgg => Ok(DataType::LargeUtf8), } } } @@ -408,6 +413,30 @@ impl AggregateFunction { .collect(), Volatility::Immutable, ), + AggregateFunction::StringAgg => { + Signature::uniform(2, STRINGS.to_vec(), Volatility::Immutable) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use strum::IntoEnumIterator; + + #[test] + // Test for AggregateFuncion's Display and from_str() implementations. + // For each variant in AggregateFuncion, it converts the variant to a string + // and then back to a variant. The test asserts that the original variant and + // the reconstructed variant are the same. This assertion is also necessary for + // function suggestion. See https://github.com/apache/arrow-datafusion/issues/8082 + fn test_display_and_from_str() { + for func_original in AggregateFunction::iter() { + let func_name = func_original.to_string(); + let func_from_str = + AggregateFunction::from_str(func_name.to_lowercase().as_str()).unwrap(); + assert_eq!(func_from_str, func_original); } } } diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 16187572c521f..e642dae06e4fd 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -28,14 +28,11 @@ use crate::signature::TIMEZONE_WILDCARD; use crate::type_coercion::binary::get_wider_type; use crate::type_coercion::functions::data_types; use crate::{ - conditional_expressions, struct_expressions, utils, FuncMonotonicity, Signature, - TypeSignature, Volatility, + conditional_expressions, FuncMonotonicity, Signature, TypeSignature, Volatility, }; use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit}; -use datafusion_common::{ - internal_err, plan_datafusion_err, plan_err, DataFusionError, Result, -}; +use datafusion_common::{internal_err, plan_err, DataFusionError, Result}; use strum::IntoEnumIterator; use strum_macros::EnumIter; @@ -132,6 +129,8 @@ pub enum BuiltinScalarFunction { // array functions /// array_append ArrayAppend, + /// array_sort + ArraySort, /// array_concat ArrayConcat, /// array_has @@ -140,10 +139,14 @@ pub enum BuiltinScalarFunction { ArrayHasAll, /// array_has_any ArrayHasAny, + /// array_pop_front + ArrayPopFront, /// array_pop_back ArrayPopBack, /// array_dims ArrayDims, + /// array_distinct + ArrayDistinct, /// array_element ArrayElement, /// array_empty @@ -176,12 +179,20 @@ pub enum BuiltinScalarFunction { ArraySlice, /// array_to_string ArrayToString, + /// array_intersect + ArrayIntersect, + /// array_union + ArrayUnion, + /// array_except + ArrayExcept, /// cardinality Cardinality, /// construct an array from columns MakeArray, /// Flatten Flatten, + /// Range + Range, // struct functions /// struct @@ -290,6 +301,14 @@ pub enum BuiltinScalarFunction { RegexpMatch, /// arrow_typeof ArrowTypeof, + /// overlay + OverLay, + /// levenshtein + Levenshtein, + /// substr_index + SubstrIndex, + /// find_in_set + FindInSet, } /// Maps the sql function name to `BuiltinScalarFunction` @@ -299,8 +318,7 @@ fn name_to_function() -> &'static HashMap<&'static str, BuiltinScalarFunction> { NAME_TO_FUNCTION_LOCK.get_or_init(|| { let mut map = HashMap::new(); BuiltinScalarFunction::iter().for_each(|func| { - let a = aliases(&func); - a.iter().for_each(|&a| { + func.aliases().iter().for_each(|&a| { map.insert(a, func); }); }); @@ -316,7 +334,7 @@ fn function_to_name() -> &'static HashMap { FUNCTION_TO_NAME_LOCK.get_or_init(|| { let mut map = HashMap::new(); BuiltinScalarFunction::iter().for_each(|func| { - map.insert(func, *aliases(&func).first().unwrap_or(&"NO_ALIAS")); + map.insert(func, *func.aliases().first().unwrap_or(&"NO_ALIAS")); }); map }) @@ -333,6 +351,12 @@ impl BuiltinScalarFunction { self.signature().type_signature.supports_zero_argument() } + /// Returns the name of this function + pub fn name(&self) -> &str { + // .unwrap is safe here because compiler makes sure the map will have matches for each BuiltinScalarFunction + function_to_name().get(self).unwrap() + } + /// Returns the [Volatility] of the builtin function. pub fn volatility(&self) -> Volatility { match self { @@ -377,15 +401,19 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Tanh => Volatility::Immutable, BuiltinScalarFunction::Trunc => Volatility::Immutable, BuiltinScalarFunction::ArrayAppend => Volatility::Immutable, + BuiltinScalarFunction::ArraySort => Volatility::Immutable, BuiltinScalarFunction::ArrayConcat => Volatility::Immutable, BuiltinScalarFunction::ArrayEmpty => Volatility::Immutable, BuiltinScalarFunction::ArrayHasAll => Volatility::Immutable, BuiltinScalarFunction::ArrayHasAny => Volatility::Immutable, BuiltinScalarFunction::ArrayHas => Volatility::Immutable, BuiltinScalarFunction::ArrayDims => Volatility::Immutable, + BuiltinScalarFunction::ArrayDistinct => Volatility::Immutable, BuiltinScalarFunction::ArrayElement => Volatility::Immutable, + BuiltinScalarFunction::ArrayExcept => Volatility::Immutable, BuiltinScalarFunction::ArrayLength => Volatility::Immutable, BuiltinScalarFunction::ArrayNdims => Volatility::Immutable, + BuiltinScalarFunction::ArrayPopFront => Volatility::Immutable, BuiltinScalarFunction::ArrayPopBack => Volatility::Immutable, BuiltinScalarFunction::ArrayPosition => Volatility::Immutable, BuiltinScalarFunction::ArrayPositions => Volatility::Immutable, @@ -400,6 +428,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Flatten => Volatility::Immutable, BuiltinScalarFunction::ArraySlice => Volatility::Immutable, BuiltinScalarFunction::ArrayToString => Volatility::Immutable, + BuiltinScalarFunction::ArrayIntersect => Volatility::Immutable, + BuiltinScalarFunction::ArrayUnion => Volatility::Immutable, + BuiltinScalarFunction::Range => Volatility::Immutable, BuiltinScalarFunction::Cardinality => Volatility::Immutable, BuiltinScalarFunction::MakeArray => Volatility::Immutable, BuiltinScalarFunction::Ascii => Volatility::Immutable, @@ -451,6 +482,10 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Struct => Volatility::Immutable, BuiltinScalarFunction::FromUnixtime => Volatility::Immutable, BuiltinScalarFunction::ArrowTypeof => Volatility::Immutable, + BuiltinScalarFunction::OverLay => Volatility::Immutable, + BuiltinScalarFunction::Levenshtein => Volatility::Immutable, + BuiltinScalarFunction::SubstrIndex => Volatility::Immutable, + BuiltinScalarFunction::FindInSet => Volatility::Immutable, // Stable builtin functions BuiltinScalarFunction::Now => Volatility::Stable, @@ -483,6 +518,13 @@ impl BuiltinScalarFunction { } /// Returns the output [`DataType`] of this function + /// + /// This method should be invoked only after `input_expr_types` have been validated + /// against the function's `TypeSignature` using `type_coercion::functions::data_types()`. + /// + /// This method will: + /// 1. Perform additional checks on `input_expr_types` that are beyond the scope of `TypeSignature` validation. + /// 2. Deduce the output `DataType` based on the provided `input_expr_types`. pub fn return_type(self, input_expr_types: &[DataType]) -> Result { use DataType::*; use TimeUnit::*; @@ -490,31 +532,6 @@ impl BuiltinScalarFunction { // Note that this function *must* return the same type that the respective physical expression returns // or the execution panics. - if input_expr_types.is_empty() - && !self.signature().type_signature.supports_zero_argument() - { - return plan_err!( - "{}", - utils::generate_signature_error_msg( - &format!("{self}"), - self.signature(), - input_expr_types - ) - ); - } - - // verify that this is a valid set of data types for this function - data_types(input_expr_types, &self.signature()).map_err(|_| { - plan_datafusion_err!( - "{}", - utils::generate_signature_error_msg( - &format!("{self}"), - self.signature(), - input_expr_types, - ) - ) - })?; - // the return type of the built in function. // Some built-in functions' return type depends on the incoming type. match self { @@ -533,6 +550,7 @@ impl BuiltinScalarFunction { Ok(data_type) } BuiltinScalarFunction::ArrayAppend => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::ArraySort => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayConcat => { let mut expr_type = Null; let mut max_dims = 0; @@ -570,14 +588,17 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayDims => { Ok(List(Arc::new(Field::new("item", UInt64, true)))) } + BuiltinScalarFunction::ArrayDistinct => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayElement => match &input_expr_types[0] { List(field) => Ok(field.data_type().clone()), + LargeList(field) => Ok(field.data_type().clone()), _ => plan_err!( - "The {self} function can only accept list as the first argument" + "The {self} function can only accept list or largelist as the first argument" ), }, BuiltinScalarFunction::ArrayLength => Ok(UInt64), BuiltinScalarFunction::ArrayNdims => Ok(UInt64), + BuiltinScalarFunction::ArrayPopFront => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayPopBack => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayPosition => Ok(UInt64), BuiltinScalarFunction::ArrayPositions => { @@ -597,6 +618,35 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayToString => Ok(Utf8), + BuiltinScalarFunction::ArrayIntersect => { + match (input_expr_types[0].clone(), input_expr_types[1].clone()) { + (DataType::Null, DataType::Null) | (DataType::Null, _) => { + Ok(DataType::Null) + } + (_, DataType::Null) => { + Ok(List(Arc::new(Field::new("item", Null, true)))) + } + (dt, _) => Ok(dt), + } + } + BuiltinScalarFunction::ArrayUnion => { + match (input_expr_types[0].clone(), input_expr_types[1].clone()) { + (DataType::Null, dt) => Ok(dt), + (dt, DataType::Null) => Ok(dt), + (dt, _) => Ok(dt), + } + } + BuiltinScalarFunction::Range => { + Ok(List(Arc::new(Field::new("item", Int64, true)))) + } + BuiltinScalarFunction::ArrayExcept => { + match (input_expr_types[0].clone(), input_expr_types[1].clone()) { + (DataType::Null, _) | (_, DataType::Null) => { + Ok(input_expr_types[0].clone()) + } + (dt, _) => Ok(dt), + } + } BuiltinScalarFunction::Cardinality => Ok(UInt64), BuiltinScalarFunction::MakeArray => match input_expr_types.len() { 0 => Ok(List(Arc::new(Field::new("item", Null, true)))), @@ -753,13 +803,16 @@ impl BuiltinScalarFunction { return plan_err!("The to_hex function can only accept integers."); } }), - BuiltinScalarFunction::ToTimestamp => Ok(match &input_expr_types[0] { - Int64 => Timestamp(Second, None), - _ => Timestamp(Nanosecond, None), - }), + BuiltinScalarFunction::SubstrIndex => { + utf8_to_str_type(&input_expr_types[0], "substr_index") + } + BuiltinScalarFunction::FindInSet => { + utf8_to_int_type(&input_expr_types[0], "find_in_set") + } + BuiltinScalarFunction::ToTimestamp + | BuiltinScalarFunction::ToTimestampNanos => Ok(Timestamp(Nanosecond, None)), BuiltinScalarFunction::ToTimestampMillis => Ok(Timestamp(Millisecond, None)), BuiltinScalarFunction::ToTimestampMicros => Ok(Timestamp(Microsecond, None)), - BuiltinScalarFunction::ToTimestampNanos => Ok(Timestamp(Nanosecond, None)), BuiltinScalarFunction::ToTimestampSeconds => Ok(Timestamp(Second, None)), BuiltinScalarFunction::FromUnixtime => Ok(Timestamp(Second, None)), BuiltinScalarFunction::Now => { @@ -824,6 +877,14 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Abs => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::OverLay => { + utf8_to_str_type(&input_expr_types[0], "overlay") + } + + BuiltinScalarFunction::Levenshtein => { + utf8_to_int_type(&input_expr_types[0], "levenshtein") + } + BuiltinScalarFunction::Acos | BuiltinScalarFunction::Asin | BuiltinScalarFunction::Atan @@ -866,7 +927,18 @@ impl BuiltinScalarFunction { // for now, the list is small, as we do not have many built-in functions. match self { - BuiltinScalarFunction::ArrayAppend => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArraySort => { + Signature::variadic_any(self.volatility()) + } + BuiltinScalarFunction::ArrayAppend => Signature { + type_signature: ArrayAndElement, + volatility: self.volatility(), + }, + BuiltinScalarFunction::MakeArray => { + // 0 or more arguments of arbitrary type + Signature::one_of(vec![VariadicEqual, Any(0)], self.volatility()) + } + BuiltinScalarFunction::ArrayPopFront => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayPopBack => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayConcat => { Signature::variadic_any(self.volatility()) @@ -874,6 +946,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayDims => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayEmpty => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayElement => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayExcept => Signature::any(2, self.volatility()), BuiltinScalarFunction::Flatten => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayHasAll | BuiltinScalarFunction::ArrayHasAny @@ -882,11 +955,15 @@ impl BuiltinScalarFunction { Signature::variadic_any(self.volatility()) } BuiltinScalarFunction::ArrayNdims => Signature::any(1, self.volatility()), + BuiltinScalarFunction::ArrayDistinct => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayPosition => { Signature::variadic_any(self.volatility()) } BuiltinScalarFunction::ArrayPositions => Signature::any(2, self.volatility()), - BuiltinScalarFunction::ArrayPrepend => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayPrepend => Signature { + type_signature: ElementAndArray, + volatility: self.volatility(), + }, BuiltinScalarFunction::ArrayRepeat => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayRemove => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayRemoveN => Signature::any(3, self.volatility()), @@ -900,15 +977,18 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayToString => { Signature::variadic_any(self.volatility()) } + BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayUnion => Signature::any(2, self.volatility()), BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()), - BuiltinScalarFunction::MakeArray => { - // 0 or more arguments of arbitrary type - Signature::one_of(vec![VariadicAny, Any(0)], self.volatility()) - } - BuiltinScalarFunction::Struct => Signature::variadic( - struct_expressions::SUPPORTED_STRUCT_TYPES.to_vec(), + BuiltinScalarFunction::Range => Signature::one_of( + vec![ + Exact(vec![Int64]), + Exact(vec![Int64, Int64]), + Exact(vec![Int64, Int64, Int64]), + ], self.volatility(), ), + BuiltinScalarFunction::Struct => Signature::variadic_any(self.volatility()), BuiltinScalarFunction::Concat | BuiltinScalarFunction::ConcatWithSeparator => { Signature::variadic(vec![Utf8], self.volatility()) @@ -969,6 +1049,7 @@ impl BuiltinScalarFunction { 1, vec![ Int64, + Float64, Timestamp(Nanosecond, None), Timestamp(Microsecond, None), Timestamp(Millisecond, None), @@ -1195,6 +1276,18 @@ impl BuiltinScalarFunction { self.volatility(), ), + BuiltinScalarFunction::SubstrIndex => Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + ], + self.volatility(), + ), + BuiltinScalarFunction::FindInSet => Signature::one_of( + vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])], + self.volatility(), + ), + BuiltinScalarFunction::Replace | BuiltinScalarFunction::Translate => { Signature::one_of(vec![Exact(vec![Utf8, Utf8, Utf8])], self.volatility()) } @@ -1268,7 +1361,19 @@ impl BuiltinScalarFunction { } BuiltinScalarFunction::ArrowTypeof => Signature::any(1, self.volatility()), BuiltinScalarFunction::Abs => Signature::any(1, self.volatility()), - + BuiltinScalarFunction::OverLay => Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8, Int64, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]), + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + ], + self.volatility(), + ), + BuiltinScalarFunction::Levenshtein => Signature::one_of( + vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])], + self.volatility(), + ), BuiltinScalarFunction::Acos | BuiltinScalarFunction::Asin | BuiltinScalarFunction::Atan @@ -1354,187 +1459,210 @@ impl BuiltinScalarFunction { None } } -} -fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { - match func { - BuiltinScalarFunction::Abs => &["abs"], - BuiltinScalarFunction::Acos => &["acos"], - BuiltinScalarFunction::Acosh => &["acosh"], - BuiltinScalarFunction::Asin => &["asin"], - BuiltinScalarFunction::Asinh => &["asinh"], - BuiltinScalarFunction::Atan => &["atan"], - BuiltinScalarFunction::Atanh => &["atanh"], - BuiltinScalarFunction::Atan2 => &["atan2"], - BuiltinScalarFunction::Cbrt => &["cbrt"], - BuiltinScalarFunction::Ceil => &["ceil"], - BuiltinScalarFunction::Cos => &["cos"], - BuiltinScalarFunction::Cot => &["cot"], - BuiltinScalarFunction::Cosh => &["cosh"], - BuiltinScalarFunction::Degrees => &["degrees"], - BuiltinScalarFunction::Exp => &["exp"], - BuiltinScalarFunction::Factorial => &["factorial"], - BuiltinScalarFunction::Floor => &["floor"], - BuiltinScalarFunction::Gcd => &["gcd"], - BuiltinScalarFunction::Isnan => &["isnan"], - BuiltinScalarFunction::Iszero => &["iszero"], - BuiltinScalarFunction::Lcm => &["lcm"], - BuiltinScalarFunction::Ln => &["ln"], - BuiltinScalarFunction::Log => &["log"], - BuiltinScalarFunction::Log10 => &["log10"], - BuiltinScalarFunction::Log2 => &["log2"], - BuiltinScalarFunction::Nanvl => &["nanvl"], - BuiltinScalarFunction::Pi => &["pi"], - BuiltinScalarFunction::Power => &["power", "pow"], - BuiltinScalarFunction::Radians => &["radians"], - BuiltinScalarFunction::Random => &["random"], - BuiltinScalarFunction::Round => &["round"], - BuiltinScalarFunction::Signum => &["signum"], - BuiltinScalarFunction::Sin => &["sin"], - BuiltinScalarFunction::Sinh => &["sinh"], - BuiltinScalarFunction::Sqrt => &["sqrt"], - BuiltinScalarFunction::Tan => &["tan"], - BuiltinScalarFunction::Tanh => &["tanh"], - BuiltinScalarFunction::Trunc => &["trunc"], + /// Returns all names that can be used to call this function + pub fn aliases(&self) -> &'static [&'static str] { + match self { + BuiltinScalarFunction::Abs => &["abs"], + BuiltinScalarFunction::Acos => &["acos"], + BuiltinScalarFunction::Acosh => &["acosh"], + BuiltinScalarFunction::Asin => &["asin"], + BuiltinScalarFunction::Asinh => &["asinh"], + BuiltinScalarFunction::Atan => &["atan"], + BuiltinScalarFunction::Atanh => &["atanh"], + BuiltinScalarFunction::Atan2 => &["atan2"], + BuiltinScalarFunction::Cbrt => &["cbrt"], + BuiltinScalarFunction::Ceil => &["ceil"], + BuiltinScalarFunction::Cos => &["cos"], + BuiltinScalarFunction::Cot => &["cot"], + BuiltinScalarFunction::Cosh => &["cosh"], + BuiltinScalarFunction::Degrees => &["degrees"], + BuiltinScalarFunction::Exp => &["exp"], + BuiltinScalarFunction::Factorial => &["factorial"], + BuiltinScalarFunction::Floor => &["floor"], + BuiltinScalarFunction::Gcd => &["gcd"], + BuiltinScalarFunction::Isnan => &["isnan"], + BuiltinScalarFunction::Iszero => &["iszero"], + BuiltinScalarFunction::Lcm => &["lcm"], + BuiltinScalarFunction::Ln => &["ln"], + BuiltinScalarFunction::Log => &["log"], + BuiltinScalarFunction::Log10 => &["log10"], + BuiltinScalarFunction::Log2 => &["log2"], + BuiltinScalarFunction::Nanvl => &["nanvl"], + BuiltinScalarFunction::Pi => &["pi"], + BuiltinScalarFunction::Power => &["power", "pow"], + BuiltinScalarFunction::Radians => &["radians"], + BuiltinScalarFunction::Random => &["random"], + BuiltinScalarFunction::Round => &["round"], + BuiltinScalarFunction::Signum => &["signum"], + BuiltinScalarFunction::Sin => &["sin"], + BuiltinScalarFunction::Sinh => &["sinh"], + BuiltinScalarFunction::Sqrt => &["sqrt"], + BuiltinScalarFunction::Tan => &["tan"], + BuiltinScalarFunction::Tanh => &["tanh"], + BuiltinScalarFunction::Trunc => &["trunc"], - // conditional functions - BuiltinScalarFunction::Coalesce => &["coalesce"], - BuiltinScalarFunction::NullIf => &["nullif"], + // conditional functions + BuiltinScalarFunction::Coalesce => &["coalesce"], + BuiltinScalarFunction::NullIf => &["nullif"], - // string functions - BuiltinScalarFunction::Ascii => &["ascii"], - BuiltinScalarFunction::BitLength => &["bit_length"], - BuiltinScalarFunction::Btrim => &["btrim"], - BuiltinScalarFunction::CharacterLength => { - &["character_length", "char_length", "length"] - } - BuiltinScalarFunction::Concat => &["concat"], - BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"], - BuiltinScalarFunction::Chr => &["chr"], - BuiltinScalarFunction::InitCap => &["initcap"], - BuiltinScalarFunction::Left => &["left"], - BuiltinScalarFunction::Lower => &["lower"], - BuiltinScalarFunction::Lpad => &["lpad"], - BuiltinScalarFunction::Ltrim => &["ltrim"], - BuiltinScalarFunction::OctetLength => &["octet_length"], - BuiltinScalarFunction::Repeat => &["repeat"], - BuiltinScalarFunction::Replace => &["replace"], - BuiltinScalarFunction::Reverse => &["reverse"], - BuiltinScalarFunction::Right => &["right"], - BuiltinScalarFunction::Rpad => &["rpad"], - BuiltinScalarFunction::Rtrim => &["rtrim"], - BuiltinScalarFunction::SplitPart => &["split_part"], - BuiltinScalarFunction::StringToArray => &["string_to_array", "string_to_list"], - BuiltinScalarFunction::StartsWith => &["starts_with"], - BuiltinScalarFunction::Strpos => &["strpos"], - BuiltinScalarFunction::Substr => &["substr"], - BuiltinScalarFunction::ToHex => &["to_hex"], - BuiltinScalarFunction::Translate => &["translate"], - BuiltinScalarFunction::Trim => &["trim"], - BuiltinScalarFunction::Upper => &["upper"], - BuiltinScalarFunction::Uuid => &["uuid"], + // string functions + BuiltinScalarFunction::Ascii => &["ascii"], + BuiltinScalarFunction::BitLength => &["bit_length"], + BuiltinScalarFunction::Btrim => &["btrim"], + BuiltinScalarFunction::CharacterLength => { + &["character_length", "char_length", "length"] + } + BuiltinScalarFunction::Concat => &["concat"], + BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"], + BuiltinScalarFunction::Chr => &["chr"], + BuiltinScalarFunction::InitCap => &["initcap"], + BuiltinScalarFunction::Left => &["left"], + BuiltinScalarFunction::Lower => &["lower"], + BuiltinScalarFunction::Lpad => &["lpad"], + BuiltinScalarFunction::Ltrim => &["ltrim"], + BuiltinScalarFunction::OctetLength => &["octet_length"], + BuiltinScalarFunction::Repeat => &["repeat"], + BuiltinScalarFunction::Replace => &["replace"], + BuiltinScalarFunction::Reverse => &["reverse"], + BuiltinScalarFunction::Right => &["right"], + BuiltinScalarFunction::Rpad => &["rpad"], + BuiltinScalarFunction::Rtrim => &["rtrim"], + BuiltinScalarFunction::SplitPart => &["split_part"], + BuiltinScalarFunction::StringToArray => { + &["string_to_array", "string_to_list"] + } + BuiltinScalarFunction::StartsWith => &["starts_with"], + BuiltinScalarFunction::Strpos => &["strpos"], + BuiltinScalarFunction::Substr => &["substr"], + BuiltinScalarFunction::ToHex => &["to_hex"], + BuiltinScalarFunction::Translate => &["translate"], + BuiltinScalarFunction::Trim => &["trim"], + BuiltinScalarFunction::Upper => &["upper"], + BuiltinScalarFunction::Uuid => &["uuid"], + BuiltinScalarFunction::Levenshtein => &["levenshtein"], + BuiltinScalarFunction::SubstrIndex => &["substr_index", "substring_index"], + BuiltinScalarFunction::FindInSet => &["find_in_set"], - // regex functions - BuiltinScalarFunction::RegexpMatch => &["regexp_match"], - BuiltinScalarFunction::RegexpReplace => &["regexp_replace"], + // regex functions + BuiltinScalarFunction::RegexpMatch => &["regexp_match"], + BuiltinScalarFunction::RegexpReplace => &["regexp_replace"], - // time/date functions - BuiltinScalarFunction::Now => &["now"], - BuiltinScalarFunction::CurrentDate => &["current_date"], - BuiltinScalarFunction::CurrentTime => &["current_time"], - BuiltinScalarFunction::DateBin => &["date_bin"], - BuiltinScalarFunction::DateTrunc => &["date_trunc", "datetrunc"], - BuiltinScalarFunction::DatePart => &["date_part", "datepart"], - BuiltinScalarFunction::ToTimestamp => &["to_timestamp"], - BuiltinScalarFunction::ToTimestampMillis => &["to_timestamp_millis"], - BuiltinScalarFunction::ToTimestampMicros => &["to_timestamp_micros"], - BuiltinScalarFunction::ToTimestampSeconds => &["to_timestamp_seconds"], - BuiltinScalarFunction::ToTimestampNanos => &["to_timestamp_nanos"], - BuiltinScalarFunction::FromUnixtime => &["from_unixtime"], + // time/date functions + BuiltinScalarFunction::Now => &["now"], + BuiltinScalarFunction::CurrentDate => &["current_date", "today"], + BuiltinScalarFunction::CurrentTime => &["current_time"], + BuiltinScalarFunction::DateBin => &["date_bin"], + BuiltinScalarFunction::DateTrunc => &["date_trunc", "datetrunc"], + BuiltinScalarFunction::DatePart => &["date_part", "datepart"], + BuiltinScalarFunction::ToTimestamp => &["to_timestamp"], + BuiltinScalarFunction::ToTimestampMillis => &["to_timestamp_millis"], + BuiltinScalarFunction::ToTimestampMicros => &["to_timestamp_micros"], + BuiltinScalarFunction::ToTimestampSeconds => &["to_timestamp_seconds"], + BuiltinScalarFunction::ToTimestampNanos => &["to_timestamp_nanos"], + BuiltinScalarFunction::FromUnixtime => &["from_unixtime"], - // hashing functions - BuiltinScalarFunction::Digest => &["digest"], - BuiltinScalarFunction::MD5 => &["md5"], - BuiltinScalarFunction::SHA224 => &["sha224"], - BuiltinScalarFunction::SHA256 => &["sha256"], - BuiltinScalarFunction::SHA384 => &["sha384"], - BuiltinScalarFunction::SHA512 => &["sha512"], + // hashing functions + BuiltinScalarFunction::Digest => &["digest"], + BuiltinScalarFunction::MD5 => &["md5"], + BuiltinScalarFunction::SHA224 => &["sha224"], + BuiltinScalarFunction::SHA256 => &["sha256"], + BuiltinScalarFunction::SHA384 => &["sha384"], + BuiltinScalarFunction::SHA512 => &["sha512"], - // encode/decode - BuiltinScalarFunction::Encode => &["encode"], - BuiltinScalarFunction::Decode => &["decode"], + // encode/decode + BuiltinScalarFunction::Encode => &["encode"], + BuiltinScalarFunction::Decode => &["decode"], - // other functions - BuiltinScalarFunction::ArrowTypeof => &["arrow_typeof"], + // other functions + BuiltinScalarFunction::ArrowTypeof => &["arrow_typeof"], - // array functions - BuiltinScalarFunction::ArrayAppend => &[ - "array_append", - "list_append", - "array_push_back", - "list_push_back", - ], - BuiltinScalarFunction::ArrayConcat => { - &["array_concat", "array_cat", "list_concat", "list_cat"] - } - BuiltinScalarFunction::ArrayDims => &["array_dims", "list_dims"], - BuiltinScalarFunction::ArrayEmpty => &["empty"], - BuiltinScalarFunction::ArrayElement => &[ - "array_element", - "array_extract", - "list_element", - "list_extract", - ], - BuiltinScalarFunction::Flatten => &["flatten"], - BuiltinScalarFunction::ArrayHasAll => &["array_has_all", "list_has_all"], - BuiltinScalarFunction::ArrayHasAny => &["array_has_any", "list_has_any"], - BuiltinScalarFunction::ArrayHas => { - &["array_has", "list_has", "array_contains", "list_contains"] - } - BuiltinScalarFunction::ArrayLength => &["array_length", "list_length"], - BuiltinScalarFunction::ArrayNdims => &["array_ndims", "list_ndims"], - BuiltinScalarFunction::ArrayPopBack => &["array_pop_back", "list_pop_back"], - BuiltinScalarFunction::ArrayPosition => &[ - "array_position", - "list_position", - "array_indexof", - "list_indexof", - ], - BuiltinScalarFunction::ArrayPositions => &["array_positions", "list_positions"], - BuiltinScalarFunction::ArrayPrepend => &[ - "array_prepend", - "list_prepend", - "array_push_front", - "list_push_front", - ], - BuiltinScalarFunction::ArrayRepeat => &["array_repeat", "list_repeat"], - BuiltinScalarFunction::ArrayRemove => &["array_remove", "list_remove"], - BuiltinScalarFunction::ArrayRemoveN => &["array_remove_n", "list_remove_n"], - BuiltinScalarFunction::ArrayRemoveAll => &["array_remove_all", "list_remove_all"], - BuiltinScalarFunction::ArrayReplace => &["array_replace", "list_replace"], - BuiltinScalarFunction::ArrayReplaceN => &["array_replace_n", "list_replace_n"], - BuiltinScalarFunction::ArrayReplaceAll => { - &["array_replace_all", "list_replace_all"] - } - BuiltinScalarFunction::ArraySlice => &["array_slice", "list_slice"], - BuiltinScalarFunction::ArrayToString => &[ - "array_to_string", - "list_to_string", - "array_join", - "list_join", - ], - BuiltinScalarFunction::Cardinality => &["cardinality"], - BuiltinScalarFunction::MakeArray => &["make_array", "make_list"], + // array functions + BuiltinScalarFunction::ArrayAppend => &[ + "array_append", + "list_append", + "array_push_back", + "list_push_back", + ], + BuiltinScalarFunction::ArraySort => &["array_sort", "list_sort"], + BuiltinScalarFunction::ArrayConcat => { + &["array_concat", "array_cat", "list_concat", "list_cat"] + } + BuiltinScalarFunction::ArrayDims => &["array_dims", "list_dims"], + BuiltinScalarFunction::ArrayDistinct => &["array_distinct", "list_distinct"], + BuiltinScalarFunction::ArrayEmpty => &["empty"], + BuiltinScalarFunction::ArrayElement => &[ + "array_element", + "array_extract", + "list_element", + "list_extract", + ], + BuiltinScalarFunction::ArrayExcept => &["array_except", "list_except"], + BuiltinScalarFunction::Flatten => &["flatten"], + BuiltinScalarFunction::ArrayHasAll => &["array_has_all", "list_has_all"], + BuiltinScalarFunction::ArrayHasAny => &["array_has_any", "list_has_any"], + BuiltinScalarFunction::ArrayHas => { + &["array_has", "list_has", "array_contains", "list_contains"] + } + BuiltinScalarFunction::ArrayLength => &["array_length", "list_length"], + BuiltinScalarFunction::ArrayNdims => &["array_ndims", "list_ndims"], + BuiltinScalarFunction::ArrayPopFront => { + &["array_pop_front", "list_pop_front"] + } + BuiltinScalarFunction::ArrayPopBack => &["array_pop_back", "list_pop_back"], + BuiltinScalarFunction::ArrayPosition => &[ + "array_position", + "list_position", + "array_indexof", + "list_indexof", + ], + BuiltinScalarFunction::ArrayPositions => { + &["array_positions", "list_positions"] + } + BuiltinScalarFunction::ArrayPrepend => &[ + "array_prepend", + "list_prepend", + "array_push_front", + "list_push_front", + ], + BuiltinScalarFunction::ArrayRepeat => &["array_repeat", "list_repeat"], + BuiltinScalarFunction::ArrayRemove => &["array_remove", "list_remove"], + BuiltinScalarFunction::ArrayRemoveN => &["array_remove_n", "list_remove_n"], + BuiltinScalarFunction::ArrayRemoveAll => { + &["array_remove_all", "list_remove_all"] + } + BuiltinScalarFunction::ArrayReplace => &["array_replace", "list_replace"], + BuiltinScalarFunction::ArrayReplaceN => { + &["array_replace_n", "list_replace_n"] + } + BuiltinScalarFunction::ArrayReplaceAll => { + &["array_replace_all", "list_replace_all"] + } + BuiltinScalarFunction::ArraySlice => &["array_slice", "list_slice"], + BuiltinScalarFunction::ArrayToString => &[ + "array_to_string", + "list_to_string", + "array_join", + "list_join", + ], + BuiltinScalarFunction::ArrayUnion => &["array_union", "list_union"], + BuiltinScalarFunction::Cardinality => &["cardinality"], + BuiltinScalarFunction::MakeArray => &["make_array", "make_list"], + BuiltinScalarFunction::ArrayIntersect => { + &["array_intersect", "list_intersect"] + } + BuiltinScalarFunction::OverLay => &["overlay"], + BuiltinScalarFunction::Range => &["range", "generate_series"], - // struct functions - BuiltinScalarFunction::Struct => &["struct"], + // struct functions + BuiltinScalarFunction::Struct => &["struct"], + } } } impl fmt::Display for BuiltinScalarFunction { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - // .unwrap is safe here because compiler makes sure the map will have matches for each BuiltinScalarFunction - write!(f, "{}", function_to_name().get(self).unwrap()) + write!(f, "{}", self.name()) } } @@ -1621,7 +1749,8 @@ mod tests { // Test for BuiltinScalarFunction's Display and from_str() implementations. // For each variant in BuiltinScalarFunction, it converts the variant to a string // and then back to a variant. The test asserts that the original variant and - // the reconstructed variant are the same. + // the reconstructed variant are the same. This assertion is also necessary for + // function suggestion. See https://github.com/apache/arrow-datafusion/issues/8082 fn test_display_and_from_str() { for (_, func_original) in name_to_function().iter() { let func_name = func_original.to_string(); diff --git a/datafusion/expr/src/built_in_window_function.rs b/datafusion/expr/src/built_in_window_function.rs new file mode 100644 index 0000000000000..a03e3d2d24a9e --- /dev/null +++ b/datafusion/expr/src/built_in_window_function.rs @@ -0,0 +1,207 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Built-in functions module contains all the built-in functions definitions. + +use std::fmt; +use std::str::FromStr; + +use crate::type_coercion::functions::data_types; +use crate::utils; +use crate::{Signature, TypeSignature, Volatility}; +use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; + +use arrow::datatypes::DataType; + +use strum_macros::EnumIter; + +impl fmt::Display for BuiltInWindowFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.name()) + } +} + +/// A [window function] built in to DataFusion +/// +/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL) +#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter)] +pub enum BuiltInWindowFunction { + /// number of the current row within its partition, counting from 1 + RowNumber, + /// rank of the current row with gaps; same as row_number of its first peer + Rank, + /// rank of the current row without gaps; this function counts peer groups + DenseRank, + /// relative rank of the current row: (rank - 1) / (total rows - 1) + PercentRank, + /// relative rank of the current row: (number of rows preceding or peer with current row) / (total rows) + CumeDist, + /// integer ranging from 1 to the argument value, dividing the partition as equally as possible + Ntile, + /// returns value evaluated at the row that is offset rows before the current row within the partition; + /// if there is no such row, instead return default (which must be of the same type as value). + /// Both offset and default are evaluated with respect to the current row. + /// If omitted, offset defaults to 1 and default to null + Lag, + /// returns value evaluated at the row that is offset rows after the current row within the partition; + /// if there is no such row, instead return default (which must be of the same type as value). + /// Both offset and default are evaluated with respect to the current row. + /// If omitted, offset defaults to 1 and default to null + Lead, + /// returns value evaluated at the row that is the first row of the window frame + FirstValue, + /// returns value evaluated at the row that is the last row of the window frame + LastValue, + /// returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row + NthValue, +} + +impl BuiltInWindowFunction { + fn name(&self) -> &str { + use BuiltInWindowFunction::*; + match self { + RowNumber => "ROW_NUMBER", + Rank => "RANK", + DenseRank => "DENSE_RANK", + PercentRank => "PERCENT_RANK", + CumeDist => "CUME_DIST", + Ntile => "NTILE", + Lag => "LAG", + Lead => "LEAD", + FirstValue => "FIRST_VALUE", + LastValue => "LAST_VALUE", + NthValue => "NTH_VALUE", + } + } +} + +impl FromStr for BuiltInWindowFunction { + type Err = DataFusionError; + fn from_str(name: &str) -> Result { + Ok(match name.to_uppercase().as_str() { + "ROW_NUMBER" => BuiltInWindowFunction::RowNumber, + "RANK" => BuiltInWindowFunction::Rank, + "DENSE_RANK" => BuiltInWindowFunction::DenseRank, + "PERCENT_RANK" => BuiltInWindowFunction::PercentRank, + "CUME_DIST" => BuiltInWindowFunction::CumeDist, + "NTILE" => BuiltInWindowFunction::Ntile, + "LAG" => BuiltInWindowFunction::Lag, + "LEAD" => BuiltInWindowFunction::Lead, + "FIRST_VALUE" => BuiltInWindowFunction::FirstValue, + "LAST_VALUE" => BuiltInWindowFunction::LastValue, + "NTH_VALUE" => BuiltInWindowFunction::NthValue, + _ => return plan_err!("There is no built-in window function named {name}"), + }) + } +} + +/// Returns the datatype of the built-in window function +impl BuiltInWindowFunction { + pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { + // Note that this function *must* return the same type that the respective physical expression returns + // or the execution panics. + + // verify that this is a valid set of data types for this function + data_types(input_expr_types, &self.signature()) + // original errors are all related to wrong function signature + // aggregate them for better error message + .map_err(|_| { + plan_datafusion_err!( + "{}", + utils::generate_signature_error_msg( + &format!("{self}"), + self.signature(), + input_expr_types, + ) + ) + })?; + + match self { + BuiltInWindowFunction::RowNumber + | BuiltInWindowFunction::Rank + | BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64), + BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => { + Ok(DataType::Float64) + } + BuiltInWindowFunction::Ntile => Ok(DataType::UInt64), + BuiltInWindowFunction::Lag + | BuiltInWindowFunction::Lead + | BuiltInWindowFunction::FirstValue + | BuiltInWindowFunction::LastValue + | BuiltInWindowFunction::NthValue => Ok(input_expr_types[0].clone()), + } + } + + /// the signatures supported by the built-in window function `fun`. + pub fn signature(&self) -> Signature { + // note: the physical expression must accept the type returned by this function or the execution panics. + match self { + BuiltInWindowFunction::RowNumber + | BuiltInWindowFunction::Rank + | BuiltInWindowFunction::DenseRank + | BuiltInWindowFunction::PercentRank + | BuiltInWindowFunction::CumeDist => Signature::any(0, Volatility::Immutable), + BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => { + Signature::one_of( + vec![ + TypeSignature::Any(1), + TypeSignature::Any(2), + TypeSignature::Any(3), + ], + Volatility::Immutable, + ) + } + BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => { + Signature::any(1, Volatility::Immutable) + } + BuiltInWindowFunction::Ntile => Signature::uniform( + 1, + vec![ + DataType::UInt64, + DataType::UInt32, + DataType::UInt16, + DataType::UInt8, + DataType::Int64, + DataType::Int32, + DataType::Int16, + DataType::Int8, + ], + Volatility::Immutable, + ), + BuiltInWindowFunction::NthValue => Signature::any(2, Volatility::Immutable), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use strum::IntoEnumIterator; + #[test] + // Test for BuiltInWindowFunction's Display and from_str() implementations. + // For each variant in BuiltInWindowFunction, it converts the variant to a string + // and then back to a variant. The test asserts that the original variant and + // the reconstructed variant are the same. This assertion is also necessary for + // function suggestion. See https://github.com/apache/arrow-datafusion/issues/8082 + fn test_display_and_from_str() { + for func_original in BuiltInWindowFunction::iter() { + let func_name = func_original.to_string(); + let func_from_str = BuiltInWindowFunction::from_str(&func_name).unwrap(); + assert_eq!(func_from_str, func_original); + } + } +} diff --git a/datafusion/expr/src/columnar_value.rs b/datafusion/expr/src/columnar_value.rs index c72aae69c8314..7a28839281697 100644 --- a/datafusion/expr/src/columnar_value.rs +++ b/datafusion/expr/src/columnar_value.rs @@ -20,7 +20,7 @@ use arrow::array::ArrayRef; use arrow::array::NullArray; use arrow::datatypes::DataType; -use datafusion_common::ScalarValue; +use datafusion_common::{Result, ScalarValue}; use std::sync::Arc; /// Represents the result of evaluating an expression: either a single @@ -47,11 +47,15 @@ impl ColumnarValue { /// Convert a columnar value into an ArrayRef. [`Self::Scalar`] is /// converted by repeating the same scalar multiple times. - pub fn into_array(self, num_rows: usize) -> ArrayRef { - match self { + /// + /// # Errors + /// + /// Errors if `self` is a Scalar that fails to be converted into an array of size + pub fn into_array(self, num_rows: usize) -> Result { + Ok(match self { ColumnarValue::Array(array) => array, - ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(num_rows), - } + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(num_rows)?, + }) } /// null columnar values are implemented as a null array in order to pass batch diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 8929b21f44125..ebf4d3143c122 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -17,25 +17,28 @@ //! Expr module contains core type definition for `Expr`. -use crate::built_in_function; use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; -use crate::udaf; use crate::utils::{expr_to_columns, find_out_reference_exprs}; use crate::window_frame; -use crate::window_function; + use crate::Operator; use crate::{aggregate_function, ExprSchemable}; +use crate::{built_in_function, BuiltinScalarFunction}; +use crate::{built_in_window_function, udaf}; use arrow::datatypes::DataType; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{internal_err, DFSchema}; +use datafusion_common::{internal_err, DFSchema, OwnedTableReference}; use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue}; use std::collections::HashSet; use std::fmt; use std::fmt::{Display, Formatter, Write}; use std::hash::{BuildHasher, Hash, Hasher}; +use std::str::FromStr; use std::sync::Arc; +use crate::Signature; + /// `Expr` is a central struct of DataFusion's query API, and /// represent logical expressions such as `A + 1`, or `CAST(c1 AS /// int)`. @@ -148,16 +151,12 @@ pub enum Expr { TryCast(TryCast), /// A sort expression, that can be used to sort values. Sort(Sort), - /// Represents the call of a built-in scalar function with a set of arguments. + /// Represents the call of a scalar function with a set of arguments. ScalarFunction(ScalarFunction), - /// Represents the call of a user-defined scalar function with arguments. - ScalarUDF(ScalarUDF), /// Represents the call of an aggregate built-in function with arguments. AggregateFunction(AggregateFunction), /// Represents the call of a window function with arguments. WindowFunction(WindowFunction), - /// aggregate function - AggregateUDF(AggregateUDF), /// Returns whether the list contains the expr value. InList(InList), /// EXISTS subquery @@ -166,16 +165,12 @@ pub enum Expr { InSubquery(InSubquery), /// Scalar subquery ScalarSubquery(Subquery), - /// Represents a reference to all available fields. + /// Represents a reference to all available fields in a specific schema, + /// with an optional (schema) qualifier. /// /// This expr has to be resolved to a list of columns before translating logical /// plan into physical plan. - Wildcard, - /// Represents a reference to all available fields in a specific schema. - /// - /// This expr has to be resolved to a list of columns before translating logical - /// plan into physical plan. - QualifiedWildcard { qualifier: String }, + Wildcard { qualifier: Option }, /// List of grouping set expressions. Only valid in the context of an aggregate /// GROUP BY expression list GroupingSet(GroupingSet), @@ -191,13 +186,20 @@ pub enum Expr { #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct Alias { pub expr: Box, + pub relation: Option, pub name: String, } impl Alias { - pub fn new(expr: Expr, name: impl Into) -> Self { + /// Create an alias with an optional schema/field qualifier. + pub fn new( + expr: Expr, + relation: Option>, + name: impl Into, + ) -> Self { Self { expr: Box::new(expr), + relation: relation.map(|r| r.into()), name: name.into(), } } @@ -335,35 +337,80 @@ impl Between { } } -/// ScalarFunction expression +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +/// Defines which implementation of a function for DataFusion to call. +pub enum ScalarFunctionDefinition { + /// Resolved to a `BuiltinScalarFunction` + /// There is plan to migrate `BuiltinScalarFunction` to UDF-based implementation (issue#8045) + /// This variant is planned to be removed in long term + BuiltIn(BuiltinScalarFunction), + /// Resolved to a user defined function + UDF(Arc), + /// A scalar function constructed with name. This variant can not be executed directly + /// and instead must be resolved to one of the other variants prior to physical planning. + Name(Arc), +} + +/// ScalarFunction expression invokes a built-in scalar function #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct ScalarFunction { /// The function - pub fun: built_in_function::BuiltinScalarFunction, + pub func_def: ScalarFunctionDefinition, /// List of expressions to feed to the functions as arguments pub args: Vec, } impl ScalarFunction { - /// Create a new ScalarFunction expression - pub fn new(fun: built_in_function::BuiltinScalarFunction, args: Vec) -> Self { - Self { fun, args } + // return the Function's name + pub fn name(&self) -> &str { + self.func_def.name() } } -/// ScalarUDF expression -#[derive(Clone, PartialEq, Eq, Hash, Debug)] -pub struct ScalarUDF { - /// The function - pub fun: Arc, - /// List of expressions to feed to the functions as arguments - pub args: Vec, +impl ScalarFunctionDefinition { + /// Function's name for display + pub fn name(&self) -> &str { + match self { + ScalarFunctionDefinition::BuiltIn(fun) => fun.name(), + ScalarFunctionDefinition::UDF(udf) => udf.name(), + ScalarFunctionDefinition::Name(func_name) => func_name.as_ref(), + } + } + + /// Whether this function is volatile, i.e. whether it can return different results + /// when evaluated multiple times with the same input. + pub fn is_volatile(&self) -> Result { + match self { + ScalarFunctionDefinition::BuiltIn(fun) => { + Ok(fun.volatility() == crate::Volatility::Volatile) + } + ScalarFunctionDefinition::UDF(udf) => { + Ok(udf.signature().volatility == crate::Volatility::Volatile) + } + ScalarFunctionDefinition::Name(func) => { + internal_err!( + "Cannot determine volatility of unresolved function: {func}" + ) + } + } + } } -impl ScalarUDF { - /// Create a new ScalarUDF expression - pub fn new(fun: Arc, args: Vec) -> Self { - Self { fun, args } +impl ScalarFunction { + /// Create a new ScalarFunction expression + pub fn new(fun: built_in_function::BuiltinScalarFunction, args: Vec) -> Self { + Self { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + args, + } + } + + /// Create a new ScalarFunction expression with a user-defined function (UDF) + pub fn new_udf(udf: Arc, args: Vec) -> Self { + Self { + func_def: ScalarFunctionDefinition::UDF(udf), + args, + } } } @@ -450,11 +497,33 @@ impl Sort { } } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +/// Defines which implementation of an aggregate function DataFusion should call. +pub enum AggregateFunctionDefinition { + BuiltIn(aggregate_function::AggregateFunction), + /// Resolved to a user defined aggregate function + UDF(Arc), + /// A aggregation function constructed with name. This variant can not be executed directly + /// and instead must be resolved to one of the other variants prior to physical planning. + Name(Arc), +} + +impl AggregateFunctionDefinition { + /// Function's name for display + pub fn name(&self) -> &str { + match self { + AggregateFunctionDefinition::BuiltIn(fun) => fun.name(), + AggregateFunctionDefinition::UDF(udf) => udf.name(), + AggregateFunctionDefinition::Name(func_name) => func_name.as_ref(), + } + } +} + /// Aggregate function #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct AggregateFunction { /// Name of the function - pub fun: aggregate_function::AggregateFunction, + pub func_def: AggregateFunctionDefinition, /// List of expressions to feed to the functions as arguments pub args: Vec, /// Whether this is a DISTINCT aggregation or not @@ -474,7 +543,24 @@ impl AggregateFunction { order_by: Option>, ) -> Self { Self { - fun, + func_def: AggregateFunctionDefinition::BuiltIn(fun), + args, + distinct, + filter, + order_by, + } + } + + /// Create a new AggregateFunction expression with a user-defined function (UDF) + pub fn new_udf( + udf: Arc, + args: Vec, + distinct: bool, + filter: Option>, + order_by: Option>, + ) -> Self { + Self { + func_def: AggregateFunctionDefinition::UDF(udf), args, distinct, filter, @@ -483,11 +569,64 @@ impl AggregateFunction { } } +/// WindowFunction +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +/// Defines which implementation of an aggregate function DataFusion should call. +pub enum WindowFunctionDefinition { + /// A built in aggregate function that leverages an aggregate function + AggregateFunction(aggregate_function::AggregateFunction), + /// A a built-in window function + BuiltInWindowFunction(built_in_window_function::BuiltInWindowFunction), + /// A user defined aggregate function + AggregateUDF(Arc), + /// A user defined aggregate function + WindowUDF(Arc), +} + +impl WindowFunctionDefinition { + /// Returns the datatype of the window function + pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { + match self { + WindowFunctionDefinition::AggregateFunction(fun) => { + fun.return_type(input_expr_types) + } + WindowFunctionDefinition::BuiltInWindowFunction(fun) => { + fun.return_type(input_expr_types) + } + WindowFunctionDefinition::AggregateUDF(fun) => { + fun.return_type(input_expr_types) + } + WindowFunctionDefinition::WindowUDF(fun) => fun.return_type(input_expr_types), + } + } + + /// the signatures supported by the function `fun`. + pub fn signature(&self) -> Signature { + match self { + WindowFunctionDefinition::AggregateFunction(fun) => fun.signature(), + WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.signature(), + WindowFunctionDefinition::AggregateUDF(fun) => fun.signature().clone(), + WindowFunctionDefinition::WindowUDF(fun) => fun.signature().clone(), + } + } +} + +impl fmt::Display for WindowFunctionDefinition { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + WindowFunctionDefinition::AggregateFunction(fun) => fun.fmt(f), + WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.fmt(f), + WindowFunctionDefinition::AggregateUDF(fun) => std::fmt::Debug::fmt(fun, f), + WindowFunctionDefinition::WindowUDF(fun) => fun.fmt(f), + } + } +} + /// Window function #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct WindowFunction { /// Name of the function - pub fun: window_function::WindowFunction, + pub fun: WindowFunctionDefinition, /// List of expressions to feed to the functions as arguments pub args: Vec, /// List of partition by expressions @@ -501,7 +640,7 @@ pub struct WindowFunction { impl WindowFunction { /// Create a new Window expression pub fn new( - fun: window_function::WindowFunction, + fun: WindowFunctionDefinition, args: Vec, partition_by: Vec, order_by: Vec, @@ -517,6 +656,50 @@ impl WindowFunction { } } +/// Find DataFusion's built-in window function by name. +pub fn find_df_window_func(name: &str) -> Option { + let name = name.to_lowercase(); + // Code paths for window functions leveraging ordinary aggregators and + // built-in window functions are quite different, and the same function + // may have different implementations for these cases. If the sought + // function is not found among built-in window functions, we search for + // it among aggregate functions. + if let Ok(built_in_function) = + built_in_window_function::BuiltInWindowFunction::from_str(name.as_str()) + { + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_function, + )) + } else if let Ok(aggregate) = + aggregate_function::AggregateFunction::from_str(name.as_str()) + { + Some(WindowFunctionDefinition::AggregateFunction(aggregate)) + } else { + None + } +} + +/// Returns the datatype of the window function +#[deprecated( + since = "27.0.0", + note = "please use `WindowFunction::return_type` instead" +)] +pub fn return_type( + fun: &WindowFunctionDefinition, + input_expr_types: &[DataType], +) -> Result { + fun.return_type(input_expr_types) +} + +/// the signatures supported by the function `fun`. +#[deprecated( + since = "27.0.0", + note = "please use `WindowFunction::signature` instead" +)] +pub fn signature(fun: &WindowFunctionDefinition) -> Signature { + fun.signature() +} + // Exists expression. #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct Exists { @@ -606,7 +789,7 @@ impl InSubquery { } } -/// Placeholder, representing bind parameter values such as `$1`. +/// Placeholder, representing bind parameter values such as `$1` or `$name`. /// /// The type of these parameters is inferred using [`Expr::infer_placeholder_types`] /// or can be specified directly using `PREPARE` statements. @@ -702,7 +885,6 @@ impl Expr { pub fn variant_name(&self) -> &str { match self { Expr::AggregateFunction { .. } => "AggregateFunction", - Expr::AggregateUDF { .. } => "AggregateUDF", Expr::Alias(..) => "Alias", Expr::Between { .. } => "Between", Expr::BinaryExpr { .. } => "BinaryExpr", @@ -729,15 +911,13 @@ impl Expr { Expr::Negative(..) => "Negative", Expr::Not(..) => "Not", Expr::Placeholder(_) => "Placeholder", - Expr::QualifiedWildcard { .. } => "QualifiedWildcard", Expr::ScalarFunction(..) => "ScalarFunction", Expr::ScalarSubquery { .. } => "ScalarSubquery", - Expr::ScalarUDF(..) => "ScalarUDF", Expr::ScalarVariable(..) => "ScalarVariable", Expr::Sort { .. } => "Sort", Expr::TryCast { .. } => "TryCast", Expr::WindowFunction { .. } => "WindowFunction", - Expr::Wildcard => "Wildcard", + Expr::Wildcard { .. } => "Wildcard", } } @@ -849,14 +1029,34 @@ impl Expr { asc, nulls_first, }) => Expr::Sort(Sort::new(Box::new(expr.alias(name)), asc, nulls_first)), - _ => Expr::Alias(Alias::new(self, name.into())), + _ => Expr::Alias(Alias::new(self, None::<&str>, name.into())), + } + } + + /// Return `self AS name` alias expression with a specific qualifier + pub fn alias_qualified( + self, + relation: Option>, + name: impl Into, + ) -> Expr { + match self { + Expr::Sort(Sort { + expr, + asc, + nulls_first, + }) => Expr::Sort(Sort::new( + Box::new(expr.alias_qualified(relation, name)), + asc, + nulls_first, + )), + _ => Expr::Alias(Alias::new(self, relation, name.into())), } } /// Remove an alias from an expression if one exists. pub fn unalias(self) -> Expr { match self { - Expr::Alias(alias) => alias.expr.as_ref().clone(), + Expr::Alias(alias) => *alias.expr, _ => self, } } @@ -962,7 +1162,7 @@ impl Expr { Expr::GetIndexedField(GetIndexedField { expr: Box::new(self), field: GetFieldAccess::NamedStructField { - name: ScalarValue::Utf8(Some(name.into())), + name: ScalarValue::from(name.into()), }, }) } @@ -1174,11 +1374,8 @@ impl fmt::Display for Expr { write!(f, " NULLS LAST") } } - Expr::ScalarFunction(func) => { - fmt_function(f, &func.fun.to_string(), false, &func.args, true) - } - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - fmt_function(f, &fun.name, false, args, true) + Expr::ScalarFunction(fun) => { + fmt_function(f, fun.name(), false, &fun.args, true) } Expr::WindowFunction(WindowFunction { fun, @@ -1202,30 +1399,14 @@ impl fmt::Display for Expr { Ok(()) } Expr::AggregateFunction(AggregateFunction { - fun, + func_def, distinct, ref args, filter, order_by, .. }) => { - fmt_function(f, &fun.to_string(), *distinct, args, true)?; - if let Some(fe) = filter { - write!(f, " FILTER (WHERE {fe})")?; - } - if let Some(ob) = order_by { - write!(f, " ORDER BY [{}]", expr_vec_fmt!(ob))?; - } - Ok(()) - } - Expr::AggregateUDF(AggregateUDF { - fun, - ref args, - filter, - order_by, - .. - }) => { - fmt_function(f, &fun.name, false, args, true)?; + fmt_function(f, func_def.name(), *distinct, args, true)?; if let Some(fe) = filter { write!(f, " FILTER (WHERE {fe})")?; } @@ -1292,8 +1473,10 @@ impl fmt::Display for Expr { write!(f, "{expr} IN ([{}])", expr_vec_fmt!(list)) } } - Expr::Wildcard => write!(f, "*"), - Expr::QualifiedWildcard { qualifier } => write!(f, "{qualifier}.*"), + Expr::Wildcard { qualifier } => match qualifier { + Some(qualifier) => write!(f, "{qualifier}.*"), + None => write!(f, "*"), + }, Expr::GetIndexedField(GetIndexedField { field, expr }) => match field { GetFieldAccess::NamedStructField { name } => { write!(f, "({expr})[{name}]") @@ -1508,12 +1691,7 @@ fn create_name(e: &Expr) -> Result { } } } - Expr::ScalarFunction(func) => { - create_function_name(&func.fun.to_string(), false, &func.args) - } - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - create_function_name(&fun.name, false, args) - } + Expr::ScalarFunction(fun) => create_function_name(fun.name(), false, &fun.args), Expr::WindowFunction(WindowFunction { fun, args, @@ -1533,39 +1711,39 @@ fn create_name(e: &Expr) -> Result { Ok(parts.join(" ")) } Expr::AggregateFunction(AggregateFunction { - fun, + func_def, distinct, args, filter, order_by, }) => { - let mut name = create_function_name(&fun.to_string(), *distinct, args)?; - if let Some(fe) = filter { - name = format!("{name} FILTER (WHERE {fe})"); - }; - if let Some(order_by) = order_by { - name = format!("{name} ORDER BY [{}]", expr_vec_fmt!(order_by)); + let name = match func_def { + AggregateFunctionDefinition::BuiltIn(..) + | AggregateFunctionDefinition::Name(..) => { + create_function_name(func_def.name(), *distinct, args)? + } + AggregateFunctionDefinition::UDF(..) => { + let names: Vec = + args.iter().map(create_name).collect::>()?; + names.join(",") + } }; - Ok(name) - } - Expr::AggregateUDF(AggregateUDF { - fun, - args, - filter, - order_by, - }) => { - let mut names = Vec::with_capacity(args.len()); - for e in args { - names.push(create_name(e)?); - } let mut info = String::new(); if let Some(fe) = filter { info += &format!(" FILTER (WHERE {fe})"); + }; + if let Some(order_by) = order_by { + info += &format!(" ORDER BY [{}]", expr_vec_fmt!(order_by)); + }; + match func_def { + AggregateFunctionDefinition::BuiltIn(..) + | AggregateFunctionDefinition::Name(..) => { + Ok(format!("{}{}", name, info)) + } + AggregateFunctionDefinition::UDF(fun) => { + Ok(format!("{}({}){}", fun.name(), name, info)) + } } - if let Some(ob) = order_by { - info += &format!(" ORDER BY ([{}])", expr_vec_fmt!(ob)); - } - Ok(format!("{}({}){}", fun.name, names.join(","), info)) } Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => { @@ -1613,10 +1791,12 @@ fn create_name(e: &Expr) -> Result { Expr::Sort { .. } => { internal_err!("Create name does not support sort expression") } - Expr::Wildcard => Ok("*".to_string()), - Expr::QualifiedWildcard { .. } => { - internal_err!("Create name does not support qualified wildcard") - } + Expr::Wildcard { qualifier } => match qualifier { + Some(qualifier) => internal_err!( + "Create name does not support qualified wildcard, got {qualifier}" + ), + None => Ok("*".to_string()), + }, Expr::Placeholder(Placeholder { id, .. }) => Ok((*id).to_string()), } } @@ -1630,14 +1810,28 @@ fn create_names(exprs: &[Expr]) -> Result { .join(", ")) } +/// Whether the given expression is volatile, i.e. whether it can return different results +/// when evaluated multiple times with the same input. +pub fn is_volatile(expr: &Expr) -> Result { + match expr { + Expr::ScalarFunction(func) => func.func_def.is_volatile(), + _ => Ok(false), + } +} + #[cfg(test)] mod test { use crate::expr::Cast; use crate::expr_fn::col; - use crate::{case, lit, Expr}; + use crate::{ + case, lit, BuiltinScalarFunction, ColumnarValue, Expr, ScalarFunctionDefinition, + ScalarUDF, ScalarUDFImpl, Signature, Volatility, + }; use arrow::datatypes::DataType; use datafusion_common::Column; use datafusion_common::{Result, ScalarValue}; + use std::any::Any; + use std::sync::Arc; #[test] fn format_case_when() -> Result<()> { @@ -1738,4 +1932,245 @@ mod test { "UInt32(1) OR UInt32(2)" ); } + + #[test] + fn test_is_volatile_scalar_func_definition() { + // BuiltIn + assert!( + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Random) + .is_volatile() + .unwrap() + ); + assert!( + !ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Abs) + .is_volatile() + .unwrap() + ); + + // UDF + struct TestScalarUDF { + signature: Signature, + } + impl ScalarUDFImpl for TestScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "TestScalarUDF" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + Ok(ColumnarValue::Scalar(ScalarValue::from("a"))) + } + } + let udf = Arc::new(ScalarUDF::from(TestScalarUDF { + signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), + })); + assert!(!ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap()); + + let udf = Arc::new(ScalarUDF::from(TestScalarUDF { + signature: Signature::uniform( + 1, + vec![DataType::Float32], + Volatility::Volatile, + ), + })); + assert!(ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap()); + + // Unresolved function + ScalarFunctionDefinition::Name(Arc::from("UnresolvedFunc")) + .is_volatile() + .expect_err("Shouldn't determine volatility of unresolved function"); + } + + use super::*; + + #[test] + fn test_count_return_type() -> Result<()> { + let fun = find_df_window_func("count").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Int64, observed); + + let observed = fun.return_type(&[DataType::UInt64])?; + assert_eq!(DataType::Int64, observed); + + Ok(()) + } + + #[test] + fn test_first_value_return_type() -> Result<()> { + let fun = find_df_window_func("first_value").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::UInt64])?; + assert_eq!(DataType::UInt64, observed); + + Ok(()) + } + + #[test] + fn test_last_value_return_type() -> Result<()> { + let fun = find_df_window_func("last_value").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_lead_return_type() -> Result<()> { + let fun = find_df_window_func("lead").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_lag_return_type() -> Result<()> { + let fun = find_df_window_func("lag").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_nth_value_return_type() -> Result<()> { + let fun = find_df_window_func("nth_value").unwrap(); + let observed = fun.return_type(&[DataType::Utf8, DataType::UInt64])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::Float64, DataType::UInt64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_percent_rank_return_type() -> Result<()> { + let fun = find_df_window_func("percent_rank").unwrap(); + let observed = fun.return_type(&[])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_cume_dist_return_type() -> Result<()> { + let fun = find_df_window_func("cume_dist").unwrap(); + let observed = fun.return_type(&[])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_ntile_return_type() -> Result<()> { + let fun = find_df_window_func("ntile").unwrap(); + let observed = fun.return_type(&[DataType::Int16])?; + assert_eq!(DataType::UInt64, observed); + + Ok(()) + } + + #[test] + fn test_window_function_case_insensitive() -> Result<()> { + let names = vec![ + "row_number", + "rank", + "dense_rank", + "percent_rank", + "cume_dist", + "ntile", + "lag", + "lead", + "first_value", + "last_value", + "nth_value", + "min", + "max", + "count", + "avg", + "sum", + ]; + for name in names { + let fun = find_df_window_func(name).unwrap(); + let fun2 = find_df_window_func(name.to_uppercase().as_str()).unwrap(); + assert_eq!(fun, fun2); + assert_eq!(fun.to_string(), name.to_uppercase()); + } + Ok(()) + } + + #[test] + fn test_find_df_window_function() { + assert_eq!( + find_df_window_func("max"), + Some(WindowFunctionDefinition::AggregateFunction( + aggregate_function::AggregateFunction::Max + )) + ); + assert_eq!( + find_df_window_func("min"), + Some(WindowFunctionDefinition::AggregateFunction( + aggregate_function::AggregateFunction::Min + )) + ); + assert_eq!( + find_df_window_func("avg"), + Some(WindowFunctionDefinition::AggregateFunction( + aggregate_function::AggregateFunction::Avg + )) + ); + assert_eq!( + find_df_window_func("cume_dist"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::CumeDist + )) + ); + assert_eq!( + find_df_window_func("first_value"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::FirstValue + )) + ); + assert_eq!( + find_df_window_func("LAST_value"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::LastValue + )) + ); + assert_eq!( + find_df_window_func("LAG"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::Lag + )) + ); + assert_eq!( + find_df_window_func("LEAD"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::Lead + )) + ); + assert_eq!(find_df_window_func("not_exist"), None) + } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 5a60c2470c95b..f76fb17b38bba 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -22,15 +22,16 @@ use crate::expr::{ Placeholder, ScalarFunction, TryCast, }; use crate::function::PartitionEvaluatorFactory; -use crate::WindowUDF; use crate::{ aggregate_function, built_in_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, AccumulatorFactoryFunction, AggregateUDF, BuiltinScalarFunction, Expr, LogicalPlan, Operator, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUDF, Signature, StateTypeFunction, Volatility, }; +use crate::{ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; use arrow::datatypes::DataType; use datafusion_common::{Column, Result}; +use std::any::Any; use std::ops::Not; use std::sync::Arc; @@ -99,6 +100,19 @@ pub fn placeholder(id: impl Into) -> Expr { }) } +/// Create an '*' [`Expr::Wildcard`] expression that matches all columns +/// +/// # Example +/// +/// ```rust +/// # use datafusion_expr::{wildcard}; +/// let p = wildcard(); +/// assert_eq!(p.to_string(), "*") +/// ``` +pub fn wildcard() -> Expr { + Expr::Wildcard { qualifier: None } +} + /// Return a new expression `left right` pub fn binary_expr(left: Expr, op: Operator, right: Expr) -> Expr { Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right))) @@ -570,6 +584,8 @@ scalar_expr!( "appends an element to the end of an array." ); +scalar_expr!(ArraySort, array_sort, array desc null_first, "returns sorted array."); + scalar_expr!( ArrayPopBack, array_pop_back, @@ -577,6 +593,13 @@ scalar_expr!( "returns the array without the last element." ); +scalar_expr!( + ArrayPopFront, + array_pop_front, + array, + "returns the array without the first element." +); + nary_scalar_expr!(ArrayConcat, array_concat, "concatenates arrays."); scalar_expr!( ArrayHas, @@ -620,6 +643,12 @@ scalar_expr!( array element, "extracts the element with the index n from the array." ); +scalar_expr!( + ArrayExcept, + array_except, + first_array second_array, + "Returns an array of the elements that appear in the first array but not in the second." +); scalar_expr!( ArrayLength, array_length, @@ -632,6 +661,12 @@ scalar_expr!( array, "returns the number of dimensions of the array." ); +scalar_expr!( + ArrayDistinct, + array_distinct, + array, + "return distinct values from the array after removing duplicates." +); scalar_expr!( ArrayPosition, array_position, @@ -704,6 +739,8 @@ scalar_expr!( array delimiter, "converts each element to its text representation." ); +scalar_expr!(ArrayUnion, array_union, array1 array2, "returns an array of the elements in the union of array1 and array2 without duplicates."); + scalar_expr!( Cardinality, cardinality, @@ -715,6 +752,18 @@ nary_scalar_expr!( array, "returns an Arrow array using the specified input expressions." ); +scalar_expr!( + ArrayIntersect, + array_intersect, + first_array second_array, + "Returns an array of the elements in the intersection of array1 and array2." +); + +nary_scalar_expr!( + Range, + gen_range, + "Returns a list of values in the range between start and stop with step." +); // string functions scalar_expr!(Ascii, ascii, chr, "ASCII code value of the character"); @@ -817,6 +866,11 @@ nary_scalar_expr!( "concatenates several strings, placing a seperator between each one" ); nary_scalar_expr!(Concat, concat_expr, "concatenates several strings"); +nary_scalar_expr!( + OverLay, + overlay, + "replace the substring of string that starts at the start'th character and extends for count characters with new substring" +); // date functions scalar_expr!(DatePart, date_part, part date, "extracts a subfield from the date"); @@ -870,6 +924,16 @@ scalar_expr!( ); scalar_expr!(ArrowTypeof, arrow_typeof, val, "data type"); +scalar_expr!(Levenshtein, levenshtein, string1 string2, "Returns the Levenshtein distance between the two given strings"); +scalar_expr!(SubstrIndex, substr_index, string delimiter count, "Returns the substring from str before count occurrences of the delimiter"); +scalar_expr!(FindInSet, find_in_set, str strlist, "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings"); + +scalar_expr!( + Struct, + struct_fun, + val, + "returns a vector of fields from the struct" +); /// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression. pub fn case(expr: Expr) -> CaseBuilder { @@ -881,11 +945,18 @@ pub fn when(when: Expr, then: Expr) -> CaseBuilder { CaseBuilder::new(None, vec![when], vec![then], None) } -/// Creates a new UDF with a specific signature and specific return type. -/// This is a helper function to create a new UDF. -/// The function `create_udf` returns a subset of all possible `ScalarFunction`: -/// * the UDF has a fixed return type -/// * the UDF has a fixed signature (e.g. [f64, f64]) +/// Convenience method to create a new user defined scalar function (UDF) with a +/// specific signature and specific return type. +/// +/// Note this function does not expose all available features of [`ScalarUDF`], +/// such as +/// +/// * computing return types based on input types +/// * multiple [`Signature`]s +/// * aliases +/// +/// See [`ScalarUDF`] for details and examples on how to use the full +/// functionality. pub fn create_udf( name: &str, input_types: Vec, @@ -893,13 +964,66 @@ pub fn create_udf( volatility: Volatility, fun: ScalarFunctionImplementation, ) -> ScalarUDF { - let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); - ScalarUDF::new( + let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); + ScalarUDF::from(SimpleScalarUDF::new( name, - &Signature::exact(input_types, volatility), - &return_type, - &fun, - ) + input_types, + return_type, + volatility, + fun, + )) +} + +/// Implements [`ScalarUDFImpl`] for functions that have a single signature and +/// return type. +pub struct SimpleScalarUDF { + name: String, + signature: Signature, + return_type: DataType, + fun: ScalarFunctionImplementation, +} + +impl SimpleScalarUDF { + /// Create a new `SimpleScalarUDF` from a name, input types, return type and + /// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility + pub fn new( + name: impl Into, + input_types: Vec, + return_type: DataType, + volatility: Volatility, + fun: ScalarFunctionImplementation, + ) -> Self { + let name = name.into(); + let signature = Signature::exact(input_types, volatility); + Self { + name, + signature, + return_type, + fun, + } + } +} + +impl ScalarUDFImpl for SimpleScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + (self.fun)(args) + } } /// Creates a new UDAF with a specific signature, state type and return type. @@ -935,13 +1059,66 @@ pub fn create_udwf( volatility: Volatility, partition_evaluator_factory: PartitionEvaluatorFactory, ) -> WindowUDF { - let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); - WindowUDF::new( + let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); + WindowUDF::from(SimpleWindowUDF::new( name, - &Signature::exact(vec![input_type], volatility), - &return_type, - &partition_evaluator_factory, - ) + input_type, + return_type, + volatility, + partition_evaluator_factory, + )) +} + +/// Implements [`WindowUDFImpl`] for functions that have a single signature and +/// return type. +pub struct SimpleWindowUDF { + name: String, + signature: Signature, + return_type: DataType, + partition_evaluator_factory: PartitionEvaluatorFactory, +} + +impl SimpleWindowUDF { + /// Create a new `SimpleWindowUDF` from a name, input types, return type and + /// implementation. Implementing [`WindowUDFImpl`] allows more flexibility + pub fn new( + name: impl Into, + input_type: DataType, + return_type: DataType, + volatility: Volatility, + partition_evaluator_factory: PartitionEvaluatorFactory, + ) -> Self { + let name = name.into(); + let signature = Signature::exact([input_type].to_vec(), volatility); + Self { + name, + signature, + return_type, + partition_evaluator_factory, + } + } +} + +impl WindowUDFImpl for SimpleWindowUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn partition_evaluator(&self) -> Result> { + (self.partition_evaluator_factory)() + } } /// Calls a named built in function @@ -961,7 +1138,7 @@ pub fn call_fn(name: impl AsRef, args: Vec) -> Result { #[cfg(test)] mod test { use super::*; - use crate::lit; + use crate::{lit, ScalarFunctionDefinition}; #[test] fn filter_is_null_and_is_not_null() { @@ -976,8 +1153,10 @@ mod test { macro_rules! test_unary_scalar_expr { ($ENUM:ident, $FUNC:ident) => {{ - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = - $FUNC(col("tableA.a")) + if let Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + args, + }) = $FUNC(col("tableA.a")) { let name = built_in_function::BuiltinScalarFunction::$ENUM; assert_eq!(name, fun); @@ -989,42 +1168,42 @@ mod test { } macro_rules! test_scalar_expr { - ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => { - let expected = [$(stringify!($arg)),*]; - let result = $FUNC( + ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => { + let expected = [$(stringify!($arg)),*]; + let result = $FUNC( + $( + col(stringify!($arg.to_string())) + ),* + ); + if let Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(fun), args }) = result { + let name = built_in_function::BuiltinScalarFunction::$ENUM; + assert_eq!(name, fun); + assert_eq!(expected.len(), args.len()); + } else { + assert!(false, "unexpected: {:?}", result); + } + }; +} + + macro_rules! test_nary_scalar_expr { + ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => { + let expected = [$(stringify!($arg)),*]; + let result = $FUNC( + vec![ $( col(stringify!($arg.to_string())) ),* - ); - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = result { - let name = built_in_function::BuiltinScalarFunction::$ENUM; - assert_eq!(name, fun); - assert_eq!(expected.len(), args.len()); - } else { - assert!(false, "unexpected: {:?}", result); - } - }; - } - - macro_rules! test_nary_scalar_expr { - ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => { - let expected = [$(stringify!($arg)),*]; - let result = $FUNC( - vec![ - $( - col(stringify!($arg.to_string())) - ),* - ] - ); - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = result { - let name = built_in_function::BuiltinScalarFunction::$ENUM; - assert_eq!(name, fun); - assert_eq!(expected.len(), args.len()); - } else { - assert!(false, "unexpected: {:?}", result); - } - }; - } + ] + ); + if let Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(fun), args }) = result { + let name = built_in_function::BuiltinScalarFunction::$ENUM; + assert_eq!(name, fun); + assert_eq!(expected.len(), args.len()); + } else { + assert!(false, "unexpected: {:?}", result); + } + }; +} #[test] fn scalar_function_definitions() { @@ -1127,6 +1306,8 @@ mod test { test_scalar_expr!(FromUnixtime, from_unixtime, unixtime); test_scalar_expr!(ArrayAppend, array_append, array, element); + test_scalar_expr!(ArraySort, array_sort, array, desc, null_first); + test_scalar_expr!(ArrayPopFront, array_pop_front, array); test_scalar_expr!(ArrayPopBack, array_pop_back, array); test_unary_scalar_expr!(ArrayDims, array_dims); test_scalar_expr!(ArrayLength, array_length, array, dimension); @@ -1146,11 +1327,20 @@ mod test { test_nary_scalar_expr!(MakeArray, array, input); test_unary_scalar_expr!(ArrowTypeof, arrow_typeof); + test_nary_scalar_expr!(OverLay, overlay, string, characters, position, len); + test_nary_scalar_expr!(OverLay, overlay, string, characters, position); + test_scalar_expr!(Levenshtein, levenshtein, string1, string2); + test_scalar_expr!(SubstrIndex, substr_index, string, delimiter, count); + test_scalar_expr!(FindInSet, find_in_set, string, stringlist); } #[test] fn uuid_function_definitions() { - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = uuid() { + if let Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + args, + }) = uuid() + { let name = BuiltinScalarFunction::Uuid; assert_eq!(name, fun); assert_eq!(0, args.len()); @@ -1161,8 +1351,10 @@ mod test { #[test] fn digest_function_definitions() { - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = - digest(col("tableA.a"), lit("md5")) + if let Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + args, + }) = digest(col("tableA.a"), lit("md5")) { let name = BuiltinScalarFunction::Digest; assert_eq!(name, fun); @@ -1174,8 +1366,10 @@ mod test { #[test] fn encode_function_definitions() { - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = - encode(col("tableA.a"), lit("base64")) + if let Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + args, + }) = encode(col("tableA.a"), lit("base64")) { let name = BuiltinScalarFunction::Encode; assert_eq!(name, fun); @@ -1187,8 +1381,10 @@ mod test { #[test] fn decode_function_definitions() { - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = - decode(col("tableA.a"), lit("hex")) + if let Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + args, + }) = decode(col("tableA.a"), lit("hex")) { let name = BuiltinScalarFunction::Decode; assert_eq!(name, fun); diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 025b74eb5009a..ba21d09f06193 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -17,13 +17,14 @@ use super::{Between, Expr, Like}; use crate::expr::{ - AggregateFunction, AggregateUDF, Alias, BinaryExpr, Cast, GetFieldAccess, - GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction, ScalarUDF, Sort, - TryCast, WindowFunction, + AggregateFunction, AggregateFunctionDefinition, Alias, BinaryExpr, Cast, + GetFieldAccess, GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction, + ScalarFunctionDefinition, Sort, TryCast, WindowFunction, }; use crate::field_util::GetFieldAccessSchema; use crate::type_coercion::binary::get_result_type; -use crate::{LogicalPlan, Projection, Subquery}; +use crate::type_coercion::functions::data_types; +use crate::{utils, LogicalPlan, Projection, Subquery}; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field}; use datafusion_common::{ @@ -81,20 +82,34 @@ impl ExprSchemable for Expr { Expr::Case(case) => case.when_then_expr[0].1.get_type(schema), Expr::Cast(Cast { data_type, .. }) | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - Ok((fun.return_type)(&data_types)?.as_ref().clone()) - } - Expr::ScalarFunction(ScalarFunction { fun, args }) => { - let data_types = args + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + let arg_data_types = args .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - - fun.return_type(&data_types) + match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + // verify that input data types is consistent with function's `TypeSignature` + data_types(&arg_data_types, &fun.signature()).map_err(|_| { + plan_datafusion_err!( + "{}", + utils::generate_signature_error_msg( + &format!("{fun}"), + fun.signature(), + &arg_data_types, + ) + ) + })?; + + fun.return_type(&arg_data_types) + } + ScalarFunctionDefinition::UDF(fun) => { + Ok(fun.return_type(&arg_data_types)?) + } + ScalarFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } + } } Expr::WindowFunction(WindowFunction { fun, args, .. }) => { let data_types = args @@ -103,19 +118,22 @@ impl ExprSchemable for Expr { .collect::>>()?; fun.return_type(&data_types) } - Expr::AggregateFunction(AggregateFunction { fun, args, .. }) => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - fun.return_type(&data_types) - } - Expr::AggregateUDF(AggregateUDF { fun, args, .. }) => { + Expr::AggregateFunction(AggregateFunction { func_def, args, .. }) => { let data_types = args .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - Ok((fun.return_type)(&data_types)?.as_ref().clone()) + match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + fun.return_type(&data_types) + } + AggregateFunctionDefinition::UDF(fun) => { + Ok(fun.return_type(&data_types)?) + } + AggregateFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } + } } Expr::Not(_) | Expr::IsNull(_) @@ -144,13 +162,13 @@ impl ExprSchemable for Expr { plan_datafusion_err!("Placeholder type could not be resolved") }) } - Expr::Wildcard => { + Expr::Wildcard { qualifier } => { // Wildcard do not really have a type and do not appear in projections - Ok(DataType::Null) + match qualifier { + Some(_) => internal_err!("QualifiedWildcard expressions are not valid in a logical query plan"), + None => Ok(DataType::Null) + } } - Expr::QualifiedWildcard { .. } => internal_err!( - "QualifiedWildcard expressions are not valid in a logical query plan" - ), Expr::GroupingSet(_) => { // grouping sets do not really have a type and do not appear in projections Ok(DataType::Null) @@ -230,10 +248,8 @@ impl ExprSchemable for Expr { Expr::ScalarVariable(_, _) | Expr::TryCast { .. } | Expr::ScalarFunction(..) - | Expr::ScalarUDF(..) | Expr::WindowFunction { .. } | Expr::AggregateFunction { .. } - | Expr::AggregateUDF { .. } | Expr::Placeholder(_) => Ok(true), Expr::IsNull(_) | Expr::IsNotNull(_) @@ -257,13 +273,17 @@ impl ExprSchemable for Expr { | Expr::SimilarTo(Like { expr, pattern, .. }) => { Ok(expr.nullable(input_schema)? || pattern.nullable(input_schema)?) } - Expr::Wildcard => internal_err!( + Expr::Wildcard { .. } => internal_err!( "Wildcard expressions are not valid in a logical query plan" ), - Expr::QualifiedWildcard { .. } => internal_err!( - "QualifiedWildcard expressions are not valid in a logical query plan" - ), Expr::GetIndexedField(GetIndexedField { expr, field }) => { + // If schema is nested, check if parent is nullable + // if it is, return early + if let Expr::Column(col) = expr.as_ref() { + if input_schema.nullable(col)? { + return Ok(true); + } + } field_for_index(expr, field, input_schema).map(|x| x.is_nullable()) } Expr::GroupingSet(_) => { @@ -295,6 +315,13 @@ impl ExprSchemable for Expr { self.nullable(input_schema)?, ) .with_metadata(self.metadata(input_schema)?)), + Expr::Alias(Alias { relation, name, .. }) => Ok(DFField::new( + relation.clone(), + name, + self.get_type(input_schema)?, + self.nullable(input_schema)?, + ) + .with_metadata(self.metadata(input_schema)?)), _ => Ok(DFField::new_unqualified( &self.display_name()?, self.get_type(input_schema)?, @@ -391,8 +418,8 @@ pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result {{ @@ -528,6 +555,27 @@ mod tests { assert_eq!(&meta, expr.to_field(&schema).unwrap().metadata()); } + #[test] + fn test_nested_schema_nullability() { + let fields = DFField::new( + Some(TableReference::Bare { + table: "table_name".into(), + }), + "parent", + DataType::Struct(Fields::from(vec![Field::new( + "child", + DataType::Int64, + false, + )])), + true, + ); + + let schema = DFSchema::new_with_metadata(vec![fields], HashMap::new()).unwrap(); + + let expr = col("parent").field("child"); + assert!(expr.nullable(&schema).unwrap()); + } + #[derive(Debug)] struct MockExprSchema { nullable: bool, diff --git a/datafusion/expr/src/interval_arithmetic.rs b/datafusion/expr/src/interval_arithmetic.rs new file mode 100644 index 0000000000000..5d34fe91c3ace --- /dev/null +++ b/datafusion/expr/src/interval_arithmetic.rs @@ -0,0 +1,3307 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Interval arithmetic library + +use std::borrow::Borrow; +use std::fmt::{self, Display, Formatter}; +use std::ops::{AddAssign, SubAssign}; + +use crate::type_coercion::binary::get_result_type; +use crate::Operator; + +use arrow::compute::{cast_with_options, CastOptions}; +use arrow::datatypes::DataType; +use arrow::datatypes::{IntervalUnit, TimeUnit}; +use datafusion_common::rounding::{alter_fp_rounding_mode, next_down, next_up}; +use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; + +macro_rules! get_extreme_value { + ($extreme:ident, $value:expr) => { + match $value { + DataType::UInt8 => ScalarValue::UInt8(Some(u8::$extreme)), + DataType::UInt16 => ScalarValue::UInt16(Some(u16::$extreme)), + DataType::UInt32 => ScalarValue::UInt32(Some(u32::$extreme)), + DataType::UInt64 => ScalarValue::UInt64(Some(u64::$extreme)), + DataType::Int8 => ScalarValue::Int8(Some(i8::$extreme)), + DataType::Int16 => ScalarValue::Int16(Some(i16::$extreme)), + DataType::Int32 => ScalarValue::Int32(Some(i32::$extreme)), + DataType::Int64 => ScalarValue::Int64(Some(i64::$extreme)), + DataType::Float32 => ScalarValue::Float32(Some(f32::$extreme)), + DataType::Float64 => ScalarValue::Float64(Some(f64::$extreme)), + DataType::Duration(TimeUnit::Second) => { + ScalarValue::DurationSecond(Some(i64::$extreme)) + } + DataType::Duration(TimeUnit::Millisecond) => { + ScalarValue::DurationMillisecond(Some(i64::$extreme)) + } + DataType::Duration(TimeUnit::Microsecond) => { + ScalarValue::DurationMicrosecond(Some(i64::$extreme)) + } + DataType::Duration(TimeUnit::Nanosecond) => { + ScalarValue::DurationNanosecond(Some(i64::$extreme)) + } + DataType::Timestamp(TimeUnit::Second, _) => { + ScalarValue::TimestampSecond(Some(i64::$extreme), None) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + ScalarValue::TimestampMillisecond(Some(i64::$extreme), None) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + ScalarValue::TimestampMicrosecond(Some(i64::$extreme), None) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + ScalarValue::TimestampNanosecond(Some(i64::$extreme), None) + } + DataType::Interval(IntervalUnit::YearMonth) => { + ScalarValue::IntervalYearMonth(Some(i32::$extreme)) + } + DataType::Interval(IntervalUnit::DayTime) => { + ScalarValue::IntervalDayTime(Some(i64::$extreme)) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + ScalarValue::IntervalMonthDayNano(Some(i128::$extreme)) + } + _ => unreachable!(), + } + }; +} + +macro_rules! value_transition { + ($bound:ident, $direction:expr, $value:expr) => { + match $value { + UInt8(Some(value)) if value == u8::$bound => UInt8(None), + UInt16(Some(value)) if value == u16::$bound => UInt16(None), + UInt32(Some(value)) if value == u32::$bound => UInt32(None), + UInt64(Some(value)) if value == u64::$bound => UInt64(None), + Int8(Some(value)) if value == i8::$bound => Int8(None), + Int16(Some(value)) if value == i16::$bound => Int16(None), + Int32(Some(value)) if value == i32::$bound => Int32(None), + Int64(Some(value)) if value == i64::$bound => Int64(None), + Float32(Some(value)) if value == f32::$bound => Float32(None), + Float64(Some(value)) if value == f64::$bound => Float64(None), + DurationSecond(Some(value)) if value == i64::$bound => DurationSecond(None), + DurationMillisecond(Some(value)) if value == i64::$bound => { + DurationMillisecond(None) + } + DurationMicrosecond(Some(value)) if value == i64::$bound => { + DurationMicrosecond(None) + } + DurationNanosecond(Some(value)) if value == i64::$bound => { + DurationNanosecond(None) + } + TimestampSecond(Some(value), tz) if value == i64::$bound => { + TimestampSecond(None, tz) + } + TimestampMillisecond(Some(value), tz) if value == i64::$bound => { + TimestampMillisecond(None, tz) + } + TimestampMicrosecond(Some(value), tz) if value == i64::$bound => { + TimestampMicrosecond(None, tz) + } + TimestampNanosecond(Some(value), tz) if value == i64::$bound => { + TimestampNanosecond(None, tz) + } + IntervalYearMonth(Some(value)) if value == i32::$bound => { + IntervalYearMonth(None) + } + IntervalDayTime(Some(value)) if value == i64::$bound => IntervalDayTime(None), + IntervalMonthDayNano(Some(value)) if value == i128::$bound => { + IntervalMonthDayNano(None) + } + _ => next_value_helper::<$direction>($value), + } + }; +} + +/// The `Interval` type represents a closed interval used for computing +/// reliable bounds for mathematical expressions. +/// +/// Conventions: +/// +/// 1. **Closed bounds**: The interval always encompasses its endpoints. We +/// accommodate operations resulting in open intervals by incrementing or +/// decrementing the interval endpoint value to its successor/predecessor. +/// +/// 2. **Unbounded endpoints**: If the `lower` or `upper` bounds are indeterminate, +/// they are labeled as *unbounded*. This is represented using a `NULL`. +/// +/// 3. **Overflow handling**: If the `lower` or `upper` endpoints exceed their +/// limits after any operation, they either become unbounded or they are fixed +/// to the maximum/minimum value of the datatype, depending on the direction +/// of the overflowing endpoint, opting for the safer choice. +/// +/// 4. **Floating-point special cases**: +/// - `INF` values are converted to `NULL`s while constructing an interval to +/// ensure consistency, with other data types. +/// - `NaN` (Not a Number) results are conservatively result in unbounded +/// endpoints. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Interval { + lower: ScalarValue, + upper: ScalarValue, +} + +/// This macro handles the `NaN` and `INF` floating point values. +/// +/// - `NaN` values are always converted to unbounded i.e. `NULL` values. +/// - For lower bounds: +/// - A `NEG_INF` value is converted to a `NULL`. +/// - An `INF` value is conservatively converted to the maximum representable +/// number for the floating-point type in question. In this case, converting +/// to `NULL` doesn't make sense as it would be interpreted as a `NEG_INF`. +/// - For upper bounds: +/// - An `INF` value is converted to a `NULL`. +/// - An `NEG_INF` value is conservatively converted to the minimum representable +/// number for the floating-point type in question. In this case, converting +/// to `NULL` doesn't make sense as it would be interpreted as an `INF`. +macro_rules! handle_float_intervals { + ($scalar_type:ident, $primitive_type:ident, $lower:expr, $upper:expr) => {{ + let lower = match $lower { + ScalarValue::$scalar_type(Some(l_val)) + if l_val == $primitive_type::NEG_INFINITY || l_val.is_nan() => + { + ScalarValue::$scalar_type(None) + } + ScalarValue::$scalar_type(Some(l_val)) + if l_val == $primitive_type::INFINITY => + { + ScalarValue::$scalar_type(Some($primitive_type::MAX)) + } + value @ ScalarValue::$scalar_type(Some(_)) => value, + _ => ScalarValue::$scalar_type(None), + }; + + let upper = match $upper { + ScalarValue::$scalar_type(Some(r_val)) + if r_val == $primitive_type::INFINITY || r_val.is_nan() => + { + ScalarValue::$scalar_type(None) + } + ScalarValue::$scalar_type(Some(r_val)) + if r_val == $primitive_type::NEG_INFINITY => + { + ScalarValue::$scalar_type(Some($primitive_type::MIN)) + } + value @ ScalarValue::$scalar_type(Some(_)) => value, + _ => ScalarValue::$scalar_type(None), + }; + + Interval { lower, upper } + }}; +} + +/// Ordering floating-point numbers according to their binary representations +/// contradicts with their natural ordering. Floating-point number ordering +/// after unsigned integer transmutation looks like: +/// +/// ```text +/// 0, 1, 2, 3, ..., MAX, -0, -1, -2, ..., -MAX +/// ``` +/// +/// This macro applies a one-to-one map that fixes the ordering above. +macro_rules! map_floating_point_order { + ($value:expr, $ty:ty) => {{ + let num_bits = std::mem::size_of::<$ty>() * 8; + let sign_bit = 1 << (num_bits - 1); + if $value & sign_bit == sign_bit { + // Negative numbers: + !$value + } else { + // Positive numbers: + $value | sign_bit + } + }}; +} + +impl Interval { + /// Attempts to create a new `Interval` from the given lower and upper bounds. + /// + /// # Notes + /// + /// This constructor creates intervals in a "canonical" form where: + /// - **Boolean intervals**: + /// - Unboundedness (`NULL`) for boolean endpoints is converted to `false` + /// for lower and `true` for upper bounds. + /// - **Floating-point intervals**: + /// - Floating-point endpoints with `NaN`, `INF`, or `NEG_INF` are converted + /// to `NULL`s. + pub fn try_new(lower: ScalarValue, upper: ScalarValue) -> Result { + if lower.data_type() != upper.data_type() { + return internal_err!("Endpoints of an Interval should have the same type"); + } + + let interval = Self::new(lower, upper); + + if interval.lower.is_null() + || interval.upper.is_null() + || interval.lower <= interval.upper + { + Ok(interval) + } else { + internal_err!( + "Interval's lower bound {} is greater than the upper bound {}", + interval.lower, + interval.upper + ) + } + } + + /// Only for internal usage. Responsible for standardizing booleans and + /// floating-point values, as well as fixing NaNs. It doesn't validate + /// the given bounds for ordering, or verify that they have the same data + /// type. For its user-facing counterpart and more details, see + /// [`Interval::try_new`]. + fn new(lower: ScalarValue, upper: ScalarValue) -> Self { + if let ScalarValue::Boolean(lower_bool) = lower { + let ScalarValue::Boolean(upper_bool) = upper else { + // We are sure that upper and lower bounds have the same type. + unreachable!(); + }; + // Standardize boolean interval endpoints: + Self { + lower: ScalarValue::Boolean(Some(lower_bool.unwrap_or(false))), + upper: ScalarValue::Boolean(Some(upper_bool.unwrap_or(true))), + } + } + // Standardize floating-point endpoints: + else if lower.data_type() == DataType::Float32 { + handle_float_intervals!(Float32, f32, lower, upper) + } else if lower.data_type() == DataType::Float64 { + handle_float_intervals!(Float64, f64, lower, upper) + } else { + // Other data types do not require standardization: + Self { lower, upper } + } + } + + /// Convenience function to create a new `Interval` from the given (optional) + /// bounds, for use in tests only. Absence of either endpoint indicates + /// unboundedness on that side. See [`Interval::try_new`] for more information. + pub fn make(lower: Option, upper: Option) -> Result + where + ScalarValue: From>, + { + Self::try_new(ScalarValue::from(lower), ScalarValue::from(upper)) + } + + /// Creates an unbounded interval from both sides if the datatype supported. + pub fn make_unbounded(data_type: &DataType) -> Result { + let unbounded_endpoint = ScalarValue::try_from(data_type)?; + Ok(Self::new(unbounded_endpoint.clone(), unbounded_endpoint)) + } + + /// Returns a reference to the lower bound. + pub fn lower(&self) -> &ScalarValue { + &self.lower + } + + /// Returns a reference to the upper bound. + pub fn upper(&self) -> &ScalarValue { + &self.upper + } + + /// Converts this `Interval` into its boundary scalar values. It's useful + /// when you need to work with the individual bounds directly. + pub fn into_bounds(self) -> (ScalarValue, ScalarValue) { + (self.lower, self.upper) + } + + /// This function returns the data type of this interval. + pub fn data_type(&self) -> DataType { + let lower_type = self.lower.data_type(); + let upper_type = self.upper.data_type(); + + // There must be no way to create an interval whose endpoints have + // different types. + assert!( + lower_type == upper_type, + "Interval bounds have different types: {lower_type} != {upper_type}" + ); + lower_type + } + + /// Casts this interval to `data_type` using `cast_options`. + pub fn cast_to( + &self, + data_type: &DataType, + cast_options: &CastOptions, + ) -> Result { + Self::try_new( + cast_scalar_value(&self.lower, data_type, cast_options)?, + cast_scalar_value(&self.upper, data_type, cast_options)?, + ) + } + + pub const CERTAINLY_FALSE: Self = Self { + lower: ScalarValue::Boolean(Some(false)), + upper: ScalarValue::Boolean(Some(false)), + }; + + pub const UNCERTAIN: Self = Self { + lower: ScalarValue::Boolean(Some(false)), + upper: ScalarValue::Boolean(Some(true)), + }; + + pub const CERTAINLY_TRUE: Self = Self { + lower: ScalarValue::Boolean(Some(true)), + upper: ScalarValue::Boolean(Some(true)), + }; + + /// Decide if this interval is certainly greater than, possibly greater than, + /// or can't be greater than `other` by returning `[true, true]`, + /// `[false, true]` or `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub(crate) fn gt>(&self, other: T) -> Result { + let rhs = other.borrow(); + if self.data_type().ne(&rhs.data_type()) { + internal_err!( + "Only intervals with the same data type are comparable, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ) + } else if !(self.upper.is_null() || rhs.lower.is_null()) + && self.upper <= rhs.lower + { + // Values in this interval are certainly less than or equal to + // those in the given interval. + Ok(Self::CERTAINLY_FALSE) + } else if !(self.lower.is_null() || rhs.upper.is_null()) + && (self.lower > rhs.upper) + { + // Values in this interval are certainly greater than those in the + // given interval. + Ok(Self::CERTAINLY_TRUE) + } else { + // All outcomes are possible. + Ok(Self::UNCERTAIN) + } + } + + /// Decide if this interval is certainly greater than or equal to, possibly + /// greater than or equal to, or can't be greater than or equal to `other` + /// by returning `[true, true]`, `[false, true]` or `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub(crate) fn gt_eq>(&self, other: T) -> Result { + let rhs = other.borrow(); + if self.data_type().ne(&rhs.data_type()) { + internal_err!( + "Only intervals with the same data type are comparable, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ) + } else if !(self.lower.is_null() || rhs.upper.is_null()) + && self.lower >= rhs.upper + { + // Values in this interval are certainly greater than or equal to + // those in the given interval. + Ok(Self::CERTAINLY_TRUE) + } else if !(self.upper.is_null() || rhs.lower.is_null()) + && (self.upper < rhs.lower) + { + // Values in this interval are certainly less than those in the + // given interval. + Ok(Self::CERTAINLY_FALSE) + } else { + // All outcomes are possible. + Ok(Self::UNCERTAIN) + } + } + + /// Decide if this interval is certainly less than, possibly less than, or + /// can't be less than `other` by returning `[true, true]`, `[false, true]` + /// or `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub(crate) fn lt>(&self, other: T) -> Result { + other.borrow().gt(self) + } + + /// Decide if this interval is certainly less than or equal to, possibly + /// less than or equal to, or can't be less than or equal to `other` by + /// returning `[true, true]`, `[false, true]` or `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub(crate) fn lt_eq>(&self, other: T) -> Result { + other.borrow().gt_eq(self) + } + + /// Decide if this interval is certainly equal to, possibly equal to, or + /// can't be equal to `other` by returning `[true, true]`, `[false, true]` + /// or `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub(crate) fn equal>(&self, other: T) -> Result { + let rhs = other.borrow(); + if get_result_type(&self.data_type(), &Operator::Eq, &rhs.data_type()).is_err() { + internal_err!( + "Interval data types must be compatible for equality checks, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ) + } else if !self.lower.is_null() + && (self.lower == self.upper) + && (rhs.lower == rhs.upper) + && (self.lower == rhs.lower) + { + Ok(Self::CERTAINLY_TRUE) + } else if self.intersect(rhs)?.is_none() { + Ok(Self::CERTAINLY_FALSE) + } else { + Ok(Self::UNCERTAIN) + } + } + + /// Compute the logical conjunction of this (boolean) interval with the + /// given boolean interval. + pub(crate) fn and>(&self, other: T) -> Result { + let rhs = other.borrow(); + match (&self.lower, &self.upper, &rhs.lower, &rhs.upper) { + ( + &ScalarValue::Boolean(Some(self_lower)), + &ScalarValue::Boolean(Some(self_upper)), + &ScalarValue::Boolean(Some(other_lower)), + &ScalarValue::Boolean(Some(other_upper)), + ) => { + let lower = self_lower && other_lower; + let upper = self_upper && other_upper; + + Ok(Self { + lower: ScalarValue::Boolean(Some(lower)), + upper: ScalarValue::Boolean(Some(upper)), + }) + } + _ => internal_err!("Incompatible data types for logical conjunction"), + } + } + + /// Compute the logical negation of this (boolean) interval. + pub(crate) fn not(&self) -> Result { + if self.data_type().ne(&DataType::Boolean) { + internal_err!("Cannot apply logical negation to a non-boolean interval") + } else if self == &Self::CERTAINLY_TRUE { + Ok(Self::CERTAINLY_FALSE) + } else if self == &Self::CERTAINLY_FALSE { + Ok(Self::CERTAINLY_TRUE) + } else { + Ok(Self::UNCERTAIN) + } + } + + /// Compute the intersection of this interval with the given interval. + /// If the intersection is empty, return `None`. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub fn intersect>(&self, other: T) -> Result> { + let rhs = other.borrow(); + if self.data_type().ne(&rhs.data_type()) { + return internal_err!( + "Only intervals with the same data type are intersectable, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ); + }; + + // If it is evident that the result is an empty interval, short-circuit + // and directly return `None`. + if (!(self.lower.is_null() || rhs.upper.is_null()) && self.lower > rhs.upper) + || (!(self.upper.is_null() || rhs.lower.is_null()) && self.upper < rhs.lower) + { + return Ok(None); + } + + let lower = max_of_bounds(&self.lower, &rhs.lower); + let upper = min_of_bounds(&self.upper, &rhs.upper); + + // New lower and upper bounds must always construct a valid interval. + assert!( + (lower.is_null() || upper.is_null() || (lower <= upper)), + "The intersection of two intervals can not be an invalid interval" + ); + + Ok(Some(Self { lower, upper })) + } + + /// Decide if this interval certainly contains, possibly contains, or can't + /// contain a [`ScalarValue`] (`other`) by returning `[true, true]`, + /// `[false, true]` or `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub fn contains_value>(&self, other: T) -> Result { + let rhs = other.borrow(); + if self.data_type().ne(&rhs.data_type()) { + return internal_err!( + "Data types must be compatible for containment checks, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ); + } + + // We only check the upper bound for a `None` value because `None` + // values are less than `Some` values according to Rust. + Ok(&self.lower <= rhs && (self.upper.is_null() || rhs <= &self.upper)) + } + + /// Decide if this interval is a superset of, overlaps with, or + /// disjoint with `other` by returning `[true, true]`, `[false, true]` or + /// `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub fn contains>(&self, other: T) -> Result { + let rhs = other.borrow(); + if self.data_type().ne(&rhs.data_type()) { + return internal_err!( + "Interval data types must match for containment checks, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ); + }; + + match self.intersect(rhs)? { + Some(intersection) => { + if &intersection == rhs { + Ok(Self::CERTAINLY_TRUE) + } else { + Ok(Self::UNCERTAIN) + } + } + None => Ok(Self::CERTAINLY_FALSE), + } + } + + /// Add the given interval (`other`) to this interval. Say we have intervals + /// `[a1, b1]` and `[a2, b2]`, then their sum is `[a1 + a2, b1 + b2]`. Note + /// that this represents all possible values the sum can take if one can + /// choose single values arbitrarily from each of the operands. + pub fn add>(&self, other: T) -> Result { + let rhs = other.borrow(); + let dt = get_result_type(&self.data_type(), &Operator::Plus, &rhs.data_type())?; + + Ok(Self::new( + add_bounds::(&dt, &self.lower, &rhs.lower), + add_bounds::(&dt, &self.upper, &rhs.upper), + )) + } + + /// Subtract the given interval (`other`) from this interval. Say we have + /// intervals `[a1, b1]` and `[a2, b2]`, then their difference is + /// `[a1 - b2, b1 - a2]`. Note that this represents all possible values the + /// difference can take if one can choose single values arbitrarily from + /// each of the operands. + pub fn sub>(&self, other: T) -> Result { + let rhs = other.borrow(); + let dt = get_result_type(&self.data_type(), &Operator::Minus, &rhs.data_type())?; + + Ok(Self::new( + sub_bounds::(&dt, &self.lower, &rhs.upper), + sub_bounds::(&dt, &self.upper, &rhs.lower), + )) + } + + /// Multiply the given interval (`other`) with this interval. Say we have + /// intervals `[a1, b1]` and `[a2, b2]`, then their product is `[min(a1 * a2, + /// a1 * b2, b1 * a2, b1 * b2), max(a1 * a2, a1 * b2, b1 * a2, b1 * b2)]`. + /// Note that this represents all possible values the product can take if + /// one can choose single values arbitrarily from each of the operands. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub fn mul>(&self, other: T) -> Result { + let rhs = other.borrow(); + let dt = if self.data_type().eq(&rhs.data_type()) { + self.data_type() + } else { + return internal_err!( + "Intervals must have the same data type for multiplication, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ); + }; + + let zero = ScalarValue::new_zero(&dt)?; + + let result = match ( + self.contains_value(&zero)?, + rhs.contains_value(&zero)?, + dt.is_unsigned_integer(), + ) { + (true, true, false) => mul_helper_multi_zero_inclusive(&dt, self, rhs), + (true, false, false) => { + mul_helper_single_zero_inclusive(&dt, self, rhs, zero) + } + (false, true, false) => { + mul_helper_single_zero_inclusive(&dt, rhs, self, zero) + } + _ => mul_helper_zero_exclusive(&dt, self, rhs, zero), + }; + Ok(result) + } + + /// Divide this interval by the given interval (`other`). Say we have intervals + /// `[a1, b1]` and `[a2, b2]`, then their division is `[a1, b1] * [1 / b2, 1 / a2]` + /// if `0 ∉ [a2, b2]` and `[NEG_INF, INF]` otherwise. Note that this represents + /// all possible values the quotient can take if one can choose single values + /// arbitrarily from each of the operands. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + /// + /// **TODO**: Once interval sets are supported, cases where the divisor contains + /// zero should result in an interval set, not the universal set. + pub fn div>(&self, other: T) -> Result { + let rhs = other.borrow(); + let dt = if self.data_type().eq(&rhs.data_type()) { + self.data_type() + } else { + return internal_err!( + "Intervals must have the same data type for division, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ); + }; + + let zero = ScalarValue::new_zero(&dt)?; + // We want 0 to be approachable from both negative and positive sides. + let zero_point = match &dt { + DataType::Float32 | DataType::Float64 => Self::new(zero.clone(), zero), + _ => Self::new(prev_value(zero.clone()), next_value(zero)), + }; + + // Exit early with an unbounded interval if zero is strictly inside the + // right hand side: + if rhs.contains(&zero_point)? == Self::CERTAINLY_TRUE && !dt.is_unsigned_integer() + { + Self::make_unbounded(&dt) + } + // At this point, we know that only one endpoint of the right hand side + // can be zero. + else if self.contains(&zero_point)? == Self::CERTAINLY_TRUE + && !dt.is_unsigned_integer() + { + Ok(div_helper_lhs_zero_inclusive(&dt, self, rhs, &zero_point)) + } else { + Ok(div_helper_zero_exclusive(&dt, self, rhs, &zero_point)) + } + } + + /// Returns the cardinality of this interval, which is the number of all + /// distinct points inside it. This function returns `None` if: + /// - The interval is unbounded from either side, or + /// - Cardinality calculations for the datatype in question is not + /// implemented yet, or + /// - An overflow occurs during the calculation: This case can only arise + /// when the calculated cardinality does not fit in an `u64`. + pub fn cardinality(&self) -> Option { + let data_type = self.data_type(); + if data_type.is_integer() { + self.upper.distance(&self.lower).map(|diff| diff as u64) + } else if data_type.is_floating() { + // Negative numbers are sorted in the reverse order. To + // always have a positive difference after the subtraction, + // we perform following transformation: + match (&self.lower, &self.upper) { + // Exploit IEEE 754 ordering properties to calculate the correct + // cardinality in all cases (including subnormals). + ( + ScalarValue::Float32(Some(lower)), + ScalarValue::Float32(Some(upper)), + ) => { + let lower_bits = map_floating_point_order!(lower.to_bits(), u32); + let upper_bits = map_floating_point_order!(upper.to_bits(), u32); + Some((upper_bits - lower_bits) as u64) + } + ( + ScalarValue::Float64(Some(lower)), + ScalarValue::Float64(Some(upper)), + ) => { + let lower_bits = map_floating_point_order!(lower.to_bits(), u64); + let upper_bits = map_floating_point_order!(upper.to_bits(), u64); + let count = upper_bits - lower_bits; + (count != u64::MAX).then_some(count) + } + _ => None, + } + } else { + // Cardinality calculations are not implemented for this data type yet: + None + } + .map(|result| result + 1) + } +} + +impl Display for Interval { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "[{}, {}]", self.lower, self.upper) + } +} + +/// Applies the given binary operator the `lhs` and `rhs` arguments. +pub fn apply_operator(op: &Operator, lhs: &Interval, rhs: &Interval) -> Result { + match *op { + Operator::Eq => lhs.equal(rhs), + Operator::NotEq => lhs.equal(rhs)?.not(), + Operator::Gt => lhs.gt(rhs), + Operator::GtEq => lhs.gt_eq(rhs), + Operator::Lt => lhs.lt(rhs), + Operator::LtEq => lhs.lt_eq(rhs), + Operator::And => lhs.and(rhs), + Operator::Plus => lhs.add(rhs), + Operator::Minus => lhs.sub(rhs), + Operator::Multiply => lhs.mul(rhs), + Operator::Divide => lhs.div(rhs), + _ => internal_err!("Interval arithmetic does not support the operator {op}"), + } +} + +/// Helper function used for adding the end-point values of intervals. +/// +/// **Caution:** This function contains multiple calls to `unwrap()`, and may +/// return non-standardized interval bounds. Therefore, it should be used +/// with caution. Currently, it is used in contexts where the `DataType` +/// (`dt`) is validated prior to calling this function, and the following +/// interval creation is standardized with `Interval::new`. +fn add_bounds( + dt: &DataType, + lhs: &ScalarValue, + rhs: &ScalarValue, +) -> ScalarValue { + if lhs.is_null() || rhs.is_null() { + return ScalarValue::try_from(dt).unwrap(); + } + + match dt { + DataType::Float64 | DataType::Float32 => { + alter_fp_rounding_mode::(lhs, rhs, |lhs, rhs| lhs.add_checked(rhs)) + } + _ => lhs.add_checked(rhs), + } + .unwrap_or_else(|_| handle_overflow::(dt, Operator::Plus, lhs, rhs)) +} + +/// Helper function used for subtracting the end-point values of intervals. +/// +/// **Caution:** This function contains multiple calls to `unwrap()`, and may +/// return non-standardized interval bounds. Therefore, it should be used +/// with caution. Currently, it is used in contexts where the `DataType` +/// (`dt`) is validated prior to calling this function, and the following +/// interval creation is standardized with `Interval::new`. +fn sub_bounds( + dt: &DataType, + lhs: &ScalarValue, + rhs: &ScalarValue, +) -> ScalarValue { + if lhs.is_null() || rhs.is_null() { + return ScalarValue::try_from(dt).unwrap(); + } + + match dt { + DataType::Float64 | DataType::Float32 => { + alter_fp_rounding_mode::(lhs, rhs, |lhs, rhs| lhs.sub_checked(rhs)) + } + _ => lhs.sub_checked(rhs), + } + .unwrap_or_else(|_| handle_overflow::(dt, Operator::Minus, lhs, rhs)) +} + +/// Helper function used for multiplying the end-point values of intervals. +/// +/// **Caution:** This function contains multiple calls to `unwrap()`, and may +/// return non-standardized interval bounds. Therefore, it should be used +/// with caution. Currently, it is used in contexts where the `DataType` +/// (`dt`) is validated prior to calling this function, and the following +/// interval creation is standardized with `Interval::new`. +fn mul_bounds( + dt: &DataType, + lhs: &ScalarValue, + rhs: &ScalarValue, +) -> ScalarValue { + if lhs.is_null() || rhs.is_null() { + return ScalarValue::try_from(dt).unwrap(); + } + + match dt { + DataType::Float64 | DataType::Float32 => { + alter_fp_rounding_mode::(lhs, rhs, |lhs, rhs| lhs.mul_checked(rhs)) + } + _ => lhs.mul_checked(rhs), + } + .unwrap_or_else(|_| handle_overflow::(dt, Operator::Multiply, lhs, rhs)) +} + +/// Helper function used for dividing the end-point values of intervals. +/// +/// **Caution:** This function contains multiple calls to `unwrap()`, and may +/// return non-standardized interval bounds. Therefore, it should be used +/// with caution. Currently, it is used in contexts where the `DataType` +/// (`dt`) is validated prior to calling this function, and the following +/// interval creation is standardized with `Interval::new`. +fn div_bounds( + dt: &DataType, + lhs: &ScalarValue, + rhs: &ScalarValue, +) -> ScalarValue { + let zero = ScalarValue::new_zero(dt).unwrap(); + + if (lhs.is_null() || rhs.eq(&zero)) || (dt.is_unsigned_integer() && rhs.is_null()) { + return ScalarValue::try_from(dt).unwrap(); + } else if rhs.is_null() { + return zero; + } + + match dt { + DataType::Float64 | DataType::Float32 => { + alter_fp_rounding_mode::(lhs, rhs, |lhs, rhs| lhs.div(rhs)) + } + _ => lhs.div(rhs), + } + .unwrap_or_else(|_| handle_overflow::(dt, Operator::Divide, lhs, rhs)) +} + +/// This function handles cases where an operation results in an overflow. Such +/// results are converted to an *unbounded endpoint* if: +/// - We are calculating an upper bound and we have a positive overflow. +/// - We are calculating a lower bound and we have a negative overflow. +/// Otherwise; the function sets the endpoint as: +/// - The minimum representable number with the given datatype (`dt`) if +/// we are calculating an upper bound and we have a negative overflow. +/// - The maximum representable number with the given datatype (`dt`) if +/// we are calculating a lower bound and we have a positive overflow. +/// +/// **Caution:** This function contains multiple calls to `unwrap()`, and may +/// return non-standardized interval bounds. Therefore, it should be used +/// with caution. Currently, it is used in contexts where the `DataType` +/// (`dt`) is validated prior to calling this function, `op` is supported by +/// interval library, and the following interval creation is standardized with +/// `Interval::new`. +fn handle_overflow( + dt: &DataType, + op: Operator, + lhs: &ScalarValue, + rhs: &ScalarValue, +) -> ScalarValue { + let zero = ScalarValue::new_zero(dt).unwrap(); + let positive_sign = match op { + Operator::Multiply | Operator::Divide => { + lhs.lt(&zero) && rhs.lt(&zero) || lhs.gt(&zero) && rhs.gt(&zero) + } + Operator::Plus => lhs.ge(&zero), + Operator::Minus => lhs.ge(rhs), + _ => { + unreachable!() + } + }; + match (UPPER, positive_sign) { + (true, true) | (false, false) => ScalarValue::try_from(dt).unwrap(), + (true, false) => { + get_extreme_value!(MIN, dt) + } + (false, true) => { + get_extreme_value!(MAX, dt) + } + } +} + +// This function should remain private since it may corrupt the an interval if +// used without caution. +fn next_value(value: ScalarValue) -> ScalarValue { + use ScalarValue::*; + value_transition!(MAX, true, value) +} + +// This function should remain private since it may corrupt the an interval if +// used without caution. +fn prev_value(value: ScalarValue) -> ScalarValue { + use ScalarValue::*; + value_transition!(MIN, false, value) +} + +trait OneTrait: Sized + std::ops::Add + std::ops::Sub { + fn one() -> Self; +} +macro_rules! impl_OneTrait{ + ($($m:ty),*) => {$( impl OneTrait for $m { fn one() -> Self { 1 as $m } })*} +} +impl_OneTrait! {u8, u16, u32, u64, i8, i16, i32, i64, i128} + +/// This function either increments or decrements its argument, depending on +/// the `INC` value (where a `true` value corresponds to the increment). +fn increment_decrement( + mut value: T, +) -> T { + if INC { + value.add_assign(T::one()); + } else { + value.sub_assign(T::one()); + } + value +} + +/// This function returns the next/previous value depending on the `INC` value. +/// If `true`, it returns the next value; otherwise it returns the previous value. +fn next_value_helper(value: ScalarValue) -> ScalarValue { + use ScalarValue::*; + match value { + // f32/f64::NEG_INF/INF and f32/f64::NaN values should not emerge at this point. + Float32(Some(val)) => { + assert!(val.is_finite(), "Non-standardized floating point usage"); + Float32(Some(if INC { next_up(val) } else { next_down(val) })) + } + Float64(Some(val)) => { + assert!(val.is_finite(), "Non-standardized floating point usage"); + Float64(Some(if INC { next_up(val) } else { next_down(val) })) + } + Int8(Some(val)) => Int8(Some(increment_decrement::(val))), + Int16(Some(val)) => Int16(Some(increment_decrement::(val))), + Int32(Some(val)) => Int32(Some(increment_decrement::(val))), + Int64(Some(val)) => Int64(Some(increment_decrement::(val))), + UInt8(Some(val)) => UInt8(Some(increment_decrement::(val))), + UInt16(Some(val)) => UInt16(Some(increment_decrement::(val))), + UInt32(Some(val)) => UInt32(Some(increment_decrement::(val))), + UInt64(Some(val)) => UInt64(Some(increment_decrement::(val))), + DurationSecond(Some(val)) => { + DurationSecond(Some(increment_decrement::(val))) + } + DurationMillisecond(Some(val)) => { + DurationMillisecond(Some(increment_decrement::(val))) + } + DurationMicrosecond(Some(val)) => { + DurationMicrosecond(Some(increment_decrement::(val))) + } + DurationNanosecond(Some(val)) => { + DurationNanosecond(Some(increment_decrement::(val))) + } + TimestampSecond(Some(val), tz) => { + TimestampSecond(Some(increment_decrement::(val)), tz) + } + TimestampMillisecond(Some(val), tz) => { + TimestampMillisecond(Some(increment_decrement::(val)), tz) + } + TimestampMicrosecond(Some(val), tz) => { + TimestampMicrosecond(Some(increment_decrement::(val)), tz) + } + TimestampNanosecond(Some(val), tz) => { + TimestampNanosecond(Some(increment_decrement::(val)), tz) + } + IntervalYearMonth(Some(val)) => { + IntervalYearMonth(Some(increment_decrement::(val))) + } + IntervalDayTime(Some(val)) => { + IntervalDayTime(Some(increment_decrement::(val))) + } + IntervalMonthDayNano(Some(val)) => { + IntervalMonthDayNano(Some(increment_decrement::(val))) + } + _ => value, // Unbounded values return without change. + } +} + +/// Returns the greater of the given interval bounds. Assumes that a `NULL` +/// value represents `NEG_INF`. +fn max_of_bounds(first: &ScalarValue, second: &ScalarValue) -> ScalarValue { + if !first.is_null() && (second.is_null() || first >= second) { + first.clone() + } else { + second.clone() + } +} + +/// Returns the lesser of the given interval bounds. Assumes that a `NULL` +/// value represents `INF`. +fn min_of_bounds(first: &ScalarValue, second: &ScalarValue) -> ScalarValue { + if !first.is_null() && (second.is_null() || first <= second) { + first.clone() + } else { + second.clone() + } +} + +/// This function updates the given intervals by enforcing (i.e. propagating) +/// the inequality `left > right` (or the `left >= right` inequality, if `strict` +/// is `true`). +/// +/// Returns a `Result` wrapping an `Option` containing the tuple of resulting +/// intervals. If the comparison is infeasible, returns `None`. +/// +/// Example usage: +/// ``` +/// use datafusion_common::DataFusionError; +/// use datafusion_expr::interval_arithmetic::{satisfy_greater, Interval}; +/// +/// let left = Interval::make(Some(-1000.0_f32), Some(1000.0_f32))?; +/// let right = Interval::make(Some(500.0_f32), Some(2000.0_f32))?; +/// let strict = false; +/// assert_eq!( +/// satisfy_greater(&left, &right, strict)?, +/// Some(( +/// Interval::make(Some(500.0_f32), Some(1000.0_f32))?, +/// Interval::make(Some(500.0_f32), Some(1000.0_f32))? +/// )) +/// ); +/// Ok::<(), DataFusionError>(()) +/// ``` +/// +/// NOTE: This function only works with intervals of the same data type. +/// Attempting to compare intervals of different data types will lead +/// to an error. +pub fn satisfy_greater( + left: &Interval, + right: &Interval, + strict: bool, +) -> Result> { + if left.data_type().ne(&right.data_type()) { + return internal_err!( + "Intervals must have the same data type, lhs:{}, rhs:{}", + left.data_type(), + right.data_type() + ); + } + + if !left.upper.is_null() && left.upper <= right.lower { + if !strict && left.upper == right.lower { + // Singleton intervals: + return Ok(Some(( + Interval::new(left.upper.clone(), left.upper.clone()), + Interval::new(left.upper.clone(), left.upper.clone()), + ))); + } else { + // Left-hand side: <--======----0------------> + // Right-hand side: <------------0--======----> + // No intersection, infeasible to propagate: + return Ok(None); + } + } + + // Only the lower bound of left hand side and the upper bound of the right + // hand side can change after propagating the greater-than operation. + let new_left_lower = if left.lower.is_null() || left.lower <= right.lower { + if strict { + next_value(right.lower.clone()) + } else { + right.lower.clone() + } + } else { + left.lower.clone() + }; + // Below code is asymmetric relative to the above if statement, because + // `None` compares less than `Some` in Rust. + let new_right_upper = if right.upper.is_null() + || (!left.upper.is_null() && left.upper <= right.upper) + { + if strict { + prev_value(left.upper.clone()) + } else { + left.upper.clone() + } + } else { + right.upper.clone() + }; + + Ok(Some(( + Interval::new(new_left_lower, left.upper.clone()), + Interval::new(right.lower.clone(), new_right_upper), + ))) +} + +/// Multiplies two intervals that both contain zero. +/// +/// This function takes in two intervals (`lhs` and `rhs`) as arguments and +/// returns their product (whose data type is known to be `dt`). It is +/// specifically designed to handle intervals that contain zero within their +/// ranges. Returns an error if the multiplication of bounds fails. +/// +/// ```text +/// Left-hand side: <-------=====0=====-------> +/// Right-hand side: <-------=====0=====-------> +/// ``` +/// +/// **Caution:** This function contains multiple calls to `unwrap()`. Therefore, +/// it should be used with caution. Currently, it is used in contexts where the +/// `DataType` (`dt`) is validated prior to calling this function. +fn mul_helper_multi_zero_inclusive( + dt: &DataType, + lhs: &Interval, + rhs: &Interval, +) -> Interval { + if lhs.lower.is_null() + || lhs.upper.is_null() + || rhs.lower.is_null() + || rhs.upper.is_null() + { + return Interval::make_unbounded(dt).unwrap(); + } + // Since unbounded cases are handled above, we can safely + // use the utility functions here to eliminate code duplication. + let lower = min_of_bounds( + &mul_bounds::(dt, &lhs.lower, &rhs.upper), + &mul_bounds::(dt, &rhs.lower, &lhs.upper), + ); + let upper = max_of_bounds( + &mul_bounds::(dt, &lhs.upper, &rhs.upper), + &mul_bounds::(dt, &lhs.lower, &rhs.lower), + ); + // There is no possibility to create an invalid interval. + Interval::new(lower, upper) +} + +/// Multiplies two intervals when only left-hand side interval contains zero. +/// +/// This function takes in two intervals (`lhs` and `rhs`) as arguments and +/// returns their product (whose data type is known to be `dt`). This function +/// serves as a subroutine that handles the specific case when only `lhs` contains +/// zero within its range. The interval not containing zero, i.e. rhs, can lie +/// on either side of zero. Returns an error if the multiplication of bounds fails. +/// +/// ``` text +/// Left-hand side: <-------=====0=====-------> +/// Right-hand side: <--======----0------------> +/// +/// or +/// +/// Left-hand side: <-------=====0=====-------> +/// Right-hand side: <------------0--======----> +/// ``` +/// +/// **Caution:** This function contains multiple calls to `unwrap()`. Therefore, +/// it should be used with caution. Currently, it is used in contexts where the +/// `DataType` (`dt`) is validated prior to calling this function. +fn mul_helper_single_zero_inclusive( + dt: &DataType, + lhs: &Interval, + rhs: &Interval, + zero: ScalarValue, +) -> Interval { + // With the following interval bounds, there is no possibility to create an invalid interval. + if rhs.upper <= zero && !rhs.upper.is_null() { + // <-------=====0=====-------> + // <--======----0------------> + let lower = mul_bounds::(dt, &lhs.upper, &rhs.lower); + let upper = mul_bounds::(dt, &lhs.lower, &rhs.lower); + Interval::new(lower, upper) + } else { + // <-------=====0=====-------> + // <------------0--======----> + let lower = mul_bounds::(dt, &lhs.lower, &rhs.upper); + let upper = mul_bounds::(dt, &lhs.upper, &rhs.upper); + Interval::new(lower, upper) + } +} + +/// Multiplies two intervals when neither of them contains zero. +/// +/// This function takes in two intervals (`lhs` and `rhs`) as arguments and +/// returns their product (whose data type is known to be `dt`). It is +/// specifically designed to handle intervals that do not contain zero within +/// their ranges. Returns an error if the multiplication of bounds fails. +/// +/// ``` text +/// Left-hand side: <--======----0------------> +/// Right-hand side: <--======----0------------> +/// +/// or +/// +/// Left-hand side: <--======----0------------> +/// Right-hand side: <------------0--======----> +/// +/// or +/// +/// Left-hand side: <------------0--======----> +/// Right-hand side: <--======----0------------> +/// +/// or +/// +/// Left-hand side: <------------0--======----> +/// Right-hand side: <------------0--======----> +/// ``` +/// +/// **Caution:** This function contains multiple calls to `unwrap()`. Therefore, +/// it should be used with caution. Currently, it is used in contexts where the +/// `DataType` (`dt`) is validated prior to calling this function. +fn mul_helper_zero_exclusive( + dt: &DataType, + lhs: &Interval, + rhs: &Interval, + zero: ScalarValue, +) -> Interval { + let (lower, upper) = match ( + lhs.upper <= zero && !lhs.upper.is_null(), + rhs.upper <= zero && !rhs.upper.is_null(), + ) { + // With the following interval bounds, there is no possibility to create an invalid interval. + (true, true) => ( + // <--======----0------------> + // <--======----0------------> + mul_bounds::(dt, &lhs.upper, &rhs.upper), + mul_bounds::(dt, &lhs.lower, &rhs.lower), + ), + (true, false) => ( + // <--======----0------------> + // <------------0--======----> + mul_bounds::(dt, &lhs.lower, &rhs.upper), + mul_bounds::(dt, &lhs.upper, &rhs.lower), + ), + (false, true) => ( + // <------------0--======----> + // <--======----0------------> + mul_bounds::(dt, &rhs.lower, &lhs.upper), + mul_bounds::(dt, &rhs.upper, &lhs.lower), + ), + (false, false) => ( + // <------------0--======----> + // <------------0--======----> + mul_bounds::(dt, &lhs.lower, &rhs.lower), + mul_bounds::(dt, &lhs.upper, &rhs.upper), + ), + }; + Interval::new(lower, upper) +} + +/// Divides the left-hand side interval by the right-hand side interval when +/// the former contains zero. +/// +/// This function takes in two intervals (`lhs` and `rhs`) as arguments and +/// returns their quotient (whose data type is known to be `dt`). This function +/// serves as a subroutine that handles the specific case when only `lhs` contains +/// zero within its range. Returns an error if the division of bounds fails. +/// +/// ``` text +/// Left-hand side: <-------=====0=====-------> +/// Right-hand side: <--======----0------------> +/// +/// or +/// +/// Left-hand side: <-------=====0=====-------> +/// Right-hand side: <------------0--======----> +/// ``` +/// +/// **Caution:** This function contains multiple calls to `unwrap()`. Therefore, +/// it should be used with caution. Currently, it is used in contexts where the +/// `DataType` (`dt`) is validated prior to calling this function. +fn div_helper_lhs_zero_inclusive( + dt: &DataType, + lhs: &Interval, + rhs: &Interval, + zero_point: &Interval, +) -> Interval { + // With the following interval bounds, there is no possibility to create an invalid interval. + if rhs.upper <= zero_point.lower && !rhs.upper.is_null() { + // <-------=====0=====-------> + // <--======----0------------> + let lower = div_bounds::(dt, &lhs.upper, &rhs.upper); + let upper = div_bounds::(dt, &lhs.lower, &rhs.upper); + Interval::new(lower, upper) + } else { + // <-------=====0=====-------> + // <------------0--======----> + let lower = div_bounds::(dt, &lhs.lower, &rhs.lower); + let upper = div_bounds::(dt, &lhs.upper, &rhs.lower); + Interval::new(lower, upper) + } +} + +/// Divides the left-hand side interval by the right-hand side interval when +/// neither interval contains zero. +/// +/// This function takes in two intervals (`lhs` and `rhs`) as arguments and +/// returns their quotient (whose data type is known to be `dt`). It is +/// specifically designed to handle intervals that do not contain zero within +/// their ranges. Returns an error if the division of bounds fails. +/// +/// ``` text +/// Left-hand side: <--======----0------------> +/// Right-hand side: <--======----0------------> +/// +/// or +/// +/// Left-hand side: <--======----0------------> +/// Right-hand side: <------------0--======----> +/// +/// or +/// +/// Left-hand side: <------------0--======----> +/// Right-hand side: <--======----0------------> +/// +/// or +/// +/// Left-hand side: <------------0--======----> +/// Right-hand side: <------------0--======----> +/// ``` +/// +/// **Caution:** This function contains multiple calls to `unwrap()`. Therefore, +/// it should be used with caution. Currently, it is used in contexts where the +/// `DataType` (`dt`) is validated prior to calling this function. +fn div_helper_zero_exclusive( + dt: &DataType, + lhs: &Interval, + rhs: &Interval, + zero_point: &Interval, +) -> Interval { + let (lower, upper) = match ( + lhs.upper <= zero_point.lower && !lhs.upper.is_null(), + rhs.upper <= zero_point.lower && !rhs.upper.is_null(), + ) { + // With the following interval bounds, there is no possibility to create an invalid interval. + (true, true) => ( + // <--======----0------------> + // <--======----0------------> + div_bounds::(dt, &lhs.upper, &rhs.lower), + div_bounds::(dt, &lhs.lower, &rhs.upper), + ), + (true, false) => ( + // <--======----0------------> + // <------------0--======----> + div_bounds::(dt, &lhs.lower, &rhs.lower), + div_bounds::(dt, &lhs.upper, &rhs.upper), + ), + (false, true) => ( + // <------------0--======----> + // <--======----0------------> + div_bounds::(dt, &lhs.upper, &rhs.upper), + div_bounds::(dt, &lhs.lower, &rhs.lower), + ), + (false, false) => ( + // <------------0--======----> + // <------------0--======----> + div_bounds::(dt, &lhs.lower, &rhs.upper), + div_bounds::(dt, &lhs.upper, &rhs.lower), + ), + }; + Interval::new(lower, upper) +} + +/// This function computes the selectivity of an operation by computing the +/// cardinality ratio of the given input/output intervals. If this can not be +/// calculated for some reason, it returns `1.0` meaning fully selective (no +/// filtering). +pub fn cardinality_ratio(initial_interval: &Interval, final_interval: &Interval) -> f64 { + match (final_interval.cardinality(), initial_interval.cardinality()) { + (Some(final_interval), Some(initial_interval)) => { + (final_interval as f64) / (initial_interval as f64) + } + _ => 1.0, + } +} + +/// Cast scalar value to the given data type using an arrow kernel. +fn cast_scalar_value( + value: &ScalarValue, + data_type: &DataType, + cast_options: &CastOptions, +) -> Result { + let cast_array = cast_with_options(&value.to_array()?, data_type, cast_options)?; + ScalarValue::try_from_array(&cast_array, 0) +} + +/// An [Interval] that also tracks null status using a boolean interval. +/// +/// This represents values that may be in a particular range or be null. +/// +/// # Examples +/// +/// ``` +/// use arrow::datatypes::DataType; +/// use datafusion_common::ScalarValue; +/// use datafusion_expr::interval_arithmetic::Interval; +/// use datafusion_expr::interval_arithmetic::NullableInterval; +/// +/// // [1, 2) U {NULL} +/// let maybe_null = NullableInterval::MaybeNull { +/// values: Interval::try_new( +/// ScalarValue::Int32(Some(1)), +/// ScalarValue::Int32(Some(2)), +/// ).unwrap(), +/// }; +/// +/// // (0, ∞) +/// let not_null = NullableInterval::NotNull { +/// values: Interval::try_new( +/// ScalarValue::Int32(Some(0)), +/// ScalarValue::Int32(None), +/// ).unwrap(), +/// }; +/// +/// // {NULL} +/// let null_interval = NullableInterval::Null { datatype: DataType::Int32 }; +/// +/// // {4} +/// let single_value = NullableInterval::from(ScalarValue::Int32(Some(4))); +/// ``` +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum NullableInterval { + /// The value is always null. This is typed so it can be used in physical + /// expressions, which don't do type coercion. + Null { datatype: DataType }, + /// The value may or may not be null. If it is non-null, its is within the + /// specified range. + MaybeNull { values: Interval }, + /// The value is definitely not null, and is within the specified range. + NotNull { values: Interval }, +} + +impl Display for NullableInterval { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::Null { .. } => write!(f, "NullableInterval: {{NULL}}"), + Self::MaybeNull { values } => { + write!(f, "NullableInterval: {} U {{NULL}}", values) + } + Self::NotNull { values } => write!(f, "NullableInterval: {}", values), + } + } +} + +impl From for NullableInterval { + /// Create an interval that represents a single value. + fn from(value: ScalarValue) -> Self { + if value.is_null() { + Self::Null { + datatype: value.data_type(), + } + } else { + Self::NotNull { + values: Interval { + lower: value.clone(), + upper: value, + }, + } + } + } +} + +impl NullableInterval { + /// Get the values interval, or None if this interval is definitely null. + pub fn values(&self) -> Option<&Interval> { + match self { + Self::Null { .. } => None, + Self::MaybeNull { values } | Self::NotNull { values } => Some(values), + } + } + + /// Get the data type + pub fn data_type(&self) -> DataType { + match self { + Self::Null { datatype } => datatype.clone(), + Self::MaybeNull { values } | Self::NotNull { values } => values.data_type(), + } + } + + /// Return true if the value is definitely true (and not null). + pub fn is_certainly_true(&self) -> bool { + match self { + Self::Null { .. } | Self::MaybeNull { .. } => false, + Self::NotNull { values } => values == &Interval::CERTAINLY_TRUE, + } + } + + /// Return true if the value is definitely false (and not null). + pub fn is_certainly_false(&self) -> bool { + match self { + Self::Null { .. } => false, + Self::MaybeNull { .. } => false, + Self::NotNull { values } => values == &Interval::CERTAINLY_FALSE, + } + } + + /// Perform logical negation on a boolean nullable interval. + fn not(&self) -> Result { + match self { + Self::Null { datatype } => Ok(Self::Null { + datatype: datatype.clone(), + }), + Self::MaybeNull { values } => Ok(Self::MaybeNull { + values: values.not()?, + }), + Self::NotNull { values } => Ok(Self::NotNull { + values: values.not()?, + }), + } + } + + /// Apply the given operator to this interval and the given interval. + /// + /// # Examples + /// + /// ``` + /// use datafusion_common::ScalarValue; + /// use datafusion_expr::Operator; + /// use datafusion_expr::interval_arithmetic::Interval; + /// use datafusion_expr::interval_arithmetic::NullableInterval; + /// + /// // 4 > 3 -> true + /// let lhs = NullableInterval::from(ScalarValue::Int32(Some(4))); + /// let rhs = NullableInterval::from(ScalarValue::Int32(Some(3))); + /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); + /// assert_eq!(result, NullableInterval::from(ScalarValue::Boolean(Some(true)))); + /// + /// // [1, 3) > NULL -> NULL + /// let lhs = NullableInterval::NotNull { + /// values: Interval::try_new( + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(Some(3)), + /// ).unwrap(), + /// }; + /// let rhs = NullableInterval::from(ScalarValue::Int32(None)); + /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); + /// assert_eq!(result.single_value(), Some(ScalarValue::Boolean(None))); + /// + /// // [1, 3] > [2, 4] -> [false, true] + /// let lhs = NullableInterval::NotNull { + /// values: Interval::try_new( + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(Some(3)), + /// ).unwrap(), + /// }; + /// let rhs = NullableInterval::NotNull { + /// values: Interval::try_new( + /// ScalarValue::Int32(Some(2)), + /// ScalarValue::Int32(Some(4)), + /// ).unwrap(), + /// }; + /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); + /// // Both inputs are valid (non-null), so result must be non-null + /// assert_eq!(result, NullableInterval::NotNull { + /// // Uncertain whether inequality is true or false + /// values: Interval::UNCERTAIN, + /// }); + /// ``` + pub fn apply_operator(&self, op: &Operator, rhs: &Self) -> Result { + match op { + Operator::IsDistinctFrom => { + let values = match (self, rhs) { + // NULL is distinct from NULL -> False + (Self::Null { .. }, Self::Null { .. }) => Interval::CERTAINLY_FALSE, + // x is distinct from y -> x != y, + // if at least one of them is never null. + (Self::NotNull { .. }, _) | (_, Self::NotNull { .. }) => { + let lhs_values = self.values(); + let rhs_values = rhs.values(); + match (lhs_values, rhs_values) { + (Some(lhs_values), Some(rhs_values)) => { + lhs_values.equal(rhs_values)?.not()? + } + (Some(_), None) | (None, Some(_)) => Interval::CERTAINLY_TRUE, + (None, None) => unreachable!("Null case handled above"), + } + } + _ => Interval::UNCERTAIN, + }; + // IsDistinctFrom never returns null. + Ok(Self::NotNull { values }) + } + Operator::IsNotDistinctFrom => self + .apply_operator(&Operator::IsDistinctFrom, rhs) + .map(|i| i.not())?, + _ => { + if let (Some(left_values), Some(right_values)) = + (self.values(), rhs.values()) + { + let values = apply_operator(op, left_values, right_values)?; + match (self, rhs) { + (Self::NotNull { .. }, Self::NotNull { .. }) => { + Ok(Self::NotNull { values }) + } + _ => Ok(Self::MaybeNull { values }), + } + } else if op.is_comparison_operator() { + Ok(Self::Null { + datatype: DataType::Boolean, + }) + } else { + Ok(Self::Null { + datatype: self.data_type(), + }) + } + } + } + } + + /// Decide if this interval is a superset of, overlaps with, or + /// disjoint with `other` by returning `[true, true]`, `[false, true]` or + /// `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub fn contains>(&self, other: T) -> Result { + let rhs = other.borrow(); + if let (Some(left_values), Some(right_values)) = (self.values(), rhs.values()) { + left_values + .contains(right_values) + .map(|values| match (self, rhs) { + (Self::NotNull { .. }, Self::NotNull { .. }) => { + Self::NotNull { values } + } + _ => Self::MaybeNull { values }, + }) + } else { + Ok(Self::Null { + datatype: DataType::Boolean, + }) + } + } + + /// If the interval has collapsed to a single value, return that value. + /// Otherwise, returns `None`. + /// + /// # Examples + /// + /// ``` + /// use datafusion_common::ScalarValue; + /// use datafusion_expr::interval_arithmetic::Interval; + /// use datafusion_expr::interval_arithmetic::NullableInterval; + /// + /// let interval = NullableInterval::from(ScalarValue::Int32(Some(4))); + /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(Some(4)))); + /// + /// let interval = NullableInterval::from(ScalarValue::Int32(None)); + /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(None))); + /// + /// let interval = NullableInterval::MaybeNull { + /// values: Interval::try_new( + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(Some(4)), + /// ).unwrap(), + /// }; + /// assert_eq!(interval.single_value(), None); + /// ``` + pub fn single_value(&self) -> Option { + match self { + Self::Null { datatype } => { + Some(ScalarValue::try_from(datatype).unwrap_or(ScalarValue::Null)) + } + Self::MaybeNull { values } | Self::NotNull { values } + if values.lower == values.upper && !values.lower.is_null() => + { + Some(values.lower.clone()) + } + _ => None, + } + } +} + +#[cfg(test)] +mod tests { + use crate::interval_arithmetic::{next_value, prev_value, satisfy_greater, Interval}; + + use arrow::datatypes::DataType; + use datafusion_common::{Result, ScalarValue}; + + #[test] + fn test_next_prev_value() -> Result<()> { + let zeros = vec![ + ScalarValue::new_zero(&DataType::UInt8)?, + ScalarValue::new_zero(&DataType::UInt16)?, + ScalarValue::new_zero(&DataType::UInt32)?, + ScalarValue::new_zero(&DataType::UInt64)?, + ScalarValue::new_zero(&DataType::Int8)?, + ScalarValue::new_zero(&DataType::Int16)?, + ScalarValue::new_zero(&DataType::Int32)?, + ScalarValue::new_zero(&DataType::Int64)?, + ]; + let ones = vec![ + ScalarValue::new_one(&DataType::UInt8)?, + ScalarValue::new_one(&DataType::UInt16)?, + ScalarValue::new_one(&DataType::UInt32)?, + ScalarValue::new_one(&DataType::UInt64)?, + ScalarValue::new_one(&DataType::Int8)?, + ScalarValue::new_one(&DataType::Int16)?, + ScalarValue::new_one(&DataType::Int32)?, + ScalarValue::new_one(&DataType::Int64)?, + ]; + zeros.into_iter().zip(ones).for_each(|(z, o)| { + assert_eq!(next_value(z.clone()), o); + assert_eq!(prev_value(o), z); + }); + + let values = vec![ + ScalarValue::new_zero(&DataType::Float32)?, + ScalarValue::new_zero(&DataType::Float64)?, + ]; + let eps = vec![ + ScalarValue::Float32(Some(1e-6)), + ScalarValue::Float64(Some(1e-6)), + ]; + values.into_iter().zip(eps).for_each(|(value, eps)| { + assert!(next_value(value.clone()) + .sub(value.clone()) + .unwrap() + .lt(&eps)); + assert!(value + .clone() + .sub(prev_value(value.clone())) + .unwrap() + .lt(&eps)); + assert_ne!(next_value(value.clone()), value); + assert_ne!(prev_value(value.clone()), value); + }); + + let min_max = vec![ + ( + ScalarValue::UInt64(Some(u64::MIN)), + ScalarValue::UInt64(Some(u64::MAX)), + ), + ( + ScalarValue::Int8(Some(i8::MIN)), + ScalarValue::Int8(Some(i8::MAX)), + ), + ( + ScalarValue::Float32(Some(f32::MIN)), + ScalarValue::Float32(Some(f32::MAX)), + ), + ( + ScalarValue::Float64(Some(f64::MIN)), + ScalarValue::Float64(Some(f64::MAX)), + ), + ]; + let inf = vec![ + ScalarValue::UInt64(None), + ScalarValue::Int8(None), + ScalarValue::Float32(None), + ScalarValue::Float64(None), + ]; + min_max.into_iter().zip(inf).for_each(|((min, max), inf)| { + assert_eq!(next_value(max.clone()), inf); + assert_ne!(prev_value(max.clone()), max); + assert_ne!(prev_value(max.clone()), inf); + + assert_eq!(prev_value(min.clone()), inf); + assert_ne!(next_value(min.clone()), min); + assert_ne!(next_value(min.clone()), inf); + + assert_eq!(next_value(inf.clone()), inf); + assert_eq!(prev_value(inf.clone()), inf); + }); + + Ok(()) + } + + #[test] + fn test_new_interval() -> Result<()> { + use ScalarValue::*; + + let cases = vec![ + ( + (Boolean(None), Boolean(Some(false))), + Boolean(Some(false)), + Boolean(Some(false)), + ), + ( + (Boolean(Some(false)), Boolean(None)), + Boolean(Some(false)), + Boolean(Some(true)), + ), + ( + (Boolean(Some(false)), Boolean(Some(true))), + Boolean(Some(false)), + Boolean(Some(true)), + ), + ( + (UInt16(Some(u16::MAX)), UInt16(None)), + UInt16(Some(u16::MAX)), + UInt16(None), + ), + ( + (Int16(None), Int16(Some(-1000))), + Int16(None), + Int16(Some(-1000)), + ), + ( + (Float32(Some(f32::MAX)), Float32(Some(f32::MAX))), + Float32(Some(f32::MAX)), + Float32(Some(f32::MAX)), + ), + ( + (Float32(Some(f32::NAN)), Float32(Some(f32::MIN))), + Float32(None), + Float32(Some(f32::MIN)), + ), + ( + ( + Float64(Some(f64::NEG_INFINITY)), + Float64(Some(f64::INFINITY)), + ), + Float64(None), + Float64(None), + ), + ]; + for (inputs, lower, upper) in cases { + let result = Interval::try_new(inputs.0, inputs.1)?; + assert_eq!(result.clone().lower(), &lower); + assert_eq!(result.upper(), &upper); + } + + let invalid_intervals = vec![ + (Float32(Some(f32::INFINITY)), Float32(Some(100_f32))), + (Float64(Some(0_f64)), Float64(Some(f64::NEG_INFINITY))), + (Boolean(Some(true)), Boolean(Some(false))), + (Int32(Some(1000)), Int32(Some(-2000))), + (UInt64(Some(1)), UInt64(Some(0))), + ]; + for (lower, upper) in invalid_intervals { + Interval::try_new(lower, upper).expect_err( + "Given parameters should have given an invalid interval error", + ); + } + + Ok(()) + } + + #[test] + fn test_make_unbounded() -> Result<()> { + use ScalarValue::*; + + let unbounded_cases = vec![ + (DataType::Boolean, Boolean(Some(false)), Boolean(Some(true))), + (DataType::UInt8, UInt8(None), UInt8(None)), + (DataType::UInt16, UInt16(None), UInt16(None)), + (DataType::UInt32, UInt32(None), UInt32(None)), + (DataType::UInt64, UInt64(None), UInt64(None)), + (DataType::Int8, Int8(None), Int8(None)), + (DataType::Int16, Int16(None), Int16(None)), + (DataType::Int32, Int32(None), Int32(None)), + (DataType::Int64, Int64(None), Int64(None)), + (DataType::Float32, Float32(None), Float32(None)), + (DataType::Float64, Float64(None), Float64(None)), + ]; + for (dt, lower, upper) in unbounded_cases { + let inf = Interval::make_unbounded(&dt)?; + assert_eq!(inf.clone().lower(), &lower); + assert_eq!(inf.upper(), &upper); + } + + Ok(()) + } + + #[test] + fn gt_lt_test() -> Result<()> { + let exactly_gt_cases = vec![ + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(999_i64))?, + ), + ( + Interval::make(Some(1000_i64), Some(1000_i64))?, + Interval::make(None, Some(999_i64))?, + ), + ( + Interval::make(Some(501_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(500_i64))?, + ), + ( + Interval::make(Some(-1000_i64), Some(1000_i64))?, + Interval::make(None, Some(-1500_i64))?, + ), + ( + Interval::try_new( + next_value(ScalarValue::Float32(Some(0.0))), + next_value(ScalarValue::Float32(Some(0.0))), + )?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-1.0_f32), Some(-1.0_f32))?, + Interval::try_new( + prev_value(ScalarValue::Float32(Some(-1.0))), + prev_value(ScalarValue::Float32(Some(-1.0))), + )?, + ), + ]; + for (first, second) in exactly_gt_cases { + assert_eq!(first.gt(second.clone())?, Interval::CERTAINLY_TRUE); + assert_eq!(second.lt(first)?, Interval::CERTAINLY_TRUE); + } + + let possibly_gt_cases = vec![ + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(Some(1000_i64), None)?, + ), + ( + Interval::make::(None, None)?, + Interval::make::(None, None)?, + ), + ( + Interval::try_new( + ScalarValue::Float32(Some(0.0_f32)), + next_value(ScalarValue::Float32(Some(0.0_f32))), + )?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-1.0_f32), Some(-1.0_f32))?, + Interval::try_new( + prev_value(ScalarValue::Float32(Some(-1.0_f32))), + ScalarValue::Float32(Some(-1.0_f32)), + )?, + ), + ]; + for (first, second) in possibly_gt_cases { + assert_eq!(first.gt(second.clone())?, Interval::UNCERTAIN); + assert_eq!(second.lt(first)?, Interval::UNCERTAIN); + } + + let not_gt_cases = vec![ + ( + Interval::make(Some(1000_i64), Some(1000_i64))?, + Interval::make(Some(1000_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(1000_i64))?, + Interval::make(Some(1000_i64), None)?, + ), + ( + Interval::make(None, Some(1000_i64))?, + Interval::make(Some(1000_i64), Some(1500_i64))?, + ), + ( + Interval::try_new( + prev_value(ScalarValue::Float32(Some(0.0_f32))), + ScalarValue::Float32(Some(0.0_f32)), + )?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-1.0_f32), Some(-1.0_f32))?, + Interval::try_new( + ScalarValue::Float32(Some(-1.0_f32)), + next_value(ScalarValue::Float32(Some(-1.0_f32))), + )?, + ), + ]; + for (first, second) in not_gt_cases { + assert_eq!(first.gt(second.clone())?, Interval::CERTAINLY_FALSE); + assert_eq!(second.lt(first)?, Interval::CERTAINLY_FALSE); + } + + Ok(()) + } + + #[test] + fn gteq_lteq_test() -> Result<()> { + let exactly_gteq_cases = vec![ + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(1000_i64))?, + ), + ( + Interval::make(Some(1000_i64), Some(1000_i64))?, + Interval::make(None, Some(1000_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(500_i64))?, + ), + ( + Interval::make(Some(-1000_i64), Some(1000_i64))?, + Interval::make(None, Some(-1500_i64))?, + ), + ( + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::try_new( + ScalarValue::Float32(Some(-1.0)), + next_value(ScalarValue::Float32(Some(-1.0))), + )?, + Interval::try_new( + prev_value(ScalarValue::Float32(Some(-1.0))), + ScalarValue::Float32(Some(-1.0)), + )?, + ), + ]; + for (first, second) in exactly_gteq_cases { + assert_eq!(first.gt_eq(second.clone())?, Interval::CERTAINLY_TRUE); + assert_eq!(second.lt_eq(first)?, Interval::CERTAINLY_TRUE); + } + + let possibly_gteq_cases = vec![ + ( + Interval::make(Some(999_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(1001_i64))?, + ), + ( + Interval::make(Some(0_i64), None)?, + Interval::make(Some(1000_i64), None)?, + ), + ( + Interval::make::(None, None)?, + Interval::make::(None, None)?, + ), + ( + Interval::try_new( + prev_value(ScalarValue::Float32(Some(0.0))), + ScalarValue::Float32(Some(0.0)), + )?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-1.0_f32), Some(-1.0_f32))?, + Interval::try_new( + prev_value(ScalarValue::Float32(Some(-1.0_f32))), + next_value(ScalarValue::Float32(Some(-1.0_f32))), + )?, + ), + ]; + for (first, second) in possibly_gteq_cases { + assert_eq!(first.gt_eq(second.clone())?, Interval::UNCERTAIN); + assert_eq!(second.lt_eq(first)?, Interval::UNCERTAIN); + } + + let not_gteq_cases = vec![ + ( + Interval::make(Some(1000_i64), Some(1000_i64))?, + Interval::make(Some(2000_i64), Some(2000_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(999_i64))?, + Interval::make(Some(1000_i64), None)?, + ), + ( + Interval::make(None, Some(1000_i64))?, + Interval::make(Some(1001_i64), Some(1500_i64))?, + ), + ( + Interval::try_new( + prev_value(ScalarValue::Float32(Some(0.0_f32))), + prev_value(ScalarValue::Float32(Some(0.0_f32))), + )?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-1.0_f32), Some(-1.0_f32))?, + Interval::try_new( + next_value(ScalarValue::Float32(Some(-1.0))), + next_value(ScalarValue::Float32(Some(-1.0))), + )?, + ), + ]; + for (first, second) in not_gteq_cases { + assert_eq!(first.gt_eq(second.clone())?, Interval::CERTAINLY_FALSE); + assert_eq!(second.lt_eq(first)?, Interval::CERTAINLY_FALSE); + } + + Ok(()) + } + + #[test] + fn equal_test() -> Result<()> { + let exactly_eq_cases = vec![ + ( + Interval::make(Some(1000_i64), Some(1000_i64))?, + Interval::make(Some(1000_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(0_u64), Some(0_u64))?, + Interval::make(Some(0_u64), Some(0_u64))?, + ), + ( + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + ), + ( + Interval::make(Some(f64::MIN), Some(f64::MIN))?, + Interval::make(Some(f64::MIN), Some(f64::MIN))?, + ), + ]; + for (first, second) in exactly_eq_cases { + assert_eq!(first.equal(second.clone())?, Interval::CERTAINLY_TRUE); + assert_eq!(second.equal(first)?, Interval::CERTAINLY_TRUE); + } + + let possibly_eq_cases = vec![ + ( + Interval::make::(None, None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(0_i64), Some(0_i64))?, + Interval::make(Some(0_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(0_i64), Some(0_i64))?, + Interval::make(Some(0_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(100.0_f32), Some(200.0_f32))?, + Interval::make(Some(0.0_f32), Some(1000.0_f32))?, + ), + ( + Interval::try_new( + prev_value(ScalarValue::Float32(Some(0.0))), + ScalarValue::Float32(Some(0.0)), + )?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-1.0_f32), Some(-1.0_f32))?, + Interval::try_new( + prev_value(ScalarValue::Float32(Some(-1.0))), + next_value(ScalarValue::Float32(Some(-1.0))), + )?, + ), + ]; + for (first, second) in possibly_eq_cases { + assert_eq!(first.equal(second.clone())?, Interval::UNCERTAIN); + assert_eq!(second.equal(first)?, Interval::UNCERTAIN); + } + + let not_eq_cases = vec![ + ( + Interval::make(Some(1000_i64), Some(1000_i64))?, + Interval::make(Some(2000_i64), Some(2000_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(999_i64))?, + Interval::make(Some(1000_i64), None)?, + ), + ( + Interval::make(None, Some(1000_i64))?, + Interval::make(Some(1001_i64), Some(1500_i64))?, + ), + ( + Interval::try_new( + prev_value(ScalarValue::Float32(Some(0.0))), + prev_value(ScalarValue::Float32(Some(0.0))), + )?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-1.0_f32), Some(-1.0_f32))?, + Interval::try_new( + next_value(ScalarValue::Float32(Some(-1.0))), + next_value(ScalarValue::Float32(Some(-1.0))), + )?, + ), + ]; + for (first, second) in not_eq_cases { + assert_eq!(first.equal(second.clone())?, Interval::CERTAINLY_FALSE); + assert_eq!(second.equal(first)?, Interval::CERTAINLY_FALSE); + } + + Ok(()) + } + + #[test] + fn and_test() -> Result<()> { + let cases = vec![ + (false, true, false, false, false, false), + (false, false, false, true, false, false), + (false, true, false, true, false, true), + (false, true, true, true, false, true), + (false, false, false, false, false, false), + (true, true, true, true, true, true), + ]; + + for case in cases { + assert_eq!( + Interval::make(Some(case.0), Some(case.1))? + .and(Interval::make(Some(case.2), Some(case.3))?)?, + Interval::make(Some(case.4), Some(case.5))? + ); + } + Ok(()) + } + + #[test] + fn not_test() -> Result<()> { + let cases = vec![ + (false, true, false, true), + (false, false, true, true), + (true, true, false, false), + ]; + + for case in cases { + assert_eq!( + Interval::make(Some(case.0), Some(case.1))?.not()?, + Interval::make(Some(case.2), Some(case.3))? + ); + } + Ok(()) + } + + #[test] + fn intersect_test() -> Result<()> { + let possible_cases = vec![ + ( + Interval::make(Some(1000_i64), None)?, + Interval::make::(None, None)?, + Interval::make(Some(1000_i64), None)?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(1000_i64))?, + Interval::make(Some(1000_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(2000_i64))?, + Interval::make(Some(1000_i64), Some(2000_i64))?, + ), + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), None)?, + Interval::make(Some(1000_i64), Some(2000_i64))?, + ), + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), Some(1500_i64))?, + Interval::make(Some(1000_i64), Some(1500_i64))?, + ), + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make(Some(500_i64), Some(1500_i64))?, + Interval::make(Some(1000_i64), Some(1500_i64))?, + ), + ( + Interval::make::(None, None)?, + Interval::make::(None, None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make(None, Some(2000_u64))?, + Interval::make(Some(500_u64), None)?, + Interval::make(Some(500_u64), Some(2000_u64))?, + ), + ( + Interval::make(Some(0_u64), Some(0_u64))?, + Interval::make(Some(0_u64), None)?, + Interval::make(Some(0_u64), Some(0_u64))?, + ), + ( + Interval::make(Some(1000.0_f32), None)?, + Interval::make(None, Some(1000.0_f32))?, + Interval::make(Some(1000.0_f32), Some(1000.0_f32))?, + ), + ( + Interval::make(Some(1000.0_f32), Some(1500.0_f32))?, + Interval::make(Some(0.0_f32), Some(1500.0_f32))?, + Interval::make(Some(1000.0_f32), Some(1500.0_f32))?, + ), + ( + Interval::make(Some(-1000.0_f64), Some(1500.0_f64))?, + Interval::make(Some(-1500.0_f64), Some(2000.0_f64))?, + Interval::make(Some(-1000.0_f64), Some(1500.0_f64))?, + ), + ( + Interval::make(Some(16.0_f64), Some(32.0_f64))?, + Interval::make(Some(32.0_f64), Some(64.0_f64))?, + Interval::make(Some(32.0_f64), Some(32.0_f64))?, + ), + ]; + for (first, second, expected) in possible_cases { + assert_eq!(first.intersect(second)?.unwrap(), expected) + } + + let empty_cases = vec![ + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(0_i64))?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(999_i64))?, + ), + ( + Interval::make(Some(1500_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), Some(1499_i64))?, + ), + ( + Interval::make(Some(0_i64), Some(1000_i64))?, + Interval::make(Some(2000_i64), Some(3000_i64))?, + ), + ( + Interval::try_new( + prev_value(ScalarValue::Float32(Some(1.0))), + prev_value(ScalarValue::Float32(Some(1.0))), + )?, + Interval::make(Some(1.0_f32), Some(1.0_f32))?, + ), + ( + Interval::try_new( + next_value(ScalarValue::Float32(Some(1.0))), + next_value(ScalarValue::Float32(Some(1.0))), + )?, + Interval::make(Some(1.0_f32), Some(1.0_f32))?, + ), + ]; + for (first, second) in empty_cases { + assert_eq!(first.intersect(second)?, None) + } + + Ok(()) + } + + #[test] + fn test_contains() -> Result<()> { + let possible_cases = vec![ + ( + Interval::make::(None, None)?, + Interval::make::(None, None)?, + Interval::CERTAINLY_TRUE, + ), + ( + Interval::make(Some(1500_i64), Some(2000_i64))?, + Interval::make(Some(1501_i64), Some(1999_i64))?, + Interval::CERTAINLY_TRUE, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make::(None, None)?, + Interval::UNCERTAIN, + ), + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make(Some(500), Some(1500_i64))?, + Interval::UNCERTAIN, + ), + ( + Interval::make(Some(16.0), Some(32.0))?, + Interval::make(Some(32.0), Some(64.0))?, + Interval::UNCERTAIN, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(0_i64))?, + Interval::CERTAINLY_FALSE, + ), + ( + Interval::make(Some(1500_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), Some(1499_i64))?, + Interval::CERTAINLY_FALSE, + ), + ( + Interval::try_new( + prev_value(ScalarValue::Float32(Some(1.0))), + prev_value(ScalarValue::Float32(Some(1.0))), + )?, + Interval::make(Some(1.0_f32), Some(1.0_f32))?, + Interval::CERTAINLY_FALSE, + ), + ( + Interval::try_new( + next_value(ScalarValue::Float32(Some(1.0))), + next_value(ScalarValue::Float32(Some(1.0))), + )?, + Interval::make(Some(1.0_f32), Some(1.0_f32))?, + Interval::CERTAINLY_FALSE, + ), + ]; + for (first, second, expected) in possible_cases { + assert_eq!(first.contains(second)?, expected) + } + + Ok(()) + } + + #[test] + fn test_add() -> Result<()> { + let cases = vec![ + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(None, Some(200_i64))?, + Interval::make(None, Some(400_i64))?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(200_i64), None)?, + Interval::make(Some(300_i64), None)?, + ), + ( + Interval::make(None, Some(200_i64))?, + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(None, Some(400_i64))?, + ), + ( + Interval::make(Some(200_i64), None)?, + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(300_i64), None)?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(-300_i64), Some(150_i64))?, + Interval::make(Some(-200_i64), Some(350_i64))?, + ), + ( + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(Some(11_f32), Some(11_f32))?, + Interval::make(Some(f32::MAX), None)?, + ), + ( + Interval::make(Some(f32::MIN), Some(f32::MIN))?, + Interval::make(Some(-10_f32), Some(10_f32))?, + // Since rounding mode is up, the result would be much greater than f32::MIN + // (f32::MIN = -3.4_028_235e38, the result is -3.4_028_233e38) + Interval::make( + None, + Some(-340282330000000000000000000000000000000.0_f32), + )?, + ), + ( + Interval::make(Some(f32::MIN), Some(f32::MIN))?, + Interval::make(Some(-10_f32), Some(-10_f32))?, + Interval::make(None, Some(f32::MIN))?, + ), + ( + Interval::make(Some(1.0), Some(f32::MAX))?, + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(Some(f32::MAX), None)?, + ), + ( + Interval::make(Some(f32::MIN), Some(f32::MIN))?, + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(Some(-0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(100_f64), None)?, + Interval::make(None, Some(200_f64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(None, Some(100_f64))?, + Interval::make(None, Some(200_f64))?, + Interval::make(None, Some(300_f64))?, + ), + ]; + for case in cases { + let result = case.0.add(case.1)?; + if case.0.data_type().is_floating() { + assert!( + result.lower().is_null() && case.2.lower().is_null() + || result.lower().le(case.2.lower()) + ); + assert!( + result.upper().is_null() && case.2.upper().is_null() + || result.upper().ge(case.2.upper()) + ); + } else { + assert_eq!(result, case.2); + } + } + + Ok(()) + } + + #[test] + fn test_sub() -> Result<()> { + let cases = vec![ + ( + Interval::make(Some(i32::MAX), Some(i32::MAX))?, + Interval::make(Some(11_i32), Some(11_i32))?, + Interval::make(Some(i32::MAX - 11), Some(i32::MAX - 11))?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(None, Some(200_i64))?, + Interval::make(Some(-100_i64), None)?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(200_i64), None)?, + Interval::make(None, Some(0_i64))?, + ), + ( + Interval::make(None, Some(200_i64))?, + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(None, Some(100_i64))?, + ), + ( + Interval::make(Some(200_i64), None)?, + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(0_i64), None)?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(-300_i64), Some(150_i64))?, + Interval::make(Some(-50_i64), Some(500_i64))?, + ), + ( + Interval::make(Some(i64::MIN), Some(i64::MIN))?, + Interval::make(Some(-10_i64), Some(-10_i64))?, + Interval::make(Some(i64::MIN + 10), Some(i64::MIN + 10))?, + ), + ( + Interval::make(Some(1), Some(i64::MAX))?, + Interval::make(Some(i64::MAX), Some(i64::MAX))?, + Interval::make(Some(1 - i64::MAX), Some(0))?, + ), + ( + Interval::make(Some(i64::MIN), Some(i64::MIN))?, + Interval::make(Some(i64::MAX), Some(i64::MAX))?, + Interval::make(None, Some(i64::MIN))?, + ), + ( + Interval::make(Some(2_u32), Some(10_u32))?, + Interval::make(Some(4_u32), Some(6_u32))?, + Interval::make(None, Some(6_u32))?, + ), + ( + Interval::make(Some(2_u32), Some(10_u32))?, + Interval::make(Some(20_u32), Some(30_u32))?, + Interval::make(None, Some(0_u32))?, + ), + ( + Interval::make(Some(f32::MIN), Some(f32::MIN))?, + Interval::make(Some(-10_f32), Some(10_f32))?, + // Since rounding mode is up, the result would be much larger than f32::MIN + // (f32::MIN = -3.4_028_235e38, the result is -3.4_028_233e38) + Interval::make( + None, + Some(-340282330000000000000000000000000000000.0_f32), + )?, + ), + ( + Interval::make(Some(100_f64), None)?, + Interval::make(None, Some(200_f64))?, + Interval::make(Some(-100_f64), None)?, + ), + ( + Interval::make(None, Some(100_f64))?, + Interval::make(None, Some(200_f64))?, + Interval::make::(None, None)?, + ), + ]; + for case in cases { + let result = case.0.sub(case.1)?; + if case.0.data_type().is_floating() { + assert!( + result.lower().is_null() && case.2.lower().is_null() + || result.lower().le(case.2.lower()) + ); + assert!( + result.upper().is_null() && case.2.upper().is_null() + || result.upper().ge(case.2.upper(),) + ); + } else { + assert_eq!(result, case.2); + } + } + + Ok(()) + } + + #[test] + fn test_mul() -> Result<()> { + let cases = vec![ + ( + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(None, Some(2_i64))?, + Interval::make(None, Some(4_i64))?, + ), + ( + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(Some(2_i64), None)?, + Interval::make(Some(2_i64), None)?, + ), + ( + Interval::make(None, Some(2_i64))?, + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(None, Some(4_i64))?, + ), + ( + Interval::make(Some(2_i64), None)?, + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(Some(2_i64), None)?, + ), + ( + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(Some(-3_i64), Some(15_i64))?, + Interval::make(Some(-6_i64), Some(30_i64))?, + ), + ( + Interval::make(Some(-0.0), Some(0.0))?, + Interval::make(None, Some(0.0))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(f32::MIN), Some(f32::MIN))?, + Interval::make(Some(-10_f32), Some(10_f32))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(1_u32), Some(2_u32))?, + Interval::make(Some(0_u32), Some(1_u32))?, + Interval::make(Some(0_u32), Some(2_u32))?, + ), + ( + Interval::make(None, Some(2_u32))?, + Interval::make(Some(0_u32), Some(1_u32))?, + Interval::make(None, Some(2_u32))?, + ), + ( + Interval::make(None, Some(2_u32))?, + Interval::make(Some(1_u32), Some(2_u32))?, + Interval::make(None, Some(4_u32))?, + ), + ( + Interval::make(None, Some(2_u32))?, + Interval::make(Some(1_u32), None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make::(None, None)?, + Interval::make(Some(0_u32), None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(Some(11_f32), Some(11_f32))?, + Interval::make(Some(f32::MAX), None)?, + ), + ( + Interval::make(Some(f32::MIN), Some(f32::MIN))?, + Interval::make(Some(-10_f32), Some(-10_f32))?, + Interval::make(Some(f32::MAX), None)?, + ), + ( + Interval::make(Some(1.0), Some(f32::MAX))?, + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(Some(f32::MAX), None)?, + ), + ( + Interval::make(Some(f32::MIN), Some(f32::MIN))?, + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(None, Some(f32::MIN))?, + ), + ( + Interval::make(Some(-0.0_f32), Some(0.0_f32))?, + Interval::make(Some(f32::MAX), None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + Interval::make(Some(f32::MAX), None)?, + Interval::make(Some(0.0_f32), None)?, + ), + ( + Interval::make(Some(1_f64), None)?, + Interval::make(None, Some(2_f64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(None, Some(1_f64))?, + Interval::make(None, Some(2_f64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(-0.0_f64), Some(-0.0_f64))?, + Interval::make(Some(1_f64), Some(2_f64))?, + Interval::make(Some(-0.0_f64), Some(-0.0_f64))?, + ), + ( + Interval::make(Some(0.0_f64), Some(0.0_f64))?, + Interval::make(Some(1_f64), Some(2_f64))?, + Interval::make(Some(0.0_f64), Some(0.0_f64))?, + ), + ( + Interval::make(Some(-0.0_f64), Some(0.0_f64))?, + Interval::make(Some(1_f64), Some(2_f64))?, + Interval::make(Some(-0.0_f64), Some(0.0_f64))?, + ), + ( + Interval::make(Some(-0.0_f64), Some(1.0_f64))?, + Interval::make(Some(1_f64), Some(2_f64))?, + Interval::make(Some(-0.0_f64), Some(2.0_f64))?, + ), + ( + Interval::make(Some(0.0_f64), Some(1.0_f64))?, + Interval::make(Some(1_f64), Some(2_f64))?, + Interval::make(Some(0.0_f64), Some(2.0_f64))?, + ), + ( + Interval::make(Some(-0.0_f64), Some(1.0_f64))?, + Interval::make(Some(-1_f64), Some(2_f64))?, + Interval::make(Some(-1.0_f64), Some(2.0_f64))?, + ), + ( + Interval::make::(None, None)?, + Interval::make(Some(-0.0_f64), Some(0.0_f64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make::(None, Some(10.0_f64))?, + Interval::make(Some(-0.0_f64), Some(0.0_f64))?, + Interval::make::(None, None)?, + ), + ]; + for case in cases { + let result = case.0.mul(case.1)?; + if case.0.data_type().is_floating() { + assert!( + result.lower().is_null() && case.2.lower().is_null() + || result.lower().le(case.2.lower()) + ); + assert!( + result.upper().is_null() && case.2.upper().is_null() + || result.upper().ge(case.2.upper()) + ); + } else { + assert_eq!(result, case.2); + } + } + + Ok(()) + } + + #[test] + fn test_div() -> Result<()> { + let cases = vec![ + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(Some(50_i64), Some(200_i64))?, + ), + ( + Interval::make(Some(-200_i64), Some(-100_i64))?, + Interval::make(Some(-2_i64), Some(-1_i64))?, + Interval::make(Some(50_i64), Some(200_i64))?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(-2_i64), Some(-1_i64))?, + Interval::make(Some(-200_i64), Some(-50_i64))?, + ), + ( + Interval::make(Some(-200_i64), Some(-100_i64))?, + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(Some(-200_i64), Some(-50_i64))?, + ), + ( + Interval::make(Some(-200_i64), Some(100_i64))?, + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(Some(-200_i64), Some(100_i64))?, + ), + ( + Interval::make(Some(-100_i64), Some(200_i64))?, + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(Some(-100_i64), Some(200_i64))?, + ), + ( + Interval::make(Some(10_i64), Some(20_i64))?, + Interval::make::(None, None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(-100_i64), Some(200_i64))?, + Interval::make(Some(-1_i64), Some(2_i64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(-100_i64), Some(200_i64))?, + Interval::make(Some(-2_i64), Some(1_i64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(0_i64), Some(1_i64))?, + Interval::make(Some(100_i64), None)?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(None, Some(0_i64))?, + Interval::make(None, Some(0_i64))?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(0_i64), Some(0_i64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(0_i64), Some(1_i64))?, + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(0_i64), Some(0_i64))?, + ), + ( + Interval::make(Some(0_i64), Some(1_i64))?, + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(0_i64), Some(0_i64))?, + ), + ( + Interval::make(Some(1_u32), Some(2_u32))?, + Interval::make(Some(0_u32), Some(0_u32))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(10_u32), Some(20_u32))?, + Interval::make(None, Some(2_u32))?, + Interval::make(Some(5_u32), None)?, + ), + ( + Interval::make(Some(10_u32), Some(20_u32))?, + Interval::make(Some(0_u32), Some(2_u32))?, + Interval::make(Some(5_u32), None)?, + ), + ( + Interval::make(Some(10_u32), Some(20_u32))?, + Interval::make(Some(0_u32), Some(0_u32))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(12_u64), Some(48_u64))?, + Interval::make(Some(10_u64), Some(20_u64))?, + Interval::make(Some(0_u64), Some(4_u64))?, + ), + ( + Interval::make(Some(12_u64), Some(48_u64))?, + Interval::make(None, Some(2_u64))?, + Interval::make(Some(6_u64), None)?, + ), + ( + Interval::make(Some(12_u64), Some(48_u64))?, + Interval::make(Some(0_u64), Some(2_u64))?, + Interval::make(Some(6_u64), None)?, + ), + ( + Interval::make(None, Some(48_u64))?, + Interval::make(Some(0_u64), Some(2_u64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(Some(-0.1_f32), Some(0.1_f32))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(f32::MIN), None)?, + Interval::make(Some(0.1_f32), Some(0.1_f32))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(-10.0_f32), Some(10.0_f32))?, + Interval::make(Some(-0.1_f32), Some(-0.1_f32))?, + Interval::make(Some(-100.0_f32), Some(100.0_f32))?, + ), + ( + Interval::make(Some(-10.0_f32), Some(f32::MAX))?, + Interval::make::(None, None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(f32::MIN), Some(10.0_f32))?, + Interval::make(Some(1.0_f32), None)?, + Interval::make(Some(f32::MIN), Some(10.0_f32))?, + ), + ( + Interval::make(Some(-0.0_f32), Some(0.0_f32))?, + Interval::make(Some(f32::MAX), None)?, + Interval::make(Some(-0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-0.0_f32), Some(0.0_f32))?, + Interval::make(None, Some(-0.0_f32))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + Interval::make(Some(f32::MAX), None)?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(1.0_f32), Some(2.0_f32))?, + Interval::make(Some(0.0_f32), Some(4.0_f32))?, + Interval::make(Some(0.25_f32), None)?, + ), + ( + Interval::make(Some(1.0_f32), Some(2.0_f32))?, + Interval::make(Some(-4.0_f32), Some(-0.0_f32))?, + Interval::make(None, Some(-0.25_f32))?, + ), + ( + Interval::make(Some(-4.0_f64), Some(2.0_f64))?, + Interval::make(Some(10.0_f64), Some(20.0_f64))?, + Interval::make(Some(-0.4_f64), Some(0.2_f64))?, + ), + ( + Interval::make(Some(-0.0_f64), Some(-0.0_f64))?, + Interval::make(None, Some(-0.0_f64))?, + Interval::make(Some(0.0_f64), None)?, + ), + ( + Interval::make(Some(1.0_f64), Some(2.0_f64))?, + Interval::make::(None, None)?, + Interval::make(Some(0.0_f64), None)?, + ), + ]; + for case in cases { + let result = case.0.div(case.1)?; + if case.0.data_type().is_floating() { + assert!( + result.lower().is_null() && case.2.lower().is_null() + || result.lower().le(case.2.lower()) + ); + assert!( + result.upper().is_null() && case.2.upper().is_null() + || result.upper().ge(case.2.upper()) + ); + } else { + assert_eq!(result, case.2); + } + } + + Ok(()) + } + + #[test] + fn test_cardinality_of_intervals() -> Result<()> { + // In IEEE 754 standard for floating-point arithmetic, if we keep the sign and exponent fields same, + // we can represent 4503599627370496+1 different numbers by changing the mantissa + // (4503599627370496 = 2^52, since there are 52 bits in mantissa, and 2^23 = 8388608 for f32). + // TODO: Add tests for non-exponential boundary aligned intervals too. + let distinct_f64 = 4503599627370497; + let distinct_f32 = 8388609; + let intervals = [ + Interval::make(Some(0.25_f64), Some(0.50_f64))?, + Interval::make(Some(0.5_f64), Some(1.0_f64))?, + Interval::make(Some(1.0_f64), Some(2.0_f64))?, + Interval::make(Some(32.0_f64), Some(64.0_f64))?, + Interval::make(Some(-0.50_f64), Some(-0.25_f64))?, + Interval::make(Some(-32.0_f64), Some(-16.0_f64))?, + ]; + for interval in intervals { + assert_eq!(interval.cardinality().unwrap(), distinct_f64); + } + + let intervals = [ + Interval::make(Some(0.25_f32), Some(0.50_f32))?, + Interval::make(Some(-1_f32), Some(-0.5_f32))?, + ]; + for interval in intervals { + assert_eq!(interval.cardinality().unwrap(), distinct_f32); + } + + // The regular logarithmic distribution of floating-point numbers are + // only applicable outside of the `(-phi, phi)` interval where `phi` + // denotes the largest positive subnormal floating-point number. Since + // the following intervals include such subnormal points, we cannot use + // a simple powers-of-two type formula for our expectations. Therefore, + // we manually supply the actual expected cardinality. + let interval = Interval::make(Some(-0.0625), Some(0.0625))?; + assert_eq!(interval.cardinality().unwrap(), 9178336040581070850); + + let interval = Interval::try_new( + ScalarValue::UInt64(Some(u64::MIN + 1)), + ScalarValue::UInt64(Some(u64::MAX)), + )?; + assert_eq!(interval.cardinality().unwrap(), u64::MAX); + + let interval = Interval::try_new( + ScalarValue::Int64(Some(i64::MIN + 1)), + ScalarValue::Int64(Some(i64::MAX)), + )?; + assert_eq!(interval.cardinality().unwrap(), u64::MAX); + + let interval = Interval::try_new( + ScalarValue::Float32(Some(-0.0_f32)), + ScalarValue::Float32(Some(0.0_f32)), + )?; + assert_eq!(interval.cardinality().unwrap(), 2); + + Ok(()) + } + + #[test] + fn test_satisfy_comparison() -> Result<()> { + let cases = vec![ + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(1000_i64))?, + true, + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(1000_i64))?, + ), + ( + Interval::make(None, Some(1000_i64))?, + Interval::make(Some(1000_i64), None)?, + true, + Interval::make(Some(1000_i64), Some(1000_i64))?, + Interval::make(Some(1000_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(1000_i64))?, + false, + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(1000_i64))?, + ), + ( + Interval::make(Some(0_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(1500_i64))?, + true, + Interval::make(Some(500_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(1500_i64))?, + Interval::make(Some(0_i64), Some(1000_i64))?, + true, + Interval::make(Some(500_i64), Some(1500_i64))?, + Interval::make(Some(0_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(0_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(1500_i64))?, + false, + Interval::make(Some(501_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(999_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(1500_i64))?, + Interval::make(Some(0_i64), Some(1000_i64))?, + false, + Interval::make(Some(500_i64), Some(1500_i64))?, + Interval::make(Some(0_i64), Some(1000_i64))?, + ), + ( + Interval::make::(None, None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + false, + Interval::make(Some(2_i64), None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + ), + ( + Interval::make::(None, None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + true, + Interval::make(Some(1_i64), None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + ), + ( + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make::(None, None)?, + false, + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make(None, Some(0_i64))?, + ), + ( + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make::(None, None)?, + true, + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make(None, Some(1_i64))?, + ), + ( + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make::(None, None)?, + false, + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make(None, Some(0_i64))?, + ), + ( + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make::(None, None)?, + true, + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make(None, Some(1_i64))?, + ), + ( + Interval::make::(None, None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + false, + Interval::make(Some(2_i64), None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + ), + ( + Interval::make::(None, None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + true, + Interval::make(Some(1_i64), None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + ), + ( + Interval::make(Some(-1000.0_f32), Some(1000.0_f32))?, + Interval::make(Some(-500.0_f32), Some(500.0_f32))?, + false, + Interval::try_new( + next_value(ScalarValue::Float32(Some(-500.0))), + ScalarValue::Float32(Some(1000.0)), + )?, + Interval::make(Some(-500_f32), Some(500.0_f32))?, + ), + ( + Interval::make(Some(-500.0_f32), Some(500.0_f32))?, + Interval::make(Some(-1000.0_f32), Some(1000.0_f32))?, + true, + Interval::make(Some(-500.0_f32), Some(500.0_f32))?, + Interval::make(Some(-1000.0_f32), Some(500.0_f32))?, + ), + ( + Interval::make(Some(-500.0_f32), Some(500.0_f32))?, + Interval::make(Some(-1000.0_f32), Some(1000.0_f32))?, + false, + Interval::make(Some(-500.0_f32), Some(500.0_f32))?, + Interval::try_new( + ScalarValue::Float32(Some(-1000.0_f32)), + prev_value(ScalarValue::Float32(Some(500.0_f32))), + )?, + ), + ( + Interval::make(Some(-1000.0_f64), Some(1000.0_f64))?, + Interval::make(Some(-500.0_f64), Some(500.0_f64))?, + true, + Interval::make(Some(-500.0_f64), Some(1000.0_f64))?, + Interval::make(Some(-500.0_f64), Some(500.0_f64))?, + ), + ]; + for (first, second, includes_endpoints, left_modified, right_modified) in cases { + assert_eq!( + satisfy_greater(&first, &second, !includes_endpoints)?.unwrap(), + (left_modified, right_modified) + ); + } + + let infeasible_cases = vec![ + ( + Interval::make(None, Some(1000_i64))?, + Interval::make(Some(1000_i64), None)?, + false, + ), + ( + Interval::make(Some(-1000.0_f32), Some(1000.0_f32))?, + Interval::make(Some(1500.0_f32), Some(2000.0_f32))?, + false, + ), + ]; + for (first, second, includes_endpoints) in infeasible_cases { + assert_eq!(satisfy_greater(&first, &second, !includes_endpoints)?, None); + } + + Ok(()) + } + + #[test] + fn test_interval_display() { + let interval = Interval::make(Some(0.25_f32), Some(0.50_f32)).unwrap(); + assert_eq!(format!("{}", interval), "[0.25, 0.5]"); + + let interval = Interval::try_new( + ScalarValue::Float32(Some(f32::NEG_INFINITY)), + ScalarValue::Float32(Some(f32::INFINITY)), + ) + .unwrap(); + assert_eq!(format!("{}", interval), "[NULL, NULL]"); + } + + macro_rules! capture_mode_change { + ($TYPE:ty) => { + paste::item! { + capture_mode_change_helper!([], + [], + $TYPE); + } + }; + } + + macro_rules! capture_mode_change_helper { + ($TEST_FN_NAME:ident, $CREATE_FN_NAME:ident, $TYPE:ty) => { + fn $CREATE_FN_NAME(lower: $TYPE, upper: $TYPE) -> Interval { + Interval::try_new( + ScalarValue::try_from(Some(lower as $TYPE)).unwrap(), + ScalarValue::try_from(Some(upper as $TYPE)).unwrap(), + ) + .unwrap() + } + + fn $TEST_FN_NAME(input: ($TYPE, $TYPE), expect_low: bool, expect_high: bool) { + assert!(expect_low || expect_high); + let interval1 = $CREATE_FN_NAME(input.0, input.0); + let interval2 = $CREATE_FN_NAME(input.1, input.1); + let result = interval1.add(&interval2).unwrap(); + let without_fe = $CREATE_FN_NAME(input.0 + input.1, input.0 + input.1); + assert!( + (!expect_low || result.lower < without_fe.lower) + && (!expect_high || result.upper > without_fe.upper) + ); + } + }; + } + + capture_mode_change!(f32); + capture_mode_change!(f64); + + #[cfg(all( + any(target_arch = "x86_64", target_arch = "aarch64"), + not(target_os = "windows") + ))] + #[test] + fn test_add_intervals_lower_affected_f32() { + // Lower is affected + let lower = f32::from_bits(1073741887); //1000000000000000000000000111111 + let upper = f32::from_bits(1098907651); //1000001100000000000000000000011 + capture_mode_change_f32((lower, upper), true, false); + + // Upper is affected + let lower = f32::from_bits(1072693248); //111111111100000000000000000000 + let upper = f32::from_bits(715827883); //101010101010101010101010101011 + capture_mode_change_f32((lower, upper), false, true); + + // Lower is affected + let lower = 1.0; // 0x3FF0000000000000 + let upper = 0.3; // 0x3FD3333333333333 + capture_mode_change_f64((lower, upper), true, false); + + // Upper is affected + let lower = 1.4999999999999998; // 0x3FF7FFFFFFFFFFFF + let upper = 0.000_000_000_000_000_022_044_604_925_031_31; // 0x3C796A6B413BB21F + capture_mode_change_f64((lower, upper), false, true); + } + + #[cfg(any( + not(any(target_arch = "x86_64", target_arch = "aarch64")), + target_os = "windows" + ))] + #[test] + fn test_next_impl_add_intervals_f64() { + let lower = 1.5; + let upper = 1.5; + capture_mode_change_f64((lower, upper), true, true); + + let lower = 1.5; + let upper = 1.5; + capture_mode_change_f32((lower, upper), true, true); + } +} diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 21c0d750a36d0..077681d217257 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -26,10 +26,21 @@ //! The [expr_fn] module contains functions for creating expressions. mod accumulator; -pub mod aggregate_function; -pub mod array_expressions; mod built_in_function; +mod built_in_window_function; mod columnar_value; +mod literal; +mod nullif; +mod operator; +mod partition_evaluator; +mod signature; +mod table_source; +mod udaf; +mod udf; +mod udwf; + +pub mod aggregate_function; +pub mod array_expressions; pub mod conditional_expressions; pub mod expr; pub mod expr_fn; @@ -37,31 +48,22 @@ pub mod expr_rewriter; pub mod expr_schema; pub mod field_util; pub mod function; -mod literal; +pub mod interval_arithmetic; pub mod logical_plan; -mod nullif; -mod operator; -mod partition_evaluator; -mod signature; -pub mod struct_expressions; -mod table_source; pub mod tree_node; pub mod type_coercion; -mod udaf; -mod udf; -mod udwf; pub mod utils; pub mod window_frame; -pub mod window_function; pub mod window_state; pub use accumulator::Accumulator; pub use aggregate_function::AggregateFunction; pub use built_in_function::BuiltinScalarFunction; +pub use built_in_window_function::BuiltInWindowFunction; pub use columnar_value::ColumnarValue; pub use expr::{ Between, BinaryExpr, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, - Like, TryCast, + Like, ScalarFunctionDefinition, TryCast, WindowFunctionDefinition, }; pub use expr_fn::*; pub use expr_schema::ExprSchemable; @@ -79,10 +81,9 @@ pub use signature::{ }; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::AggregateUDF; -pub use udf::ScalarUDF; -pub use udwf::WindowUDF; +pub use udf::{ScalarUDF, ScalarUDFImpl}; +pub use udwf::{WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; -pub use window_function::{BuiltInWindowFunction, WindowFunction}; #[cfg(test)] #[ctor::ctor] diff --git a/datafusion/expr/src/literal.rs b/datafusion/expr/src/literal.rs index effc315538192..2f04729af2edb 100644 --- a/datafusion/expr/src/literal.rs +++ b/datafusion/expr/src/literal.rs @@ -43,19 +43,19 @@ pub trait TimestampLiteral { impl Literal for &str { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned()))) + Expr::Literal(ScalarValue::from(*self)) } } impl Literal for String { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned()))) + Expr::Literal(ScalarValue::from(self.as_ref())) } } impl Literal for &String { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned()))) + Expr::Literal(ScalarValue::from(self.as_ref())) } } diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 162a6a959e59f..a684f3e974855 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -32,8 +32,8 @@ use crate::expr_rewriter::{ rewrite_sort_cols_by_aggs, }; use crate::logical_plan::{ - Aggregate, Analyze, CrossJoin, Distinct, EmptyRelation, Explain, Filter, Join, - JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, + Aggregate, Analyze, CrossJoin, Distinct, DistinctOn, EmptyRelation, Explain, Filter, + Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values, Window, }; @@ -50,9 +50,9 @@ use crate::{ use arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::{ - plan_datafusion_err, plan_err, Column, DFField, DFSchema, DFSchemaRef, - DataFusionError, FileType, OwnedTableReference, Result, ScalarValue, TableReference, - ToDFSchema, UnnestOptions, + get_target_functional_dependencies, plan_datafusion_err, plan_err, Column, DFField, + DFSchema, DFSchemaRef, DataFusionError, FileType, OwnedTableReference, Result, + ScalarValue, TableReference, ToDFSchema, UnnestOptions, }; /// Default table name for unnamed table @@ -292,7 +292,7 @@ impl LogicalPlanBuilder { window_exprs: Vec, ) -> Result { let mut plan = input; - let mut groups = group_window_expr_by_sort_keys(&window_exprs)?; + let mut groups = group_window_expr_by_sort_keys(window_exprs)?; // To align with the behavior of PostgreSQL, we want the sort_keys sorted as same rule as PostgreSQL that first // we compare the sort key themselves and if one window's sort keys are a prefix of another // put the window with more sort keys first. so more deeply sorted plans gets nested further down as children. @@ -314,7 +314,7 @@ impl LogicalPlanBuilder { key_b.len().cmp(&key_a.len()) }); for (_, exprs) in groups { - let window_exprs = exprs.into_iter().cloned().collect::>(); + let window_exprs = exprs.into_iter().collect::>(); // Partition and sorting is done at physical level, see the EnforceDistribution // and EnforceSorting rules. plan = LogicalPlanBuilder::from(plan) @@ -445,7 +445,7 @@ impl LogicalPlanBuilder { ) }) .collect::>>()?; - curr_plan.with_new_inputs(&new_inputs) + curr_plan.with_new_exprs(curr_plan.expressions(), &new_inputs) } } } @@ -551,16 +551,29 @@ impl LogicalPlanBuilder { let left_plan: LogicalPlan = self.plan; let right_plan: LogicalPlan = plan; - Ok(Self::from(LogicalPlan::Distinct(Distinct { - input: Arc::new(union(left_plan, right_plan)?), - }))) + Ok(Self::from(LogicalPlan::Distinct(Distinct::All(Arc::new( + union(left_plan, right_plan)?, + ))))) } /// Apply deduplication: Only distinct (different) values are returned) pub fn distinct(self) -> Result { - Ok(Self::from(LogicalPlan::Distinct(Distinct { - input: Arc::new(self.plan), - }))) + Ok(Self::from(LogicalPlan::Distinct(Distinct::All(Arc::new( + self.plan, + ))))) + } + + /// Project first values of the specified expression list according to the provided + /// sorting expressions grouped by the `DISTINCT ON` clause expressions. + pub fn distinct_on( + self, + on_expr: Vec, + select_expr: Vec, + sort_expr: Option>, + ) -> Result { + Ok(Self::from(LogicalPlan::Distinct(Distinct::On( + DistinctOn::try_new(on_expr, select_expr, sort_expr, Arc::new(self.plan))?, + )))) } /// Apply a join to `right` using explicitly specified columns and an @@ -893,6 +906,9 @@ impl LogicalPlanBuilder { ) -> Result { let group_expr = normalize_cols(group_expr, &self.plan)?; let aggr_expr = normalize_cols(aggr_expr, &self.plan)?; + + let group_expr = + add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?; Aggregate::try_new(Arc::new(self.plan), group_expr, aggr_expr) .map(LogicalPlan::Aggregate) .map(Self::from) @@ -1153,10 +1169,46 @@ pub fn build_join_schema( ); let mut metadata = left.metadata().clone(); metadata.extend(right.metadata().clone()); - DFSchema::new_with_metadata(fields, metadata) - .map(|schema| schema.with_functional_dependencies(func_dependencies)) + let schema = DFSchema::new_with_metadata(fields, metadata)?; + schema.with_functional_dependencies(func_dependencies) } +/// Add additional "synthetic" group by expressions based on functional +/// dependencies. +/// +/// For example, if we are grouping on `[c1]`, and we know from +/// functional dependencies that column `c1` determines `c2`, this function +/// adds `c2` to the group by list. +/// +/// This allows MySQL style selects like +/// `SELECT col FROM t WHERE pk = 5` if col is unique +fn add_group_by_exprs_from_dependencies( + mut group_expr: Vec, + schema: &DFSchemaRef, +) -> Result> { + // Names of the fields produced by the GROUP BY exprs for example, `GROUP BY + // c1 + 1` produces an output field named `"c1 + 1"` + let mut group_by_field_names = group_expr + .iter() + .map(|e| e.display_name()) + .collect::>>()?; + + if let Some(target_indices) = + get_target_functional_dependencies(schema, &group_by_field_names) + { + for idx in target_indices { + let field = schema.field(idx); + let expr = + Expr::Column(Column::new(field.qualifier().cloned(), field.name())); + let expr_name = expr.display_name()?; + if !group_by_field_names.contains(&expr_name) { + group_by_field_names.push(expr_name); + group_expr.push(expr); + } + } + } + Ok(group_expr) +} /// Errors if one or more expressions have equal names. pub(crate) fn validate_unique_names<'a>( node_name: &str, @@ -1287,11 +1339,16 @@ pub fn project( for e in expr { let e = e.into(); match e { - Expr::Wildcard => { + Expr::Wildcard { qualifier: None } => { projected_expr.extend(expand_wildcard(input_schema, &plan, None)?) } - Expr::QualifiedWildcard { ref qualifier } => projected_expr - .extend(expand_qualified_wildcard(qualifier, input_schema, None)?), + Expr::Wildcard { + qualifier: Some(qualifier), + } => projected_expr.extend(expand_qualified_wildcard( + &qualifier, + input_schema, + None, + )?), _ => projected_expr .push(columnize_expr(normalize_col(e, &plan)?, input_schema)), } @@ -1306,7 +1363,7 @@ pub fn subquery_alias( plan: LogicalPlan, alias: impl Into, ) -> Result { - SubqueryAlias::try_new(plan, alias).map(LogicalPlan::SubqueryAlias) + SubqueryAlias::try_new(Arc::new(plan), alias).map(LogicalPlan::SubqueryAlias) } /// Create a LogicalPlanBuilder representing a scan of a table with the provided name and schema. @@ -1473,7 +1530,7 @@ pub fn unnest_with_options( let df_schema = DFSchema::new_with_metadata(fields, metadata)?; // We can use the existing functional dependencies: let deps = input_schema.functional_dependencies().clone(); - let schema = Arc::new(df_schema.with_functional_dependencies(deps)); + let schema = Arc::new(df_schema.with_functional_dependencies(deps)?); Ok(LogicalPlan::Unnest(Unnest { input: Arc::new(input), @@ -1590,7 +1647,7 @@ mod tests { let plan = table_scan(Some("t1"), &employee_schema(), None)? .join_using(t2, JoinType::Inner, vec!["id"])? - .project(vec![Expr::Wildcard])? + .project(vec![Expr::Wildcard { qualifier: None }])? .build()?; // id column should only show up once in projection diff --git a/datafusion/expr/src/logical_plan/ddl.rs b/datafusion/expr/src/logical_plan/ddl.rs index 2c90a3aca7543..e74992d993734 100644 --- a/datafusion/expr/src/logical_plan/ddl.rs +++ b/datafusion/expr/src/logical_plan/ddl.rs @@ -194,6 +194,8 @@ pub struct CreateExternalTable { pub options: HashMap, /// The list of constraints in the schema, such as primary key, unique, etc. pub constraints: Constraints, + /// Default values for columns + pub column_defaults: HashMap, } // Hashing refers to a subset of fields considered in PartialEq. @@ -228,6 +230,8 @@ pub struct CreateMemoryTable { pub if_not_exists: bool, /// Option to replace table content if table already exists pub or_replace: bool, + /// Default values for columns + pub column_defaults: Vec<(String, Expr)>, } /// Creates a view. diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index 8316417138bd1..bc722dd69acea 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -33,10 +33,11 @@ pub use ddl::{ }; pub use dml::{DmlStatement, WriteOp}; pub use plan::{ - Aggregate, Analyze, CrossJoin, DescribeTable, Distinct, EmptyRelation, Explain, - Extension, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, - PlanType, Prepare, Projection, Repartition, Sort, StringifiedPlan, Subquery, - SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, + projection_schema, Aggregate, Analyze, CrossJoin, DescribeTable, Distinct, + DistinctOn, EmptyRelation, Explain, Extension, Filter, Join, JoinConstraint, + JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, Projection, + Repartition, Sort, StringifiedPlan, Subquery, SubqueryAlias, TableScan, + ToStringifiedPlan, Union, Unnest, Values, Window, }; pub use statement::{ SetVariable, Statement, TransactionAccessMode, TransactionConclusion, TransactionEnd, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index d62ac89263288..93a38fb40df58 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -25,19 +25,22 @@ use std::sync::Arc; use super::dml::CopyTo; use super::DdlStatement; use crate::dml::CopyOptions; -use crate::expr::{Alias, Exists, InSubquery, Placeholder}; -use crate::expr_rewriter::create_col_from_scalar_expr; +use crate::expr::{ + Alias, Exists, InSubquery, Placeholder, Sort as SortExpr, WindowFunction, +}; +use crate::expr_rewriter::{create_col_from_scalar_expr, normalize_cols}; use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::{DmlStatement, Statement}; use crate::utils::{ enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs, grouping_set_expr_count, grouping_set_to_exprlist, inspect_expr_pre, + split_conjunction, }; use crate::{ - build_join_schema, expr_vec_fmt, BinaryExpr, CreateMemoryTable, CreateView, Expr, - ExprSchemable, LogicalPlanBuilder, Operator, TableProviderFilterPushDown, - TableSource, + build_join_schema, expr_vec_fmt, BinaryExpr, BuiltInWindowFunction, + CreateMemoryTable, CreateView, Expr, ExprSchemable, LogicalPlanBuilder, Operator, + TableProviderFilterPushDown, TableSource, WindowFunctionDefinition, }; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; @@ -47,9 +50,10 @@ use datafusion_common::tree_node::{ }; use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, - DFField, DFSchema, DFSchemaRef, DataFusionError, FunctionalDependencies, - OwnedTableReference, Result, ScalarValue, UnnestOptions, + DFField, DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependence, + FunctionalDependencies, OwnedTableReference, ParamValues, Result, UnnestOptions, }; + // backwards compatibility pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; @@ -163,7 +167,8 @@ impl LogicalPlan { }) => projected_schema, LogicalPlan::Projection(Projection { schema, .. }) => schema, LogicalPlan::Filter(Filter { input, .. }) => input.schema(), - LogicalPlan::Distinct(Distinct { input }) => input.schema(), + LogicalPlan::Distinct(Distinct::All(input)) => input.schema(), + LogicalPlan::Distinct(Distinct::On(DistinctOn { schema, .. })) => schema, LogicalPlan::Window(Window { schema, .. }) => schema, LogicalPlan::Aggregate(Aggregate { schema, .. }) => schema, LogicalPlan::Sort(Sort { input, .. }) => input.schema(), @@ -367,6 +372,16 @@ impl LogicalPlan { LogicalPlan::Unnest(Unnest { column, .. }) => { f(&Expr::Column(column.clone())) } + LogicalPlan::Distinct(Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + .. + })) => on_expr + .iter() + .chain(select_expr.iter()) + .chain(sort_expr.clone().unwrap_or(vec![]).iter()) + .try_for_each(f), // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::Subquery(_) @@ -377,7 +392,7 @@ impl LogicalPlan { | LogicalPlan::Analyze(_) | LogicalPlan::Explain(_) | LogicalPlan::Union(_) - | LogicalPlan::Distinct(_) + | LogicalPlan::Distinct(Distinct::All(_)) | LogicalPlan::Dml(_) | LogicalPlan::Ddl(_) | LogicalPlan::Copy(_) @@ -405,7 +420,9 @@ impl LogicalPlan { LogicalPlan::Union(Union { inputs, .. }) => { inputs.iter().map(|arc| arc.as_ref()).collect() } - LogicalPlan::Distinct(Distinct { input }) => vec![input], + LogicalPlan::Distinct( + Distinct::All(input) | Distinct::On(DistinctOn { input, .. }), + ) => vec![input], LogicalPlan::Explain(explain) => vec![&explain.plan], LogicalPlan::Analyze(analyze) => vec![&analyze.input], LogicalPlan::Dml(write) => vec![&write.input], @@ -461,8 +478,11 @@ impl LogicalPlan { Ok(Some(agg.group_expr.as_slice()[0].clone())) } } + LogicalPlan::Distinct(Distinct::On(DistinctOn { select_expr, .. })) => { + Ok(Some(select_expr[0].clone())) + } LogicalPlan::Filter(Filter { input, .. }) - | LogicalPlan::Distinct(Distinct { input, .. }) + | LogicalPlan::Distinct(Distinct::All(input)) | LogicalPlan::Sort(Sort { input, .. }) | LogicalPlan::Limit(Limit { input, .. }) | LogicalPlan::Repartition(Repartition { input, .. }) @@ -524,41 +544,9 @@ impl LogicalPlan { } /// Returns a copy of this `LogicalPlan` with the new inputs + #[deprecated(since = "35.0.0", note = "please use `with_new_exprs` instead")] pub fn with_new_inputs(&self, inputs: &[LogicalPlan]) -> Result { - // with_new_inputs use original expression, - // so we don't need to recompute Schema. - match &self { - LogicalPlan::Projection(projection) => { - // Schema of the projection may change - // when its input changes. Hence we should use - // `try_new` method instead of `try_new_with_schema`. - Projection::try_new(projection.expr.to_vec(), Arc::new(inputs[0].clone())) - .map(LogicalPlan::Projection) - } - LogicalPlan::Window(Window { - window_expr, - schema, - .. - }) => Ok(LogicalPlan::Window(Window { - input: Arc::new(inputs[0].clone()), - window_expr: window_expr.to_vec(), - schema: schema.clone(), - })), - LogicalPlan::Aggregate(Aggregate { - group_expr, - aggr_expr, - .. - }) => Aggregate::try_new( - // Schema of the aggregate may change - // when its input changes. Hence we should use - // `try_new` method instead of `try_new_with_schema`. - Arc::new(inputs[0].clone()), - group_expr.to_vec(), - aggr_expr.to_vec(), - ) - .map(LogicalPlan::Aggregate), - _ => self.with_new_exprs(self.expressions(), inputs), - } + self.with_new_exprs(self.expressions(), inputs) } /// Returns a new `LogicalPlan` based on `self` with inputs and @@ -580,10 +568,6 @@ impl LogicalPlan { /// // create new plan using rewritten_exprs in same position /// let new_plan = plan.new_with_exprs(rewritten_exprs, new_inputs); /// ``` - /// - /// Note: sometimes [`Self::with_new_exprs`] will use schema of - /// original plan, it will not change the scheam. Such as - /// `Projection/Aggregate/Window` pub fn with_new_exprs( &self, mut expr: Vec, @@ -695,17 +679,10 @@ impl LogicalPlan { })) } }, - LogicalPlan::Window(Window { - window_expr, - schema, - .. - }) => { + LogicalPlan::Window(Window { window_expr, .. }) => { assert_eq!(window_expr.len(), expr.len()); - Ok(LogicalPlan::Window(Window { - input: Arc::new(inputs[0].clone()), - window_expr: expr, - schema: schema.clone(), - })) + Window::try_new(expr, Arc::new(inputs[0].clone())) + .map(LogicalPlan::Window) } LogicalPlan::Aggregate(Aggregate { group_expr, .. }) => { // group exprs are the first expressions @@ -781,7 +758,7 @@ impl LogicalPlan { })) } LogicalPlan::SubqueryAlias(SubqueryAlias { alias, .. }) => { - SubqueryAlias::try_new(inputs[0].clone(), alias.clone()) + SubqueryAlias::try_new(Arc::new(inputs[0].clone()), alias.clone()) .map(LogicalPlan::SubqueryAlias) } LogicalPlan::Limit(Limit { skip, fetch, .. }) => { @@ -795,6 +772,7 @@ impl LogicalPlan { name, if_not_exists, or_replace, + column_defaults, .. })) => Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( CreateMemoryTable { @@ -803,6 +781,7 @@ impl LogicalPlan { name: name.clone(), if_not_exists: *if_not_exists, or_replace: *or_replace, + column_defaults: column_defaults.clone(), }, ))), LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { @@ -819,15 +798,43 @@ impl LogicalPlan { LogicalPlan::Extension(e) => Ok(LogicalPlan::Extension(Extension { node: e.node.from_template(&expr, inputs), })), - LogicalPlan::Union(Union { schema, .. }) => Ok(LogicalPlan::Union(Union { - inputs: inputs.iter().cloned().map(Arc::new).collect(), - schema: schema.clone(), - })), - LogicalPlan::Distinct(Distinct { .. }) => { - Ok(LogicalPlan::Distinct(Distinct { - input: Arc::new(inputs[0].clone()), + LogicalPlan::Union(Union { schema, .. }) => { + let input_schema = inputs[0].schema(); + // If inputs are not pruned do not change schema. + let schema = if schema.fields().len() == input_schema.fields().len() { + schema + } else { + input_schema + }; + Ok(LogicalPlan::Union(Union { + inputs: inputs.iter().cloned().map(Arc::new).collect(), + schema: schema.clone(), })) } + LogicalPlan::Distinct(distinct) => { + let distinct = match distinct { + Distinct::All(_) => Distinct::All(Arc::new(inputs[0].clone())), + Distinct::On(DistinctOn { + on_expr, + select_expr, + .. + }) => { + let sort_expr = expr.split_off(on_expr.len() + select_expr.len()); + let select_expr = expr.split_off(on_expr.len()); + Distinct::On(DistinctOn::try_new( + expr, + select_expr, + if !sort_expr.is_empty() { + Some(sort_expr) + } else { + None + }, + Arc::new(inputs[0].clone()), + )?) + } + }; + Ok(LogicalPlan::Distinct(distinct)) + } LogicalPlan::Analyze(a) => { assert!(expr.is_empty()); assert_eq!(inputs.len(), 1); @@ -837,19 +844,19 @@ impl LogicalPlan { input: Arc::new(inputs[0].clone()), })) } - LogicalPlan::Explain(_) => { - // Explain should be handled specially in the optimizers; - // If this check cannot pass it means some optimizer pass is - // trying to optimize Explain directly - if expr.is_empty() { - return plan_err!("Invalid EXPLAIN command. Expression is empty"); - } - - if inputs.is_empty() { - return plan_err!("Invalid EXPLAIN command. Inputs are empty"); - } - - Ok(self.clone()) + LogicalPlan::Explain(e) => { + assert!( + expr.is_empty(), + "Invalid EXPLAIN command. Expression should empty" + ); + assert_eq!(inputs.len(), 1, "Invalid EXPLAIN command. Inputs are empty"); + Ok(LogicalPlan::Explain(Explain { + verbose: e.verbose, + plan: Arc::new(inputs[0].clone()), + stringified_plans: e.stringified_plans.clone(), + schema: e.schema.clone(), + logical_optimization_succeeded: e.logical_optimization_succeeded, + })) } LogicalPlan::Prepare(Prepare { name, data_types, .. @@ -905,7 +912,7 @@ impl LogicalPlan { // We can use the existing functional dependencies as is: .with_functional_dependencies( input.schema().functional_dependencies().clone(), - ), + )?, ); Ok(LogicalPlan::Unnest(Unnest { @@ -936,9 +943,10 @@ impl LogicalPlan { /// .filter(col("id").eq(placeholder("$1"))).unwrap() /// .build().unwrap(); /// - /// assert_eq!("Filter: t1.id = $1\ - /// \n TableScan: t1", - /// plan.display_indent().to_string() + /// assert_eq!( + /// "Filter: t1.id = $1\ + /// \n TableScan: t1", + /// plan.display_indent().to_string() /// ); /// /// // Fill in the parameter $1 with a literal 3 @@ -946,39 +954,37 @@ impl LogicalPlan { /// ScalarValue::from(3i32) // value at index 0 --> $1 /// ]).unwrap(); /// - /// assert_eq!("Filter: t1.id = Int32(3)\ - /// \n TableScan: t1", - /// plan.display_indent().to_string() + /// assert_eq!( + /// "Filter: t1.id = Int32(3)\ + /// \n TableScan: t1", + /// plan.display_indent().to_string() /// ); + /// + /// // Note you can also used named parameters + /// // Build SELECT * FROM t1 WHRERE id = $my_param + /// let plan = table_scan(Some("t1"), &schema, None).unwrap() + /// .filter(col("id").eq(placeholder("$my_param"))).unwrap() + /// .build().unwrap() + /// // Fill in the parameter $my_param with a literal 3 + /// .with_param_values(vec![ + /// ("my_param", ScalarValue::from(3i32)), + /// ]).unwrap(); + /// + /// assert_eq!( + /// "Filter: t1.id = Int32(3)\ + /// \n TableScan: t1", + /// plan.display_indent().to_string() + /// ); + /// /// ``` pub fn with_param_values( self, - param_values: Vec, + param_values: impl Into, ) -> Result { + let param_values = param_values.into(); match self { LogicalPlan::Prepare(prepare_lp) => { - // Verify if the number of params matches the number of values - if prepare_lp.data_types.len() != param_values.len() { - return plan_err!( - "Expected {} parameters, got {}", - prepare_lp.data_types.len(), - param_values.len() - ); - } - - // Verify if the types of the params matches the types of the values - let iter = prepare_lp.data_types.iter().zip(param_values.iter()); - for (i, (param_type, value)) in iter.enumerate() { - if *param_type != value.data_type() { - return plan_err!( - "Expected parameter of type {:?}, got {:?} at index {}", - param_type, - value.data_type(), - i - ); - } - } - + param_values.verify(&prepare_lp.data_types)?; let input_plan = prepare_lp.input; input_plan.replace_params_with_values(¶m_values) } @@ -993,7 +999,13 @@ impl LogicalPlan { pub fn max_rows(self: &LogicalPlan) -> Option { match self { LogicalPlan::Projection(Projection { input, .. }) => input.max_rows(), - LogicalPlan::Filter(Filter { input, .. }) => input.max_rows(), + LogicalPlan::Filter(filter) => { + if filter.is_scalar() { + Some(1) + } else { + filter.input.max_rows() + } + } LogicalPlan::Window(Window { input, .. }) => input.max_rows(), LogicalPlan::Aggregate(Aggregate { input, group_expr, .. @@ -1064,7 +1076,9 @@ impl LogicalPlan { LogicalPlan::Subquery(_) => None, LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => input.max_rows(), LogicalPlan::Limit(Limit { fetch, .. }) => *fetch, - LogicalPlan::Distinct(Distinct { input }) => input.max_rows(), + LogicalPlan::Distinct( + Distinct::All(input) | Distinct::On(DistinctOn { input, .. }), + ) => input.max_rows(), LogicalPlan::Values(v) => Some(v.values.len()), LogicalPlan::Unnest(_) => None, LogicalPlan::Ddl(_) @@ -1140,7 +1154,7 @@ impl LogicalPlan { /// See [`Self::with_param_values`] for examples and usage pub fn replace_params_with_values( &self, - param_values: &[ScalarValue], + param_values: &ParamValues, ) -> Result { let new_exprs = self .expressions() @@ -1160,7 +1174,7 @@ impl LogicalPlan { self.with_new_exprs(new_exprs, &new_inputs_with_values) } - /// Walk the logical plan, find any `PlaceHolder` tokens, and return a map of their IDs and DataTypes + /// Walk the logical plan, find any `Placeholder` tokens, and return a map of their IDs and DataTypes pub fn get_parameter_types( &self, ) -> Result>, DataFusionError> { @@ -1197,36 +1211,15 @@ impl LogicalPlan { /// corresponding values provided in the params_values fn replace_placeholders_with_values( expr: Expr, - param_values: &[ScalarValue], + param_values: &ParamValues, ) -> Result { expr.transform(&|expr| { match &expr { Expr::Placeholder(Placeholder { id, data_type }) => { - if id.is_empty() || id == "$0" { - return plan_err!("Empty placeholder id"); - } - // convert id (in format $1, $2, ..) to idx (0, 1, ..) - let idx = id[1..].parse::().map_err(|e| { - DataFusionError::Internal(format!( - "Failed to parse placeholder id: {e}" - )) - })? - 1; - // value at the idx-th position in param_values should be the value for the placeholder - let value = param_values.get(idx).ok_or_else(|| { - DataFusionError::Internal(format!( - "No value found for placeholder with id {id}" - )) - })?; - // check if the data type of the value matches the data type of the placeholder - if Some(value.data_type()) != *data_type { - return internal_err!( - "Placeholder value type mismatch: expected {:?}, got {:?}", - data_type, - value.data_type() - ); - } + let value = param_values + .get_placeholders_with_values(id, data_type.as_ref())?; // Replace the placeholder with the value - Ok(Transformed::Yes(Expr::Literal(value.clone()))) + Ok(Transformed::Yes(Expr::Literal(value))) } Expr::ScalarSubquery(qry) => { let subquery = @@ -1667,9 +1660,21 @@ impl LogicalPlan { LogicalPlan::Statement(statement) => { write!(f, "{}", statement.display()) } - LogicalPlan::Distinct(Distinct { .. }) => { - write!(f, "Distinct:") - } + LogicalPlan::Distinct(distinct) => match distinct { + Distinct::All(_) => write!(f, "Distinct:"), + Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + .. + }) => write!( + f, + "DistinctOn: on_expr=[[{}]], select_expr=[[{}]], sort_expr=[[{}]]", + expr_vec_fmt!(on_expr), + expr_vec_fmt!(select_expr), + if let Some(sort_expr) = sort_expr { expr_vec_fmt!(sort_expr) } else { "".to_string() }, + ), + }, LogicalPlan::Explain { .. } => write!(f, "Explain"), LogicalPlan::Analyze { .. } => write!(f, "Analyze"), LogicalPlan::Union(_) => write!(f, "Union"), @@ -1741,11 +1746,8 @@ pub struct Projection { impl Projection { /// Create a new Projection pub fn try_new(expr: Vec, input: Arc) -> Result { - let schema = Arc::new(DFSchema::new_with_metadata( - exprlist_to_fields(&expr, &input)?, - input.schema().metadata().clone(), - )?); - Self::try_new_with_schema(expr, input, schema) + let projection_schema = projection_schema(&input, &expr)?; + Self::try_new_with_schema(expr, input, projection_schema) } /// Create a new Projection using the specified output schema @@ -1757,11 +1759,6 @@ impl Projection { if expr.len() != schema.fields().len() { return plan_err!("Projection has mismatch between number of expressions ({}) and number of fields in schema ({})", expr.len(), schema.fields().len()); } - // Update functional dependencies of `input` according to projection - // expressions: - let id_key_groups = calc_func_dependencies_for_project(&expr, &input)?; - let schema = schema.as_ref().clone(); - let schema = Arc::new(schema.with_functional_dependencies(id_key_groups)); Ok(Self { expr, input, @@ -1785,6 +1782,30 @@ impl Projection { } } +/// Computes the schema of the result produced by applying a projection to the input logical plan. +/// +/// # Arguments +/// +/// * `input`: A reference to the input `LogicalPlan` for which the projection schema +/// will be computed. +/// * `exprs`: A slice of `Expr` expressions representing the projection operation to apply. +/// +/// # Returns +/// +/// A `Result` containing an `Arc` representing the schema of the result +/// produced by the projection operation. If the schema computation is successful, +/// the `Result` will contain the schema; otherwise, it will contain an error. +pub fn projection_schema(input: &LogicalPlan, exprs: &[Expr]) -> Result> { + let mut schema = DFSchema::new_with_metadata( + exprlist_to_fields(exprs, input)?, + input.schema().metadata().clone(), + )?; + schema = schema.with_functional_dependencies(calc_func_dependencies_for_project( + exprs, input, + )?)?; + Ok(Arc::new(schema)) +} + /// Aliased subquery #[derive(Clone, PartialEq, Eq, Hash)] // mark non_exhaustive to encourage use of try_new/new() @@ -1800,7 +1821,7 @@ pub struct SubqueryAlias { impl SubqueryAlias { pub fn try_new( - plan: LogicalPlan, + plan: Arc, alias: impl Into, ) -> Result { let alias = alias.into(); @@ -1810,10 +1831,10 @@ impl SubqueryAlias { let func_dependencies = plan.schema().functional_dependencies().clone(); let schema = DFSchemaRef::new( DFSchema::try_from_qualified_schema(&alias, &schema)? - .with_functional_dependencies(func_dependencies), + .with_functional_dependencies(func_dependencies)?, ); Ok(SubqueryAlias { - input: Arc::new(plan), + input: plan, alias, schema, }) @@ -1866,6 +1887,73 @@ impl Filter { Ok(Self { predicate, input }) } + + /// Is this filter guaranteed to return 0 or 1 row in a given instantiation? + /// + /// This function will return `true` if its predicate contains a conjunction of + /// `col(a) = `, where its schema has a unique filter that is covered + /// by this conjunction. + /// + /// For example, for the table: + /// ```sql + /// CREATE TABLE t (a INTEGER PRIMARY KEY, b INTEGER); + /// ``` + /// `Filter(a = 2).is_scalar() == true` + /// , whereas + /// `Filter(b = 2).is_scalar() == false` + /// and + /// `Filter(a = 2 OR b = 2).is_scalar() == false` + fn is_scalar(&self) -> bool { + let schema = self.input.schema(); + + let functional_dependencies = self.input.schema().functional_dependencies(); + let unique_keys = functional_dependencies.iter().filter(|dep| { + let nullable = dep.nullable + && dep + .source_indices + .iter() + .any(|&source| schema.field(source).is_nullable()); + !nullable + && dep.mode == Dependency::Single + && dep.target_indices.len() == schema.fields().len() + }); + + let exprs = split_conjunction(&self.predicate); + let eq_pred_cols: HashSet<_> = exprs + .iter() + .filter_map(|expr| { + let Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + }) = expr + else { + return None; + }; + // This is a no-op filter expression + if left == right { + return None; + } + + match (left.as_ref(), right.as_ref()) { + (Expr::Column(_), Expr::Column(_)) => None, + (Expr::Column(c), _) | (_, Expr::Column(c)) => { + Some(schema.index_of_column(c).unwrap()) + } + _ => None, + } + }) + .collect(); + + // If we have a functional dependence that is a subset of our predicate, + // this filter is scalar + for key in unique_keys { + if key.source_indices.iter().all(|c| eq_pred_cols.contains(c)) { + return true; + } + } + false + } } /// Window its input based on a set of window spec and window function (e.g. SUM or RANK) @@ -1882,9 +1970,10 @@ pub struct Window { impl Window { /// Create a new window operator. pub fn try_new(window_expr: Vec, input: Arc) -> Result { - let mut window_fields: Vec = input.schema().fields().clone(); - window_fields - .extend_from_slice(&exprlist_to_fields(window_expr.iter(), input.as_ref())?); + let fields = input.schema().fields(); + let input_len = fields.len(); + let mut window_fields = fields.clone(); + window_fields.extend_from_slice(&exprlist_to_fields(window_expr.iter(), &input)?); let metadata = input.schema().metadata().clone(); // Update functional dependencies for window: @@ -1892,12 +1981,52 @@ impl Window { input.schema().functional_dependencies().clone(); window_func_dependencies.extend_target_indices(window_fields.len()); + // Since we know that ROW_NUMBER outputs will be unique (i.e. it consists + // of consecutive numbers per partition), we can represent this fact with + // functional dependencies. + let mut new_dependencies = window_expr + .iter() + .enumerate() + .filter_map(|(idx, expr)| { + if let Expr::WindowFunction(WindowFunction { + // Function is ROW_NUMBER + fun: + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::RowNumber, + ), + partition_by, + .. + }) = expr + { + // When there is no PARTITION BY, row number will be unique + // across the entire table. + if partition_by.is_empty() { + return Some(idx + input_len); + } + } + None + }) + .map(|idx| { + FunctionalDependence::new(vec![idx], vec![], false) + .with_mode(Dependency::Single) + }) + .collect::>(); + + if !new_dependencies.is_empty() { + for dependence in new_dependencies.iter_mut() { + dependence.target_indices = (0..window_fields.len()).collect(); + } + // Add the dependency introduced because of ROW_NUMBER window function to the functional dependency + let new_deps = FunctionalDependencies::new(new_dependencies); + window_func_dependencies.extend(new_deps); + } + Ok(Window { input, window_expr, schema: Arc::new( DFSchema::new_with_metadata(window_fields, metadata)? - .with_functional_dependencies(window_func_dependencies), + .with_functional_dependencies(window_func_dependencies)?, ), }) } @@ -1967,7 +2096,7 @@ impl TableScan { .map(|p| { let projected_func_dependencies = func_dependencies.project_functional_dependencies(p, p.len()); - DFSchema::new_with_metadata( + let df_schema = DFSchema::new_with_metadata( p.iter() .map(|i| { DFField::from_qualified( @@ -1977,15 +2106,13 @@ impl TableScan { }) .collect(), schema.metadata().clone(), - ) - .map(|df_schema| { - df_schema.with_functional_dependencies(projected_func_dependencies) - }) + )?; + df_schema.with_functional_dependencies(projected_func_dependencies) }) .unwrap_or_else(|| { - DFSchema::try_from_qualified_schema(table_name.clone(), &schema).map( - |df_schema| df_schema.with_functional_dependencies(func_dependencies), - ) + let df_schema = + DFSchema::try_from_qualified_schema(table_name.clone(), &schema)?; + df_schema.with_functional_dependencies(func_dependencies) })?; let projected_schema = Arc::new(projected_schema); Ok(Self { @@ -2132,9 +2259,93 @@ pub struct Limit { /// Removes duplicate rows from the input #[derive(Clone, PartialEq, Eq, Hash)] -pub struct Distinct { +pub enum Distinct { + /// Plain `DISTINCT` referencing all selection expressions + All(Arc), + /// The `Postgres` addition, allowing separate control over DISTINCT'd and selected columns + On(DistinctOn), +} + +/// Removes duplicate rows from the input +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct DistinctOn { + /// The `DISTINCT ON` clause expression list + pub on_expr: Vec, + /// The selected projection expression list + pub select_expr: Vec, + /// The `ORDER BY` clause, whose initial expressions must match those of the `ON` clause when + /// present. Note that those matching expressions actually wrap the `ON` expressions with + /// additional info pertaining to the sorting procedure (i.e. ASC/DESC, and NULLS FIRST/LAST). + pub sort_expr: Option>, /// The logical plan that is being DISTINCT'd pub input: Arc, + /// The schema description of the DISTINCT ON output + pub schema: DFSchemaRef, +} + +impl DistinctOn { + /// Create a new `DistinctOn` struct. + pub fn try_new( + on_expr: Vec, + select_expr: Vec, + sort_expr: Option>, + input: Arc, + ) -> Result { + if on_expr.is_empty() { + return plan_err!("No `ON` expressions provided"); + } + + let on_expr = normalize_cols(on_expr, input.as_ref())?; + + let schema = DFSchema::new_with_metadata( + exprlist_to_fields(&select_expr, &input)?, + input.schema().metadata().clone(), + )?; + + let mut distinct_on = DistinctOn { + on_expr, + select_expr, + sort_expr: None, + input, + schema: Arc::new(schema), + }; + + if let Some(sort_expr) = sort_expr { + distinct_on = distinct_on.with_sort_expr(sort_expr)?; + } + + Ok(distinct_on) + } + + /// Try to update `self` with a new sort expressions. + /// + /// Validates that the sort expressions are a super-set of the `ON` expressions. + pub fn with_sort_expr(mut self, sort_expr: Vec) -> Result { + let sort_expr = normalize_cols(sort_expr, self.input.as_ref())?; + + // Check that the left-most sort expressions are the same as the `ON` expressions. + let mut matched = true; + for (on, sort) in self.on_expr.iter().zip(sort_expr.iter()) { + match sort { + Expr::Sort(SortExpr { expr, .. }) => { + if on != &**expr { + matched = false; + break; + } + } + _ => return plan_err!("Not a sort expression: {sort}"), + } + } + + if self.on_expr.len() > sort_expr.len() || !matched { + return plan_err!( + "SELECT DISTINCT ON expressions must match initial ORDER BY expressions" + ); + } + + self.sort_expr = Some(sort_expr); + Ok(self) + } } /// Aggregates its input based on a set of grouping and aggregate @@ -2161,13 +2372,25 @@ impl Aggregate { aggr_expr: Vec, ) -> Result { let group_expr = enumerate_grouping_sets(group_expr)?; + + let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]); + let grouping_expr: Vec = grouping_set_to_exprlist(group_expr.as_slice())?; - let all_expr = grouping_expr.iter().chain(aggr_expr.iter()); - let schema = DFSchema::new_with_metadata( - exprlist_to_fields(all_expr, &input)?, - input.schema().metadata().clone(), - )?; + let mut fields = exprlist_to_fields(grouping_expr.iter(), &input)?; + + // Even columns that cannot be null will become nullable when used in a grouping set. + if is_grouping_set { + fields = fields + .into_iter() + .map(|field| field.with_nullable(true)) + .collect::>(); + } + + fields.extend(exprlist_to_fields(aggr_expr.iter(), &input)?); + + let schema = + DFSchema::new_with_metadata(fields, input.schema().metadata().clone())?; Self::try_new_with_schema(input, group_expr, aggr_expr, Arc::new(schema)) } @@ -2201,7 +2424,7 @@ impl Aggregate { calc_func_dependencies_for_aggregate(&group_expr, &input, &schema)?; let new_schema = schema.as_ref().clone(); let schema = Arc::new( - new_schema.with_functional_dependencies(aggregate_func_dependencies), + new_schema.with_functional_dependencies(aggregate_func_dependencies)?, ); Ok(Self { input, @@ -2210,6 +2433,13 @@ impl Aggregate { schema, }) } + + /// Get the length of the group by expression in the output schema + /// This is not simply group by expression length. Expression may be + /// GroupingSet, etc. In these case we need to get inner expression lengths. + pub fn group_expr_len(&self) -> Result { + grouping_set_expr_count(&self.group_expr) + } } /// Checks whether any expression in `group_expr` contains `Expr::GroupingSet`. @@ -2404,13 +2634,19 @@ pub struct Unnest { #[cfg(test)] mod tests { + use std::collections::HashMap; + use std::sync::Arc; + use super::*; + use crate::builder::LogicalTableSource; use crate::logical_plan::table_scan; - use crate::{col, exists, in_subquery, lit, placeholder}; + use crate::{col, count, exists, in_subquery, lit, placeholder, GroupingSet}; + use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::tree_node::TreeNodeVisitor; - use datafusion_common::{not_impl_err, DFSchema, TableReference}; - use std::collections::HashMap; + use datafusion_common::{ + not_impl_err, Constraint, DFSchema, ScalarValue, TableReference, + }; fn employee_schema() -> Schema { Schema::new(vec![ @@ -2857,7 +3093,8 @@ digraph { .build() .unwrap(); - plan.replace_params_with_values(&[42i32.into()]) + let param_values = vec![ScalarValue::Int32(Some(42))]; + plan.replace_params_with_values(¶m_values.clone().into()) .expect_err("unexpectedly succeeded to replace an invalid placeholder"); // test $0 placeholder @@ -2870,7 +3107,154 @@ digraph { .build() .unwrap(); - plan.replace_params_with_values(&[42i32.into()]) + plan.replace_params_with_values(¶m_values.clone().into()) + .expect_err("unexpectedly succeeded to replace an invalid placeholder"); + + // test $00 placeholder + let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); + + let plan = table_scan(TableReference::none(), &schema, None) + .unwrap() + .filter(col("id").eq(placeholder("$00"))) + .unwrap() + .build() + .unwrap(); + + plan.replace_params_with_values(¶m_values.into()) .expect_err("unexpectedly succeeded to replace an invalid placeholder"); } + + #[test] + fn test_nullable_schema_after_grouping_set() { + let schema = Schema::new(vec![ + Field::new("foo", DataType::Int32, false), + Field::new("bar", DataType::Int32, false), + ]); + + let plan = table_scan(TableReference::none(), &schema, None) + .unwrap() + .aggregate( + vec![Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + vec![col("foo")], + vec![col("bar")], + ]))], + vec![count(lit(true))], + ) + .unwrap() + .build() + .unwrap(); + + let output_schema = plan.schema(); + + assert!(output_schema + .field_with_name(None, "foo") + .unwrap() + .is_nullable(),); + assert!(output_schema + .field_with_name(None, "bar") + .unwrap() + .is_nullable()); + } + + #[test] + fn test_filter_is_scalar() { + // test empty placeholder + let schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + + let source = Arc::new(LogicalTableSource::new(schema)); + let schema = Arc::new( + DFSchema::try_from_qualified_schema( + TableReference::bare("tab"), + &source.schema(), + ) + .unwrap(), + ); + let scan = Arc::new(LogicalPlan::TableScan(TableScan { + table_name: TableReference::bare("tab"), + source: source.clone(), + projection: None, + projected_schema: schema.clone(), + filters: vec![], + fetch: None, + })); + let col = schema.field(0).qualified_column(); + + let filter = Filter::try_new( + Expr::Column(col).eq(Expr::Literal(ScalarValue::Int32(Some(1)))), + scan, + ) + .unwrap(); + assert!(!filter.is_scalar()); + let unique_schema = Arc::new( + schema + .as_ref() + .clone() + .with_functional_dependencies( + FunctionalDependencies::new_from_constraints( + Some(&Constraints::new_unverified(vec![Constraint::Unique( + vec![0], + )])), + 1, + ), + ) + .unwrap(), + ); + let scan = Arc::new(LogicalPlan::TableScan(TableScan { + table_name: TableReference::bare("tab"), + source, + projection: None, + projected_schema: unique_schema.clone(), + filters: vec![], + fetch: None, + })); + let col = schema.field(0).qualified_column(); + + let filter = Filter::try_new( + Expr::Column(col).eq(Expr::Literal(ScalarValue::Int32(Some(1)))), + scan, + ) + .unwrap(); + assert!(filter.is_scalar()); + } + + #[test] + fn test_transform_explain() { + let schema = Schema::new(vec![ + Field::new("foo", DataType::Int32, false), + Field::new("bar", DataType::Int32, false), + ]); + + let plan = table_scan(TableReference::none(), &schema, None) + .unwrap() + .explain(false, false) + .unwrap() + .build() + .unwrap(); + + let external_filter = + col("foo").eq(Expr::Literal(ScalarValue::Boolean(Some(true)))); + + // after transformation, because plan is not the same anymore, + // the parent plan is built again with call to LogicalPlan::with_new_inputs -> with_new_exprs + let plan = plan + .transform(&|plan| match plan { + LogicalPlan::TableScan(table) => { + let filter = Filter::try_new( + external_filter.clone(), + Arc::new(LogicalPlan::TableScan(table)), + ) + .unwrap(); + Ok(Transformed::Yes(LogicalPlan::Filter(filter))) + } + x => Ok(Transformed::No(x)), + }) + .unwrap(); + + let expected = "Explain\ + \n Filter: foo = Boolean(true)\ + \n TableScan: ?table?"; + let actual = format!("{}", plan.display_indent()); + assert_eq!(expected.to_string(), actual) + } } diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index 685601523f9bb..729131bd95e13 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -91,11 +91,14 @@ pub enum TypeSignature { /// DataFusion attempts to coerce all argument types to match the first argument's type /// /// # Examples - /// A function such as `array` is `VariadicEqual` + /// Given types in signature should be coericible to the same final type. + /// A function such as `make_array` is `VariadicEqual`. + /// + /// `make_array(i32, i64) -> make_array(i64, i64)` VariadicEqual, /// One or more arguments with arbitrary types VariadicAny, - /// fixed number of arguments of an arbitrary but equal type out of a list of valid types. + /// Fixed number of arguments of an arbitrary but equal type out of a list of valid types. /// /// # Examples /// 1. A function of one argument of f64 is `Uniform(1, vec![DataType::Float64])` @@ -113,6 +116,16 @@ pub enum TypeSignature { /// Function `make_array` takes 0 or more arguments with arbitrary types, its `TypeSignature` /// is `OneOf(vec![Any(0), VariadicAny])`. OneOf(Vec), + /// Specialized Signature for ArrayAppend and similar functions + /// The first argument should be List/LargeList, and the second argument should be non-list or list. + /// The second argument's list dimension should be one dimension less than the first argument's list dimension. + /// List dimension of the List/LargeList is equivalent to the number of List. + /// List dimension of the non-list is 0. + ArrayAndElement, + /// Specialized Signature for ArrayPrepend and similar functions + /// The first argument should be non-list or list, and the second argument should be List/LargeList. + /// The first argument's list dimension should be one dimension less than the second argument's list dimension. + ElementAndArray, } impl TypeSignature { @@ -136,11 +149,19 @@ impl TypeSignature { .collect::>() .join(", ")] } - TypeSignature::VariadicEqual => vec!["T, .., T".to_string()], + TypeSignature::VariadicEqual => { + vec!["CoercibleT, .., CoercibleT".to_string()] + } TypeSignature::VariadicAny => vec!["Any, .., Any".to_string()], TypeSignature::OneOf(sigs) => { sigs.iter().flat_map(|s| s.to_string_repr()).collect() } + TypeSignature::ArrayAndElement => { + vec!["ArrayAndElement(List, T)".to_string()] + } + TypeSignature::ElementAndArray => { + vec!["ElementAndArray(T, List)".to_string()] + } } } diff --git a/datafusion/expr/src/struct_expressions.rs b/datafusion/expr/src/struct_expressions.rs deleted file mode 100644 index bbfcac0e2396f..0000000000000 --- a/datafusion/expr/src/struct_expressions.rs +++ /dev/null @@ -1,35 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::datatypes::DataType; - -/// Currently supported types by the struct function. -pub static SUPPORTED_STRUCT_TYPES: &[DataType] = &[ - DataType::Boolean, - DataType::UInt8, - DataType::UInt16, - DataType::UInt32, - DataType::UInt64, - DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - DataType::Utf8, - DataType::LargeUtf8, -]; diff --git a/datafusion/expr/src/table_source.rs b/datafusion/expr/src/table_source.rs index 94f26d9158cd1..565f48c1c5a9e 100644 --- a/datafusion/expr/src/table_source.rs +++ b/datafusion/expr/src/table_source.rs @@ -103,4 +103,9 @@ pub trait TableSource: Sync + Send { fn get_logical_plan(&self) -> Option<&LogicalPlan> { None } + + /// Get the default value for a column, if available. + fn get_column_default(&self, _column: &str) -> Option<&Expr> { + None + } } diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 764dcffbced99..56388be58b8a6 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -18,22 +18,20 @@ //! Tree node implementation for logical expr use crate::expr::{ - AggregateFunction, AggregateUDF, Alias, Between, BinaryExpr, Case, Cast, - GetIndexedField, GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, - ScalarUDF, Sort, TryCast, WindowFunction, + AggregateFunction, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Case, + Cast, GetIndexedField, GroupingSet, InList, InSubquery, Like, Placeholder, + ScalarFunction, ScalarFunctionDefinition, Sort, TryCast, WindowFunction, }; use crate::{Expr, GetFieldAccess}; +use std::borrow::Cow; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; -use datafusion_common::Result; +use datafusion_common::tree_node::TreeNode; +use datafusion_common::{internal_err, DataFusionError, Result}; impl TreeNode for Expr { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - let children = match self { - Expr::Alias(Alias{expr,..}) + fn children_nodes(&self) -> Vec> { + match self { + Expr::Alias(Alias { expr, .. }) | Expr::Not(expr) | Expr::IsNotNull(expr) | Expr::IsTrue(expr) @@ -47,28 +45,26 @@ impl TreeNode for Expr { | Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) | Expr::Sort(Sort { expr, .. }) - | Expr::InSubquery(InSubquery{ expr, .. }) => vec![expr.as_ref().clone()], + | Expr::InSubquery(InSubquery { expr, .. }) => vec![Cow::Borrowed(expr)], Expr::GetIndexedField(GetIndexedField { expr, field }) => { - let expr = expr.as_ref().clone(); + let expr = Cow::Borrowed(expr.as_ref()); match field { - GetFieldAccess::ListIndex {key} => { - vec![key.as_ref().clone(), expr] - }, - GetFieldAccess::ListRange {start, stop} => { - vec![start.as_ref().clone(), stop.as_ref().clone(), expr] + GetFieldAccess::ListIndex { key } => { + vec![Cow::Borrowed(key.as_ref()), expr] + } + GetFieldAccess::ListRange { start, stop } => { + vec![Cow::Borrowed(start), Cow::Borrowed(stop), expr] } - GetFieldAccess::NamedStructField {name: _name} => { + GetFieldAccess::NamedStructField { name: _name } => { vec![expr] } } } Expr::GroupingSet(GroupingSet::Rollup(exprs)) - | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.clone(), - Expr::ScalarFunction (ScalarFunction{ args, .. } )| Expr::ScalarUDF(ScalarUDF { args, .. }) => { - args.clone() - } + | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.iter().map(Cow::Borrowed).collect(), + Expr::ScalarFunction(ScalarFunction { args, .. }) => args.iter().map(Cow::Borrowed).collect(), Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { - lists_of_exprs.clone().into_iter().flatten().collect() + lists_of_exprs.iter().flatten().map(Cow::Borrowed).collect() } Expr::Column(_) // Treat OuterReferenceColumn as a leaf expression @@ -77,46 +73,49 @@ impl TreeNode for Expr { | Expr::Literal(_) | Expr::Exists { .. } | Expr::ScalarSubquery(_) - | Expr::Wildcard - | Expr::QualifiedWildcard { .. } - | Expr::Placeholder (_) => vec![], + | Expr::Wildcard { .. } + | Expr::Placeholder(_) => vec![], Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { - vec![left.as_ref().clone(), right.as_ref().clone()] + vec![Cow::Borrowed(left), Cow::Borrowed(right)] } Expr::Like(Like { expr, pattern, .. }) | Expr::SimilarTo(Like { expr, pattern, .. }) => { - vec![expr.as_ref().clone(), pattern.as_ref().clone()] + vec![Cow::Borrowed(expr), Cow::Borrowed(pattern)] } Expr::Between(Between { expr, low, high, .. }) => vec![ - expr.as_ref().clone(), - low.as_ref().clone(), - high.as_ref().clone(), + Cow::Borrowed(expr), + Cow::Borrowed(low), + Cow::Borrowed(high), ], Expr::Case(case) => { let mut expr_vec = vec![]; if let Some(expr) = case.expr.as_ref() { - expr_vec.push(expr.as_ref().clone()); + expr_vec.push(Cow::Borrowed(expr.as_ref())); }; for (when, then) in case.when_then_expr.iter() { - expr_vec.push(when.as_ref().clone()); - expr_vec.push(then.as_ref().clone()); + expr_vec.push(Cow::Borrowed(when)); + expr_vec.push(Cow::Borrowed(then)); } if let Some(else_expr) = case.else_expr.as_ref() { - expr_vec.push(else_expr.as_ref().clone()); + expr_vec.push(Cow::Borrowed(else_expr)); } expr_vec } - Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) - | Expr::AggregateUDF(AggregateUDF { args, filter, order_by, .. }) => { - let mut expr_vec = args.clone(); + Expr::AggregateFunction(AggregateFunction { + args, + filter, + order_by, + .. + }) => { + let mut expr_vec: Vec<_> = args.iter().map(Cow::Borrowed).collect(); if let Some(f) = filter { - expr_vec.push(f.as_ref().clone()); + expr_vec.push(Cow::Borrowed(f)); } if let Some(o) = order_by { - expr_vec.extend(o.clone()); + expr_vec.extend(o.iter().map(Cow::Borrowed).collect::>()); } expr_vec @@ -127,28 +126,17 @@ impl TreeNode for Expr { order_by, .. }) => { - let mut expr_vec = args.clone(); - expr_vec.extend(partition_by.clone()); - expr_vec.extend(order_by.clone()); + let mut expr_vec: Vec<_> = args.iter().map(Cow::Borrowed).collect(); + expr_vec.extend(partition_by.iter().map(Cow::Borrowed).collect::>()); + expr_vec.extend(order_by.iter().map(Cow::Borrowed).collect::>()); expr_vec } Expr::InList(InList { expr, list, .. }) => { - let mut expr_vec = vec![]; - expr_vec.push(expr.as_ref().clone()); - expr_vec.extend(list.clone()); + let mut expr_vec = vec![Cow::Borrowed(expr.as_ref())]; + expr_vec.extend(list.iter().map(Cow::Borrowed).collect::>()); expr_vec } - }; - - for child in children.iter() { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } } - - Ok(VisitRecursion::Continue) } fn map_children(self, transform: F) -> Result @@ -158,9 +146,11 @@ impl TreeNode for Expr { let mut transform = transform; Ok(match self { - Expr::Alias(Alias { expr, name, .. }) => { - Expr::Alias(Alias::new(transform(*expr)?, name)) - } + Expr::Alias(Alias { + expr, + relation, + name, + }) => Expr::Alias(Alias::new(transform(*expr)?, relation, name)), Expr::Column(_) => self, Expr::OuterReferenceColumn(_, _) => self, Expr::Exists { .. } => self, @@ -275,12 +265,19 @@ impl TreeNode for Expr { asc, nulls_first, )), - Expr::ScalarFunction(ScalarFunction { args, fun }) => Expr::ScalarFunction( - ScalarFunction::new(fun, transform_vec(args, &mut transform)?), - ), - Expr::ScalarUDF(ScalarUDF { args, fun }) => { - Expr::ScalarUDF(ScalarUDF::new(fun, transform_vec(args, &mut transform)?)) - } + Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => Expr::ScalarFunction( + ScalarFunction::new(fun, transform_vec(args, &mut transform)?), + ), + ScalarFunctionDefinition::UDF(fun) => Expr::ScalarFunction( + ScalarFunction::new_udf(fun, transform_vec(args, &mut transform)?), + ), + ScalarFunctionDefinition::Name(_) => { + return internal_err!( + "Function `Expr` with name should be resolved." + ); + } + }, Expr::WindowFunction(WindowFunction { args, fun, @@ -296,17 +293,40 @@ impl TreeNode for Expr { )), Expr::AggregateFunction(AggregateFunction { args, - fun, + func_def, distinct, filter, order_by, - }) => Expr::AggregateFunction(AggregateFunction::new( - fun, - transform_vec(args, &mut transform)?, - distinct, - transform_option_box(filter, &mut transform)?, - transform_option_vec(order_by, &mut transform)?, - )), + }) => match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + Expr::AggregateFunction(AggregateFunction::new( + fun, + transform_vec(args, &mut transform)?, + distinct, + transform_option_box(filter, &mut transform)?, + transform_option_vec(order_by, &mut transform)?, + )) + } + AggregateFunctionDefinition::UDF(fun) => { + let order_by = if let Some(order_by) = order_by { + Some(transform_vec(order_by, &mut transform)?) + } else { + None + }; + Expr::AggregateFunction(AggregateFunction::new_udf( + fun, + transform_vec(args, &mut transform)?, + false, + transform_option_box(filter, &mut transform)?, + transform_option_vec(order_by, &mut transform)?, + )) + } + AggregateFunctionDefinition::Name(_) => { + return internal_err!( + "Function `Expr` with name should be resolved." + ); + } + }, Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => Expr::GroupingSet(GroupingSet::Rollup( transform_vec(exprs, &mut transform)?, @@ -323,24 +343,7 @@ impl TreeNode for Expr { )) } }, - Expr::AggregateUDF(AggregateUDF { - args, - fun, - filter, - order_by, - }) => { - let order_by = if let Some(order_by) = order_by { - Some(transform_vec(order_by, &mut transform)?) - } else { - None - }; - Expr::AggregateUDF(AggregateUDF::new( - fun, - transform_vec(args, &mut transform)?, - transform_option_box(filter, &mut transform)?, - transform_option_vec(order_by, &mut transform)?, - )) - } + Expr::InList(InList { expr, list, @@ -350,10 +353,7 @@ impl TreeNode for Expr { transform_vec(list, &mut transform)?, negated, )), - Expr::Wildcard => Expr::Wildcard, - Expr::QualifiedWildcard { qualifier } => { - Expr::QualifiedWildcard { qualifier } - } + Expr::Wildcard { qualifier } => Expr::Wildcard { qualifier }, Expr::GetIndexedField(GetIndexedField { expr, field }) => { Expr::GetIndexedField(GetIndexedField::new( transform_boxed(expr, &mut transform)?, diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index c7621bc178332..208a8b57d7b0a 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -20,8 +20,13 @@ use crate::LogicalPlan; use datafusion_common::tree_node::{TreeNodeVisitor, VisitRecursion}; use datafusion_common::{tree_node::TreeNode, Result}; +use std::borrow::Cow; impl TreeNode for LogicalPlan { + fn children_nodes(&self) -> Vec> { + self.inputs().into_iter().map(Cow::Borrowed).collect() + } + fn apply(&self, op: &mut F) -> Result where F: FnMut(&Self) -> Result, @@ -91,21 +96,6 @@ impl TreeNode for LogicalPlan { visitor.post_visit(self) } - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in self.inputs() { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) - } - fn map_children(self, transform: F) -> Result where F: FnMut(Self) -> Result, @@ -123,7 +113,7 @@ impl TreeNode for LogicalPlan { .zip(new_children.iter()) .any(|(c1, c2)| c1 != &c2) { - self.with_new_inputs(new_children.as_slice()) + self.with_new_exprs(self.expressions(), new_children.as_slice()) } else { Ok(self) } diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 261c406d5d5e7..7128b575978a3 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -298,6 +298,23 @@ pub fn coerce_types( | AggregateFunction::FirstValue | AggregateFunction::LastValue => Ok(input_types.to_vec()), AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), + AggregateFunction::StringAgg => { + if !is_string_agg_supported_arg_type(&input_types[0]) { + return plan_err!( + "The function {:?} does not support inputs of type {:?}", + agg_fun, + input_types[0] + ); + } + if !is_string_agg_supported_arg_type(&input_types[1]) { + return plan_err!( + "The function {:?} does not support inputs of type {:?}", + agg_fun, + input_types[1] + ); + } + Ok(vec![LargeUtf8, input_types[1].clone()]) + } } } @@ -565,6 +582,15 @@ pub fn is_approx_percentile_cont_supported_arg_type(arg_type: &DataType) -> bool ) } +/// Return `true` if `arg_type` is of a [`DataType`] that the +/// [`AggregateFunction::StringAgg`] aggregation can operate on. +pub fn is_string_agg_supported_arg_type(arg_type: &DataType) -> bool { + matches!( + arg_type, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Null + ) +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index cf93d15e23f0e..1b62c1bc05c16 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -116,7 +116,7 @@ fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result }) } AtArrow | ArrowAt => { - // ArrowAt and AtArrow check for whether one array ic contained in another. + // ArrowAt and AtArrow check for whether one array is contained in another. // The result type is boolean. Signature::comparison defines this signature. // Operation has nothing to do with comparison array_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { @@ -331,26 +331,27 @@ fn string_temporal_coercion( rhs_type: &DataType, ) -> Option { use arrow::datatypes::DataType::*; - match (lhs_type, rhs_type) { - (Utf8, Date32) | (Date32, Utf8) => Some(Date32), - (Utf8, Date64) | (Date64, Utf8) => Some(Date64), - (Utf8, Time32(unit)) | (Time32(unit), Utf8) => { - match is_time_with_valid_unit(Time32(unit.clone())) { - false => None, - true => Some(Time32(unit.clone())), - } - } - (Utf8, Time64(unit)) | (Time64(unit), Utf8) => { - match is_time_with_valid_unit(Time64(unit.clone())) { - false => None, - true => Some(Time64(unit.clone())), - } - } - (Timestamp(_, tz), Utf8) | (Utf8, Timestamp(_, tz)) => { - Some(Timestamp(TimeUnit::Nanosecond, tz.clone())) + + fn match_rule(l: &DataType, r: &DataType) -> Option { + match (l, r) { + // Coerce Utf8/LargeUtf8 to Date32/Date64/Time32/Time64/Timestamp + (Utf8, temporal) | (LargeUtf8, temporal) => match temporal { + Date32 | Date64 => Some(temporal.clone()), + Time32(_) | Time64(_) => { + if is_time_with_valid_unit(temporal.to_owned()) { + Some(temporal.to_owned()) + } else { + None + } + } + Timestamp(_, tz) => Some(Timestamp(TimeUnit::Nanosecond, tz.clone())), + _ => None, + }, + _ => None, } - _ => None, } + + match_rule(lhs_type, rhs_type).or_else(|| match_rule(rhs_type, lhs_type)) } /// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation @@ -782,9 +783,14 @@ fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Some(Interval(MonthDayNano)), (Date64, Date32) | (Date32, Date64) => Some(Date64), + (Timestamp(_, None), Date64) | (Date64, Timestamp(_, None)) => { + Some(Timestamp(Nanosecond, None)) + } + (Timestamp(_, _tz), Date64) | (Date64, Timestamp(_, _tz)) => { + Some(Timestamp(Nanosecond, None)) + } (Timestamp(_, None), Date32) | (Date32, Timestamp(_, None)) => { Some(Timestamp(Nanosecond, None)) } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index b49bf37d6754d..63908d539bd01 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -21,7 +21,10 @@ use arrow::{ compute::can_cast_types, datatypes::{DataType, TimeUnit}, }; -use datafusion_common::{plan_err, DataFusionError, Result}; +use datafusion_common::utils::list_ndims; +use datafusion_common::{internal_err, plan_err, DataFusionError, Result}; + +use super::binary::comparison_coercion; /// Performs type coercion for function arguments. /// @@ -35,8 +38,17 @@ pub fn data_types( signature: &Signature, ) -> Result> { if current_types.is_empty() { - return Ok(vec![]); + if signature.type_signature.supports_zero_argument() { + return Ok(vec![]); + } else { + return plan_err!( + "Coercion from {:?} to the signature {:?} failed.", + current_types, + &signature.type_signature + ); + } } + let valid_types = get_valid_types(&signature.type_signature, current_types)?; if valid_types @@ -67,6 +79,55 @@ fn get_valid_types( signature: &TypeSignature, current_types: &[DataType], ) -> Result>> { + fn array_append_or_prepend_valid_types( + current_types: &[DataType], + is_append: bool, + ) -> Result>> { + if current_types.len() != 2 { + return Ok(vec![vec![]]); + } + + let (array_type, elem_type) = if is_append { + (¤t_types[0], ¤t_types[1]) + } else { + (¤t_types[1], ¤t_types[0]) + }; + + // We follow Postgres on `array_append(Null, T)`, which is not valid. + if array_type.eq(&DataType::Null) { + return Ok(vec![vec![]]); + } + + // We need to find the coerced base type, mainly for cases like: + // `array_append(List(null), i64)` -> `List(i64)` + let array_base_type = datafusion_common::utils::base_type(array_type); + let elem_base_type = datafusion_common::utils::base_type(elem_type); + let new_base_type = comparison_coercion(&array_base_type, &elem_base_type); + + if new_base_type.is_none() { + return internal_err!( + "Coercion from {array_base_type:?} to {elem_base_type:?} not supported." + ); + } + let new_base_type = new_base_type.unwrap(); + + let array_type = datafusion_common::utils::coerced_type_with_base_type_only( + array_type, + &new_base_type, + ); + + match array_type { + DataType::List(ref field) | DataType::LargeList(ref field) => { + let elem_type = field.data_type(); + if is_append { + Ok(vec![vec![array_type.clone(), elem_type.to_owned()]]) + } else { + Ok(vec![vec![elem_type.to_owned(), array_type.clone()]]) + } + } + _ => Ok(vec![vec![]]), + } + } let valid_types = match signature { TypeSignature::Variadic(valid_types) => valid_types .iter() @@ -77,16 +138,34 @@ fn get_valid_types( .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) .collect(), TypeSignature::VariadicEqual => { - // one entry with the same len as current_types, whose type is `current_types[0]`. - vec![current_types - .iter() - .map(|_| current_types[0].clone()) - .collect()] + let new_type = current_types.iter().skip(1).try_fold( + current_types.first().unwrap().clone(), + |acc, x| { + let coerced_type = comparison_coercion(&acc, x); + if let Some(coerced_type) = coerced_type { + Ok(coerced_type) + } else { + internal_err!("Coercion from {acc:?} to {x:?} failed.") + } + }, + ); + + match new_type { + Ok(new_type) => vec![vec![new_type; current_types.len()]], + Err(e) => return Err(e), + } } TypeSignature::VariadicAny => { vec![current_types.to_vec()] } + TypeSignature::Exact(valid_types) => vec![valid_types.clone()], + TypeSignature::ArrayAndElement => { + return array_append_or_prepend_valid_types(current_types, true) + } + TypeSignature::ElementAndArray => { + return array_append_or_prepend_valid_types(current_types, false) + } TypeSignature::Any(number) => { if current_types.len() != *number { return plan_err!( @@ -232,6 +311,15 @@ fn coerced_from<'a>( Utf8 | LargeUtf8 => Some(type_into.clone()), Null if can_cast_types(type_from, type_into) => Some(type_into.clone()), + // Only accept list and largelist with the same number of dimensions unless the type is Null. + // List or LargeList with different dimensions should be handled in TypeSignature or other places before this. + List(_) | LargeList(_) + if datafusion_common::utils::base_type(type_from).eq(&Null) + || list_ndims(type_from) == list_ndims(type_into) => + { + Some(type_into.clone()) + } + Timestamp(unit, Some(tz)) if tz.as_ref() == TIMEZONE_WILDCARD => { match type_from { Timestamp(_, Some(from_tz)) => { diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 84e238a1215b2..cfbca4ab1337a 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -15,12 +15,14 @@ // specific language governing permissions and limitations // under the License. -//! Udaf module contains functions and structs supporting user-defined aggregate functions. +//! [`AggregateUDF`]: User Defined Aggregate Functions -use crate::Expr; +use crate::{Accumulator, Expr}; use crate::{ AccumulatorFactoryFunction, ReturnTypeFunction, Signature, StateTypeFunction, }; +use arrow::datatypes::DataType; +use datafusion_common::Result; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; @@ -46,15 +48,15 @@ use std::sync::Arc; #[derive(Clone)] pub struct AggregateUDF { /// name - pub name: String, + name: String, /// Signature (input arguments) - pub signature: Signature, + signature: Signature, /// Return type - pub return_type: ReturnTypeFunction, + return_type: ReturnTypeFunction, /// actual implementation - pub accumulator: AccumulatorFactoryFunction, + accumulator: AccumulatorFactoryFunction, /// the accumulator's state's description as a function of the return type - pub state_type: StateTypeFunction, + state_type: StateTypeFunction, } impl Debug for AggregateUDF { @@ -105,11 +107,43 @@ impl AggregateUDF { /// This utility allows using the UDAF without requiring access to /// the registry, such as with the DataFrame API. pub fn call(&self, args: Vec) -> Expr { - Expr::AggregateUDF(crate::expr::AggregateUDF { - fun: Arc::new(self.clone()), + Expr::AggregateFunction(crate::expr::AggregateFunction::new_udf( + Arc::new(self.clone()), args, - filter: None, - order_by: None, - }) + false, + None, + None, + )) + } + + /// Returns this function's name + pub fn name(&self) -> &str { + &self.name + } + + /// Returns this function's signature (what input types are accepted) + pub fn signature(&self) -> &Signature { + &self.signature + } + + /// Return the type of the function given its input types + pub fn return_type(&self, args: &[DataType]) -> Result { + // Old API returns an Arc of the datatype for some reason + let res = (self.return_type)(args)?; + Ok(res.as_ref().clone()) + } + + /// Return an accumualator the given aggregate, given + /// its return datatype. + pub fn accumulator(&self, return_type: &DataType) -> Result> { + (self.accumulator)(return_type) + } + + /// Return the type of the intermediate state used by this aggregator, given + /// its return datatype. Supports multi-phase aggregations + pub fn state_type(&self, return_type: &DataType) -> Result> { + // old API returns an Arc for some reason, try and unwrap it here + let res = (self.state_type)(return_type)?; + Ok(Arc::try_unwrap(res).unwrap_or_else(|res| res.as_ref().clone())) } } diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index be6c90aa5985d..2ec80a4a9ea1c 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -15,23 +15,42 @@ // specific language governing permissions and limitations // under the License. -//! Udf module contains foundational types that are used to represent UDFs in DataFusion. +//! [`ScalarUDF`]: Scalar User Defined Functions -use crate::{Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature}; +use crate::{ + ColumnarValue, Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature, +}; +use arrow::datatypes::DataType; +use datafusion_common::Result; +use std::any::Any; use std::fmt; use std::fmt::Debug; use std::fmt::Formatter; use std::sync::Arc; -/// Logical representation of a UDF. +/// Logical representation of a Scalar User Defined Function. +/// +/// A scalar function produces a single row output for each row of input. This +/// struct contains the information DataFusion needs to plan and invoke +/// functions you supply such name, type signature, return type, and actual +/// implementation. +/// +/// +/// 1. For simple (less performant) use cases, use [`create_udf`] and [`simple_udf.rs`]. +/// +/// 2. For advanced use cases, use [`ScalarUDFImpl`] and [`advanced_udf.rs`]. +/// +/// [`create_udf`]: crate::expr_fn::create_udf +/// [`simple_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udf.rs +/// [`advanced_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs #[derive(Clone)] pub struct ScalarUDF { - /// name - pub name: String, - /// signature - pub signature: Signature, - /// Return type - pub return_type: ReturnTypeFunction, + /// The name of the function + name: String, + /// The signature (the types of arguments that are supported) + signature: Signature, + /// Function that returns the return type given the argument types + return_type: ReturnTypeFunction, /// actual implementation /// /// The fn param is the wrapped function but be aware that the function will @@ -40,7 +59,9 @@ pub struct ScalarUDF { /// will be passed. In that case the single element is a null array to indicate /// the batch's row count (so that the generative zero-argument function can know /// the result array size). - pub fun: ScalarFunctionImplementation, + fun: ScalarFunctionImplementation, + /// Optional aliases for the function. This list should NOT include the value of `name` as well + aliases: Vec, } impl Debug for ScalarUDF { @@ -69,7 +90,11 @@ impl std::hash::Hash for ScalarUDF { } impl ScalarUDF { - /// Create a new ScalarUDF + /// Create a new ScalarUDF from low level details. + /// + /// See [`ScalarUDFImpl`] for a more convenient way to create a + /// `ScalarUDF` using trait objects + #[deprecated(since = "34.0.0", note = "please implement ScalarUDFImpl instead")] pub fn new( name: &str, signature: &Signature, @@ -81,12 +106,189 @@ impl ScalarUDF { signature: signature.clone(), return_type: return_type.clone(), fun: fun.clone(), + aliases: vec![], + } + } + + /// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object + /// + /// Note this is the same as using the `From` impl (`ScalarUDF::from`) + pub fn new_from_impl(fun: F) -> ScalarUDF + where + F: ScalarUDFImpl + Send + Sync + 'static, + { + // TODO change the internal implementation to use the trait object + let arc_fun = Arc::new(fun); + let captured_self = arc_fun.clone(); + let return_type: ReturnTypeFunction = Arc::new(move |arg_types| { + let return_type = captured_self.return_type(arg_types)?; + Ok(Arc::new(return_type)) + }); + + let captured_self = arc_fun.clone(); + let func: ScalarFunctionImplementation = + Arc::new(move |args| captured_self.invoke(args)); + + Self { + name: arc_fun.name().to_string(), + signature: arc_fun.signature().clone(), + return_type: return_type.clone(), + fun: func, + aliases: arc_fun.aliases().to_vec(), } } - /// creates a logical expression with a call of the UDF + /// Adds additional names that can be used to invoke this function, in addition to `name` + pub fn with_aliases( + mut self, + aliases: impl IntoIterator, + ) -> Self { + self.aliases + .extend(aliases.into_iter().map(|s| s.to_string())); + self + } + + /// Returns a [`Expr`] logical expression to call this UDF with specified + /// arguments. + /// /// This utility allows using the UDF without requiring access to the registry. pub fn call(&self, args: Vec) -> Expr { - Expr::ScalarUDF(crate::expr::ScalarUDF::new(Arc::new(self.clone()), args)) + Expr::ScalarFunction(crate::expr::ScalarFunction::new_udf( + Arc::new(self.clone()), + args, + )) + } + + /// Returns this function's name + pub fn name(&self) -> &str { + &self.name + } + + /// Returns the aliases for this function. See [`ScalarUDF::with_aliases`] for more details + pub fn aliases(&self) -> &[String] { + &self.aliases + } + + /// Returns this function's [`Signature`] (what input types are accepted) + pub fn signature(&self) -> &Signature { + &self.signature + } + + /// The datatype this function returns given the input argument input types + pub fn return_type(&self, args: &[DataType]) -> Result { + // Old API returns an Arc of the datatype for some reason + let res = (self.return_type)(args)?; + Ok(res.as_ref().clone()) + } + + /// Return an [`Arc`] to the function implementation + pub fn fun(&self) -> ScalarFunctionImplementation { + self.fun.clone() + } +} + +impl From for ScalarUDF +where + F: ScalarUDFImpl + Send + Sync + 'static, +{ + fn from(fun: F) -> Self { + Self::new_from_impl(fun) + } +} + +/// Trait for implementing [`ScalarUDF`]. +/// +/// This trait exposes the full API for implementing user defined functions and +/// can be used to implement any function. +/// +/// See [`advanced_udf.rs`] for a full example with complete implementation and +/// [`ScalarUDF`] for other available options. +/// +/// +/// [`advanced_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs +/// # Basic Example +/// ``` +/// # use std::any::Any; +/// # use arrow::datatypes::DataType; +/// # use datafusion_common::{DataFusionError, plan_err, Result}; +/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility}; +/// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF}; +/// struct AddOne { +/// signature: Signature +/// }; +/// +/// impl AddOne { +/// fn new() -> Self { +/// Self { +/// signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable) +/// } +/// } +/// } +/// +/// /// Implement the ScalarUDFImpl trait for AddOne +/// impl ScalarUDFImpl for AddOne { +/// fn as_any(&self) -> &dyn Any { self } +/// fn name(&self) -> &str { "add_one" } +/// fn signature(&self) -> &Signature { &self.signature } +/// fn return_type(&self, args: &[DataType]) -> Result { +/// if !matches!(args.get(0), Some(&DataType::Int32)) { +/// return plan_err!("add_one only accepts Int32 arguments"); +/// } +/// Ok(DataType::Int32) +/// } +/// // The actual implementation would add one to the argument +/// fn invoke(&self, args: &[ColumnarValue]) -> Result { unimplemented!() } +/// } +/// +/// // Create a new ScalarUDF from the implementation +/// let add_one = ScalarUDF::from(AddOne::new()); +/// +/// // Call the function `add_one(col)` +/// let expr = add_one.call(vec![col("a")]); +/// ``` +pub trait ScalarUDFImpl { + /// Returns this object as an [`Any`] trait object + fn as_any(&self) -> &dyn Any; + + /// Returns this function's name + fn name(&self) -> &str; + + /// Returns the function's [`Signature`] for information about what input + /// types are accepted and the function's Volatility. + fn signature(&self) -> &Signature; + + /// What [`DataType`] will be returned by this function, given the types of + /// the arguments + fn return_type(&self, arg_types: &[DataType]) -> Result; + + /// Invoke the function on `args`, returning the appropriate result + /// + /// The function will be invoked passed with the slice of [`ColumnarValue`] + /// (either scalar or array). + /// + /// # Zero Argument Functions + /// If the function has zero parameters (e.g. `now()`) it will be passed a + /// single element slice which is a a null array to indicate the batch's row + /// count (so the function can know the resulting array size). + /// + /// # Performance + /// + /// For the best performance, the implementations of `invoke` should handle + /// the common case when one or more of their arguments are constant values + /// (aka [`ColumnarValue::Scalar`]). Calling [`ColumnarValue::into_array`] + /// and treating all arguments as arrays will work, but will be slower. + fn invoke(&self, args: &[ColumnarValue]) -> Result; + + /// Returns any aliases (alternate names) for this function. + /// + /// Aliases can be used to invoke the same function using different names. + /// For example in some databases `now()` and `current_timestamp()` are + /// aliases for the same function. This behavior can be obtained by + /// returning `current_timestamp` as an alias for the `now` function. + /// + /// Note: `aliases` should only include names other than [`Self::name`]. + /// Defaults to `[]` (no aliases) + fn aliases(&self) -> &[String] { + &[] } } diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index c0a2a8205a080..800386bfc77b8 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -15,17 +15,20 @@ // specific language governing permissions and limitations // under the License. -//! Support for user-defined window (UDWF) window functions +//! [`WindowUDF`]: User Defined Window Functions +use crate::{ + Expr, PartitionEvaluator, PartitionEvaluatorFactory, ReturnTypeFunction, Signature, + WindowFrame, +}; +use arrow::datatypes::DataType; +use datafusion_common::Result; use std::{ + any::Any, fmt::{self, Debug, Display, Formatter}, sync::Arc, }; -use crate::{ - Expr, PartitionEvaluatorFactory, ReturnTypeFunction, Signature, WindowFrame, -}; - /// Logical representation of a user-defined window function (UDWF) /// A UDWF is different from a UDF in that it is stateful across batches. /// @@ -35,13 +38,13 @@ use crate::{ #[derive(Clone)] pub struct WindowUDF { /// name - pub name: String, + name: String, /// signature - pub signature: Signature, + signature: Signature, /// Return type - pub return_type: ReturnTypeFunction, + return_type: ReturnTypeFunction, /// Return the partition evaluator - pub partition_evaluator_factory: PartitionEvaluatorFactory, + partition_evaluator_factory: PartitionEvaluatorFactory, } impl Debug for WindowUDF { @@ -78,7 +81,11 @@ impl std::hash::Hash for WindowUDF { } impl WindowUDF { - /// Create a new WindowUDF + /// Create a new WindowUDF from low level details. + /// + /// See [`WindowUDFImpl`] for a more convenient way to create a + /// `WindowUDF` using trait objects + #[deprecated(since = "34.0.0", note = "please implement ScalarUDFImpl instead")] pub fn new( name: &str, signature: &Signature, @@ -86,13 +93,39 @@ impl WindowUDF { partition_evaluator_factory: &PartitionEvaluatorFactory, ) -> Self { Self { - name: name.to_owned(), + name: name.to_string(), signature: signature.clone(), return_type: return_type.clone(), partition_evaluator_factory: partition_evaluator_factory.clone(), } } + /// Create a new `WindowUDF` from a `[WindowUDFImpl]` trait object + /// + /// Note this is the same as using the `From` impl (`WindowUDF::from`) + pub fn new_from_impl(fun: F) -> WindowUDF + where + F: WindowUDFImpl + Send + Sync + 'static, + { + let arc_fun = Arc::new(fun); + let captured_self = arc_fun.clone(); + let return_type: ReturnTypeFunction = Arc::new(move |arg_types| { + let return_type = captured_self.return_type(arg_types)?; + Ok(Arc::new(return_type)) + }); + + let captured_self = arc_fun.clone(); + let partition_evaluator_factory: PartitionEvaluatorFactory = + Arc::new(move || captured_self.partition_evaluator()); + + Self { + name: arc_fun.name().to_string(), + signature: arc_fun.signature().clone(), + return_type: return_type.clone(), + partition_evaluator_factory, + } + } + /// creates a [`Expr`] that calls the window function given /// the `partition_by`, `order_by`, and `window_frame` definition /// @@ -105,7 +138,7 @@ impl WindowUDF { order_by: Vec, window_frame: WindowFrame, ) -> Expr { - let fun = crate::WindowFunction::WindowUDF(Arc::new(self.clone())); + let fun = crate::WindowFunctionDefinition::WindowUDF(Arc::new(self.clone())); Expr::WindowFunction(crate::expr::WindowFunction { fun, @@ -115,4 +148,109 @@ impl WindowUDF { window_frame, }) } + + /// Returns this function's name + pub fn name(&self) -> &str { + &self.name + } + + /// Returns this function's signature (what input types are accepted) + pub fn signature(&self) -> &Signature { + &self.signature + } + + /// Return the type of the function given its input types + pub fn return_type(&self, args: &[DataType]) -> Result { + // Old API returns an Arc of the datatype for some reason + let res = (self.return_type)(args)?; + Ok(res.as_ref().clone()) + } + + /// Return a `PartitionEvaluator` for evaluating this window function + pub fn partition_evaluator_factory(&self) -> Result> { + (self.partition_evaluator_factory)() + } +} + +impl From for WindowUDF +where + F: WindowUDFImpl + Send + Sync + 'static, +{ + fn from(fun: F) -> Self { + Self::new_from_impl(fun) + } +} + +/// Trait for implementing [`WindowUDF`]. +/// +/// This trait exposes the full API for implementing user defined window functions and +/// can be used to implement any function. +/// +/// See [`advanced_udwf.rs`] for a full example with complete implementation and +/// [`WindowUDF`] for other available options. +/// +/// +/// [`advanced_udwf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udwf.rs +/// # Basic Example +/// ``` +/// # use std::any::Any; +/// # use arrow::datatypes::DataType; +/// # use datafusion_common::{DataFusionError, plan_err, Result}; +/// # use datafusion_expr::{col, Signature, Volatility, PartitionEvaluator, WindowFrame}; +/// # use datafusion_expr::{WindowUDFImpl, WindowUDF}; +/// struct SmoothIt { +/// signature: Signature +/// }; +/// +/// impl SmoothIt { +/// fn new() -> Self { +/// Self { +/// signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable) +/// } +/// } +/// } +/// +/// /// Implement the WindowUDFImpl trait for AddOne +/// impl WindowUDFImpl for SmoothIt { +/// fn as_any(&self) -> &dyn Any { self } +/// fn name(&self) -> &str { "smooth_it" } +/// fn signature(&self) -> &Signature { &self.signature } +/// fn return_type(&self, args: &[DataType]) -> Result { +/// if !matches!(args.get(0), Some(&DataType::Int32)) { +/// return plan_err!("smooth_it only accepts Int32 arguments"); +/// } +/// Ok(DataType::Int32) +/// } +/// // The actual implementation would add one to the argument +/// fn partition_evaluator(&self) -> Result> { unimplemented!() } +/// } +/// +/// // Create a new ScalarUDF from the implementation +/// let smooth_it = WindowUDF::from(SmoothIt::new()); +/// +/// // Call the function `add_one(col)` +/// let expr = smooth_it.call( +/// vec![col("speed")], // smooth_it(speed) +/// vec![col("car")], // PARTITION BY car +/// vec![col("time").sort(true, true)], // ORDER BY time ASC +/// WindowFrame::new(false), +/// ); +/// ``` +pub trait WindowUDFImpl { + /// Returns this object as an [`Any`] trait object + fn as_any(&self) -> &dyn Any; + + /// Returns this function's name + fn name(&self) -> &str; + + /// Returns the function's [`Signature`] for information about what input + /// types are accepted and the function's Volatility. + fn signature(&self) -> &Signature; + + /// What [`DataType`] will be returned by this function, given the types of + /// the arguments + fn return_type(&self, arg_types: &[DataType]) -> Result; + + /// Invoke the function, returning the [`PartitionEvaluator`] instance + fn partition_evaluator(&self) -> Result>; } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 5fc5b5b3f9c77..914b354d29505 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -17,19 +17,28 @@ //! Expression utilities +use std::cmp::Ordering; +use std::collections::HashSet; +use std::sync::Arc; + use crate::expr::{Alias, Sort, WindowFunction}; +use crate::expr_rewriter::strip_outer_reference; use crate::logical_plan::Aggregate; use crate::signature::{Signature, TypeSignature}; -use crate::{Cast, Expr, ExprSchemable, GroupingSet, LogicalPlan, TryCast}; +use crate::{ + and, BinaryExpr, Cast, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, + Operator, TryCast, +}; + use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::utils::get_at_indices; use datafusion_common::{ internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference, }; + use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem, WildcardAdditionalOptions}; -use std::cmp::Ordering; -use std::collections::HashSet; /// The value to which `COUNT(*)` is expanded to in /// `COUNT()` expressions @@ -283,17 +292,14 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::TryCast { .. } | Expr::Sort { .. } | Expr::ScalarFunction(..) - | Expr::ScalarUDF(..) | Expr::WindowFunction { .. } | Expr::AggregateFunction { .. } | Expr::GroupingSet(_) - | Expr::AggregateUDF { .. } | Expr::InList { .. } | Expr::Exists { .. } | Expr::InSubquery(_) | Expr::ScalarSubquery(_) - | Expr::Wildcard - | Expr::QualifiedWildcard { .. } + | Expr::Wildcard { .. } | Expr::GetIndexedField { .. } | Expr::Placeholder(_) | Expr::OuterReferenceColumn { .. } => {} @@ -420,18 +426,18 @@ pub fn expand_qualified_wildcard( wildcard_options: Option<&WildcardAdditionalOptions>, ) -> Result> { let qualifier = TableReference::from(qualifier); - let qualified_fields: Vec = schema - .fields_with_qualified(&qualifier) - .into_iter() - .cloned() - .collect(); + let qualified_indices = schema.fields_indices_with_qualified(&qualifier); + let projected_func_dependencies = schema + .functional_dependencies() + .project_functional_dependencies(&qualified_indices, qualified_indices.len()); + let qualified_fields = get_at_indices(schema.fields(), &qualified_indices)?; if qualified_fields.is_empty() { return plan_err!("Invalid qualifier {qualifier}"); } let qualified_schema = DFSchema::new_with_metadata(qualified_fields, schema.metadata().clone())? // We can use the functional dependencies as is, since it only stores indices: - .with_functional_dependencies(schema.functional_dependencies().clone()); + .with_functional_dependencies(projected_func_dependencies)?; let excluded_columns = if let Some(WildcardAdditionalOptions { opt_exclude, opt_except, @@ -499,7 +505,6 @@ pub fn generate_sort_key( let res = final_sort_keys .into_iter() .zip(is_partition_flag) - .map(|(lhs, rhs)| (lhs, rhs)) .collect::>(); Ok(res) } @@ -570,14 +575,14 @@ pub fn compare_sort_expr( /// group a slice of window expression expr by their order by expressions pub fn group_window_expr_by_sort_keys( - window_expr: &[Expr], -) -> Result)>> { + window_expr: Vec, +) -> Result)>> { let mut result = vec![]; - window_expr.iter().try_for_each(|expr| match expr { - Expr::WindowFunction(WindowFunction{ partition_by, order_by, .. }) => { + window_expr.into_iter().try_for_each(|expr| match &expr { + Expr::WindowFunction( WindowFunction{ partition_by, order_by, .. }) => { let sort_key = generate_sort_key(partition_by, order_by)?; if let Some((_, values)) = result.iter_mut().find( - |group: &&mut (WindowSortKey, Vec<&Expr>)| matches!(group, (key, _) if *key == sort_key), + |group: &&mut (WindowSortKey, Vec)| matches!(group, (key, _) if *key == sort_key), ) { values.push(expr); } else { @@ -592,15 +597,12 @@ pub fn group_window_expr_by_sort_keys( Ok(result) } -/// Collect all deeply nested `Expr::AggregateFunction` and -/// `Expr::AggregateUDF`. They are returned in order of occurrence (depth +/// Collect all deeply nested `Expr::AggregateFunction`. +/// They are returned in order of occurrence (depth /// first), with duplicates omitted. pub fn find_aggregate_exprs(exprs: &[Expr]) -> Vec { find_exprs_in_exprs(exprs, &|nested_expr| { - matches!( - nested_expr, - Expr::AggregateFunction { .. } | Expr::AggregateUDF { .. } - ) + matches!(nested_expr, Expr::AggregateFunction { .. }) }) } @@ -732,11 +734,7 @@ fn agg_cols(agg: &Aggregate) -> Vec { .collect() } -fn exprlist_to_fields_aggregate( - exprs: &[Expr], - plan: &LogicalPlan, - agg: &Aggregate, -) -> Result> { +fn exprlist_to_fields_aggregate(exprs: &[Expr], agg: &Aggregate) -> Result> { let agg_cols = agg_cols(agg); let mut fields = vec![]; for expr in exprs { @@ -745,7 +743,7 @@ fn exprlist_to_fields_aggregate( // resolve against schema of input to aggregate fields.push(expr.to_field(agg.input.schema())?); } - _ => fields.push(expr.to_field(plan.schema())?), + _ => fields.push(expr.to_field(&agg.schema)?), } } Ok(fields) @@ -762,15 +760,7 @@ pub fn exprlist_to_fields<'a>( // `GROUPING(person.state)` so in order to resolve `person.state` in this case we need to // look at the input to the aggregate instead. let fields = match plan { - LogicalPlan::Aggregate(agg) => { - Some(exprlist_to_fields_aggregate(&exprs, plan, agg)) - } - LogicalPlan::Window(window) => match window.input.as_ref() { - LogicalPlan::Aggregate(agg) => { - Some(exprlist_to_fields_aggregate(&exprs, plan, agg)) - } - _ => None, - }, + LogicalPlan::Aggregate(agg) => Some(exprlist_to_fields_aggregate(&exprs, agg)), _ => None, }; if let Some(fields) = fields { @@ -801,9 +791,11 @@ pub fn columnize_expr(e: Expr, input_schema: &DFSchema) -> Expr { match e { Expr::Column(_) => e, Expr::OuterReferenceColumn(_, _) => e, - Expr::Alias(Alias { expr, name, .. }) => { - columnize_expr(*expr, input_schema).alias(name) - } + Expr::Alias(Alias { + expr, + relation, + name, + }) => columnize_expr(*expr, input_schema).alias_qualified(relation, name), Expr::Cast(Cast { expr, data_type }) => Expr::Cast(Cast { expr: Box::new(columnize_expr(*expr, input_schema)), data_type, @@ -900,7 +892,7 @@ pub fn can_hash(data_type: &DataType) -> bool { DataType::UInt64 => true, DataType::Float32 => true, DataType::Float64 => true, - DataType::Timestamp(time_unit, None) => match time_unit { + DataType::Timestamp(time_unit, _) => match time_unit { TimeUnit::Second => true, TimeUnit::Millisecond => true, TimeUnit::Microsecond => true, @@ -1004,19 +996,251 @@ pub fn generate_signature_error_msg( ) } +/// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` +/// +/// See [`split_conjunction_owned`] for more details and an example. +pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> { + split_conjunction_impl(expr, vec![]) +} + +fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> { + match expr { + Expr::BinaryExpr(BinaryExpr { + right, + op: Operator::And, + left, + }) => { + let exprs = split_conjunction_impl(left, exprs); + split_conjunction_impl(right, exprs) + } + Expr::Alias(Alias { expr, .. }) => split_conjunction_impl(expr, exprs), + other => { + exprs.push(other); + exprs + } + } +} + +/// Splits an owned conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` +/// +/// This is often used to "split" filter expressions such as `col1 = 5 +/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`]; +/// +/// # Example +/// ``` +/// # use datafusion_expr::{col, lit}; +/// # use datafusion_expr::utils::split_conjunction_owned; +/// // a=1 AND b=2 +/// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2))); +/// +/// // [a=1, b=2] +/// let split = vec![ +/// col("a").eq(lit(1)), +/// col("b").eq(lit(2)), +/// ]; +/// +/// // use split_conjunction_owned to split them +/// assert_eq!(split_conjunction_owned(expr), split); +/// ``` +pub fn split_conjunction_owned(expr: Expr) -> Vec { + split_binary_owned(expr, Operator::And) +} + +/// Splits an owned binary operator tree [`Expr`] such as `A B C` => `[A, B, C]` +/// +/// This is often used to "split" expressions such as `col1 = 5 +/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`]; +/// +/// # Example +/// ``` +/// # use datafusion_expr::{col, lit, Operator}; +/// # use datafusion_expr::utils::split_binary_owned; +/// # use std::ops::Add; +/// // a=1 + b=2 +/// let expr = col("a").eq(lit(1)).add(col("b").eq(lit(2))); +/// +/// // [a=1, b=2] +/// let split = vec![ +/// col("a").eq(lit(1)), +/// col("b").eq(lit(2)), +/// ]; +/// +/// // use split_binary_owned to split them +/// assert_eq!(split_binary_owned(expr, Operator::Plus), split); +/// ``` +pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec { + split_binary_owned_impl(expr, op, vec![]) +} + +fn split_binary_owned_impl( + expr: Expr, + operator: Operator, + mut exprs: Vec, +) -> Vec { + match expr { + Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => { + let exprs = split_binary_owned_impl(*left, operator, exprs); + split_binary_owned_impl(*right, operator, exprs) + } + Expr::Alias(Alias { expr, .. }) => { + split_binary_owned_impl(*expr, operator, exprs) + } + other => { + exprs.push(other); + exprs + } + } +} + +/// Splits an binary operator tree [`Expr`] such as `A B C` => `[A, B, C]` +/// +/// See [`split_binary_owned`] for more details and an example. +pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> { + split_binary_impl(expr, op, vec![]) +} + +fn split_binary_impl<'a>( + expr: &'a Expr, + operator: Operator, + mut exprs: Vec<&'a Expr>, +) -> Vec<&'a Expr> { + match expr { + Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => { + let exprs = split_binary_impl(left, operator, exprs); + split_binary_impl(right, operator, exprs) + } + Expr::Alias(Alias { expr, .. }) => split_binary_impl(expr, operator, exprs), + other => { + exprs.push(other); + exprs + } + } +} + +/// Combines an array of filter expressions into a single filter +/// expression consisting of the input filter expressions joined with +/// logical AND. +/// +/// Returns None if the filters array is empty. +/// +/// # Example +/// ``` +/// # use datafusion_expr::{col, lit}; +/// # use datafusion_expr::utils::conjunction; +/// // a=1 AND b=2 +/// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2))); +/// +/// // [a=1, b=2] +/// let split = vec![ +/// col("a").eq(lit(1)), +/// col("b").eq(lit(2)), +/// ]; +/// +/// // use conjunction to join them together with `AND` +/// assert_eq!(conjunction(split), Some(expr)); +/// ``` +pub fn conjunction(filters: impl IntoIterator) -> Option { + filters.into_iter().reduce(|accum, expr| accum.and(expr)) +} + +/// Combines an array of filter expressions into a single filter +/// expression consisting of the input filter expressions joined with +/// logical OR. +/// +/// Returns None if the filters array is empty. +pub fn disjunction(filters: impl IntoIterator) -> Option { + filters.into_iter().reduce(|accum, expr| accum.or(expr)) +} + +/// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with +/// its predicate be all `predicates` ANDed. +pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result { + // reduce filters to a single filter with an AND + let predicate = predicates + .iter() + .skip(1) + .fold(predicates[0].clone(), |acc, predicate| { + and(acc, (*predicate).to_owned()) + }); + + Ok(LogicalPlan::Filter(Filter::try_new( + predicate, + Arc::new(plan), + )?)) +} + +/// Looks for correlating expressions: for example, a binary expression with one field from the subquery, and +/// one not in the subquery (closed upon from outer scope) +/// +/// # Arguments +/// +/// * `exprs` - List of expressions that may or may not be joins +/// +/// # Return value +/// +/// Tuple of (expressions containing joins, remaining non-join expressions) +pub fn find_join_exprs(exprs: Vec<&Expr>) -> Result<(Vec, Vec)> { + let mut joins = vec![]; + let mut others = vec![]; + for filter in exprs.into_iter() { + // If the expression contains correlated predicates, add it to join filters + if filter.contains_outer() { + if !matches!(filter, Expr::BinaryExpr(BinaryExpr{ left, op: Operator::Eq, right }) if left.eq(right)) + { + joins.push(strip_outer_reference((*filter).clone())); + } + } else { + others.push((*filter).clone()); + } + } + + Ok((joins, others)) +} + +/// Returns the first (and only) element in a slice, or an error +/// +/// # Arguments +/// +/// * `slice` - The slice to extract from +/// +/// # Return value +/// +/// The first element, or an error +pub fn only_or_err(slice: &[T]) -> Result<&T> { + match slice { + [it] => Ok(it), + [] => plan_err!("No items found!"), + _ => plan_err!("More than one item found!"), + } +} + +/// merge inputs schema into a single schema. +pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema { + if inputs.len() == 1 { + inputs[0].schema().clone().as_ref().clone() + } else { + inputs.iter().map(|input| input.schema()).fold( + DFSchema::empty(), + |mut lhs, rhs| { + lhs.merge(rhs); + lhs + }, + ) + } +} + #[cfg(test)] mod tests { use super::*; - use crate::expr_vec_fmt; use crate::{ - col, cube, expr, grouping_set, rollup, AggregateFunction, WindowFrame, - WindowFunction, + col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, AggregateFunction, + WindowFrame, WindowFunctionDefinition, }; #[test] fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> { - let result = group_window_expr_by_sort_keys(&[])?; - let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![]; + let result = group_window_expr_by_sort_keys(vec![])?; + let expected: Vec<(WindowSortKey, Vec)> = vec![]; assert_eq!(expected, result); Ok(()) } @@ -1024,38 +1248,38 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> { let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![], WindowFrame::new(false), )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![], WindowFrame::new(false), )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), vec![col("name")], vec![], vec![], WindowFrame::new(false), )); let sum4 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Sum), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), vec![col("age")], vec![], vec![], WindowFrame::new(false), )); let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; - let result = group_window_expr_by_sort_keys(exprs)?; + let result = group_window_expr_by_sort_keys(exprs.to_vec())?; let key = vec![]; - let expected: Vec<(WindowSortKey, Vec<&Expr>)> = - vec![(key, vec![&max1, &max2, &min3, &sum4])]; + let expected: Vec<(WindowSortKey, Vec)> = + vec![(key, vec![max1, max2, min3, sum4])]; assert_eq!(expected, result); Ok(()) } @@ -1067,28 +1291,28 @@ mod tests { let created_at_desc = Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)); let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![age_asc.clone(), name_desc.clone()], WindowFrame::new(true), )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![], WindowFrame::new(false), )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), vec![col("name")], vec![], vec![age_asc.clone(), name_desc.clone()], WindowFrame::new(true), )); let sum4 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Sum), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), vec![col("age")], vec![], vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()], @@ -1096,7 +1320,7 @@ mod tests { )); // FIXME use as_ref let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; - let result = group_window_expr_by_sort_keys(exprs)?; + let result = group_window_expr_by_sort_keys(exprs.to_vec())?; let key1 = vec![(age_asc.clone(), false), (name_desc.clone(), false)]; let key2 = vec![]; @@ -1106,10 +1330,10 @@ mod tests { (created_at_desc, false), ]; - let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![ - (key1, vec![&max1, &min3]), - (key2, vec![&max2]), - (key3, vec![&sum4]), + let expected: Vec<(WindowSortKey, Vec)> = vec![ + (key1, vec![max1, min3]), + (key2, vec![max2]), + (key3, vec![sum4]), ]; assert_eq!(expected, result); Ok(()) @@ -1119,7 +1343,7 @@ mod tests { fn test_find_sort_exprs() -> Result<()> { let exprs = &[ Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![ @@ -1129,7 +1353,7 @@ mod tests { WindowFrame::new(true), )), Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Sum), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), vec![col("age")], vec![], vec![ @@ -1322,4 +1546,143 @@ mod tests { Ok(()) } + #[test] + fn test_split_conjunction() { + let expr = col("a"); + let result = split_conjunction(&expr); + assert_eq!(result, vec![&expr]); + } + + #[test] + fn test_split_conjunction_two() { + let expr = col("a").eq(lit(5)).and(col("b")); + let expr1 = col("a").eq(lit(5)); + let expr2 = col("b"); + + let result = split_conjunction(&expr); + assert_eq!(result, vec![&expr1, &expr2]); + } + + #[test] + fn test_split_conjunction_alias() { + let expr = col("a").eq(lit(5)).and(col("b").alias("the_alias")); + let expr1 = col("a").eq(lit(5)); + let expr2 = col("b"); // has no alias + + let result = split_conjunction(&expr); + assert_eq!(result, vec![&expr1, &expr2]); + } + + #[test] + fn test_split_conjunction_or() { + let expr = col("a").eq(lit(5)).or(col("b")); + let result = split_conjunction(&expr); + assert_eq!(result, vec![&expr]); + } + + #[test] + fn test_split_binary_owned() { + let expr = col("a"); + assert_eq!(split_binary_owned(expr.clone(), Operator::And), vec![expr]); + } + + #[test] + fn test_split_binary_owned_two() { + assert_eq!( + split_binary_owned(col("a").eq(lit(5)).and(col("b")), Operator::And), + vec![col("a").eq(lit(5)), col("b")] + ); + } + + #[test] + fn test_split_binary_owned_different_op() { + let expr = col("a").eq(lit(5)).or(col("b")); + assert_eq!( + // expr is connected by OR, but pass in AND + split_binary_owned(expr.clone(), Operator::And), + vec![expr] + ); + } + + #[test] + fn test_split_conjunction_owned() { + let expr = col("a"); + assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); + } + + #[test] + fn test_split_conjunction_owned_two() { + assert_eq!( + split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))), + vec![col("a").eq(lit(5)), col("b")] + ); + } + + #[test] + fn test_split_conjunction_owned_alias() { + assert_eq!( + split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))), + vec![ + col("a").eq(lit(5)), + // no alias on b + col("b"), + ] + ); + } + + #[test] + fn test_conjunction_empty() { + assert_eq!(conjunction(vec![]), None); + } + + #[test] + fn test_conjunction() { + // `[A, B, C]` + let expr = conjunction(vec![col("a"), col("b"), col("c")]); + + // --> `(A AND B) AND C` + assert_eq!(expr, Some(col("a").and(col("b")).and(col("c")))); + + // which is different than `A AND (B AND C)` + assert_ne!(expr, Some(col("a").and(col("b").and(col("c"))))); + } + + #[test] + fn test_disjunction_empty() { + assert_eq!(disjunction(vec![]), None); + } + + #[test] + fn test_disjunction() { + // `[A, B, C]` + let expr = disjunction(vec![col("a"), col("b"), col("c")]); + + // --> `(A OR B) OR C` + assert_eq!(expr, Some(col("a").or(col("b")).or(col("c")))); + + // which is different than `A OR (B OR C)` + assert_ne!(expr, Some(col("a").or(col("b").or(col("c"))))); + } + + #[test] + fn test_split_conjunction_owned_or() { + let expr = col("a").eq(lit(5)).or(col("b")); + assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); + } + + #[test] + fn test_collect_expr() -> Result<()> { + let mut accum: HashSet = HashSet::new(); + expr_to_columns( + &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), + &mut accum, + )?; + expr_to_columns( + &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), + &mut accum, + )?; + assert_eq!(1, accum.len()); + assert!(accum.contains(&Column::from_name("a"))); + Ok(()) + } } diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index 5f161b85dd9ac..2701ca1ecf3b1 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -23,6 +23,8 @@ //! - An ending frame boundary, //! - An EXCLUDE clause. +use crate::expr::Sort; +use crate::Expr; use datafusion_common::{plan_err, sql_err, DataFusionError, Result, ScalarValue}; use sqlparser::ast; use sqlparser::parser::ParserError::ParserError; @@ -142,31 +144,57 @@ impl WindowFrame { } } -/// Construct equivalent explicit window frames for implicit corner cases. -/// With this processing, we may assume in downstream code that RANGE/GROUPS -/// frames contain an appropriate ORDER BY clause. -pub fn regularize(mut frame: WindowFrame, order_bys: usize) -> Result { - if frame.units == WindowFrameUnits::Range && order_bys != 1 { +/// Regularizes ORDER BY clause for window definition for implicit corner cases. +pub fn regularize_window_order_by( + frame: &WindowFrame, + order_by: &mut Vec, +) -> Result<()> { + if frame.units == WindowFrameUnits::Range && order_by.len() != 1 { // Normally, RANGE frames require an ORDER BY clause with exactly one - // column. However, an ORDER BY clause may be absent in two edge cases. + // column. However, an ORDER BY clause may be absent or present but with + // more than one column in two edge cases: + // 1. start bound is UNBOUNDED or CURRENT ROW + // 2. end bound is CURRENT ROW or UNBOUNDED. + // In these cases, we regularize the ORDER BY clause if the ORDER BY clause + // is absent. If an ORDER BY clause is present but has more than one column, + // the ORDER BY clause is unchanged. Note that this follows Postgres behavior. if (frame.start_bound.is_unbounded() || frame.start_bound == WindowFrameBound::CurrentRow) && (frame.end_bound == WindowFrameBound::CurrentRow || frame.end_bound.is_unbounded()) { - if order_bys == 0 { - frame.units = WindowFrameUnits::Rows; - frame.start_bound = - WindowFrameBound::Preceding(ScalarValue::UInt64(None)); - frame.end_bound = WindowFrameBound::Following(ScalarValue::UInt64(None)); + // If an ORDER BY clause is absent, it is equivalent to a ORDER BY clause + // with constant value as sort key. + // If an ORDER BY clause is present but has more than one column, it is + // unchanged. + if order_by.is_empty() { + order_by.push(Expr::Sort(Sort::new( + Box::new(Expr::Literal(ScalarValue::UInt64(Some(1)))), + true, + false, + ))); } - } else { + } + } + Ok(()) +} + +/// Checks if given window frame is valid. In particular, if the frame is RANGE +/// with offset PRECEDING/FOLLOWING, it must have exactly one ORDER BY column. +pub fn check_window_frame(frame: &WindowFrame, order_bys: usize) -> Result<()> { + if frame.units == WindowFrameUnits::Range && order_bys != 1 { + // See `regularize_window_order_by`. + if !(frame.start_bound.is_unbounded() + || frame.start_bound == WindowFrameBound::CurrentRow) + || !(frame.end_bound == WindowFrameBound::CurrentRow + || frame.end_bound.is_unbounded()) + { plan_err!("RANGE requires exactly one ORDER BY column")? } } else if frame.units == WindowFrameUnits::Groups && order_bys == 0 { plan_err!("GROUPS requires an ORDER BY clause")? }; - Ok(frame) + Ok(()) } /// There are five ways to describe starting and ending frame boundaries: diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs deleted file mode 100644 index e5b00c8f298be..0000000000000 --- a/datafusion/expr/src/window_function.rs +++ /dev/null @@ -1,450 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Window functions provide the ability to perform calculations across -//! sets of rows that are related to the current query row. -//! -//! see also - -use crate::aggregate_function::AggregateFunction; -use crate::type_coercion::functions::data_types; -use crate::utils; -use crate::{AggregateUDF, Signature, TypeSignature, Volatility, WindowUDF}; -use arrow::datatypes::DataType; -use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; -use std::sync::Arc; -use std::{fmt, str::FromStr}; -use strum_macros::EnumIter; - -/// WindowFunction -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum WindowFunction { - /// A built in aggregate function that leverages an aggregate function - AggregateFunction(AggregateFunction), - /// A a built-in window function - BuiltInWindowFunction(BuiltInWindowFunction), - /// A user defined aggregate function - AggregateUDF(Arc), - /// A user defined aggregate function - WindowUDF(Arc), -} - -/// Find DataFusion's built-in window function by name. -pub fn find_df_window_func(name: &str) -> Option { - let name = name.to_lowercase(); - // Code paths for window functions leveraging ordinary aggregators and - // built-in window functions are quite different, and the same function - // may have different implementations for these cases. If the sought - // function is not found among built-in window functions, we search for - // it among aggregate functions. - if let Ok(built_in_function) = BuiltInWindowFunction::from_str(name.as_str()) { - Some(WindowFunction::BuiltInWindowFunction(built_in_function)) - } else if let Ok(aggregate) = AggregateFunction::from_str(name.as_str()) { - Some(WindowFunction::AggregateFunction(aggregate)) - } else { - None - } -} - -impl fmt::Display for BuiltInWindowFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.name()) - } -} - -impl fmt::Display for WindowFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - WindowFunction::AggregateFunction(fun) => fun.fmt(f), - WindowFunction::BuiltInWindowFunction(fun) => fun.fmt(f), - WindowFunction::AggregateUDF(fun) => std::fmt::Debug::fmt(fun, f), - WindowFunction::WindowUDF(fun) => fun.fmt(f), - } - } -} - -/// A [window function] built in to DataFusion -/// -/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL) -#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter)] -pub enum BuiltInWindowFunction { - /// number of the current row within its partition, counting from 1 - RowNumber, - /// rank of the current row with gaps; same as row_number of its first peer - Rank, - /// rank of the current row without gaps; this function counts peer groups - DenseRank, - /// relative rank of the current row: (rank - 1) / (total rows - 1) - PercentRank, - /// relative rank of the current row: (number of rows preceding or peer with current row) / (total rows) - CumeDist, - /// integer ranging from 1 to the argument value, dividing the partition as equally as possible - Ntile, - /// returns value evaluated at the row that is offset rows before the current row within the partition; - /// if there is no such row, instead return default (which must be of the same type as value). - /// Both offset and default are evaluated with respect to the current row. - /// If omitted, offset defaults to 1 and default to null - Lag, - /// returns value evaluated at the row that is offset rows after the current row within the partition; - /// if there is no such row, instead return default (which must be of the same type as value). - /// Both offset and default are evaluated with respect to the current row. - /// If omitted, offset defaults to 1 and default to null - Lead, - /// returns value evaluated at the row that is the first row of the window frame - FirstValue, - /// returns value evaluated at the row that is the last row of the window frame - LastValue, - /// returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row - NthValue, -} - -impl BuiltInWindowFunction { - fn name(&self) -> &str { - use BuiltInWindowFunction::*; - match self { - RowNumber => "ROW_NUMBER", - Rank => "RANK", - DenseRank => "DENSE_RANK", - PercentRank => "PERCENT_RANK", - CumeDist => "CUME_DIST", - Ntile => "NTILE", - Lag => "LAG", - Lead => "LEAD", - FirstValue => "FIRST_VALUE", - LastValue => "LAST_VALUE", - NthValue => "NTH_VALUE", - } - } -} - -impl FromStr for BuiltInWindowFunction { - type Err = DataFusionError; - fn from_str(name: &str) -> Result { - Ok(match name.to_uppercase().as_str() { - "ROW_NUMBER" => BuiltInWindowFunction::RowNumber, - "RANK" => BuiltInWindowFunction::Rank, - "DENSE_RANK" => BuiltInWindowFunction::DenseRank, - "PERCENT_RANK" => BuiltInWindowFunction::PercentRank, - "CUME_DIST" => BuiltInWindowFunction::CumeDist, - "NTILE" => BuiltInWindowFunction::Ntile, - "LAG" => BuiltInWindowFunction::Lag, - "LEAD" => BuiltInWindowFunction::Lead, - "FIRST_VALUE" => BuiltInWindowFunction::FirstValue, - "LAST_VALUE" => BuiltInWindowFunction::LastValue, - "NTH_VALUE" => BuiltInWindowFunction::NthValue, - _ => return plan_err!("There is no built-in window function named {name}"), - }) - } -} - -/// Returns the datatype of the window function -#[deprecated( - since = "27.0.0", - note = "please use `WindowFunction::return_type` instead" -)] -pub fn return_type( - fun: &WindowFunction, - input_expr_types: &[DataType], -) -> Result { - fun.return_type(input_expr_types) -} - -impl WindowFunction { - /// Returns the datatype of the window function - pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { - match self { - WindowFunction::AggregateFunction(fun) => fun.return_type(input_expr_types), - WindowFunction::BuiltInWindowFunction(fun) => { - fun.return_type(input_expr_types) - } - WindowFunction::AggregateUDF(fun) => { - Ok((*(fun.return_type)(input_expr_types)?).clone()) - } - WindowFunction::WindowUDF(fun) => { - Ok((*(fun.return_type)(input_expr_types)?).clone()) - } - } - } -} - -/// Returns the datatype of the built-in window function -impl BuiltInWindowFunction { - pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { - // Note that this function *must* return the same type that the respective physical expression returns - // or the execution panics. - - // verify that this is a valid set of data types for this function - data_types(input_expr_types, &self.signature()) - // original errors are all related to wrong function signature - // aggregate them for better error message - .map_err(|_| { - plan_datafusion_err!( - "{}", - utils::generate_signature_error_msg( - &format!("{self}"), - self.signature(), - input_expr_types, - ) - ) - })?; - - match self { - BuiltInWindowFunction::RowNumber - | BuiltInWindowFunction::Rank - | BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64), - BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => { - Ok(DataType::Float64) - } - BuiltInWindowFunction::Ntile => Ok(DataType::UInt32), - BuiltInWindowFunction::Lag - | BuiltInWindowFunction::Lead - | BuiltInWindowFunction::FirstValue - | BuiltInWindowFunction::LastValue - | BuiltInWindowFunction::NthValue => Ok(input_expr_types[0].clone()), - } - } -} - -/// the signatures supported by the function `fun`. -#[deprecated( - since = "27.0.0", - note = "please use `WindowFunction::signature` instead" -)] -pub fn signature(fun: &WindowFunction) -> Signature { - fun.signature() -} - -impl WindowFunction { - /// the signatures supported by the function `fun`. - pub fn signature(&self) -> Signature { - match self { - WindowFunction::AggregateFunction(fun) => fun.signature(), - WindowFunction::BuiltInWindowFunction(fun) => fun.signature(), - WindowFunction::AggregateUDF(fun) => fun.signature.clone(), - WindowFunction::WindowUDF(fun) => fun.signature.clone(), - } - } -} - -/// the signatures supported by the built-in window function `fun`. -#[deprecated( - since = "27.0.0", - note = "please use `BuiltInWindowFunction::signature` instead" -)] -pub fn signature_for_built_in(fun: &BuiltInWindowFunction) -> Signature { - fun.signature() -} - -impl BuiltInWindowFunction { - /// the signatures supported by the built-in window function `fun`. - pub fn signature(&self) -> Signature { - // note: the physical expression must accept the type returned by this function or the execution panics. - match self { - BuiltInWindowFunction::RowNumber - | BuiltInWindowFunction::Rank - | BuiltInWindowFunction::DenseRank - | BuiltInWindowFunction::PercentRank - | BuiltInWindowFunction::CumeDist => Signature::any(0, Volatility::Immutable), - BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => { - Signature::one_of( - vec![ - TypeSignature::Any(1), - TypeSignature::Any(2), - TypeSignature::Any(3), - ], - Volatility::Immutable, - ) - } - BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => { - Signature::any(1, Volatility::Immutable) - } - BuiltInWindowFunction::Ntile => Signature::any(1, Volatility::Immutable), - BuiltInWindowFunction::NthValue => Signature::any(2, Volatility::Immutable), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_count_return_type() -> Result<()> { - let fun = find_df_window_func("count").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Int64, observed); - - let observed = fun.return_type(&[DataType::UInt64])?; - assert_eq!(DataType::Int64, observed); - - Ok(()) - } - - #[test] - fn test_first_value_return_type() -> Result<()> { - let fun = find_df_window_func("first_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::UInt64])?; - assert_eq!(DataType::UInt64, observed); - - Ok(()) - } - - #[test] - fn test_last_value_return_type() -> Result<()> { - let fun = find_df_window_func("last_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_lead_return_type() -> Result<()> { - let fun = find_df_window_func("lead").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_lag_return_type() -> Result<()> { - let fun = find_df_window_func("lag").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_nth_value_return_type() -> Result<()> { - let fun = find_df_window_func("nth_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8, DataType::UInt64])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64, DataType::UInt64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_percent_rank_return_type() -> Result<()> { - let fun = find_df_window_func("percent_rank").unwrap(); - let observed = fun.return_type(&[])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_cume_dist_return_type() -> Result<()> { - let fun = find_df_window_func("cume_dist").unwrap(); - let observed = fun.return_type(&[])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_window_function_case_insensitive() -> Result<()> { - let names = vec![ - "row_number", - "rank", - "dense_rank", - "percent_rank", - "cume_dist", - "ntile", - "lag", - "lead", - "first_value", - "last_value", - "nth_value", - "min", - "max", - "count", - "avg", - "sum", - ]; - for name in names { - let fun = find_df_window_func(name).unwrap(); - let fun2 = find_df_window_func(name.to_uppercase().as_str()).unwrap(); - assert_eq!(fun, fun2); - assert_eq!(fun.to_string(), name.to_uppercase()); - } - Ok(()) - } - - #[test] - fn test_find_df_window_function() { - assert_eq!( - find_df_window_func("max"), - Some(WindowFunction::AggregateFunction(AggregateFunction::Max)) - ); - assert_eq!( - find_df_window_func("min"), - Some(WindowFunction::AggregateFunction(AggregateFunction::Min)) - ); - assert_eq!( - find_df_window_func("avg"), - Some(WindowFunction::AggregateFunction(AggregateFunction::Avg)) - ); - assert_eq!( - find_df_window_func("cume_dist"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::CumeDist - )) - ); - assert_eq!( - find_df_window_func("first_value"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::FirstValue - )) - ); - assert_eq!( - find_df_window_func("LAST_value"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::LastValue - )) - ); - assert_eq!( - find_df_window_func("LAG"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::Lag - )) - ); - assert_eq!( - find_df_window_func("LEAD"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::Lead - )) - ); - assert_eq!(find_df_window_func("not_exist"), None) - } -} diff --git a/datafusion/expr/src/window_state.rs b/datafusion/expr/src/window_state.rs index 4ea9ecea5fc62..de88396d9b0e7 100644 --- a/datafusion/expr/src/window_state.rs +++ b/datafusion/expr/src/window_state.rs @@ -98,7 +98,7 @@ impl WindowAggState { } pub fn new(out_type: &DataType) -> Result { - let empty_out_col = ScalarValue::try_from(out_type)?.to_array_of_size(0); + let empty_out_col = ScalarValue::try_from(out_type)?.to_array_of_size(0)?; Ok(Self { window_frame_range: Range { start: 0, end: 0 }, window_frame_ctx: None, diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index fac880867fefd..b350d41d3fe38 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -44,7 +44,7 @@ async-trait = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true } datafusion-expr = { workspace = true } -datafusion-physical-expr = { path = "../physical-expr", version = "33.0.0", default-features = false } +datafusion-physical-expr = { path = "../physical-expr", version = "34.0.0", default-features = false } hashbrown = { version = "0.14", features = ["raw"] } itertools = { workspace = true } log = { workspace = true } @@ -52,5 +52,5 @@ regex-syntax = "0.8.0" [dev-dependencies] ctor = { workspace = true } -datafusion-sql = { path = "../sql", version = "33.0.0" } +datafusion-sql = { path = "../sql", version = "34.0.0" } env_logger = "0.10.0" diff --git a/datafusion/optimizer/README.md b/datafusion/optimizer/README.md index b8e5b93e6692c..4f9e0fb98526f 100644 --- a/datafusion/optimizer/README.md +++ b/datafusion/optimizer/README.md @@ -153,7 +153,7 @@ Looking at the `EXPLAIN` output we can see that the optimizer has effectively re | logical_plan | Projection: Int64(3) AS Int64(1) + Int64(2) | | | EmptyRelation | | physical_plan | ProjectionExec: expr=[3 as Int64(1) + Int64(2)] | -| | EmptyExec: produce_one_row=true | +| | PlaceholderRowExec | | | | +---------------+-------------------------------------------------+ ``` @@ -318,7 +318,7 @@ In the following example, the `type_coercion` and `simplify_expressions` passes | logical_plan | Projection: Utf8("3.2") AS foo | | | EmptyRelation | | initial_physical_plan | ProjectionExec: expr=[3.2 as foo] | -| | EmptyExec: produce_one_row=true | +| | PlaceholderRowExec | | | | | physical_plan after aggregate_statistics | SAME TEXT AS ABOVE | | physical_plan after join_selection | SAME TEXT AS ABOVE | @@ -326,7 +326,7 @@ In the following example, the `type_coercion` and `simplify_expressions` passes | physical_plan after repartition | SAME TEXT AS ABOVE | | physical_plan after add_merge_exec | SAME TEXT AS ABOVE | | physical_plan | ProjectionExec: expr=[3.2 as foo] | -| | EmptyExec: produce_one_row=true | +| | PlaceholderRowExec | | | | +------------------------------------------------------------+---------------------------------------------------------------------------+ ``` diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 912ac069e0b64..953716713e41c 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -19,12 +19,12 @@ use crate::analyzer::AnalyzerRule; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::Result; -use datafusion_expr::expr::{AggregateFunction, InSubquery}; +use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition, InSubquery}; use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::Expr::ScalarSubquery; use datafusion_expr::{ - aggregate_function, expr, lit, window_function, Aggregate, Expr, Filter, LogicalPlan, + aggregate_function, expr, lit, Aggregate, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Projection, Sort, Subquery, }; use std::sync::Arc; @@ -121,7 +121,7 @@ impl TreeNodeRewriter for CountWildcardRewriter { let new_expr = match old_expr.clone() { Expr::WindowFunction(expr::WindowFunction { fun: - window_function::WindowFunction::AggregateFunction( + expr::WindowFunctionDefinition::AggregateFunction( aggregate_function::AggregateFunction::Count, ), args, @@ -129,32 +129,39 @@ impl TreeNodeRewriter for CountWildcardRewriter { order_by, window_frame, }) if args.len() == 1 => match args[0] { - Expr::Wildcard => Expr::WindowFunction(expr::WindowFunction { - fun: window_function::WindowFunction::AggregateFunction( - aggregate_function::AggregateFunction::Count, - ), - args: vec![lit(COUNT_STAR_EXPANSION)], - partition_by, - order_by, - window_frame, - }), + Expr::Wildcard { qualifier: None } => { + Expr::WindowFunction(expr::WindowFunction { + fun: expr::WindowFunctionDefinition::AggregateFunction( + aggregate_function::AggregateFunction::Count, + ), + args: vec![lit(COUNT_STAR_EXPANSION)], + partition_by, + order_by, + window_frame, + }) + } _ => old_expr, }, Expr::AggregateFunction(AggregateFunction { - fun: aggregate_function::AggregateFunction::Count, + func_def: + AggregateFunctionDefinition::BuiltIn( + aggregate_function::AggregateFunction::Count, + ), args, distinct, filter, order_by, }) if args.len() == 1 => match args[0] { - Expr::Wildcard => Expr::AggregateFunction(AggregateFunction { - fun: aggregate_function::AggregateFunction::Count, - args: vec![lit(COUNT_STAR_EXPANSION)], - distinct, - filter, - order_by, - }), + Expr::Wildcard { qualifier: None } => { + Expr::AggregateFunction(AggregateFunction::new( + aggregate_function::AggregateFunction::Count, + vec![lit(COUNT_STAR_EXPANSION)], + distinct, + filter, + order_by, + )) + } _ => old_expr, }, @@ -221,8 +228,8 @@ mod tests { use datafusion_expr::expr::Sort; use datafusion_expr::{ col, count, exists, expr, in_subquery, lit, logical_plan::LogicalPlanBuilder, - max, out_ref_col, scalar_subquery, AggregateFunction, Expr, WindowFrame, - WindowFrameBound, WindowFrameUnits, WindowFunction, + max, out_ref_col, scalar_subquery, wildcard, AggregateFunction, Expr, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { @@ -237,9 +244,9 @@ mod tests { fn test_count_wildcard_on_sort() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(vec![col("b")], vec![count(Expr::Wildcard)])? - .project(vec![count(Expr::Wildcard)])? - .sort(vec![count(Expr::Wildcard).sort(true, false)])? + .aggregate(vec![col("b")], vec![count(wildcard())])? + .project(vec![count(wildcard())])? + .sort(vec![count(wildcard()).sort(true, false)])? .build()?; let expected = "Sort: COUNT(*) ASC NULLS LAST [COUNT(*):Int64;N]\ \n Projection: COUNT(*) [COUNT(*):Int64;N]\ @@ -258,8 +265,8 @@ mod tests { col("a"), Arc::new( LogicalPlanBuilder::from(table_scan_t2) - .aggregate(Vec::::new(), vec![count(Expr::Wildcard)])? - .project(vec![count(Expr::Wildcard)])? + .aggregate(Vec::::new(), vec![count(wildcard())])? + .project(vec![count(wildcard())])? .build()?, ), ))? @@ -282,8 +289,8 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan_t1) .filter(exists(Arc::new( LogicalPlanBuilder::from(table_scan_t2) - .aggregate(Vec::::new(), vec![count(Expr::Wildcard)])? - .project(vec![count(Expr::Wildcard)])? + .aggregate(Vec::::new(), vec![count(wildcard())])? + .project(vec![count(wildcard())])? .build()?, )))? .build()?; @@ -335,8 +342,8 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .window(vec![Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Count), - vec![Expr::Wildcard], + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + vec![wildcard()], vec![], vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], WindowFrame { @@ -347,7 +354,7 @@ mod tests { end_bound: WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), }, ))])? - .project(vec![count(Expr::Wildcard)])? + .project(vec![count(wildcard())])? .build()?; let expected = "Projection: COUNT(UInt8(1)) AS COUNT(*) [COUNT(*):Int64;N]\ @@ -360,8 +367,8 @@ mod tests { fn test_count_wildcard_on_aggregate() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(Vec::::new(), vec![count(Expr::Wildcard)])? - .project(vec![count(Expr::Wildcard)])? + .aggregate(Vec::::new(), vec![count(wildcard())])? + .project(vec![count(wildcard())])? .build()?; let expected = "Projection: COUNT(*) [COUNT(*):Int64;N]\ @@ -374,8 +381,8 @@ mod tests { fn test_count_wildcard_on_nesting() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(Vec::::new(), vec![max(count(Expr::Wildcard))])? - .project(vec![count(Expr::Wildcard)])? + .aggregate(Vec::::new(), vec![max(count(wildcard()))])? + .project(vec![count(wildcard())])? .build()?; let expected = "Projection: COUNT(UInt8(1)) AS COUNT(*) [COUNT(*):Int64;N]\ diff --git a/datafusion/optimizer/src/analyzer/inline_table_scan.rs b/datafusion/optimizer/src/analyzer/inline_table_scan.rs index 3d0dabdd377ce..90af7aec82935 100644 --- a/datafusion/optimizer/src/analyzer/inline_table_scan.rs +++ b/datafusion/optimizer/src/analyzer/inline_table_scan.rs @@ -126,7 +126,7 @@ fn generate_projection_expr( )); } } else { - exprs.push(Expr::Wildcard); + exprs.push(Expr::Wildcard { qualifier: None }); } Ok(exprs) } diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index 6b8b1020cd6d8..7c5b70b19af0a 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -16,10 +16,11 @@ // under the License. use crate::analyzer::check_plan; -use crate::utils::{collect_subquery_cols, split_conjunction}; +use crate::utils::collect_subquery_cols; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::expr_rewriter::strip_outer_reference; +use datafusion_expr::utils::split_conjunction; use datafusion_expr::{ Aggregate, BinaryExpr, Cast, Expr, Filter, Join, JoinType, LogicalPlan, Operator, Window, diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index bfdbec390199c..4d54dad996703 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -28,8 +28,8 @@ use datafusion_common::{ DataFusionError, Result, ScalarValue, }; use datafusion_expr::expr::{ - self, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, ScalarFunction, - ScalarUDF, WindowFunction, + self, AggregateFunctionDefinition, Between, BinaryExpr, Case, Exists, InList, + InSubquery, Like, ScalarFunction, WindowFunction, }; use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::expr_schema::cast_subquery; @@ -42,15 +42,15 @@ use datafusion_expr::type_coercion::other::{ get_coerce_type_for_case_expression, get_coerce_type_for_list, }; use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8}; +use datafusion_expr::utils::merge_schema; use datafusion_expr::{ is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, - type_coercion, window_function, AggregateFunction, BuiltinScalarFunction, Expr, - LogicalPlan, Operator, Projection, WindowFrame, WindowFrameBound, WindowFrameUnits, + type_coercion, AggregateFunction, BuiltinScalarFunction, Expr, ExprSchemable, + LogicalPlan, Operator, Projection, ScalarFunctionDefinition, Signature, WindowFrame, + WindowFrameBound, WindowFrameUnits, }; -use datafusion_expr::{ExprSchemable, Signature}; use crate::analyzer::AnalyzerRule; -use crate::utils::merge_schema; #[derive(Default)] pub struct TypeCoercion {} @@ -319,58 +319,66 @@ impl TreeNodeRewriter for TypeCoercionRewriter { let case = coerce_case_expression(case, &self.schema)?; Ok(Expr::Case(case)) } - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - let new_expr = coerce_arguments_for_signature( - args.as_slice(), - &self.schema, - &fun.signature, - )?; - Ok(Expr::ScalarUDF(ScalarUDF::new(fun, new_expr))) - } - Expr::ScalarFunction(ScalarFunction { fun, args }) => { - let new_args = coerce_arguments_for_signature( - args.as_slice(), - &self.schema, - &fun.signature(), - )?; - let new_args = - coerce_arguments_for_fun(new_args.as_slice(), &self.schema, &fun)?; - Ok(Expr::ScalarFunction(ScalarFunction::new(fun, new_args))) - } + Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + let new_args = coerce_arguments_for_signature( + args.as_slice(), + &self.schema, + &fun.signature(), + )?; + let new_args = coerce_arguments_for_fun( + new_args.as_slice(), + &self.schema, + &fun, + )?; + Ok(Expr::ScalarFunction(ScalarFunction::new(fun, new_args))) + } + ScalarFunctionDefinition::UDF(fun) => { + let new_expr = coerce_arguments_for_signature( + args.as_slice(), + &self.schema, + fun.signature(), + )?; + Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, new_expr))) + } + ScalarFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } + }, Expr::AggregateFunction(expr::AggregateFunction { - fun, + func_def, args, distinct, filter, order_by, - }) => { - let new_expr = coerce_agg_exprs_for_signature( - &fun, - &args, - &self.schema, - &fun.signature(), - )?; - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - fun, new_expr, distinct, filter, order_by, - )); - Ok(expr) - } - Expr::AggregateUDF(expr::AggregateUDF { - fun, - args, - filter, - order_by, - }) => { - let new_expr = coerce_arguments_for_signature( - args.as_slice(), - &self.schema, - &fun.signature, - )?; - let expr = Expr::AggregateUDF(expr::AggregateUDF::new( - fun, new_expr, filter, order_by, - )); - Ok(expr) - } + }) => match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + let new_expr = coerce_agg_exprs_for_signature( + &fun, + &args, + &self.schema, + &fun.signature(), + )?; + let expr = Expr::AggregateFunction(expr::AggregateFunction::new( + fun, new_expr, distinct, filter, order_by, + )); + Ok(expr) + } + AggregateFunctionDefinition::UDF(fun) => { + let new_expr = coerce_arguments_for_signature( + args.as_slice(), + &self.schema, + fun.signature(), + )?; + let expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + fun, new_expr, false, filter, order_by, + )); + Ok(expr) + } + AggregateFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } + }, Expr::WindowFunction(WindowFunction { fun, args, @@ -382,7 +390,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { coerce_window_frame(window_frame, &self.schema, &order_by)?; let args = match &fun { - window_function::WindowFunction::AggregateFunction(fun) => { + expr::WindowFunctionDefinition::AggregateFunction(fun) => { coerce_agg_exprs_for_signature( fun, &args, @@ -495,7 +503,10 @@ fn coerce_window_frame( let target_type = match window_frame.units { WindowFrameUnits::Range => { if let Some(col_type) = current_types.first() { - if col_type.is_numeric() || is_utf8_or_large_utf8(col_type) { + if col_type.is_numeric() + || is_utf8_or_large_utf8(col_type) + || matches!(col_type, DataType::Null) + { col_type } else if is_datetime(col_type) { &DataType::Interval(IntervalUnit::MonthDayNano) @@ -579,26 +590,6 @@ fn coerce_arguments_for_fun( .collect::>>()?; } - if *fun == BuiltinScalarFunction::MakeArray { - // Find the final data type for the function arguments - let current_types = expressions - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - - let new_type = current_types - .iter() - .skip(1) - .fold(current_types.first().unwrap().clone(), |acc, x| { - comparison_coercion(&acc, x).unwrap_or(acc) - }); - - return expressions - .iter() - .zip(current_types) - .map(|(expr, from_type)| cast_array_expr(expr, &from_type, &new_type, schema)) - .collect(); - } Ok(expressions) } @@ -607,20 +598,6 @@ fn cast_expr(expr: &Expr, to_type: &DataType, schema: &DFSchema) -> Result expr.clone().cast_to(to_type, schema) } -/// Cast array `expr` to the specified type, if possible -fn cast_array_expr( - expr: &Expr, - from_type: &DataType, - to_type: &DataType, - schema: &DFSchema, -) -> Result { - if from_type.equals_datatype(&DataType::Null) { - Ok(expr.clone()) - } else { - cast_expr(expr, to_type, schema) - } -} - /// Returns the coerced exprs for each `input_exprs`. /// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the /// data type of `input_exprs` need to be coerced. @@ -761,8 +738,10 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result { #[cfg(test)] mod test { - use std::sync::Arc; + use std::any::Any; + use std::sync::{Arc, OnceLock}; + use arrow::array::{FixedSizeListArray, Int32Array}; use arrow::datatypes::{DataType, TimeUnit}; use arrow::datatypes::Field; @@ -772,13 +751,13 @@ mod test { use datafusion_expr::{ cast, col, concat, concat_ws, create_udaf, is_true, AccumulatorFactoryFunction, AggregateFunction, AggregateUDF, BinaryExpr, BuiltinScalarFunction, Case, - ColumnarValue, ExprSchemable, Filter, Operator, StateTypeFunction, Subquery, + ColumnarValue, ExprSchemable, Filter, Operator, ScalarUDFImpl, StateTypeFunction, + Subquery, }; use datafusion_expr::{ lit, logical_plan::{EmptyRelation, Projection}, - Expr, LogicalPlan, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUDF, - Signature, Volatility, + Expr, LogicalPlan, ReturnTypeFunction, ScalarUDF, Signature, Volatility, }; use datafusion_physical_expr::expressions::AvgAccumulator; @@ -830,22 +809,36 @@ mod test { assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected) } + static TEST_SIGNATURE: OnceLock = OnceLock::new(); + + struct TestScalarUDF {} + impl ScalarUDFImpl for TestScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "TestScalarUDF" + } + fn signature(&self) -> &Signature { + TEST_SIGNATURE.get_or_init(|| { + Signature::uniform(1, vec![DataType::Float32], Volatility::Stable) + }) + } + fn return_type(&self, _args: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + Ok(ColumnarValue::Scalar(ScalarValue::from("a"))) + } + } + #[test] fn scalar_udf() -> Result<()> { let empty = empty(); - let return_type: ReturnTypeFunction = - Arc::new(move |_| Ok(Arc::new(DataType::Utf8))); - let fun: ScalarFunctionImplementation = - Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a")))); - let udf = Expr::ScalarUDF(expr::ScalarUDF::new( - Arc::new(ScalarUDF::new( - "TestScalarUDF", - &Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), - &return_type, - &fun, - )), - vec![lit(123_i32)], - )); + + let udf = ScalarUDF::from(TestScalarUDF {}).call(vec![lit(123_i32)]); let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?); let expected = "Projection: TestScalarUDF(CAST(Int32(123) AS Float32))\n EmptyRelation"; @@ -855,26 +848,15 @@ mod test { #[test] fn scalar_udf_invalid_input() -> Result<()> { let empty = empty(); - let return_type: ReturnTypeFunction = - Arc::new(move |_| Ok(Arc::new(DataType::Utf8))); - let fun: ScalarFunctionImplementation = Arc::new(move |_| unimplemented!()); - let udf = Expr::ScalarUDF(expr::ScalarUDF::new( - Arc::new(ScalarUDF::new( - "TestScalarUDF", - &Signature::uniform(1, vec![DataType::Int32], Volatility::Stable), - &return_type, - &fun, - )), - vec![lit("Apple")], - )); + let udf = ScalarUDF::from(TestScalarUDF {}).call(vec![lit("Apple")]); let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?); let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, "") .err() .unwrap(); assert_eq!( - "type_coercion\ncaused by\nError during planning: Coercion from [Utf8] to the signature Uniform(1, [Int32]) failed.", - err.strip_backtrace() - ); + "type_coercion\ncaused by\nError during planning: Coercion from [Utf8] to the signature Uniform(1, [Float32]) failed.", + err.strip_backtrace() + ); Ok(()) } @@ -905,9 +887,10 @@ mod test { Arc::new(|_| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); - let udaf = Expr::AggregateUDF(expr::AggregateUDF::new( + let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf( Arc::new(my_avg), vec![lit(10i64)], + false, None, None, )); @@ -932,9 +915,10 @@ mod test { &accumulator, &state_type, ); - let udaf = Expr::AggregateUDF(expr::AggregateUDF::new( + let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf( Arc::new(my_avg), vec![lit("10")], + false, None, None, )); @@ -1237,19 +1221,18 @@ mod test { #[test] fn test_casting_for_fixed_size_list() -> Result<()> { - let val = lit(ScalarValue::Fixedsizelist( - Some(vec![ - ScalarValue::from(1i32), - ScalarValue::from(2i32), - ScalarValue::from(3i32), - ]), - Arc::new(Field::new("item", DataType::Int32, true)), - 3, + let val = lit(ScalarValue::FixedSizeList(Arc::new( + FixedSizeListArray::new( + Arc::new(Field::new("item", DataType::Int32, true)), + 3, + Arc::new(Int32Array::from(vec![1, 2, 3])), + None, + ), + ))); + let expr = Expr::ScalarFunction(ScalarFunction::new( + BuiltinScalarFunction::MakeArray, + vec![val.clone()], )); - let expr = Expr::ScalarFunction(ScalarFunction { - fun: BuiltinScalarFunction::MakeArray, - args: vec![val.clone()], - }); let schema = Arc::new(DFSchema::new_with_metadata( vec![DFField::new_unqualified( "item", @@ -1278,10 +1261,10 @@ mod test { &schema, )?; - let expected = Expr::ScalarFunction(ScalarFunction { - fun: BuiltinScalarFunction::MakeArray, - args: vec![expected_casted_expr], - }); + let expected = Expr::ScalarFunction(ScalarFunction::new( + BuiltinScalarFunction::MakeArray, + vec![expected_casted_expr], + )); assert_eq!(result, expected); Ok(()) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 68a6a5607a1da..1e089257c61ad 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -29,7 +29,7 @@ use datafusion_common::tree_node::{ use datafusion_common::{ internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, }; -use datafusion_expr::expr::Alias; +use datafusion_expr::expr::{is_volatile, Alias}; use datafusion_expr::logical_plan::{ Aggregate, Filter, LogicalPlan, Projection, Sort, Window, }; @@ -113,6 +113,8 @@ impl CommonSubexprEliminate { let Projection { expr, input, .. } = projection; let input_schema = Arc::clone(input.schema()); let mut expr_set = ExprSet::new(); + + // Visit expr list and build expr identifier to occuring count map (`expr_set`). let arrays = to_arrays(expr, input_schema, &mut expr_set, ExprMask::Normal)?; let (mut new_expr, new_input) = @@ -238,6 +240,14 @@ impl CommonSubexprEliminate { let rewritten = pop_expr(&mut rewritten)?; if affected_id.is_empty() { + // Alias aggregation expressions if they have changed + let new_aggr_expr = new_aggr_expr + .iter() + .zip(aggr_expr.iter()) + .map(|(new_expr, old_expr)| { + new_expr.clone().alias_if_changed(old_expr.display_name()?) + }) + .collect::>>()?; // Since group_epxr changes, schema changes also. Use try_new method. Aggregate::try_new(Arc::new(new_input), new_group_expr, new_aggr_expr) .map(LogicalPlan::Aggregate) @@ -367,7 +377,7 @@ impl OptimizerRule for CommonSubexprEliminate { Ok(Some(build_recover_project_plan( &original_schema, optimized_plan, - ))) + )?)) } plan => Ok(plan), } @@ -458,16 +468,19 @@ fn build_common_expr_project_plan( /// the "intermediate" projection plan built in [build_common_expr_project_plan]. /// /// This is for those plans who don't keep its own output schema like `Filter` or `Sort`. -fn build_recover_project_plan(schema: &DFSchema, input: LogicalPlan) -> LogicalPlan { +fn build_recover_project_plan( + schema: &DFSchema, + input: LogicalPlan, +) -> Result { let col_exprs = schema .fields() .iter() .map(|field| Expr::Column(field.qualified_column())) .collect(); - LogicalPlan::Projection( - Projection::try_new(col_exprs, Arc::new(input)) - .expect("Cannot build projection plan from an invalid schema"), - ) + Ok(LogicalPlan::Projection(Projection::try_new( + col_exprs, + Arc::new(input), + )?)) } fn extract_expressions( @@ -498,15 +511,14 @@ enum ExprMask { /// - [`Sort`](Expr::Sort) /// - [`Wildcard`](Expr::Wildcard) /// - [`AggregateFunction`](Expr::AggregateFunction) - /// - [`AggregateUDF`](Expr::AggregateUDF) Normal, - /// Like [`Normal`](Self::Normal), but includes [`AggregateFunction`](Expr::AggregateFunction) and [`AggregateUDF`](Expr::AggregateUDF). + /// Like [`Normal`](Self::Normal), but includes [`AggregateFunction`](Expr::AggregateFunction). NormalAndAggregates, } impl ExprMask { - fn ignores(&self, expr: &Expr) -> bool { + fn ignores(&self, expr: &Expr) -> Result { let is_normal_minus_aggregates = matches!( expr, Expr::Literal(..) @@ -514,18 +526,17 @@ impl ExprMask { | Expr::ScalarVariable(..) | Expr::Alias(..) | Expr::Sort { .. } - | Expr::Wildcard + | Expr::Wildcard { .. } ); - let is_aggr = matches!( - expr, - Expr::AggregateFunction(..) | Expr::AggregateUDF { .. } - ); + let is_volatile = is_volatile(expr)?; - match self { - Self::Normal => is_normal_minus_aggregates || is_aggr, - Self::NormalAndAggregates => is_normal_minus_aggregates, - } + let is_aggr = matches!(expr, Expr::AggregateFunction(..)); + + Ok(match self { + Self::Normal => is_volatile || is_normal_minus_aggregates || is_aggr, + Self::NormalAndAggregates => is_volatile || is_normal_minus_aggregates, + }) } } @@ -617,7 +628,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { let (idx, sub_expr_desc) = self.pop_enter_mark(); // skip exprs should not be recognize. - if self.expr_mask.ignores(expr) { + if self.expr_mask.ignores(expr)? { self.id_array[idx].0 = self.series_number; let desc = Self::desc_expr(expr); self.visit_stack.push(VisitRecord::ExprItem(desc)); @@ -897,7 +908,7 @@ mod test { let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!()); let state_type: StateTypeFunction = Arc::new(|_| unimplemented!()); let udf_agg = |inner: Expr| { - Expr::AggregateUDF(datafusion_expr::expr::AggregateUDF::new( + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( Arc::new(AggregateUDF::new( "my_agg", &Signature::exact(vec![DataType::UInt32], Volatility::Stable), @@ -906,6 +917,7 @@ mod test { &state_type, )), vec![inner], + false, None, None, )) diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index b5cf737338969..b1000f042c987 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -16,15 +16,14 @@ // under the License. use crate::simplify_expressions::{ExprSimplifier, SimplifyContext}; -use crate::utils::{ - collect_subquery_cols, conjunction, find_join_exprs, split_conjunction, -}; +use crate::utils::collect_subquery_cols; use datafusion_common::tree_node::{ RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter, }; use datafusion_common::{plan_err, Result}; use datafusion_common::{Column, DFSchemaRef, DataFusionError, ScalarValue}; -use datafusion_expr::expr::Alias; +use datafusion_expr::expr::{AggregateFunctionDefinition, Alias}; +use datafusion_expr::utils::{conjunction, find_join_exprs, split_conjunction}; use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; use datafusion_physical_expr::execution_props::ExecutionProps; use std::collections::{BTreeSet, HashMap}; @@ -227,10 +226,9 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { )?; if !expr_result_map_for_count_bug.is_empty() { // has count bug - let un_matched_row = Expr::Alias(Alias::new( - Expr::Literal(ScalarValue::Boolean(Some(true))), - UN_MATCHED_ROW_INDICATOR.to_string(), - )); + let un_matched_row = + Expr::Literal(ScalarValue::Boolean(Some(true))) + .alias(UN_MATCHED_ROW_INDICATOR); // add the unmatched rows indicator to the Aggregation's group expressions missing_exprs.push(un_matched_row); } @@ -374,16 +372,25 @@ fn agg_exprs_evaluation_result_on_empty_batch( for e in agg_expr.iter() { let result_expr = e.clone().transform_up(&|expr| { let new_expr = match expr { - Expr::AggregateFunction(expr::AggregateFunction { fun, .. }) => { - if matches!(fun, datafusion_expr::AggregateFunction::Count) { - Transformed::Yes(Expr::Literal(ScalarValue::Int64(Some(0)))) - } else { - Transformed::Yes(Expr::Literal(ScalarValue::Null)) + Expr::AggregateFunction(expr::AggregateFunction { func_def, .. }) => { + match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + if matches!(fun, datafusion_expr::AggregateFunction::Count) { + Transformed::Yes(Expr::Literal(ScalarValue::Int64(Some( + 0, + )))) + } else { + Transformed::Yes(Expr::Literal(ScalarValue::Null)) + } + } + AggregateFunctionDefinition::UDF { .. } => { + Transformed::Yes(Expr::Literal(ScalarValue::Null)) + } + AggregateFunctionDefinition::Name(_) => { + Transformed::Yes(Expr::Literal(ScalarValue::Null)) + } } } - Expr::AggregateUDF(_) => { - Transformed::Yes(Expr::Literal(ScalarValue::Null)) - } _ => Transformed::No(expr), }; Ok(new_expr) diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index 96b46663d8e47..450336376a239 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -17,7 +17,7 @@ use crate::decorrelate::PullUpCorrelatedExpr; use crate::optimizer::ApplyOrder; -use crate::utils::{conjunction, replace_qualified_name, split_conjunction}; +use crate::utils::replace_qualified_name; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::TreeNode; @@ -25,6 +25,7 @@ use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; +use datafusion_expr::utils::{conjunction, split_conjunction}; use datafusion_expr::{ exists, in_subquery, not_exists, not_in_subquery, BinaryExpr, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Operator, diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index cf9a59d6b892f..d9e96a9f2543a 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -20,6 +20,7 @@ use std::collections::HashSet; use std::sync::Arc; use crate::{utils, OptimizerConfig, OptimizerRule}; + use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::expr::{BinaryExpr, Expr}; use datafusion_expr::logical_plan::{ @@ -44,84 +45,97 @@ impl EliminateCrossJoin { /// 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and b.xx = 200);' /// 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z) /// or (a.x = b.y and b.xx = 200 and a.z=c.z);' +/// 'select ... from a, b where a.x > b.y' /// For above queries, the join predicate is available in filters and they are moved to /// join nodes appropriately /// This fix helps to improve the performance of TPCH Q19. issue#78 -/// impl OptimizerRule for EliminateCrossJoin { fn try_optimize( &self, plan: &LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { - match plan { + let mut possible_join_keys: Vec<(Expr, Expr)> = vec![]; + let mut all_inputs: Vec = vec![]; + let parent_predicate = match plan { LogicalPlan::Filter(filter) => { - let input = filter.input.as_ref().clone(); - - let mut possible_join_keys: Vec<(Expr, Expr)> = vec![]; - let mut all_inputs: Vec = vec![]; - let did_flat_successfully = match &input { + let input = filter.input.as_ref(); + match input { LogicalPlan::Join(Join { join_type: JoinType::Inner, .. }) - | LogicalPlan::CrossJoin(_) => try_flatten_join_inputs( - &input, - &mut possible_join_keys, - &mut all_inputs, - )?, + | LogicalPlan::CrossJoin(_) => { + if !try_flatten_join_inputs( + input, + &mut possible_join_keys, + &mut all_inputs, + )? { + return Ok(None); + } + extract_possible_join_keys( + &filter.predicate, + &mut possible_join_keys, + )?; + Some(&filter.predicate) + } _ => { return utils::optimize_children(self, plan, config); } - }; - - if !did_flat_successfully { + } + } + LogicalPlan::Join(Join { + join_type: JoinType::Inner, + .. + }) => { + if !try_flatten_join_inputs( + plan, + &mut possible_join_keys, + &mut all_inputs, + )? { return Ok(None); } + None + } + _ => return utils::optimize_children(self, plan, config), + }; - let predicate = &filter.predicate; - // join keys are handled locally - let mut all_join_keys: HashSet<(Expr, Expr)> = HashSet::new(); - - extract_possible_join_keys(predicate, &mut possible_join_keys)?; + // Join keys are handled locally: + let mut all_join_keys = HashSet::<(Expr, Expr)>::new(); + let mut left = all_inputs.remove(0); + while !all_inputs.is_empty() { + left = find_inner_join( + &left, + &mut all_inputs, + &mut possible_join_keys, + &mut all_join_keys, + )?; + } - let mut left = all_inputs.remove(0); - while !all_inputs.is_empty() { - left = find_inner_join( - &left, - &mut all_inputs, - &mut possible_join_keys, - &mut all_join_keys, - )?; - } + left = utils::optimize_children(self, &left, config)?.unwrap_or(left); - left = utils::optimize_children(self, &left, config)?.unwrap_or(left); + if plan.schema() != left.schema() { + left = LogicalPlan::Projection(Projection::new_from_schema( + Arc::new(left), + plan.schema().clone(), + )); + } - if plan.schema() != left.schema() { - left = LogicalPlan::Projection(Projection::new_from_schema( - Arc::new(left.clone()), - plan.schema().clone(), - )); - } + let Some(predicate) = parent_predicate else { + return Ok(Some(left)); + }; - // if there are no join keys then do nothing. - if all_join_keys.is_empty() { - Ok(Some(LogicalPlan::Filter(Filter::try_new( - predicate.clone(), - Arc::new(left), - )?))) - } else { - // remove join expressions from filter - match remove_join_expressions(predicate, &all_join_keys)? { - Some(filter_expr) => Ok(Some(LogicalPlan::Filter( - Filter::try_new(filter_expr, Arc::new(left))?, - ))), - _ => Ok(Some(left)), - } - } + // If there are no join keys then do nothing: + if all_join_keys.is_empty() { + Filter::try_new(predicate.clone(), Arc::new(left)) + .map(|f| Some(LogicalPlan::Filter(f))) + } else { + // Remove join expressions from filter: + match remove_join_expressions(predicate, &all_join_keys)? { + Some(filter_expr) => Filter::try_new(filter_expr, Arc::new(left)) + .map(|f| Some(LogicalPlan::Filter(f))), + _ => Ok(Some(left)), } - - _ => utils::optimize_children(self, plan, config), } } @@ -325,17 +339,16 @@ fn remove_join_expressions( #[cfg(test)] mod tests { + use super::*; + use crate::optimizer::OptimizerContext; + use crate::test::*; + use datafusion_expr::{ binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, Operator::{And, Or}, }; - use crate::optimizer::OptimizerContext; - use crate::test::*; - - use super::*; - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: Vec<&str>) { let rule = EliminateCrossJoin::new(); let optimized_plan = rule diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index c97906a81adf1..fea14342ca774 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Optimizer rule to replace `where false` on a plan with an empty relation. +//! Optimizer rule to replace `where false or null` on a plan with an empty relation. //! This saves time in planning and executing the query. //! Note that this rule should be applied after simplify expressions optimizer rule. use crate::optimizer::ApplyOrder; @@ -27,7 +27,7 @@ use datafusion_expr::{ use crate::{OptimizerConfig, OptimizerRule}; -/// Optimization rule that eliminate the scalar value (true/false) filter with an [LogicalPlan::EmptyRelation] +/// Optimization rule that eliminate the scalar value (true/false/null) filter with an [LogicalPlan::EmptyRelation] #[derive(Default)] pub struct EliminateFilter; @@ -46,20 +46,22 @@ impl OptimizerRule for EliminateFilter { ) -> Result> { match plan { LogicalPlan::Filter(Filter { - predicate: Expr::Literal(ScalarValue::Boolean(Some(v))), + predicate: Expr::Literal(ScalarValue::Boolean(v)), input, .. }) => { match *v { // input also can be filter, apply again - true => Ok(Some( + Some(true) => Ok(Some( self.try_optimize(input, _config)? .unwrap_or_else(|| input.as_ref().clone()), )), - false => Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: false, - schema: input.schema().clone(), - }))), + Some(false) | None => { + Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: input.schema().clone(), + }))) + } } } _ => Ok(None), @@ -105,6 +107,21 @@ mod tests { assert_optimized_plan_equal(&plan, expected) } + #[test] + fn filter_null() -> Result<()> { + let filter_expr = Expr::Literal(ScalarValue::Boolean(None)); + + let table_scan = test_table_scan().unwrap(); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("a")], vec![sum(col("b"))])? + .filter(filter_expr)? + .build()?; + + // No aggregate / scan / limit + let expected = "EmptyRelation"; + assert_optimized_plan_equal(&plan, expected) + } + #[test] fn filter_false_nested() -> Result<()> { let filter_expr = Expr::Literal(ScalarValue::Boolean(Some(false))); diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 7844ca7909fce..4386253740aaa 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -97,7 +97,7 @@ mod tests { let optimizer = Optimizer::with_rules(vec![Arc::new(EliminateLimit::new())]); let optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs index 89bcc90bc0752..5771ea2e19a29 100644 --- a/datafusion/optimizer/src/eliminate_nested_union.rs +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -52,7 +52,7 @@ impl OptimizerRule for EliminateNestedUnion { schema: schema.clone(), }))) } - LogicalPlan::Distinct(Distinct { input: plan }) => match plan.as_ref() { + LogicalPlan::Distinct(Distinct::All(plan)) => match plan.as_ref() { LogicalPlan::Union(Union { inputs, schema }) => { let inputs = inputs .iter() @@ -60,12 +60,12 @@ impl OptimizerRule for EliminateNestedUnion { .flat_map(extract_plans_from_union) .collect::>(); - Ok(Some(LogicalPlan::Distinct(Distinct { - input: Arc::new(LogicalPlan::Union(Union { + Ok(Some(LogicalPlan::Distinct(Distinct::All(Arc::new( + LogicalPlan::Union(Union { inputs, schema: schema.clone(), - })), - }))) + }), + ))))) } _ => Ok(None), }, @@ -94,7 +94,7 @@ fn extract_plans_from_union(plan: &Arc) -> Vec> { fn extract_plan_from_distinct(plan: &Arc) -> &Arc { match plan.as_ref() { - LogicalPlan::Distinct(Distinct { input: plan }) => plan, + LogicalPlan::Distinct(Distinct::All(plan)) => plan, _ => plan, } } diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index e4d57f0209a46..53c4b3702b1e7 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -106,7 +106,8 @@ impl OptimizerRule for EliminateOuterJoin { schema: join.schema.clone(), null_equals_null: join.null_equals_null, }); - let new_plan = plan.with_new_inputs(&[new_join])?; + let new_plan = + plan.with_new_exprs(plan.expressions(), &[new_join])?; Ok(Some(new_plan)) } _ => Ok(None), diff --git a/datafusion/optimizer/src/eliminate_project.rs b/datafusion/optimizer/src/eliminate_project.rs deleted file mode 100644 index d3226eaa78cf8..0000000000000 --- a/datafusion/optimizer/src/eliminate_project.rs +++ /dev/null @@ -1,94 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::optimizer::ApplyOrder; -use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{DFSchemaRef, Result}; -use datafusion_expr::logical_plan::LogicalPlan; -use datafusion_expr::{Expr, Projection}; - -/// Optimization rule that eliminate unnecessary [LogicalPlan::Projection]. -#[derive(Default)] -pub struct EliminateProjection; - -impl EliminateProjection { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -impl OptimizerRule for EliminateProjection { - fn try_optimize( - &self, - plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - match plan { - LogicalPlan::Projection(projection) => { - let child_plan = projection.input.as_ref(); - match child_plan { - LogicalPlan::Union(_) - | LogicalPlan::Filter(_) - | LogicalPlan::TableScan(_) - | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Sort(_) => { - if can_eliminate(projection, child_plan.schema()) { - Ok(Some(child_plan.clone())) - } else { - Ok(None) - } - } - _ => { - if plan.schema() == child_plan.schema() { - Ok(Some(child_plan.clone())) - } else { - Ok(None) - } - } - } - } - _ => Ok(None), - } - } - - fn name(&self) -> &str { - "eliminate_projection" - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } -} - -pub(crate) fn can_eliminate(projection: &Projection, schema: &DFSchemaRef) -> bool { - if projection.expr.len() != schema.fields().len() { - return false; - } - for (i, e) in projection.expr.iter().enumerate() { - match e { - Expr::Column(c) => { - let d = schema.fields().get(i).unwrap(); - if c != &d.qualified_column() && c != &d.unqualified_column() { - return false; - } - } - _ => return false, - } - } - true -} diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index 575969fbf73cf..24664d57c38d8 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -17,11 +17,10 @@ //! [`ExtractEquijoinPredicate`] rule that extracts equijoin predicates use crate::optimizer::ApplyOrder; -use crate::utils::split_conjunction; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::DFSchema; use datafusion_common::Result; -use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair}; +use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair, split_conjunction}; use datafusion_expr::{BinaryExpr, Expr, ExprSchemable, Join, LogicalPlan, Operator}; use std::sync::Arc; diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index ede0ac5c71643..b54facc5d6825 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -27,10 +27,9 @@ pub mod eliminate_limit; pub mod eliminate_nested_union; pub mod eliminate_one_union; pub mod eliminate_outer_join; -pub mod eliminate_project; pub mod extract_equijoin_predicate; pub mod filter_null_join_keys; -pub mod merge_projection; +pub mod optimize_projections; pub mod optimizer; pub mod propagate_empty_relation; pub mod push_down_filter; diff --git a/datafusion/optimizer/src/merge_projection.rs b/datafusion/optimizer/src/merge_projection.rs deleted file mode 100644 index ec040cba6fe4e..0000000000000 --- a/datafusion/optimizer/src/merge_projection.rs +++ /dev/null @@ -1,169 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::collections::HashMap; - -use crate::optimizer::ApplyOrder; -use crate::push_down_filter::replace_cols_by_name; -use crate::{OptimizerConfig, OptimizerRule}; - -use datafusion_common::Result; -use datafusion_expr::{Expr, LogicalPlan, Projection}; - -/// Optimization rule that merge [LogicalPlan::Projection]. -#[derive(Default)] -pub struct MergeProjection; - -impl MergeProjection { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -impl OptimizerRule for MergeProjection { - fn try_optimize( - &self, - plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - match plan { - LogicalPlan::Projection(parent_projection) => { - match parent_projection.input.as_ref() { - LogicalPlan::Projection(child_projection) => { - let new_plan = - merge_projection(parent_projection, child_projection)?; - Ok(Some( - self.try_optimize(&new_plan, _config)?.unwrap_or(new_plan), - )) - } - _ => Ok(None), - } - } - _ => Ok(None), - } - } - - fn name(&self) -> &str { - "merge_projection" - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } -} - -pub(super) fn merge_projection( - parent_projection: &Projection, - child_projection: &Projection, -) -> Result { - let replace_map = collect_projection_expr(child_projection); - let new_exprs = parent_projection - .expr - .iter() - .map(|expr| replace_cols_by_name(expr.clone(), &replace_map)) - .enumerate() - .map(|(i, e)| match e { - Ok(e) => { - let parent_expr = parent_projection.schema.fields()[i].qualified_name(); - e.alias_if_changed(parent_expr) - } - Err(e) => Err(e), - }) - .collect::>>()?; - // Use try_new, since schema changes with changing expressions. - let new_plan = LogicalPlan::Projection(Projection::try_new( - new_exprs, - child_projection.input.clone(), - )?); - Ok(new_plan) -} - -pub fn collect_projection_expr(projection: &Projection) -> HashMap { - projection - .schema - .fields() - .iter() - .enumerate() - .flat_map(|(i, field)| { - // strip alias - let expr = projection.expr[i].clone().unalias(); - // Convert both qualified and unqualified fields - [ - (field.name().clone(), expr.clone()), - (field.qualified_name(), expr), - ] - }) - .collect::>() -} - -#[cfg(test)] -mod tests { - use crate::merge_projection::MergeProjection; - use datafusion_common::Result; - use datafusion_expr::{ - binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, LogicalPlan, - Operator, - }; - use std::sync::Arc; - - use crate::test::*; - - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(MergeProjection::new()), plan, expected) - } - - #[test] - fn merge_two_projection() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("a")])? - .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])? - .build()?; - - let expected = "Projection: Int32(1) + test.a\ - \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) - } - - #[test] - fn merge_three_projection() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("a"), col("b")])? - .project(vec![col("a")])? - .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])? - .build()?; - - let expected = "Projection: Int32(1) + test.a\ - \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) - } - - #[test] - fn merge_alias() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("a")])? - .project(vec![col("a").alias("alias")])? - .build()?; - - let expected = "Projection: test.a AS alias\ - \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) - } -} diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs new file mode 100644 index 0000000000000..891a909a3378b --- /dev/null +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -0,0 +1,1063 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Optimizer rule to prune unnecessary columns from intermediate schemas +//! inside the [`LogicalPlan`]. This rule: +//! - Removes unnecessary columns that do not appear at the output and/or are +//! not used during any computation step. +//! - Adds projections to decrease table column size before operators that +//! benefit from a smaller memory footprint at its input. +//! - Removes unnecessary [`LogicalPlan::Projection`]s from the [`LogicalPlan`]. + +use std::collections::HashSet; +use std::sync::Arc; + +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; + +use arrow::datatypes::SchemaRef; +use datafusion_common::{ + get_required_group_by_exprs_indices, Column, DFSchema, DFSchemaRef, JoinType, Result, +}; +use datafusion_expr::expr::{Alias, ScalarFunction, ScalarFunctionDefinition}; +use datafusion_expr::{ + logical_plan::LogicalPlan, projection_schema, Aggregate, BinaryExpr, Cast, Distinct, + Expr, GroupingSet, Projection, TableScan, Window, +}; + +use hashbrown::HashMap; +use itertools::{izip, Itertools}; + +/// A rule for optimizing logical plans by removing unused columns/fields. +/// +/// `OptimizeProjections` is an optimizer rule that identifies and eliminates +/// columns from a logical plan that are not used by downstream operations. +/// This can improve query performance and reduce unnecessary data processing. +/// +/// The rule analyzes the input logical plan, determines the necessary column +/// indices, and then removes any unnecessary columns. It also removes any +/// unnecessary projections from the plan tree. +#[derive(Default)] +pub struct OptimizeProjections {} + +impl OptimizeProjections { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for OptimizeProjections { + fn try_optimize( + &self, + plan: &LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + // All output fields are necessary: + let indices = (0..plan.schema().fields().len()).collect::>(); + optimize_projections(plan, config, &indices) + } + + fn name(&self) -> &str { + "optimize_projections" + } + + fn apply_order(&self) -> Option { + None + } +} + +/// Removes unnecessary columns (e.g. columns that do not appear in the output +/// schema and/or are not used during any computation step such as expression +/// evaluation) from the logical plan and its inputs. +/// +/// # Parameters +/// +/// - `plan`: A reference to the input `LogicalPlan` to optimize. +/// - `config`: A reference to the optimizer configuration. +/// - `indices`: A slice of column indices that represent the necessary column +/// indices for downstream operations. +/// +/// # Returns +/// +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(LogicalPlan))`: An optimized `LogicalPlan` without unnecessary +/// columns. +/// - `Ok(None)`: Signal that the given logical plan did not require any change. +/// - `Err(error)`: An error occured during the optimization process. +fn optimize_projections( + plan: &LogicalPlan, + config: &dyn OptimizerConfig, + indices: &[usize], +) -> Result> { + // `child_required_indices` stores + // - indices of the columns required for each child + // - a flag indicating whether putting a projection above children is beneficial for the parent. + // As an example LogicalPlan::Filter benefits from small tables. Hence for filter child this flag would be `true`. + let child_required_indices: Vec<(Vec, bool)> = match plan { + LogicalPlan::Sort(_) + | LogicalPlan::Filter(_) + | LogicalPlan::Repartition(_) + | LogicalPlan::Unnest(_) + | LogicalPlan::Union(_) + | LogicalPlan::SubqueryAlias(_) + | LogicalPlan::Distinct(Distinct::On(_)) => { + // Pass index requirements from the parent as well as column indices + // that appear in this plan's expressions to its child. All these + // operators benefit from "small" inputs, so the projection_beneficial + // flag is `true`. + let exprs = plan.expressions(); + plan.inputs() + .into_iter() + .map(|input| { + get_all_required_indices(indices, input, exprs.iter()) + .map(|idxs| (idxs, true)) + }) + .collect::>()? + } + LogicalPlan::Limit(_) | LogicalPlan::Prepare(_) => { + // Pass index requirements from the parent as well as column indices + // that appear in this plan's expressions to its child. These operators + // do not benefit from "small" inputs, so the projection_beneficial + // flag is `false`. + let exprs = plan.expressions(); + plan.inputs() + .into_iter() + .map(|input| { + get_all_required_indices(indices, input, exprs.iter()) + .map(|idxs| (idxs, false)) + }) + .collect::>()? + } + LogicalPlan::Copy(_) + | LogicalPlan::Ddl(_) + | LogicalPlan::Dml(_) + | LogicalPlan::Explain(_) + | LogicalPlan::Analyze(_) + | LogicalPlan::Subquery(_) + | LogicalPlan::Distinct(Distinct::All(_)) => { + // These plans require all their fields, and their children should + // be treated as final plans -- otherwise, we may have schema a + // mismatch. + // TODO: For some subquery variants (e.g. a subquery arising from an + // EXISTS expression), we may not need to require all indices. + plan.inputs() + .iter() + .map(|input| ((0..input.schema().fields().len()).collect_vec(), false)) + .collect::>() + } + LogicalPlan::EmptyRelation(_) + | LogicalPlan::Statement(_) + | LogicalPlan::Values(_) + | LogicalPlan::Extension(_) + | LogicalPlan::DescribeTable(_) => { + // These operators have no inputs, so stop the optimization process. + // TODO: Add support for `LogicalPlan::Extension`. + return Ok(None); + } + LogicalPlan::Projection(proj) => { + return if let Some(proj) = merge_consecutive_projections(proj)? { + Ok(Some( + rewrite_projection_given_requirements(&proj, config, indices)? + // Even if we cannot optimize the projection, merge if possible: + .unwrap_or_else(|| LogicalPlan::Projection(proj)), + )) + } else { + rewrite_projection_given_requirements(proj, config, indices) + }; + } + LogicalPlan::Aggregate(aggregate) => { + // Split parent requirements to GROUP BY and aggregate sections: + let n_group_exprs = aggregate.group_expr_len()?; + let (group_by_reqs, mut aggregate_reqs): (Vec, Vec) = + indices.iter().partition(|&&idx| idx < n_group_exprs); + // Offset aggregate indices so that they point to valid indices at + // `aggregate.aggr_expr`: + for idx in aggregate_reqs.iter_mut() { + *idx -= n_group_exprs; + } + + // Get absolutely necessary GROUP BY fields: + let group_by_expr_existing = aggregate + .group_expr + .iter() + .map(|group_by_expr| group_by_expr.display_name()) + .collect::>>()?; + let new_group_bys = if let Some(simplest_groupby_indices) = + get_required_group_by_exprs_indices( + aggregate.input.schema(), + &group_by_expr_existing, + ) { + // Some of the fields in the GROUP BY may be required by the + // parent even if these fields are unnecessary in terms of + // functional dependency. + let required_indices = + merge_slices(&simplest_groupby_indices, &group_by_reqs); + get_at_indices(&aggregate.group_expr, &required_indices) + } else { + aggregate.group_expr.clone() + }; + + // Only use the absolutely necessary aggregate expressions required + // by the parent: + let mut new_aggr_expr = get_at_indices(&aggregate.aggr_expr, &aggregate_reqs); + let all_exprs_iter = new_group_bys.iter().chain(new_aggr_expr.iter()); + let schema = aggregate.input.schema(); + let necessary_indices = indices_referred_by_exprs(schema, all_exprs_iter)?; + + let aggregate_input = if let Some(input) = + optimize_projections(&aggregate.input, config, &necessary_indices)? + { + input + } else { + aggregate.input.as_ref().clone() + }; + + // Simplify the input of the aggregation by adding a projection so + // that its input only contains absolutely necessary columns for + // the aggregate expressions. Note that necessary_indices refer to + // fields in `aggregate.input.schema()`. + let necessary_exprs = get_required_exprs(schema, &necessary_indices); + let (aggregate_input, _) = + add_projection_on_top_if_helpful(aggregate_input, necessary_exprs)?; + + // Aggregations always need at least one aggregate expression. + // With a nested count, we don't require any column as input, but + // still need to create a correct aggregate, which may be optimized + // out later. As an example, consider the following query: + // + // SELECT COUNT(*) FROM (SELECT COUNT(*) FROM [...]) + // + // which always returns 1. + if new_aggr_expr.is_empty() + && new_group_bys.is_empty() + && !aggregate.aggr_expr.is_empty() + { + new_aggr_expr = vec![aggregate.aggr_expr[0].clone()]; + } + + // Create a new aggregate plan with the updated input and only the + // absolutely necessary fields: + return Aggregate::try_new( + Arc::new(aggregate_input), + new_group_bys, + new_aggr_expr, + ) + .map(|aggregate| Some(LogicalPlan::Aggregate(aggregate))); + } + LogicalPlan::Window(window) => { + // Split parent requirements to child and window expression sections: + let n_input_fields = window.input.schema().fields().len(); + let (child_reqs, mut window_reqs): (Vec, Vec) = + indices.iter().partition(|&&idx| idx < n_input_fields); + // Offset window expression indices so that they point to valid + // indices at `window.window_expr`: + for idx in window_reqs.iter_mut() { + *idx -= n_input_fields; + } + + // Only use window expressions that are absolutely necessary according + // to parent requirements: + let new_window_expr = get_at_indices(&window.window_expr, &window_reqs); + + // Get all the required column indices at the input, either by the + // parent or window expression requirements. + let required_indices = get_all_required_indices( + &child_reqs, + &window.input, + new_window_expr.iter(), + )?; + let window_child = if let Some(new_window_child) = + optimize_projections(&window.input, config, &required_indices)? + { + new_window_child + } else { + window.input.as_ref().clone() + }; + + return if new_window_expr.is_empty() { + // When no window expression is necessary, use the input directly: + Ok(Some(window_child)) + } else { + // Calculate required expressions at the input of the window. + // Please note that we use `old_child`, because `required_indices` + // refers to `old_child`. + let required_exprs = + get_required_exprs(window.input.schema(), &required_indices); + let (window_child, _) = + add_projection_on_top_if_helpful(window_child, required_exprs)?; + Window::try_new(new_window_expr, Arc::new(window_child)) + .map(|window| Some(LogicalPlan::Window(window))) + }; + } + LogicalPlan::Join(join) => { + let left_len = join.left.schema().fields().len(); + let (left_req_indices, right_req_indices) = + split_join_requirements(left_len, indices, &join.join_type); + let exprs = plan.expressions(); + let left_indices = + get_all_required_indices(&left_req_indices, &join.left, exprs.iter())?; + let right_indices = + get_all_required_indices(&right_req_indices, &join.right, exprs.iter())?; + // Joins benefit from "small" input tables (lower memory usage). + // Therefore, each child benefits from projection: + vec![(left_indices, true), (right_indices, true)] + } + LogicalPlan::CrossJoin(cross_join) => { + let left_len = cross_join.left.schema().fields().len(); + let (left_child_indices, right_child_indices) = + split_join_requirements(left_len, indices, &JoinType::Inner); + // Joins benefit from "small" input tables (lower memory usage). + // Therefore, each child benefits from projection: + vec![(left_child_indices, true), (right_child_indices, true)] + } + LogicalPlan::TableScan(table_scan) => { + let schema = table_scan.source.schema(); + // Get indices referred to in the original (schema with all fields) + // given projected indices. + let projection = with_indices(&table_scan.projection, schema, |map| { + indices.iter().map(|&idx| map[idx]).collect() + }); + + return TableScan::try_new( + table_scan.table_name.clone(), + table_scan.source.clone(), + Some(projection), + table_scan.filters.clone(), + table_scan.fetch, + ) + .map(|table| Some(LogicalPlan::TableScan(table))); + } + }; + + let new_inputs = izip!(child_required_indices, plan.inputs().into_iter()) + .map(|((required_indices, projection_beneficial), child)| { + let (input, is_changed) = if let Some(new_input) = + optimize_projections(child, config, &required_indices)? + { + (new_input, true) + } else { + (child.clone(), false) + }; + let project_exprs = get_required_exprs(child.schema(), &required_indices); + let (input, proj_added) = if projection_beneficial { + add_projection_on_top_if_helpful(input, project_exprs)? + } else { + (input, false) + }; + Ok((is_changed || proj_added).then_some(input)) + }) + .collect::>>()?; + if new_inputs.iter().all(|child| child.is_none()) { + // All children are the same in this case, no need to change the plan: + Ok(None) + } else { + // At least one of the children is changed: + let new_inputs = izip!(new_inputs, plan.inputs()) + // If new_input is `None`, this means child is not changed, so use + // `old_child` during construction: + .map(|(new_input, old_child)| new_input.unwrap_or_else(|| old_child.clone())) + .collect::>(); + plan.with_new_exprs(plan.expressions(), &new_inputs) + .map(Some) + } +} + +/// This function applies the given function `f` to the projection indices +/// `proj_indices` if they exist. Otherwise, applies `f` to a default set +/// of indices according to `schema`. +fn with_indices( + proj_indices: &Option>, + schema: SchemaRef, + mut f: F, +) -> Vec +where + F: FnMut(&[usize]) -> Vec, +{ + match proj_indices { + Some(indices) => f(indices.as_slice()), + None => { + let range: Vec = (0..schema.fields.len()).collect(); + f(range.as_slice()) + } + } +} + +/// Merges consecutive projections. +/// +/// Given a projection `proj`, this function attempts to merge it with a previous +/// projection if it exists and if merging is beneficial. Merging is considered +/// beneficial when expressions in the current projection are non-trivial and +/// appear more than once in its input fields. This can act as a caching mechanism +/// for non-trivial computations. +/// +/// # Parameters +/// +/// * `proj` - A reference to the `Projection` to be merged. +/// +/// # Returns +/// +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(Projection))`: Merge was beneficial and successful. Contains the +/// merged projection. +/// - `Ok(None)`: Signals that merge is not beneficial (and has not taken place). +/// - `Err(error)`: An error occured during the function call. +fn merge_consecutive_projections(proj: &Projection) -> Result> { + let LogicalPlan::Projection(prev_projection) = proj.input.as_ref() else { + return Ok(None); + }; + + // Count usages (referrals) of each projection expression in its input fields: + let mut column_referral_map = HashMap::::new(); + for columns in proj.expr.iter().flat_map(|expr| expr.to_columns()) { + for col in columns.into_iter() { + *column_referral_map.entry(col.clone()).or_default() += 1; + } + } + + // If an expression is non-trivial and appears more than once, consecutive + // projections will benefit from a compute-once approach. For details, see: + // https://github.com/apache/arrow-datafusion/issues/8296 + if column_referral_map.into_iter().any(|(col, usage)| { + usage > 1 + && !is_expr_trivial( + &prev_projection.expr + [prev_projection.schema.index_of_column(&col).unwrap()], + ) + }) { + return Ok(None); + } + + // If all the expression of the top projection can be rewritten, do so and + // create a new projection: + let new_exprs = proj + .expr + .iter() + .map(|expr| rewrite_expr(expr, prev_projection)) + .collect::>>>()?; + if let Some(new_exprs) = new_exprs { + let new_exprs = new_exprs + .into_iter() + .zip(proj.expr.iter()) + .map(|(new_expr, old_expr)| { + new_expr.alias_if_changed(old_expr.name_for_alias()?) + }) + .collect::>>()?; + Projection::try_new(new_exprs, prev_projection.input.clone()).map(Some) + } else { + Ok(None) + } +} + +/// Trim the given expression by removing any unnecessary layers of aliasing. +/// If the expression is an alias, the function returns the underlying expression. +/// Otherwise, it returns the given expression as is. +/// +/// Without trimming, we can end up with unnecessary indirections inside expressions +/// during projection merges. +/// +/// Consider: +/// +/// ```text +/// Projection(a1 + b1 as sum1) +/// --Projection(a as a1, b as b1) +/// ----Source(a, b) +/// ``` +/// +/// After merge, we want to produce: +/// +/// ```text +/// Projection(a + b as sum1) +/// --Source(a, b) +/// ``` +/// +/// Without trimming, we would end up with: +/// +/// ```text +/// Projection((a as a1 + b as b1) as sum1) +/// --Source(a, b) +/// ``` +fn trim_expr(expr: Expr) -> Expr { + match expr { + Expr::Alias(alias) => trim_expr(*alias.expr), + _ => expr, + } +} + +// Check whether `expr` is trivial; i.e. it doesn't imply any computation. +fn is_expr_trivial(expr: &Expr) -> bool { + matches!(expr, Expr::Column(_) | Expr::Literal(_)) +} + +// Exit early when there is no rewrite to do. +macro_rules! rewrite_expr_with_check { + ($expr:expr, $input:expr) => { + if let Some(value) = rewrite_expr($expr, $input)? { + value + } else { + return Ok(None); + } + }; +} + +/// Rewrites a projection expression using the projection before it (i.e. its input) +/// This is a subroutine to the `merge_consecutive_projections` function. +/// +/// # Parameters +/// +/// * `expr` - A reference to the expression to rewrite. +/// * `input` - A reference to the input of the projection expression (itself +/// a projection). +/// +/// # Returns +/// +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(Expr))`: Rewrite was successful. Contains the rewritten result. +/// - `Ok(None)`: Signals that `expr` can not be rewritten. +/// - `Err(error)`: An error occured during the function call. +fn rewrite_expr(expr: &Expr, input: &Projection) -> Result> { + let result = match expr { + Expr::Column(col) => { + // Find index of column: + let idx = input.schema.index_of_column(col)?; + input.expr[idx].clone() + } + Expr::BinaryExpr(binary) => Expr::BinaryExpr(BinaryExpr::new( + Box::new(trim_expr(rewrite_expr_with_check!(&binary.left, input))), + binary.op, + Box::new(trim_expr(rewrite_expr_with_check!(&binary.right, input))), + )), + Expr::Alias(alias) => Expr::Alias(Alias::new( + trim_expr(rewrite_expr_with_check!(&alias.expr, input)), + alias.relation.clone(), + alias.name.clone(), + )), + Expr::Literal(_) => expr.clone(), + Expr::Cast(cast) => { + let new_expr = rewrite_expr_with_check!(&cast.expr, input); + Expr::Cast(Cast::new(Box::new(new_expr), cast.data_type.clone())) + } + Expr::ScalarFunction(scalar_fn) => { + // TODO: Support UDFs. + let ScalarFunctionDefinition::BuiltIn(fun) = scalar_fn.func_def else { + return Ok(None); + }; + return Ok(scalar_fn + .args + .iter() + .map(|expr| rewrite_expr(expr, input)) + .collect::>>()? + .map(|new_args| { + Expr::ScalarFunction(ScalarFunction::new(fun, new_args)) + })); + } + // Unsupported type for consecutive projection merge analysis. + _ => return Ok(None), + }; + Ok(Some(result)) +} + +/// Retrieves a set of outer-referenced columns by the given expression, `expr`. +/// Note that the `Expr::to_columns()` function doesn't return these columns. +/// +/// # Parameters +/// +/// * `expr` - The expression to analyze for outer-referenced columns. +/// +/// # Returns +/// +/// If the function can safely infer all outer-referenced columns, returns a +/// `Some(HashSet)` containing these columns. Otherwise, returns `None`. +fn outer_columns(expr: &Expr) -> Option> { + let mut columns = HashSet::new(); + outer_columns_helper(expr, &mut columns).then_some(columns) +} + +/// A recursive subroutine that accumulates outer-referenced columns by the +/// given expression, `expr`. +/// +/// # Parameters +/// +/// * `expr` - The expression to analyze for outer-referenced columns. +/// * `columns` - A mutable reference to a `HashSet` where detected +/// columns are collected. +/// +/// Returns `true` if it can safely collect all outer-referenced columns. +/// Otherwise, returns `false`. +fn outer_columns_helper(expr: &Expr, columns: &mut HashSet) -> bool { + match expr { + Expr::OuterReferenceColumn(_, col) => { + columns.insert(col.clone()); + true + } + Expr::BinaryExpr(binary_expr) => { + outer_columns_helper(&binary_expr.left, columns) + && outer_columns_helper(&binary_expr.right, columns) + } + Expr::ScalarSubquery(subquery) => { + let exprs = subquery.outer_ref_columns.iter(); + outer_columns_helper_multi(exprs, columns) + } + Expr::Exists(exists) => { + let exprs = exists.subquery.outer_ref_columns.iter(); + outer_columns_helper_multi(exprs, columns) + } + Expr::Alias(alias) => outer_columns_helper(&alias.expr, columns), + Expr::InSubquery(insubquery) => { + let exprs = insubquery.subquery.outer_ref_columns.iter(); + outer_columns_helper_multi(exprs, columns) + } + Expr::IsNotNull(expr) | Expr::IsNull(expr) => outer_columns_helper(expr, columns), + Expr::Cast(cast) => outer_columns_helper(&cast.expr, columns), + Expr::Sort(sort) => outer_columns_helper(&sort.expr, columns), + Expr::AggregateFunction(aggregate_fn) => { + outer_columns_helper_multi(aggregate_fn.args.iter(), columns) + && aggregate_fn + .order_by + .as_ref() + .map_or(true, |obs| outer_columns_helper_multi(obs.iter(), columns)) + && aggregate_fn + .filter + .as_ref() + .map_or(true, |filter| outer_columns_helper(filter, columns)) + } + Expr::WindowFunction(window_fn) => { + outer_columns_helper_multi(window_fn.args.iter(), columns) + && outer_columns_helper_multi(window_fn.order_by.iter(), columns) + && outer_columns_helper_multi(window_fn.partition_by.iter(), columns) + } + Expr::GroupingSet(groupingset) => match groupingset { + GroupingSet::GroupingSets(multi_exprs) => multi_exprs + .iter() + .all(|e| outer_columns_helper_multi(e.iter(), columns)), + GroupingSet::Cube(exprs) | GroupingSet::Rollup(exprs) => { + outer_columns_helper_multi(exprs.iter(), columns) + } + }, + Expr::ScalarFunction(scalar_fn) => { + outer_columns_helper_multi(scalar_fn.args.iter(), columns) + } + Expr::Like(like) => { + outer_columns_helper(&like.expr, columns) + && outer_columns_helper(&like.pattern, columns) + } + Expr::InList(in_list) => { + outer_columns_helper(&in_list.expr, columns) + && outer_columns_helper_multi(in_list.list.iter(), columns) + } + Expr::Case(case) => { + let when_then_exprs = case + .when_then_expr + .iter() + .flat_map(|(first, second)| [first.as_ref(), second.as_ref()]); + outer_columns_helper_multi(when_then_exprs, columns) + && case + .expr + .as_ref() + .map_or(true, |expr| outer_columns_helper(expr, columns)) + && case + .else_expr + .as_ref() + .map_or(true, |expr| outer_columns_helper(expr, columns)) + } + Expr::Column(_) | Expr::Literal(_) | Expr::Wildcard { .. } => true, + _ => false, + } +} + +/// A recursive subroutine that accumulates outer-referenced columns by the +/// given expressions (`exprs`). +/// +/// # Parameters +/// +/// * `exprs` - The expressions to analyze for outer-referenced columns. +/// * `columns` - A mutable reference to a `HashSet` where detected +/// columns are collected. +/// +/// Returns `true` if it can safely collect all outer-referenced columns. +/// Otherwise, returns `false`. +fn outer_columns_helper_multi<'a>( + mut exprs: impl Iterator, + columns: &mut HashSet, +) -> bool { + exprs.all(|e| outer_columns_helper(e, columns)) +} + +/// Generates the required expressions (columns) that reside at `indices` of +/// the given `input_schema`. +/// +/// # Arguments +/// +/// * `input_schema` - A reference to the input schema. +/// * `indices` - A slice of `usize` indices specifying required columns. +/// +/// # Returns +/// +/// A vector of `Expr::Column` expressions residing at `indices` of the `input_schema`. +fn get_required_exprs(input_schema: &Arc, indices: &[usize]) -> Vec { + let fields = input_schema.fields(); + indices + .iter() + .map(|&idx| Expr::Column(fields[idx].qualified_column())) + .collect() +} + +/// Get indices of the fields referred to by any expression in `exprs` within +/// the given schema (`input_schema`). +/// +/// # Arguments +/// +/// * `input_schema`: The input schema to analyze for index requirements. +/// * `exprs`: An iterator of expressions for which we want to find necessary +/// field indices. +/// +/// # Returns +/// +/// A [`Result`] object containing the indices of all required fields in +/// `input_schema` to calculate all `exprs` successfully. +fn indices_referred_by_exprs<'a>( + input_schema: &DFSchemaRef, + exprs: impl Iterator, +) -> Result> { + let indices = exprs + .map(|expr| indices_referred_by_expr(input_schema, expr)) + .collect::>>()?; + Ok(indices + .into_iter() + .flatten() + // Make sure no duplicate entries exist and indices are ordered: + .sorted() + .dedup() + .collect()) +} + +/// Get indices of the fields referred to by the given expression `expr` within +/// the given schema (`input_schema`). +/// +/// # Parameters +/// +/// * `input_schema`: The input schema to analyze for index requirements. +/// * `expr`: An expression for which we want to find necessary field indices. +/// +/// # Returns +/// +/// A [`Result`] object containing the indices of all required fields in +/// `input_schema` to calculate `expr` successfully. +fn indices_referred_by_expr( + input_schema: &DFSchemaRef, + expr: &Expr, +) -> Result> { + let mut cols = expr.to_columns()?; + // Get outer-referenced columns: + if let Some(outer_cols) = outer_columns(expr) { + cols.extend(outer_cols); + } else { + // Expression is not known to contain outer columns or not. Hence, do + // not assume anything and require all the schema indices at the input: + return Ok((0..input_schema.fields().len()).collect()); + } + Ok(cols + .iter() + .flat_map(|col| input_schema.index_of_column(col)) + .collect()) +} + +/// Gets all required indices for the input; i.e. those required by the parent +/// and those referred to by `exprs`. +/// +/// # Parameters +/// +/// * `parent_required_indices` - A slice of indices required by the parent plan. +/// * `input` - The input logical plan to analyze for index requirements. +/// * `exprs` - An iterator of expressions used to determine required indices. +/// +/// # Returns +/// +/// A `Result` containing a vector of `usize` indices containing all the required +/// indices. +fn get_all_required_indices<'a>( + parent_required_indices: &[usize], + input: &LogicalPlan, + exprs: impl Iterator, +) -> Result> { + indices_referred_by_exprs(input.schema(), exprs) + .map(|indices| merge_slices(parent_required_indices, &indices)) +} + +/// Retrieves the expressions at specified indices within the given slice. Ignores +/// any invalid indices. +/// +/// # Parameters +/// +/// * `exprs` - A slice of expressions to index into. +/// * `indices` - A slice of indices specifying the positions of expressions sought. +/// +/// # Returns +/// +/// A vector of expressions corresponding to specified indices. +fn get_at_indices(exprs: &[Expr], indices: &[usize]) -> Vec { + indices + .iter() + // Indices may point to further places than `exprs` len. + .filter_map(|&idx| exprs.get(idx).cloned()) + .collect() +} + +/// Merges two slices into a single vector with sorted (ascending) and +/// deduplicated elements. For example, merging `[3, 2, 4]` and `[3, 6, 1]` +/// will produce `[1, 2, 3, 6]`. +fn merge_slices(left: &[T], right: &[T]) -> Vec { + // Make sure to sort before deduping, which removes the duplicates: + left.iter() + .cloned() + .chain(right.iter().cloned()) + .sorted() + .dedup() + .collect() +} + +/// Splits requirement indices for a join into left and right children based on +/// the join type. +/// +/// This function takes the length of the left child, a slice of requirement +/// indices, and the type of join (e.g. `INNER`, `LEFT`, `RIGHT`) as arguments. +/// Depending on the join type, it divides the requirement indices into those +/// that apply to the left child and those that apply to the right child. +/// +/// - For `INNER`, `LEFT`, `RIGHT` and `FULL` joins, the requirements are split +/// between left and right children. The right child indices are adjusted to +/// point to valid positions within the right child by subtracting the length +/// of the left child. +/// +/// - For `LEFT ANTI`, `LEFT SEMI`, `RIGHT SEMI` and `RIGHT ANTI` joins, all +/// requirements are re-routed to either the left child or the right child +/// directly, depending on the join type. +/// +/// # Parameters +/// +/// * `left_len` - The length of the left child. +/// * `indices` - A slice of requirement indices. +/// * `join_type` - The type of join (e.g. `INNER`, `LEFT`, `RIGHT`). +/// +/// # Returns +/// +/// A tuple containing two vectors of `usize` indices: The first vector represents +/// the requirements for the left child, and the second vector represents the +/// requirements for the right child. The indices are appropriately split and +/// adjusted based on the join type. +fn split_join_requirements( + left_len: usize, + indices: &[usize], + join_type: &JoinType, +) -> (Vec, Vec) { + match join_type { + // In these cases requirements are split between left/right children: + JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { + let (left_reqs, mut right_reqs): (Vec, Vec) = + indices.iter().partition(|&&idx| idx < left_len); + // Decrease right side indices by `left_len` so that they point to valid + // positions within the right child: + for idx in right_reqs.iter_mut() { + *idx -= left_len; + } + (left_reqs, right_reqs) + } + // All requirements can be re-routed to left child directly. + JoinType::LeftAnti | JoinType::LeftSemi => (indices.to_vec(), vec![]), + // All requirements can be re-routed to right side directly. + // No need to change index, join schema is right child schema. + JoinType::RightSemi | JoinType::RightAnti => (vec![], indices.to_vec()), + } +} + +/// Adds a projection on top of a logical plan if doing so reduces the number +/// of columns for the parent operator. +/// +/// This function takes a `LogicalPlan` and a list of projection expressions. +/// If the projection is beneficial (it reduces the number of columns in the +/// plan) a new `LogicalPlan` with the projection is created and returned, along +/// with a `true` flag. If the projection doesn't reduce the number of columns, +/// the original plan is returned with a `false` flag. +/// +/// # Parameters +/// +/// * `plan` - The input `LogicalPlan` to potentially add a projection to. +/// * `project_exprs` - A list of expressions for the projection. +/// +/// # Returns +/// +/// A `Result` containing a tuple with two values: The resulting `LogicalPlan` +/// (with or without the added projection) and a `bool` flag indicating if a +/// projection was added (`true`) or not (`false`). +fn add_projection_on_top_if_helpful( + plan: LogicalPlan, + project_exprs: Vec, +) -> Result<(LogicalPlan, bool)> { + // Make sure projection decreases the number of columns, otherwise it is unnecessary. + if project_exprs.len() >= plan.schema().fields().len() { + Ok((plan, false)) + } else { + Projection::try_new(project_exprs, Arc::new(plan)) + .map(|proj| (LogicalPlan::Projection(proj), true)) + } +} + +/// Rewrite the given projection according to the fields required by its +/// ancestors. +/// +/// # Parameters +/// +/// * `proj` - A reference to the original projection to rewrite. +/// * `config` - A reference to the optimizer configuration. +/// * `indices` - A slice of indices representing the columns required by the +/// ancestors of the given projection. +/// +/// # Returns +/// +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(LogicalPlan))`: Contains the rewritten projection +/// - `Ok(None)`: No rewrite necessary. +/// - `Err(error)`: An error occured during the function call. +fn rewrite_projection_given_requirements( + proj: &Projection, + config: &dyn OptimizerConfig, + indices: &[usize], +) -> Result> { + let exprs_used = get_at_indices(&proj.expr, indices); + let required_indices = + indices_referred_by_exprs(proj.input.schema(), exprs_used.iter())?; + return if let Some(input) = + optimize_projections(&proj.input, config, &required_indices)? + { + if &projection_schema(&input, &exprs_used)? == input.schema() { + Ok(Some(input)) + } else { + Projection::try_new(exprs_used, Arc::new(input)) + .map(|proj| Some(LogicalPlan::Projection(proj))) + } + } else if exprs_used.len() < proj.expr.len() { + // Projection expression used is different than the existing projection. + // In this case, even if the child doesn't change, we should update the + // projection to use fewer columns: + if &projection_schema(&proj.input, &exprs_used)? == proj.input.schema() { + Ok(Some(proj.input.as_ref().clone())) + } else { + Projection::try_new(exprs_used, proj.input.clone()) + .map(|proj| Some(LogicalPlan::Projection(proj))) + } + } else { + // Projection doesn't change. + Ok(None) + }; +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::optimize_projections::OptimizeProjections; + use crate::test::{assert_optimized_plan_eq, test_table_scan}; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{Result, TableReference}; + use datafusion_expr::{ + binary_expr, col, count, lit, logical_plan::builder::LogicalPlanBuilder, + table_scan, Expr, LogicalPlan, Operator, + }; + + fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected) + } + + #[test] + fn merge_two_projection() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a")])? + .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])? + .build()?; + + let expected = "Projection: Int32(1) + test.a\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn merge_three_projection() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b")])? + .project(vec![col("a")])? + .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])? + .build()?; + + let expected = "Projection: Int32(1) + test.a\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn merge_alias() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a")])? + .project(vec![col("a").alias("alias")])? + .build()?; + + let expected = "Projection: test.a AS alias\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn merge_nested_alias() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").alias("alias1").alias("alias2")])? + .project(vec![col("alias2").alias("alias")])? + .build()?; + + let expected = "Projection: test.a AS alias\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_nested_count() -> Result<()> { + let schema = Schema::new(vec![Field::new("foo", DataType::Int32, false)]); + + let groups: Vec = vec![]; + + let plan = table_scan(TableReference::none(), &schema, None) + .unwrap() + .aggregate(groups.clone(), vec![count(lit(1))]) + .unwrap() + .aggregate(groups, vec![count(lit(1))]) + .unwrap() + .build() + .unwrap(); + + let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(Int32(1))]]\ + \n Projection: \ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(Int32(1))]]\ + \n TableScan: ?table? projection=[]"; + assert_optimized_plan_equal(&plan, expected) + } +} diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 5231dc8698751..2cb59d511ccf5 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -17,6 +17,10 @@ //! Query optimizer traits +use std::collections::HashSet; +use std::sync::Arc; +use std::time::Instant; + use crate::common_subexpr_eliminate::CommonSubexprEliminate; use crate::decorrelate_predicate_subquery::DecorrelatePredicateSubquery; use crate::eliminate_cross_join::EliminateCrossJoin; @@ -27,15 +31,13 @@ use crate::eliminate_limit::EliminateLimit; use crate::eliminate_nested_union::EliminateNestedUnion; use crate::eliminate_one_union::EliminateOneUnion; use crate::eliminate_outer_join::EliminateOuterJoin; -use crate::eliminate_project::EliminateProjection; use crate::extract_equijoin_predicate::ExtractEquijoinPredicate; use crate::filter_null_join_keys::FilterNullJoinKeys; -use crate::merge_projection::MergeProjection; +use crate::optimize_projections::OptimizeProjections; use crate::plan_signature::LogicalPlanSignature; use crate::propagate_empty_relation::PropagateEmptyRelation; use crate::push_down_filter::PushDownFilter; use crate::push_down_limit::PushDownLimit; -use crate::push_down_projection::PushDownProjection; use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; use crate::scalar_subquery_to_join::ScalarSubqueryToJoin; @@ -43,15 +45,14 @@ use crate::simplify_expressions::SimplifyExpressions; use crate::single_distinct_to_groupby::SingleDistinctToGroupBy; use crate::unwrap_cast_in_comparison::UnwrapCastInComparison; use crate::utils::log_plan; -use chrono::{DateTime, Utc}; + use datafusion_common::alias::AliasGenerator; use datafusion_common::config::ConfigOptions; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::logical_plan::LogicalPlan; + +use chrono::{DateTime, Utc}; use log::{debug, warn}; -use std::collections::HashSet; -use std::sync::Arc; -use std::time::Instant; /// `OptimizerRule` transforms one [`LogicalPlan`] into another which /// computes the same results, but in a potentially more efficient @@ -234,7 +235,6 @@ impl Optimizer { // run it again after running the optimizations that potentially converted // subqueries to joins Arc::new(SimplifyExpressions::new()), - Arc::new(MergeProjection::new()), Arc::new(RewriteDisjunctivePredicate::new()), Arc::new(EliminateDuplicatedExpr::new()), Arc::new(EliminateFilter::new()), @@ -255,10 +255,7 @@ impl Optimizer { Arc::new(SimplifyExpressions::new()), Arc::new(UnwrapCastInComparison::new()), Arc::new(CommonSubexprEliminate::new()), - Arc::new(PushDownProjection::new()), - Arc::new(EliminateProjection::new()), - // PushDownProjection can pushdown Projections through Limits, do PushDownLimit again. - Arc::new(PushDownLimit::new()), + Arc::new(OptimizeProjections::new()), ]; Self::with_rules(rules) @@ -385,7 +382,7 @@ impl Optimizer { }) .collect::>(); - Ok(Some(plan.with_new_inputs(&new_inputs)?)) + Ok(Some(plan.with_new_exprs(plan.expressions(), &new_inputs)?)) } /// Use a rule to optimize the whole plan. @@ -427,7 +424,7 @@ impl Optimizer { /// Returns an error if plans have different schemas. /// /// It ignores metadata and nullability. -fn assert_schema_is_the_same( +pub(crate) fn assert_schema_is_the_same( rule_name: &str, prev_plan: &LogicalPlan, new_plan: &LogicalPlan, @@ -438,7 +435,7 @@ fn assert_schema_is_the_same( if !equivalent { let e = DataFusionError::Internal(format!( - "Failed due to generate a different schema, original schema: {:?}, new schema: {:?}", + "Failed due to a difference in schemas, original schema: {:?}, new schema: {:?}", prev_plan.schema(), new_plan.schema() )); @@ -453,17 +450,18 @@ fn assert_schema_is_the_same( #[cfg(test)] mod tests { + use std::sync::{Arc, Mutex}; + + use super::ApplyOrder; use crate::optimizer::Optimizer; use crate::test::test_table_scan; use crate::{OptimizerConfig, OptimizerContext, OptimizerRule}; + use datafusion_common::{ plan_err, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, }; use datafusion_expr::logical_plan::EmptyRelation; use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder, Projection}; - use std::sync::{Arc, Mutex}; - - use super::ApplyOrder; #[test] fn skip_failing_rule() { @@ -503,7 +501,7 @@ mod tests { let err = opt.optimize(&plan, &config, &observe).unwrap_err(); assert_eq!( "Optimizer rule 'get table_scan rule' failed\ncaused by\nget table_scan rule\ncaused by\n\ - Internal error: Failed due to generate a different schema, \ + Internal error: Failed due to a difference in schemas, \ original schema: DFSchema { fields: [], metadata: {}, functional_dependencies: FunctionalDependencies { deps: [] } }, \ new schema: DFSchema { fields: [\ DFField { qualifier: Some(Bare { table: \"test\" }), field: Field { name: \"a\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, \ diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index ae986b3c84dde..4eb925ac06292 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -15,24 +15,29 @@ //! [`PushDownFilter`] Moves filters so they are applied as early as possible in //! the plan. +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; + use crate::optimizer::ApplyOrder; -use crate::utils::{conjunction, split_conjunction, split_conjunction_owned}; -use crate::{utils, OptimizerConfig, OptimizerRule}; +use crate::{OptimizerConfig, OptimizerRule}; + use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; use datafusion_common::{ - internal_err, plan_datafusion_err, Column, DFSchema, DataFusionError, Result, + internal_err, plan_datafusion_err, Column, DFSchema, DFSchemaRef, DataFusionError, + JoinConstraint, Result, }; use datafusion_expr::expr::Alias; -use datafusion_expr::Volatility; +use datafusion_expr::expr_rewriter::replace_col; +use datafusion_expr::logical_plan::{ + CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union, +}; +use datafusion_expr::utils::{conjunction, split_conjunction, split_conjunction_owned}; use datafusion_expr::{ - and, - expr_rewriter::replace_col, - logical_plan::{CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union}, - or, BinaryExpr, Expr, Filter, Operator, TableProviderFilterPushDown, + and, build_join_schema, or, BinaryExpr, Expr, Filter, LogicalPlanBuilder, Operator, + ScalarFunctionDefinition, TableProviderFilterPushDown, Volatility, }; + use itertools::Itertools; -use std::collections::{HashMap, HashSet}; -use std::sync::Arc; /// Optimizer rule for pushing (moving) filter expressions down in a plan so /// they are applied as early as possible. @@ -221,7 +226,10 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::InSubquery(_) | Expr::ScalarSubquery(_) | Expr::OuterReferenceColumn(_, _) - | Expr::ScalarUDF(..) => { + | Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction { + func_def: ScalarFunctionDefinition::UDF(_), + .. + }) => { is_evaluate = false; Ok(VisitRecursion::Stop) } @@ -249,9 +257,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { Expr::Sort(_) | Expr::AggregateFunction(_) | Expr::WindowFunction(_) - | Expr::AggregateUDF { .. } - | Expr::Wildcard - | Expr::QualifiedWildcard { .. } + | Expr::Wildcard { .. } | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"), })?; Ok(is_evaluate) @@ -543,9 +549,7 @@ fn push_down_join( parent_predicate: Option<&Expr>, ) -> Result> { let predicates = match parent_predicate { - Some(parent_predicate) => { - utils::split_conjunction_owned(parent_predicate.clone()) - } + Some(parent_predicate) => split_conjunction_owned(parent_predicate.clone()), None => vec![], }; @@ -553,12 +557,21 @@ fn push_down_join( let on_filters = join .filter .as_ref() - .map(|e| utils::split_conjunction_owned(e.clone())) + .map(|e| split_conjunction_owned(e.clone())) .unwrap_or_default(); let mut is_inner_join = false; let infer_predicates = if join.join_type == JoinType::Inner { is_inner_join = true; + // Only allow both side key is column. + let join_col_keys = join + .on + .iter() + .flat_map(|(l, r)| match (l.try_into_col(), r.try_into_col()) { + (Ok(l_col), Ok(r_col)) => Some((l_col, r_col)), + _ => None, + }) + .collect::>(); // TODO refine the logic, introduce EquivalenceProperties to logical plan and infer additional filters to push down // For inner joins, duplicate filters for joined columns so filters can be pushed down // to both sides. Take the following query as an example: @@ -583,16 +596,6 @@ fn push_down_join( Err(e) => return Some(Err(e)), }; - // Only allow both side key is column. - let join_col_keys = join - .on - .iter() - .flat_map(|(l, r)| match (l.try_into_col(), r.try_into_col()) { - (Ok(l_col), Ok(r_col)) => Some((l_col, r_col)), - _ => None, - }) - .collect::>(); - for col in columns.iter() { for (l, r) in join_col_keys.iter() { if col == l { @@ -688,9 +691,11 @@ impl OptimizerRule for PushDownFilter { | LogicalPlan::Distinct(_) | LogicalPlan::Sort(_) => { // commutable - let new_filter = - plan.with_new_inputs(&[child_plan.inputs()[0].clone()])?; - child_plan.with_new_inputs(&[new_filter])? + let new_filter = plan.with_new_exprs( + plan.expressions(), + &[child_plan.inputs()[0].clone()], + )?; + child_plan.with_new_exprs(child_plan.expressions(), &[new_filter])? } LogicalPlan::SubqueryAlias(subquery_alias) => { let mut replace_map = HashMap::new(); @@ -713,7 +718,7 @@ impl OptimizerRule for PushDownFilter { new_predicate, subquery_alias.input.clone(), )?); - child_plan.with_new_inputs(&[new_filter])? + child_plan.with_new_exprs(child_plan.expressions(), &[new_filter])? } LogicalPlan::Projection(projection) => { // A projection is filter-commutable if it do not contain volatile predicates or contain volatile @@ -757,10 +762,15 @@ impl OptimizerRule for PushDownFilter { )?); match conjunction(keep_predicates) { - None => child_plan.with_new_inputs(&[new_filter])?, + None => child_plan.with_new_exprs( + child_plan.expressions(), + &[new_filter], + )?, Some(keep_predicate) => { - let child_plan = - child_plan.with_new_inputs(&[new_filter])?; + let child_plan = child_plan.with_new_exprs( + child_plan.expressions(), + &[new_filter], + )?; LogicalPlan::Filter(Filter::try_new( keep_predicate, Arc::new(child_plan), @@ -802,7 +812,7 @@ impl OptimizerRule for PushDownFilter { .map(|e| Ok(Column::from_qualified_name(e.display_name()?))) .collect::>>()?; - let predicates = utils::split_conjunction_owned(filter.predicate.clone()); + let predicates = split_conjunction_owned(filter.predicate.clone()); let mut keep_predicates = vec![]; let mut push_predicates = vec![]; @@ -834,7 +844,9 @@ impl OptimizerRule for PushDownFilter { )?), None => (*agg.input).clone(), }; - let new_agg = filter.input.with_new_inputs(&vec![child])?; + let new_agg = filter + .input + .with_new_exprs(filter.input.expressions(), &vec![child])?; match conjunction(keep_predicates) { Some(predicate) => LogicalPlan::Filter(Filter::try_new( predicate, @@ -849,17 +861,23 @@ impl OptimizerRule for PushDownFilter { None => return Ok(None), } } - LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { - let predicates = utils::split_conjunction_owned(filter.predicate.clone()); - push_down_all_join( + LogicalPlan::CrossJoin(cross_join) => { + let predicates = split_conjunction_owned(filter.predicate.clone()); + let join = convert_cross_join_to_inner_join(cross_join.clone())?; + let join_plan = LogicalPlan::Join(join); + let inputs = join_plan.inputs(); + let left = inputs[0]; + let right = inputs[1]; + let plan = push_down_all_join( predicates, vec![], - &filter.input, + &join_plan, left, right, vec![], - false, - )? + true, + )?; + convert_to_cross_join_if_beneficial(plan)? } LogicalPlan::TableScan(scan) => { let filter_predicates = split_conjunction(&filter.predicate); @@ -905,7 +923,7 @@ impl OptimizerRule for PushDownFilter { let prevent_cols = extension_plan.node.prevent_predicate_push_down_columns(); - let predicates = utils::split_conjunction_owned(filter.predicate.clone()); + let predicates = split_conjunction_owned(filter.predicate.clone()); let mut keep_predicates = vec![]; let mut push_predicates = vec![]; @@ -933,7 +951,8 @@ impl OptimizerRule for PushDownFilter { None => extension_plan.node.inputs().into_iter().cloned().collect(), }; // extension with new inputs. - let new_extension = child_plan.with_new_inputs(&new_children)?; + let new_extension = + child_plan.with_new_exprs(child_plan.expressions(), &new_children)?; match conjunction(keep_predicates) { Some(predicate) => LogicalPlan::Filter(Filter::try_new( @@ -956,6 +975,42 @@ impl PushDownFilter { } } +/// Converts the given cross join to an inner join with an empty equality +/// predicate and an empty filter condition. +fn convert_cross_join_to_inner_join(cross_join: CrossJoin) -> Result { + let CrossJoin { left, right, .. } = cross_join; + let join_schema = build_join_schema(left.schema(), right.schema(), &JoinType::Inner)?; + Ok(Join { + left, + right, + join_type: JoinType::Inner, + join_constraint: JoinConstraint::On, + on: vec![], + filter: None, + schema: DFSchemaRef::new(join_schema), + null_equals_null: true, + }) +} + +/// Converts the given inner join with an empty equality predicate and an +/// empty filter condition to a cross join. +fn convert_to_cross_join_if_beneficial(plan: LogicalPlan) -> Result { + if let LogicalPlan::Join(join) = &plan { + // Can be converted back to cross join + if join.on.is_empty() && join.filter.is_none() { + return LogicalPlanBuilder::from(join.left.as_ref().clone()) + .cross_join(join.right.as_ref().clone())? + .build(); + } + } else if let LogicalPlan::Filter(filter) = &plan { + let new_input = + convert_to_cross_join_if_beneficial(filter.input.as_ref().clone())?; + return Filter::try_new(filter.predicate.clone(), Arc::new(new_input)) + .map(LogicalPlan::Filter); + } + Ok(plan) +} + /// replaces columns by its name on the projection. pub fn replace_cols_by_name( e: Expr, @@ -978,10 +1033,26 @@ fn is_volatile_expression(e: &Expr) -> bool { let mut is_volatile = false; e.apply(&mut |expr| { Ok(match expr { - Expr::ScalarFunction(f) if f.fun.volatility() == Volatility::Volatile => { - is_volatile = true; - VisitRecursion::Stop - } + Expr::ScalarFunction(f) => match &f.func_def { + ScalarFunctionDefinition::BuiltIn(fun) + if fun.volatility() == Volatility::Volatile => + { + is_volatile = true; + VisitRecursion::Stop + } + ScalarFunctionDefinition::UDF(fun) + if fun.signature().volatility == Volatility::Volatile => + { + is_volatile = true; + VisitRecursion::Stop + } + ScalarFunctionDefinition::Name(_) => { + return internal_err!( + "Function `Expr` with name should be resolved." + ); + } + _ => VisitRecursion::Continue, + }, _ => VisitRecursion::Continue, }) }) @@ -1011,13 +1082,16 @@ fn contain(e: &Expr, check_map: &HashMap) -> bool { #[cfg(test)] mod tests { + use std::fmt::{Debug, Formatter}; + use std::sync::Arc; + use super::*; use crate::optimizer::Optimizer; use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; use crate::test::*; use crate::OptimizerContext; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use async_trait::async_trait; use datafusion_common::{DFSchema, DFSchemaRef}; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{ @@ -1025,8 +1099,8 @@ mod tests { BinaryExpr, Expr, Extension, LogicalPlanBuilder, Operator, TableSource, TableType, UserDefinedLogicalNodeCore, }; - use std::fmt::{Debug, Formatter}; - use std::sync::Arc; + + use async_trait::async_trait; fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { crate::test::assert_optimized_plan_eq( @@ -1046,7 +1120,7 @@ mod tests { ]); let mut optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? @@ -2650,14 +2724,12 @@ Projection: a, b .cross_join(right)? .filter(filter)? .build()?; - let expected = "\ - Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)\ - \n CrossJoin:\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]\ - \n Projection: test1.a AS d, test1.a AS e\ - \n TableScan: test1"; + Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)\ + \n Projection: test.a, test.b, test.c\ + \n TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]\ + \n Projection: test1.a AS d, test1.a AS e\ + \n TableScan: test1"; assert_optimized_plan_eq_with_rewrite_predicate(&plan, expected)?; // Originally global state which can help to avoid duplicate Filters been generated and pushed down. diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 6703a1d787a73..c2f35a7906169 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -126,7 +126,7 @@ impl OptimizerRule for PushDownLimit { fetch: scan.fetch.map(|x| min(x, limit)).or(Some(limit)), projected_schema: scan.projected_schema.clone(), }); - Some(plan.with_new_inputs(&[new_input])?) + Some(plan.with_new_exprs(plan.expressions(), &[new_input])?) } } LogicalPlan::Union(union) => { @@ -145,7 +145,7 @@ impl OptimizerRule for PushDownLimit { inputs: new_inputs, schema: union.schema.clone(), }); - Some(plan.with_new_inputs(&[union])?) + Some(plan.with_new_exprs(plan.expressions(), &[union])?) } LogicalPlan::CrossJoin(cross_join) => { @@ -166,15 +166,16 @@ impl OptimizerRule for PushDownLimit { right: Arc::new(new_right), schema: plan.schema().clone(), }); - Some(plan.with_new_inputs(&[new_cross_join])?) + Some(plan.with_new_exprs(plan.expressions(), &[new_cross_join])?) } LogicalPlan::Join(join) => { let new_join = push_down_join(join, fetch + skip); match new_join { - Some(new_join) => { - Some(plan.with_new_inputs(&[LogicalPlan::Join(new_join)])?) - } + Some(new_join) => Some(plan.with_new_exprs( + plan.expressions(), + &[LogicalPlan::Join(new_join)], + )?), None => None, } } @@ -192,14 +193,16 @@ impl OptimizerRule for PushDownLimit { input: Arc::new((*sort.input).clone()), fetch: new_fetch, }); - Some(plan.with_new_inputs(&[new_sort])?) + Some(plan.with_new_exprs(plan.expressions(), &[new_sort])?) } } LogicalPlan::Projection(_) | LogicalPlan::SubqueryAlias(_) => { // commute - let new_limit = - plan.with_new_inputs(&[child_plan.inputs()[0].clone()])?; - Some(child_plan.with_new_inputs(&[new_limit])?) + let new_limit = plan.with_new_exprs( + plan.expressions(), + &[child_plan.inputs()[0].clone()], + )?; + Some(child_plan.with_new_exprs(child_plan.expressions(), &[new_limit])?) } _ => None, }; diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index b05d811cb4819..4ee4f7e417a6a 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -18,530 +18,26 @@ //! Projection Push Down optimizer rule ensures that only referenced columns are //! loaded into memory -use std::collections::{BTreeSet, HashMap, HashSet}; -use std::sync::Arc; - -use crate::eliminate_project::can_eliminate; -use crate::merge_projection::merge_projection; -use crate::optimizer::ApplyOrder; -use crate::push_down_filter::replace_cols_by_name; -use crate::{OptimizerConfig, OptimizerRule}; -use arrow::error::Result as ArrowResult; -use datafusion_common::ScalarValue::UInt8; -use datafusion_common::{ - plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, Result, -}; -use datafusion_expr::expr::{AggregateFunction, Alias}; -use datafusion_expr::{ - logical_plan::{Aggregate, LogicalPlan, Projection, TableScan, Union}, - utils::{expr_to_columns, exprlist_to_columns, exprlist_to_fields}, - Expr, LogicalPlanBuilder, SubqueryAlias, -}; - -// if projection is empty return projection-new_plan, else return new_plan. -#[macro_export] -macro_rules! generate_plan { - ($projection_is_empty:expr, $plan:expr, $new_plan:expr) => { - if $projection_is_empty { - $new_plan - } else { - $plan.with_new_inputs(&[$new_plan])? - } - }; -} - -/// Optimizer that removes unused projections and aggregations from plans -/// This reduces both scans and -#[derive(Default)] -pub struct PushDownProjection {} - -impl OptimizerRule for PushDownProjection { - fn try_optimize( - &self, - plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - let projection = match plan { - LogicalPlan::Projection(projection) => projection, - LogicalPlan::Aggregate(agg) => { - let mut required_columns = HashSet::new(); - for e in agg.aggr_expr.iter().chain(agg.group_expr.iter()) { - expr_to_columns(e, &mut required_columns)? - } - let new_expr = get_expr(&required_columns, agg.input.schema())?; - let projection = LogicalPlan::Projection(Projection::try_new( - new_expr, - agg.input.clone(), - )?); - let optimized_child = self - .try_optimize(&projection, _config)? - .unwrap_or(projection); - return Ok(Some(plan.with_new_inputs(&[optimized_child])?)); - } - LogicalPlan::TableScan(scan) if scan.projection.is_none() => { - return Ok(Some(push_down_scan(&HashSet::new(), scan, false)?)); - } - _ => return Ok(None), - }; - - let child_plan = &*projection.input; - let projection_is_empty = projection.expr.is_empty(); - - let new_plan = match child_plan { - LogicalPlan::Projection(child_projection) => { - let new_plan = merge_projection(projection, child_projection)?; - self.try_optimize(&new_plan, _config)?.unwrap_or(new_plan) - } - LogicalPlan::Join(join) => { - // collect column in on/filter in join and projection. - let mut push_columns: HashSet = HashSet::new(); - for e in projection.expr.iter() { - expr_to_columns(e, &mut push_columns)?; - } - for (l, r) in join.on.iter() { - expr_to_columns(l, &mut push_columns)?; - expr_to_columns(r, &mut push_columns)?; - } - if let Some(expr) = &join.filter { - expr_to_columns(expr, &mut push_columns)?; - } - - let new_left = generate_projection( - &push_columns, - join.left.schema(), - join.left.clone(), - )?; - let new_right = generate_projection( - &push_columns, - join.right.schema(), - join.right.clone(), - )?; - let new_join = child_plan.with_new_inputs(&[new_left, new_right])?; - - generate_plan!(projection_is_empty, plan, new_join) - } - LogicalPlan::CrossJoin(join) => { - // collect column in on/filter in join and projection. - let mut push_columns: HashSet = HashSet::new(); - for e in projection.expr.iter() { - expr_to_columns(e, &mut push_columns)?; - } - let new_left = generate_projection( - &push_columns, - join.left.schema(), - join.left.clone(), - )?; - let new_right = generate_projection( - &push_columns, - join.right.schema(), - join.right.clone(), - )?; - let new_join = child_plan.with_new_inputs(&[new_left, new_right])?; - - generate_plan!(projection_is_empty, plan, new_join) - } - LogicalPlan::TableScan(scan) - if !scan.projected_schema.fields().is_empty() => - { - let mut used_columns: HashSet = HashSet::new(); - if projection_is_empty { - push_down_scan(&used_columns, scan, true)? - } else { - for expr in projection.expr.iter() { - expr_to_columns(expr, &mut used_columns)?; - } - let new_scan = push_down_scan(&used_columns, scan, true)?; - - plan.with_new_inputs(&[new_scan])? - } - } - LogicalPlan::Union(union) => { - let mut required_columns = HashSet::new(); - exprlist_to_columns(&projection.expr, &mut required_columns)?; - // When there is no projection, we need to add the first column to the projection - // Because if push empty down, children may output different columns. - if required_columns.is_empty() { - required_columns.insert(union.schema.fields()[0].qualified_column()); - } - // we don't push down projection expr, we just prune columns, so we just push column - // because push expr may cause more cost. - let projection_column_exprs = get_expr(&required_columns, &union.schema)?; - let mut inputs = Vec::with_capacity(union.inputs.len()); - for input in &union.inputs { - let mut replace_map = HashMap::new(); - for (i, field) in input.schema().fields().iter().enumerate() { - replace_map.insert( - union.schema.fields()[i].qualified_name(), - Expr::Column(field.qualified_column()), - ); - } - - let exprs = projection_column_exprs - .iter() - .map(|expr| replace_cols_by_name(expr.clone(), &replace_map)) - .collect::>>()?; - - inputs.push(Arc::new(LogicalPlan::Projection(Projection::try_new( - exprs, - input.clone(), - )?))) - } - // create schema of all used columns - let schema = DFSchema::new_with_metadata( - exprlist_to_fields(&projection_column_exprs, child_plan)?, - union.schema.metadata().clone(), - )?; - let new_union = LogicalPlan::Union(Union { - inputs, - schema: Arc::new(schema), - }); - - generate_plan!(projection_is_empty, plan, new_union) - } - LogicalPlan::SubqueryAlias(subquery_alias) => { - let replace_map = generate_column_replace_map(subquery_alias); - let mut required_columns = HashSet::new(); - exprlist_to_columns(&projection.expr, &mut required_columns)?; - - let new_required_columns = required_columns - .iter() - .map(|c| { - replace_map.get(c).cloned().ok_or_else(|| { - DataFusionError::Internal("replace column failed".to_string()) - }) - }) - .collect::>>()?; - - let new_expr = - get_expr(&new_required_columns, subquery_alias.input.schema())?; - let new_projection = LogicalPlan::Projection(Projection::try_new( - new_expr, - subquery_alias.input.clone(), - )?); - let new_alias = child_plan.with_new_inputs(&[new_projection])?; - - generate_plan!(projection_is_empty, plan, new_alias) - } - LogicalPlan::Aggregate(agg) => { - let mut required_columns = HashSet::new(); - exprlist_to_columns(&projection.expr, &mut required_columns)?; - // Gather all columns needed for expressions in this Aggregate - let mut new_aggr_expr = vec![]; - for e in agg.aggr_expr.iter() { - let column = Column::from_name(e.display_name()?); - if required_columns.contains(&column) { - new_aggr_expr.push(e.clone()); - } - } - - // if new_aggr_expr emtpy and aggr is COUNT(UInt8(1)), push it - if new_aggr_expr.is_empty() && agg.aggr_expr.len() == 1 { - if let Expr::AggregateFunction(AggregateFunction { - fun, args, .. - }) = &agg.aggr_expr[0] - { - if matches!(fun, datafusion_expr::AggregateFunction::Count) - && args.len() == 1 - && args[0] == Expr::Literal(UInt8(Some(1))) - { - new_aggr_expr.push(agg.aggr_expr[0].clone()); - } - } - } - - let new_agg = LogicalPlan::Aggregate(Aggregate::try_new( - agg.input.clone(), - agg.group_expr.clone(), - new_aggr_expr, - )?); - - generate_plan!(projection_is_empty, plan, new_agg) - } - LogicalPlan::Window(window) => { - let mut required_columns = HashSet::new(); - exprlist_to_columns(&projection.expr, &mut required_columns)?; - // Gather all columns needed for expressions in this Window - let mut new_window_expr = vec![]; - for e in window.window_expr.iter() { - let column = Column::from_name(e.display_name()?); - if required_columns.contains(&column) { - new_window_expr.push(e.clone()); - } - } - - if new_window_expr.is_empty() { - // none columns in window expr are needed, remove the window expr - let input = window.input.clone(); - let new_window = restrict_outputs(input.clone(), &required_columns)? - .unwrap_or((*input).clone()); - - generate_plan!(projection_is_empty, plan, new_window) - } else { - let mut referenced_inputs = HashSet::new(); - exprlist_to_columns(&new_window_expr, &mut referenced_inputs)?; - window - .input - .schema() - .fields() - .iter() - .filter(|f| required_columns.contains(&f.qualified_column())) - .for_each(|f| { - referenced_inputs.insert(f.qualified_column()); - }); - - let input = window.input.clone(); - let new_input = restrict_outputs(input.clone(), &referenced_inputs)? - .unwrap_or((*input).clone()); - let new_window = LogicalPlanBuilder::from(new_input) - .window(new_window_expr)? - .build()?; - - generate_plan!(projection_is_empty, plan, new_window) - } - } - LogicalPlan::Filter(filter) => { - if can_eliminate(projection, child_plan.schema()) { - // when projection schema == filter schema, we can commute directly. - let new_proj = - plan.with_new_inputs(&[filter.input.as_ref().clone()])?; - child_plan.with_new_inputs(&[new_proj])? - } else { - let mut required_columns = HashSet::new(); - exprlist_to_columns(&projection.expr, &mut required_columns)?; - exprlist_to_columns( - &[filter.predicate.clone()], - &mut required_columns, - )?; - - let new_expr = get_expr(&required_columns, filter.input.schema())?; - let new_projection = LogicalPlan::Projection(Projection::try_new( - new_expr, - filter.input.clone(), - )?); - let new_filter = child_plan.with_new_inputs(&[new_projection])?; - - generate_plan!(projection_is_empty, plan, new_filter) - } - } - LogicalPlan::Sort(sort) => { - if can_eliminate(projection, child_plan.schema()) { - // can commute - let new_proj = plan.with_new_inputs(&[(*sort.input).clone()])?; - child_plan.with_new_inputs(&[new_proj])? - } else { - let mut required_columns = HashSet::new(); - exprlist_to_columns(&projection.expr, &mut required_columns)?; - exprlist_to_columns(&sort.expr, &mut required_columns)?; - - let new_expr = get_expr(&required_columns, sort.input.schema())?; - let new_projection = LogicalPlan::Projection(Projection::try_new( - new_expr, - sort.input.clone(), - )?); - let new_sort = child_plan.with_new_inputs(&[new_projection])?; - - generate_plan!(projection_is_empty, plan, new_sort) - } - } - LogicalPlan::Limit(limit) => { - // can commute - let new_proj = plan.with_new_inputs(&[limit.input.as_ref().clone()])?; - child_plan.with_new_inputs(&[new_proj])? - } - _ => return Ok(None), - }; - - Ok(Some(new_plan)) - } - - fn name(&self) -> &str { - "push_down_projection" - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } -} - -impl PushDownProjection { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -fn generate_column_replace_map( - subquery_alias: &SubqueryAlias, -) -> HashMap { - subquery_alias - .input - .schema() - .fields() - .iter() - .enumerate() - .map(|(i, field)| { - ( - subquery_alias.schema.fields()[i].qualified_column(), - field.qualified_column(), - ) - }) - .collect() -} - -pub fn collect_projection_expr(projection: &Projection) -> HashMap { - projection - .schema - .fields() - .iter() - .enumerate() - .flat_map(|(i, field)| { - // strip alias, as they should not be part of filters - let expr = match &projection.expr[i] { - Expr::Alias(Alias { expr, .. }) => expr.as_ref().clone(), - expr => expr.clone(), - }; - - // Convert both qualified and unqualified fields - [ - (field.name().clone(), expr.clone()), - (field.qualified_name(), expr), - ] - }) - .collect::>() -} - -/// Get the projection exprs from columns in the order of the schema -fn get_expr(columns: &HashSet, schema: &DFSchemaRef) -> Result> { - let expr = schema - .fields() - .iter() - .flat_map(|field| { - let qc = field.qualified_column(); - let uqc = field.unqualified_column(); - if columns.contains(&qc) || columns.contains(&uqc) { - Some(Expr::Column(qc)) - } else { - None - } - }) - .collect::>(); - if columns.len() != expr.len() { - plan_err!("required columns can't push down, columns: {columns:?}") - } else { - Ok(expr) - } -} - -fn generate_projection( - used_columns: &HashSet, - schema: &DFSchemaRef, - input: Arc, -) -> Result { - let expr = schema - .fields() - .iter() - .flat_map(|field| { - let column = field.qualified_column(); - if used_columns.contains(&column) { - Some(Expr::Column(column)) - } else { - None - } - }) - .collect::>(); - - Ok(LogicalPlan::Projection(Projection::try_new(expr, input)?)) -} - -fn push_down_scan( - used_columns: &HashSet, - scan: &TableScan, - has_projection: bool, -) -> Result { - // once we reach the table scan, we can use the accumulated set of column - // names to construct the set of column indexes in the scan - // - // we discard non-existing columns because some column names are not part of the schema, - // e.g. when the column derives from an aggregation - // - // Use BTreeSet to remove potential duplicates (e.g. union) as - // well as to sort the projection to ensure deterministic behavior - let schema = scan.source.schema(); - let mut projection: BTreeSet = used_columns - .iter() - .filter(|c| { - c.relation.is_none() || c.relation.as_ref().unwrap() == &scan.table_name - }) - .map(|c| schema.index_of(&c.name)) - .filter_map(ArrowResult::ok) - .collect(); - - if !has_projection && projection.is_empty() { - // for table scan without projection, we default to return all columns - projection = schema - .fields() - .iter() - .enumerate() - .map(|(i, _)| i) - .collect::>(); - } - - // Building new projection from BTreeSet - // preserving source projection order if it exists - let projection = if let Some(original_projection) = &scan.projection { - original_projection - .clone() - .into_iter() - .filter(|idx| projection.contains(idx)) - .collect::>() - } else { - projection.into_iter().collect::>() - }; - - TableScan::try_new( - scan.table_name.clone(), - scan.source.clone(), - Some(projection), - scan.filters.clone(), - scan.fetch, - ) - .map(LogicalPlan::TableScan) -} - -fn restrict_outputs( - plan: Arc, - permitted_outputs: &HashSet, -) -> Result> { - let schema = plan.schema(); - if permitted_outputs.len() == schema.fields().len() { - return Ok(None); - } - Ok(Some(generate_projection( - permitted_outputs, - schema, - plan.clone(), - )?)) -} - #[cfg(test)] mod tests { use std::collections::HashMap; + use std::sync::Arc; use std::vec; - use super::*; - use crate::eliminate_project::EliminateProjection; + use crate::optimize_projections::OptimizeProjections; use crate::optimizer::Optimizer; use crate::test::*; use crate::OptimizerContext; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{DFField, DFSchema}; + use datafusion_common::{Column, DFField, DFSchema, Result}; use datafusion_expr::builder::table_scan_with_filters; use datafusion_expr::expr::{self, Cast}; use datafusion_expr::logical_plan::{ builder::LogicalPlanBuilder, table_scan, JoinType, }; use datafusion_expr::{ - col, count, lit, max, min, AggregateFunction, Expr, WindowFrame, WindowFunction, + col, count, lit, max, min, AggregateFunction, Expr, LogicalPlan, Projection, + WindowFrame, WindowFunctionDefinition, }; #[test] @@ -605,6 +101,31 @@ mod tests { assert_optimized_plan_eq(&plan, expected) } + #[test] + fn aggregate_with_periods() -> Result<()> { + let schema = Schema::new(vec![Field::new("tag.one", DataType::Utf8, false)]); + + // Build a plan that looks as follows (note "tag.one" is a column named + // "tag.one", not a column named "one" in a table named "tag"): + // + // Projection: tag.one + // Aggregate: groupBy=[], aggr=[MAX("tag.one") AS "tag.one"] + // TableScan + let plan = table_scan(Some("m4"), &schema, None)? + .aggregate( + Vec::::new(), + vec![max(col(Column::new_unqualified("tag.one"))).alias("tag.one")], + )? + .project([col(Column::new_unqualified("tag.one"))])? + .build()?; + + let expected = "\ + Aggregate: groupBy=[[]], aggr=[[MAX(m4.tag.one) AS tag.one]]\ + \n TableScan: m4 projection=[tag.one]"; + + assert_optimized_plan_eq(&plan, expected) + } + #[test] fn redundant_project() -> Result<()> { let table_scan = test_table_scan()?; @@ -842,7 +363,7 @@ mod tests { // Build the LogicalPlan directly (don't use PlanBuilder), so // that the Column references are unqualified (e.g. their // relation is `None`). PlanBuilder resolves the expressions - let expr = vec![col("a"), col("b")]; + let expr = vec![col("test.a"), col("test.b")]; let plan = LogicalPlan::Projection(Projection::try_new(expr, Arc::new(table_scan))?); @@ -1061,7 +582,7 @@ mod tests { let table_scan = test_table_scan()?; let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("test.a")], vec![col("test.b")], vec![], @@ -1069,7 +590,7 @@ mod tests { )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("test.b")], vec![], vec![], @@ -1101,24 +622,14 @@ mod tests { } fn optimize(plan: &LogicalPlan) -> Result { - let optimizer = Optimizer::with_rules(vec![ - Arc::new(PushDownProjection::new()), - Arc::new(EliminateProjection::new()), - ]); - let mut optimized_plan = optimizer + let optimizer = Optimizer::with_rules(vec![Arc::new(OptimizeProjections::new())]); + let optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? .unwrap_or_else(|| plan.clone()); - optimized_plan = optimizer - .optimize_recursively( - optimizer.rules.get(1).unwrap(), - &optimized_plan, - &OptimizerContext::new(), - )? - .unwrap_or(optimized_plan); Ok(optimized_plan) } } diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index 540617b770845..187e510e557db 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -20,7 +20,11 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; use datafusion_expr::utils::expand_wildcard; -use datafusion_expr::{Aggregate, Distinct, LogicalPlan}; +use datafusion_expr::{ + aggregate_function::AggregateFunction as AggregateFunctionFunc, col, + expr::AggregateFunction, LogicalPlanBuilder, +}; +use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan}; /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] /// @@ -32,6 +36,22 @@ use datafusion_expr::{Aggregate, Distinct, LogicalPlan}; /// ```text /// SELECT a, b FROM tab GROUP BY a, b /// ``` +/// +/// On the other hand, for a `DISTINCT ON` query the replacement is +/// a bit more involved and effectively converts +/// ```text +/// SELECT DISTINCT ON (a) b FROM tab ORDER BY a DESC, c +/// ``` +/// +/// into +/// ```text +/// SELECT b FROM ( +/// SELECT a, FIRST_VALUE(b ORDER BY a DESC, c) AS b +/// FROM tab +/// GROUP BY a +/// ) +/// ORDER BY a DESC +/// ``` /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] #[derive(Default)] @@ -51,7 +71,7 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { _config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Distinct(Distinct { input }) => { + LogicalPlan::Distinct(Distinct::All(input)) => { let group_expr = expand_wildcard(input.schema(), input, None)?; let aggregate = LogicalPlan::Aggregate(Aggregate::try_new( input.clone(), @@ -60,6 +80,65 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { )?); Ok(Some(aggregate)) } + LogicalPlan::Distinct(Distinct::On(DistinctOn { + select_expr, + on_expr, + sort_expr, + input, + schema, + })) => { + // Construct the aggregation expression to be used to fetch the selected expressions. + let aggr_expr = select_expr + .iter() + .map(|e| { + Expr::AggregateFunction(AggregateFunction::new( + AggregateFunctionFunc::FirstValue, + vec![e.clone()], + false, + None, + sort_expr.clone(), + )) + }) + .collect::>(); + + // Build the aggregation plan + let plan = LogicalPlanBuilder::from(input.as_ref().clone()) + .aggregate(on_expr.clone(), aggr_expr.to_vec())? + .build()?; + + let plan = if let Some(sort_expr) = sort_expr { + // While sort expressions were used in the `FIRST_VALUE` aggregation itself above, + // this on it's own isn't enough to guarantee the proper output order of the grouping + // (`ON`) expression, so we need to sort those as well. + LogicalPlanBuilder::from(plan) + .sort(sort_expr[..on_expr.len()].to_vec())? + .build()? + } else { + plan + }; + + // Whereas the aggregation plan by default outputs both the grouping and the aggregation + // expressions, for `DISTINCT ON` we only need to emit the original selection expressions. + let project_exprs = plan + .schema() + .fields() + .iter() + .skip(on_expr.len()) + .zip(schema.fields().iter()) + .map(|(new_field, old_field)| { + Ok(col(new_field.qualified_column()).alias_qualified( + old_field.qualifier().cloned(), + old_field.name(), + )) + }) + .collect::>>()?; + + let plan = LogicalPlanBuilder::from(plan) + .project(project_exprs)? + .build()?; + + Ok(Some(plan)) + } _ => Ok(None), } } @@ -98,4 +177,27 @@ mod tests { expected, ) } + + #[test] + fn replace_distinct_on() -> datafusion_common::Result<()> { + let table_scan = test_table_scan().unwrap(); + let plan = LogicalPlanBuilder::from(table_scan) + .distinct_on( + vec![col("a")], + vec![col("b")], + Some(vec![col("a").sort(false, true), col("c").sort(true, false)]), + )? + .build()?; + + let expected = "Projection: FIRST_VALUE(test.b) ORDER BY [test.a DESC NULLS FIRST, test.c ASC NULLS LAST] AS b\ + \n Sort: test.a DESC NULLS FIRST\ + \n Aggregate: groupBy=[[test.a]], aggr=[[FIRST_VALUE(test.b) ORDER BY [test.a DESC NULLS FIRST, test.c ASC NULLS LAST]]]\ + \n TableScan: test"; + + assert_optimized_plan_eq( + Arc::new(ReplaceDistinctWithAggregate::new()), + &plan, + expected, + ) + } } diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 7ac0c25119c36..34ed4a9475cba 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -17,7 +17,7 @@ use crate::decorrelate::{PullUpCorrelatedExpr, UN_MATCHED_ROW_INDICATOR}; use crate::optimizer::ApplyOrder; -use crate::utils::{conjunction, replace_qualified_name}; +use crate::utils::replace_qualified_name; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{ @@ -26,6 +26,7 @@ use datafusion_common::tree_node::{ use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; +use datafusion_expr::utils::conjunction; use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 04fdcca0a994d..7d09aec7e748a 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -21,34 +21,33 @@ use std::ops::Not; use super::or_in_list_simplifier::OrInListSimplifier; use super::utils::*; - use crate::analyzer::type_coercion::TypeCoercionRewriter; +use crate::simplify_expressions::guarantees::GuaranteeRewriter; use crate::simplify_expressions::regex::simplify_regex_expr; +use crate::simplify_expressions::SimplifyInfo; + use arrow::{ array::new_null_array, datatypes::{DataType, Field, Schema}, - error::ArrowError, record_batch::RecordBatch, }; use datafusion_common::{ cast::{as_large_list_array, as_list_array}, + plan_err, tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}, }; use datafusion_common::{ exec_err, internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::expr::{InList, InSubquery, ScalarFunction}; use datafusion_expr::{ - and, expr, lit, or, BinaryExpr, BuiltinScalarFunction, Case, ColumnarValue, Expr, - Like, Volatility, + and, lit, or, BinaryExpr, BuiltinScalarFunction, Case, ColumnarValue, Expr, Like, + ScalarFunctionDefinition, Volatility, }; -use datafusion_physical_expr::{ - create_physical_expr, execution_props::ExecutionProps, intervals::NullableInterval, +use datafusion_expr::{ + expr::{InList, InSubquery, ScalarFunction}, + interval_arithmetic::NullableInterval, }; - -use crate::simplify_expressions::SimplifyInfo; - -use crate::simplify_expressions::guarantees::GuaranteeRewriter; +use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; /// This structure handles API for expression simplification pub struct ExprSimplifier { @@ -178,9 +177,9 @@ impl ExprSimplifier { /// ```rust /// use arrow::datatypes::{DataType, Field, Schema}; /// use datafusion_expr::{col, lit, Expr}; + /// use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; /// use datafusion_common::{Result, ScalarValue, ToDFSchema}; /// use datafusion_physical_expr::execution_props::ExecutionProps; - /// use datafusion_physical_expr::intervals::{Interval, NullableInterval}; /// use datafusion_optimizer::simplify_expressions::{ /// ExprSimplifier, SimplifyContext}; /// @@ -207,7 +206,7 @@ impl ExprSimplifier { /// ( /// col("x"), /// NullableInterval::NotNull { - /// values: Interval::make(Some(3_i64), Some(5_i64), (false, false)), + /// values: Interval::make(Some(3_i64), Some(5_i64)).unwrap() /// } /// ), /// // y = 3 @@ -333,7 +332,6 @@ impl<'a> ConstEvaluator<'a> { // Has no runtime cost, but needed during planning Expr::Alias(..) | Expr::AggregateFunction { .. } - | Expr::AggregateUDF { .. } | Expr::ScalarVariable(_, _) | Expr::Column(_) | Expr::OuterReferenceColumn(_, _) @@ -343,15 +341,17 @@ impl<'a> ConstEvaluator<'a> { | Expr::WindowFunction { .. } | Expr::Sort { .. } | Expr::GroupingSet(_) - | Expr::Wildcard - | Expr::QualifiedWildcard { .. } + | Expr::Wildcard { .. } | Expr::Placeholder(_) => false, - Expr::ScalarFunction(ScalarFunction { fun, .. }) => { - Self::volatility_ok(fun.volatility()) - } - Expr::ScalarUDF(expr::ScalarUDF { fun, .. }) => { - Self::volatility_ok(fun.signature.volatility) - } + Expr::ScalarFunction(ScalarFunction { func_def, .. }) => match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + Self::volatility_ok(fun.volatility()) + } + ScalarFunctionDefinition::UDF(fun) => { + Self::volatility_ok(fun.signature().volatility) + } + ScalarFunctionDefinition::Name(_) => false, + }, Expr::Literal(_) | Expr::BinaryExpr { .. } | Expr::Not(_) @@ -481,6 +481,14 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { lit(negated) } + // null in (x, y, z) --> null + // null not in (x, y, z) --> null + Expr::InList(InList { + expr, + list: _, + negated: _, + }) if is_null(&expr) => lit_bool_null(), + // expr IN ((subquery)) -> expr IN (subquery), see ##5529 Expr::InList(InList { expr, @@ -792,7 +800,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Divide, right, }) if is_null(&right) => *right, - // A / 0 -> DivideByZero Error if A is not null and not floating + // A / 0 -> Divide by zero error if A is not null and not floating // (float / 0 -> inf | -inf | NAN) Expr::BinaryExpr(BinaryExpr { left, @@ -802,7 +810,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { && !info.get_data_type(&left)?.is_floating() && is_zero(&right) => { - return Err(DataFusionError::ArrowError(ArrowError::DivideByZero)); + return plan_err!("Divide by zero"); } // @@ -832,7 +840,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { { lit(0) } - // A % 0 --> DivideByZero Error (if A is not floating and not null) + // A % 0 --> Divide by zero Error (if A is not floating and not null) // A % 0 --> NAN (if A is floating and not null) Expr::BinaryExpr(BinaryExpr { left, @@ -843,9 +851,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { DataType::Float32 => lit(f32::NAN), DataType::Float64 => lit(f64::NAN), _ => { - return Err(DataFusionError::ArrowError( - ArrowError::DivideByZero, - )); + return plan_err!("Divide by zero"); } } } @@ -1202,25 +1208,28 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // log Expr::ScalarFunction(ScalarFunction { - fun: BuiltinScalarFunction::Log, + func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Log), args, }) => simpl_log(args, <&S>::clone(&info))?, // power Expr::ScalarFunction(ScalarFunction { - fun: BuiltinScalarFunction::Power, + func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Power), args, }) => simpl_power(args, <&S>::clone(&info))?, // concat Expr::ScalarFunction(ScalarFunction { - fun: BuiltinScalarFunction::Concat, + func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Concat), args, }) => simpl_concat(args)?, // concat_ws Expr::ScalarFunction(ScalarFunction { - fun: BuiltinScalarFunction::ConcatWithSeparator, + func_def: + ScalarFunctionDefinition::BuiltIn( + BuiltinScalarFunction::ConcatWithSeparator, + ), args, }) => match &args[..] { [delimiter, vals @ ..] => simpl_concat_ws(delimiter, vals)?, @@ -1301,26 +1310,27 @@ mod tests { sync::Arc, }; + use super::*; use crate::simplify_expressions::{ utils::for_test::{cast_to_int64_expr, now_expr, to_timestamp_expr}, SimplifyContext, }; - - use super::*; use crate::test::test_table_scan_with_name; + use arrow::{ array::{ArrayRef, Int32Array}, datatypes::{DataType, Field, Schema}, }; - use chrono::{DateTime, TimeZone, Utc}; - use datafusion_common::{assert_contains, cast::as_int32_array, DFField, ToDFSchema}; - use datafusion_expr::*; + use datafusion_common::{ + assert_contains, cast::as_int32_array, plan_datafusion_err, DFField, ToDFSchema, + }; + use datafusion_expr::{interval_arithmetic::Interval, *}; use datafusion_physical_expr::{ - execution_props::ExecutionProps, - functions::make_scalar_function, - intervals::{Interval, NullableInterval}, + execution_props::ExecutionProps, functions::make_scalar_function, }; + use chrono::{DateTime, TimeZone, Utc}; + // ------------------------------ // --- ExprSimplifier tests ----- // ------------------------------ @@ -1553,7 +1563,7 @@ mod tests { // immutable UDF should get folded // udf_add(1+2, 30+40) --> 73 - let expr = Expr::ScalarUDF(expr::ScalarUDF::new( + let expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf( make_udf_add(Volatility::Immutable), args.clone(), )); @@ -1562,15 +1572,21 @@ mod tests { // stable UDF should be entirely folded // udf_add(1+2, 30+40) --> 73 let fun = make_udf_add(Volatility::Stable); - let expr = Expr::ScalarUDF(expr::ScalarUDF::new(Arc::clone(&fun), args.clone())); + let expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf( + Arc::clone(&fun), + args.clone(), + )); test_evaluate(expr, lit(73)); // volatile UDF should have args folded // udf_add(1+2, 30+40) --> udf_add(3, 70) let fun = make_udf_add(Volatility::Volatile); - let expr = Expr::ScalarUDF(expr::ScalarUDF::new(Arc::clone(&fun), args)); - let expected_expr = - Expr::ScalarUDF(expr::ScalarUDF::new(Arc::clone(&fun), folded_args)); + let expr = + Expr::ScalarFunction(expr::ScalarFunction::new_udf(Arc::clone(&fun), args)); + let expected_expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf( + Arc::clone(&fun), + folded_args, + )); test_evaluate(expr, expected_expr); } @@ -1763,25 +1779,23 @@ mod tests { #[test] fn test_simplify_divide_zero_by_zero() { - // 0 / 0 -> DivideByZero + // 0 / 0 -> Divide by zero let expr = lit(0) / lit(0); let err = try_simplify(expr).unwrap_err(); - assert!( - matches!(err, DataFusionError::ArrowError(ArrowError::DivideByZero)), - "{err}" - ); + let _expected = plan_datafusion_err!("Divide by zero"); + + assert!(matches!(err, ref _expected), "{err}"); } #[test] - #[should_panic( - expected = "called `Result::unwrap()` on an `Err` value: ArrowError(DivideByZero)" - )] fn test_simplify_divide_by_zero() { // A / 0 -> DivideByZeroError let expr = col("c2_non_null") / lit(0); - - simplify(expr); + assert_eq!( + try_simplify(expr).unwrap_err().strip_backtrace(), + "Error during planning: Divide by zero" + ); } #[test] @@ -2201,12 +2215,12 @@ mod tests { } #[test] - #[should_panic( - expected = "called `Result::unwrap()` on an `Err` value: ArrowError(DivideByZero)" - )] fn test_simplify_modulo_by_zero_non_null() { let expr = col("c2_non_null") % lit(0); - simplify(expr); + assert_eq!( + try_simplify(expr).unwrap_err().strip_backtrace(), + "Error during planning: Divide by zero" + ); } #[test] @@ -3090,6 +3104,18 @@ mod tests { assert_eq!(simplify(in_list(col("c1"), vec![], false)), lit(false)); assert_eq!(simplify(in_list(col("c1"), vec![], true)), lit(true)); + // null in (...) --> null + assert_eq!( + simplify(in_list(lit_bool_null(), vec![col("c1"), lit(1)], false)), + lit_bool_null() + ); + + // null not in (...) --> null + assert_eq!( + simplify(in_list(lit_bool_null(), vec![col("c1"), lit(1)], true)), + lit_bool_null() + ); + assert_eq!( simplify(in_list(col("c1"), vec![lit(1)], false)), col("c1").eq(lit(1)) @@ -3282,17 +3308,14 @@ mod tests { ( col("c3"), NullableInterval::NotNull { - values: Interval::make(Some(0_i64), Some(2_i64), (false, false)), + values: Interval::make(Some(0_i64), Some(2_i64)).unwrap(), }, ), ( col("c4"), NullableInterval::from(ScalarValue::UInt32(Some(9))), ), - ( - col("c1"), - NullableInterval::from(ScalarValue::Utf8(Some("a".to_string()))), - ), + (col("c1"), NullableInterval::from(ScalarValue::from("a"))), ]; let output = simplify_with_guarantee(expr.clone(), guarantees); assert_eq!(output, lit(false)); @@ -3302,19 +3325,23 @@ mod tests { ( col("c3"), NullableInterval::MaybeNull { - values: Interval::make(Some(0_i64), Some(2_i64), (false, false)), + values: Interval::make(Some(0_i64), Some(2_i64)).unwrap(), }, ), ( col("c4"), NullableInterval::MaybeNull { - values: Interval::make(Some(9_u32), Some(9_u32), (false, false)), + values: Interval::make(Some(9_u32), Some(9_u32)).unwrap(), }, ), ( col("c1"), NullableInterval::NotNull { - values: Interval::make(Some("d"), Some("f"), (false, false)), + values: Interval::try_new( + ScalarValue::from("d"), + ScalarValue::from("f"), + ) + .unwrap(), }, ), ]; diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index 5504d7d76e359..aa7bb4f78a93f 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -18,11 +18,12 @@ //! Simplifier implementation for [`ExprSimplifier::with_guarantees()`] //! //! [`ExprSimplifier::with_guarantees()`]: crate::simplify_expressions::expr_simplifier::ExprSimplifier::with_guarantees + +use std::{borrow::Cow, collections::HashMap}; + use datafusion_common::{tree_node::TreeNodeRewriter, DataFusionError, Result}; +use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr}; -use std::collections::HashMap; - -use datafusion_physical_expr::intervals::{Interval, IntervalBound, NullableInterval}; /// Rewrite expressions to incorporate guarantees. /// @@ -46,6 +47,10 @@ impl<'a> GuaranteeRewriter<'a> { guarantees: impl IntoIterator, ) -> Self { Self { + // TODO: Clippy wants the "map" call removed, but doing so generates + // a compilation error. Remove the clippy directive once this + // issue is fixed. + #[allow(clippy::map_identity)] guarantees: guarantees.into_iter().map(|(k, v)| (k, v)).collect(), } } @@ -82,10 +87,7 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { high.as_ref(), ) { let expr_interval = NullableInterval::NotNull { - values: Interval::new( - IntervalBound::new(low.clone(), false), - IntervalBound::new(high.clone(), false), - ), + values: Interval::try_new(low.clone(), high.clone())?, }; let contains = expr_interval.contains(*interval)?; @@ -103,48 +105,51 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - // We only support comparisons for now - if !op.is_comparison_operator() { - return Ok(expr); - }; - - // Check if this is a comparison between a column and literal - let (col, op, value) = match (left.as_ref(), right.as_ref()) { - (Expr::Column(_), Expr::Literal(value)) => (left, *op, value), - (Expr::Literal(value), Expr::Column(_)) => { - // If we can swap the op, we can simplify the expression - if let Some(op) = op.swap() { - (right, op, value) + // The left or right side of expression might either have a guarantee + // or be a literal. Either way, we can resolve them to a NullableInterval. + let left_interval = self + .guarantees + .get(left.as_ref()) + .map(|interval| Cow::Borrowed(*interval)) + .or_else(|| { + if let Expr::Literal(value) = left.as_ref() { + Some(Cow::Owned(value.clone().into())) } else { - return Ok(expr); + None + } + }); + let right_interval = self + .guarantees + .get(right.as_ref()) + .map(|interval| Cow::Borrowed(*interval)) + .or_else(|| { + if let Expr::Literal(value) = right.as_ref() { + Some(Cow::Owned(value.clone().into())) + } else { + None + } + }); + + match (left_interval, right_interval) { + (Some(left_interval), Some(right_interval)) => { + let result = + left_interval.apply_operator(op, right_interval.as_ref())?; + if result.is_certainly_true() { + Ok(lit(true)) + } else if result.is_certainly_false() { + Ok(lit(false)) + } else { + Ok(expr) } } - _ => return Ok(expr), - }; - - if let Some(col_interval) = self.guarantees.get(col.as_ref()) { - let result = - col_interval.apply_operator(&op, &value.clone().into())?; - if result.is_certainly_true() { - Ok(lit(true)) - } else if result.is_certainly_false() { - Ok(lit(false)) - } else { - Ok(expr) - } - } else { - Ok(expr) + _ => Ok(expr), } } // Columns (if interval is collapsed to a single value) Expr::Column(_) => { - if let Some(col_interval) = self.guarantees.get(&expr) { - if let Some(value) = col_interval.single_value() { - Ok(lit(value)) - } else { - Ok(expr) - } + if let Some(interval) = self.guarantees.get(&expr) { + Ok(interval.single_value().map_or(expr, lit)) } else { Ok(expr) } @@ -208,7 +213,7 @@ mod tests { ( col("x"), NullableInterval::NotNull { - values: Default::default(), + values: Interval::make_unbounded(&DataType::Boolean).unwrap(), }, ), ]; @@ -255,11 +260,18 @@ mod tests { #[test] fn test_inequalities_non_null_bounded() { let guarantees = vec![ - // x ∈ (1, 3] (not null) + // x ∈ [1, 3] (not null) ( col("x"), NullableInterval::NotNull { - values: Interval::make(Some(1_i32), Some(3_i32), (true, false)), + values: Interval::make(Some(1_i32), Some(3_i32)).unwrap(), + }, + ), + // s.y ∈ [1, 3] (not null) + ( + col("s").field("y"), + NullableInterval::NotNull { + values: Interval::make(Some(1_i32), Some(3_i32)).unwrap(), }, ), ]; @@ -268,17 +280,16 @@ mod tests { // (original_expr, expected_simplification) let simplified_cases = &[ - (col("x").lt_eq(lit(1)), false), + (col("x").lt(lit(0)), false), + (col("s").field("y").lt(lit(0)), false), (col("x").lt_eq(lit(3)), true), (col("x").gt(lit(3)), false), - (col("x").gt(lit(1)), true), + (col("x").gt(lit(0)), true), (col("x").eq(lit(0)), false), (col("x").not_eq(lit(0)), true), - (col("x").between(lit(2), lit(5)), true), - (col("x").between(lit(2), lit(3)), true), + (col("x").between(lit(0), lit(5)), true), (col("x").between(lit(5), lit(10)), false), - (col("x").not_between(lit(2), lit(5)), false), - (col("x").not_between(lit(2), lit(3)), false), + (col("x").not_between(lit(0), lit(5)), false), (col("x").not_between(lit(5), lit(10)), true), ( Expr::BinaryExpr(BinaryExpr { @@ -319,10 +330,11 @@ mod tests { ( col("x"), NullableInterval::NotNull { - values: Interval::new( - IntervalBound::new(ScalarValue::Date32(Some(18628)), false), - IntervalBound::make_unbounded(DataType::Date32).unwrap(), - ), + values: Interval::try_new( + ScalarValue::Date32(Some(18628)), + ScalarValue::Date32(None), + ) + .unwrap(), }, ), ]; @@ -397,7 +409,11 @@ mod tests { ( col("x"), NullableInterval::MaybeNull { - values: Interval::make(Some("abc"), Some("def"), (true, false)), + values: Interval::try_new( + ScalarValue::from("abc"), + ScalarValue::from("def"), + ) + .unwrap(), }, ), ]; @@ -451,7 +467,7 @@ mod tests { ScalarValue::Int32(Some(1)), ScalarValue::Boolean(Some(true)), ScalarValue::Boolean(None), - ScalarValue::Utf8(Some("abc".to_string())), + ScalarValue::from("abc"), ScalarValue::LargeUtf8(Some("def".to_string())), ScalarValue::Date32(Some(18628)), ScalarValue::Date32(None), @@ -470,11 +486,15 @@ mod tests { #[test] fn test_in_list() { let guarantees = vec![ - // x ∈ [1, 10) (not null) + // x ∈ [1, 10] (not null) ( col("x"), NullableInterval::NotNull { - values: Interval::make(Some(1_i32), Some(10_i32), (false, true)), + values: Interval::try_new( + ScalarValue::Int32(Some(1)), + ScalarValue::Int32(Some(10)), + ) + .unwrap(), }, ), ]; @@ -486,8 +506,8 @@ mod tests { let cases = &[ // x IN (9, 11) => x IN (9) ("x", vec![9, 11], false, vec![9]), - // x IN (10, 2) => x IN (2) - ("x", vec![10, 2], false, vec![2]), + // x IN (10, 2) => x IN (10, 2) + ("x", vec![10, 2], false, vec![10, 2]), // x NOT IN (9, 11) => x NOT IN (9) ("x", vec![9, 11], true, vec![9]), // x NOT IN (0, 22) => x NOT IN () diff --git a/datafusion/optimizer/src/simplify_expressions/regex.rs b/datafusion/optimizer/src/simplify_expressions/regex.rs index b9d9821b43f09..175b70f2b10e4 100644 --- a/datafusion/optimizer/src/simplify_expressions/regex.rs +++ b/datafusion/optimizer/src/simplify_expressions/regex.rs @@ -84,7 +84,7 @@ impl OperatorMode { let like = Like { negated: self.not, expr, - pattern: Box::new(Expr::Literal(ScalarValue::Utf8(Some(pattern)))), + pattern: Box::new(Expr::Literal(ScalarValue::from(pattern))), escape_char: None, case_insensitive: self.i, }; diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index 9dc83e0fadf57..43a41b1185a33 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -20,10 +20,10 @@ use std::sync::Arc; use super::{ExprSimplifier, SimplifyContext}; -use crate::utils::merge_schema; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{DFSchema, DFSchemaRef, Result}; use datafusion_expr::logical_plan::LogicalPlan; +use datafusion_expr::utils::merge_schema; use datafusion_physical_expr::execution_props::ExecutionProps; /// Optimizer Pass that simplifies [`LogicalPlan`]s by rewriting diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index 17e5d97c30062..fa91a3ace2a25 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -23,7 +23,7 @@ use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ expr::{Between, BinaryExpr, InList}, expr_fn::{and, bitwise_and, bitwise_or, concat_ws, or}, - lit, BuiltinScalarFunction, Expr, Like, Operator, + lit, BuiltinScalarFunction, Expr, Like, Operator, ScalarFunctionDefinition, }; pub static POWS_OF_TEN: [i128; 38] = [ @@ -365,7 +365,7 @@ pub fn simpl_log(current_args: Vec, info: &dyn SimplifyInfo) -> Result Ok(args[1].clone()), _ => { @@ -405,7 +405,7 @@ pub fn simpl_power(current_args: Vec, info: &dyn SimplifyInfo) -> Result Ok(args[1].clone()), _ => Ok(Expr::ScalarFunction(ScalarFunction::new( diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index be76c069f0b73..7e6fb6b355ab1 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -23,7 +23,9 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{DFSchema, Result}; +use datafusion_expr::expr::AggregateFunctionDefinition; use datafusion_expr::{ + aggregate_function::AggregateFunction::{Max, Min, Sum}, col, expr::AggregateFunction, logical_plan::{Aggregate, LogicalPlan, Projection}, @@ -35,17 +37,19 @@ use hashbrown::HashSet; /// single distinct to group by optimizer rule /// ```text -/// SELECT F1(DISTINCT s),F2(DISTINCT s) -/// ... -/// GROUP BY k +/// Before: +/// SELECT a, COUNT(DINSTINCT b), SUM(c) +/// FROM t +/// GROUP BY a /// -/// Into -/// -/// SELECT F1(alias1),F2(alias1) +/// After: +/// SELECT a, COUNT(alias1), SUM(alias2) /// FROM ( -/// SELECT s as alias1, k ... GROUP BY s, k +/// SELECT a, b as alias1, SUM(c) as alias2 +/// FROM t +/// GROUP BY a, b /// ) -/// GROUP BY k +/// GROUP BY a /// ``` #[derive(Default)] pub struct SingleDistinctToGroupBy {} @@ -64,22 +68,30 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> Result { match plan { LogicalPlan::Aggregate(Aggregate { aggr_expr, .. }) => { let mut fields_set = HashSet::new(); - let mut distinct_count = 0; + let mut aggregate_count = 0; for expr in aggr_expr { if let Expr::AggregateFunction(AggregateFunction { - distinct, args, .. + func_def: AggregateFunctionDefinition::BuiltIn(fun), + distinct, + args, + filter, + order_by, }) = expr { - if *distinct { - distinct_count += 1; + if filter.is_some() || order_by.is_some() { + return Ok(false); } - for e in args { - fields_set.insert(e.canonical_name()); + aggregate_count += 1; + if *distinct { + for e in args { + fields_set.insert(e.canonical_name()); + } + } else if !matches!(fun, Sum | Min | Max) { + return Ok(false); } } } - let res = distinct_count == aggr_expr.len() && fields_set.len() == 1; - Ok(res) + Ok(aggregate_count == aggr_expr.len() && fields_set.len() == 1) } _ => Ok(false), } @@ -152,30 +164,57 @@ impl OptimizerRule for SingleDistinctToGroupBy { .collect::>(); // replace the distinct arg with alias + let mut index = 1; let mut group_fields_set = HashSet::new(); - let new_aggr_exprs = aggr_expr + let mut inner_aggr_exprs = vec![]; + let outer_aggr_exprs = aggr_expr .iter() .map(|aggr_expr| match aggr_expr { Expr::AggregateFunction(AggregateFunction { - fun, + func_def: AggregateFunctionDefinition::BuiltIn(fun), args, - filter, - order_by, + distinct, .. }) => { // is_single_distinct_agg ensure args.len=1 - if group_fields_set.insert(args[0].display_name()?) { + if *distinct + && group_fields_set.insert(args[0].display_name()?) + { inner_group_exprs.push( args[0].clone().alias(SINGLE_DISTINCT_ALIAS), ); } - Ok(Expr::AggregateFunction(AggregateFunction::new( - fun.clone(), - vec![col(SINGLE_DISTINCT_ALIAS)], - false, // intentional to remove distinct here - filter.clone(), - order_by.clone(), - ))) + + // if the aggregate function is not distinct, we need to rewrite it like two phase aggregation + if !(*distinct) { + index += 1; + let alias_str = format!("alias{}", index); + inner_aggr_exprs.push( + Expr::AggregateFunction(AggregateFunction::new( + fun.clone(), + args.clone(), + false, + None, + None, + )) + .alias(&alias_str), + ); + Ok(Expr::AggregateFunction(AggregateFunction::new( + fun.clone(), + vec![col(&alias_str)], + false, + None, + None, + ))) + } else { + Ok(Expr::AggregateFunction(AggregateFunction::new( + fun.clone(), + vec![col(SINGLE_DISTINCT_ALIAS)], + false, // intentional to remove distinct here + None, + None, + ))) + } } _ => Ok(aggr_expr.clone()), }) @@ -184,6 +223,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { // construct the inner AggrPlan let inner_fields = inner_group_exprs .iter() + .chain(inner_aggr_exprs.iter()) .map(|expr| expr.to_field(input.schema())) .collect::>>()?; let inner_schema = DFSchema::new_with_metadata( @@ -193,12 +233,12 @@ impl OptimizerRule for SingleDistinctToGroupBy { let inner_agg = LogicalPlan::Aggregate(Aggregate::try_new( input.clone(), inner_group_exprs, - Vec::new(), + inner_aggr_exprs, )?); let outer_fields = outer_group_exprs .iter() - .chain(new_aggr_exprs.iter()) + .chain(outer_aggr_exprs.iter()) .map(|expr| expr.to_field(&inner_schema)) .collect::>>()?; let outer_aggr_schema = Arc::new(DFSchema::new_with_metadata( @@ -220,7 +260,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { group_expr } }) - .chain(new_aggr_exprs.iter().enumerate().map(|(idx, expr)| { + .chain(outer_aggr_exprs.iter().enumerate().map(|(idx, expr)| { let idx = idx + group_size; let name = fields[idx].qualified_name(); columnize_expr(expr.clone().alias(name), &outer_aggr_schema) @@ -230,7 +270,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { let outer_aggr = LogicalPlan::Aggregate(Aggregate::try_new( Arc::new(inner_agg), outer_group_exprs, - new_aggr_exprs, + outer_aggr_exprs, )?); Ok(Some(LogicalPlan::Projection(Projection::try_new( @@ -262,7 +302,7 @@ mod tests { use datafusion_expr::expr::GroupingSet; use datafusion_expr::{ col, count, count_distinct, lit, logical_plan::builder::LogicalPlanBuilder, max, - AggregateFunction, + min, sum, AggregateFunction, }; fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { @@ -322,7 +362,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) @@ -340,7 +380,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) @@ -359,7 +399,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) @@ -478,4 +518,181 @@ mod tests { assert_optimized_plan_equal(&plan, expected) } + + #[test] + fn two_distinct_and_one_common() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![ + sum(col("c")), + count_distinct(col("b")), + Expr::AggregateFunction(expr::AggregateFunction::new( + AggregateFunction::Max, + vec![col("b")], + true, + None, + None, + )), + ], + )? + .build()?; + // Should work + let expected = "Projection: test.a, SUM(alias2) AS SUM(test.c), COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, SUM(test.c):UInt64;N, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(alias2), COUNT(alias1), MAX(alias1)]] [a:UInt32, SUM(alias2):UInt64;N, COUNT(alias1):Int64;N, MAX(alias1):UInt32;N]\ + \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn one_distinctand_and_two_common() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![sum(col("c")), max(col("c")), count_distinct(col("b"))], + )? + .build()?; + // Should work + let expected = "Projection: test.a, SUM(alias2) AS SUM(test.c), MAX(alias3) AS MAX(test.c), COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, SUM(test.c):UInt64;N, MAX(test.c):UInt32;N, COUNT(DISTINCT test.b):Int64;N]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(alias2), MAX(alias3), COUNT(alias1)]] [a:UInt32, SUM(alias2):UInt64;N, MAX(alias3):UInt32;N, COUNT(alias1):Int64;N]\ + \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2, MAX(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn one_distinct_and_one_common() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("c")], + vec![min(col("a")), count_distinct(col("b"))], + )? + .build()?; + // Should work + let expected = "Projection: test.c, MIN(alias2) AS MIN(test.a), COUNT(alias1) AS COUNT(DISTINCT test.b) [c:UInt32, MIN(test.a):UInt32;N, COUNT(DISTINCT test.b):Int64;N]\ + \n Aggregate: groupBy=[[test.c]], aggr=[[MIN(alias2), COUNT(alias1)]] [c:UInt32, MIN(alias2):UInt32;N, COUNT(alias1):Int64;N]\ + \n Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[MIN(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn common_with_filter() -> Result<()> { + let table_scan = test_table_scan()?; + + // SUM(a) FILTER (WHERE a > 5) + let expr = Expr::AggregateFunction(expr::AggregateFunction::new( + AggregateFunction::Sum, + vec![col("a")], + false, + Some(Box::new(col("a").gt(lit(5)))), + None, + )); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])? + .build()?; + // Do nothing + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a) FILTER (WHERE test.a > Int32(5)), COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, COUNT(DISTINCT test.b):Int64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn distinct_with_filter() -> Result<()> { + let table_scan = test_table_scan()?; + + // COUNT(DISTINCT a) FILTER (WHERE a > 5) + let expr = Expr::AggregateFunction(expr::AggregateFunction::new( + AggregateFunction::Count, + vec![col("a")], + true, + Some(Box::new(col("a").gt(lit(5)))), + None, + )); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("c")], vec![sum(col("a")), expr])? + .build()?; + // Do nothing + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn common_with_order_by() -> Result<()> { + let table_scan = test_table_scan()?; + + // SUM(a ORDER BY a) + let expr = Expr::AggregateFunction(expr::AggregateFunction::new( + AggregateFunction::Sum, + vec![col("a")], + false, + None, + Some(vec![col("a")]), + )); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])? + .build()?; + // Do nothing + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a) ORDER BY [test.a], COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) ORDER BY [test.a]:UInt64;N, COUNT(DISTINCT test.b):Int64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn distinct_with_order_by() -> Result<()> { + let table_scan = test_table_scan()?; + + // COUNT(DISTINCT a ORDER BY a) + let expr = Expr::AggregateFunction(expr::AggregateFunction::new( + AggregateFunction::Count, + vec![col("a")], + true, + None, + Some(vec![col("a")]), + )); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("c")], vec![sum(col("a")), expr])? + .build()?; + // Do nothing + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) ORDER BY [test.a]:Int64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn aggregate_with_filter_and_order_by() -> Result<()> { + let table_scan = test_table_scan()?; + + // COUNT(DISTINCT a ORDER BY a) FILTER (WHERE a > 5) + let expr = Expr::AggregateFunction(expr::AggregateFunction::new( + AggregateFunction::Count, + vec![col("a")], + true, + Some(Box::new(col("a").gt(lit(5)))), + Some(vec![col("a")]), + )); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("c")], vec![sum(col("a")), expr])? + .build()?; + // Do nothing + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]:Int64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } } diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index 3eac2317b849f..e691fe9a53516 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -16,7 +16,7 @@ // under the License. use crate::analyzer::{Analyzer, AnalyzerRule}; -use crate::optimizer::Optimizer; +use crate::optimizer::{assert_schema_is_the_same, Optimizer}; use crate::{OptimizerContext, OptimizerRule}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; @@ -155,14 +155,17 @@ pub fn assert_optimized_plan_eq( plan: &LogicalPlan, expected: &str, ) -> Result<()> { - let optimizer = Optimizer::with_rules(vec![rule]); + let optimizer = Optimizer::with_rules(vec![rule.clone()]); let optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? .unwrap_or_else(|| plan.clone()); + + // Ensure schemas always match after an optimization + assert_schema_is_the_same(rule.name(), plan, &optimized_plan)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); @@ -196,7 +199,7 @@ pub fn assert_optimized_plan_eq_display_indent( let optimizer = Optimizer::with_rules(vec![rule]); let optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), ) @@ -230,7 +233,7 @@ pub fn assert_optimizer_err( ) { let optimizer = Optimizer::with_rules(vec![rule]); let res = optimizer.optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), ); @@ -252,7 +255,7 @@ pub fn assert_optimization_skipped( let optimizer = Optimizer::with_rules(vec![rule]); let new_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 468981a5fb0c8..91603e82a54fc 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -19,7 +19,6 @@ //! of expr can be added if needed. //! This rule can reduce adding the `Expr::Cast` the expr instead of adding the `Expr::Cast` to literal expr. use crate::optimizer::ApplyOrder; -use crate::utils::merge_schema; use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::{ DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, @@ -31,6 +30,7 @@ use datafusion_common::{ }; use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast}; use datafusion_expr::expr_rewriter::rewrite_preserving_name; +use datafusion_expr::utils::merge_schema; use datafusion_expr::{ binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator, }; @@ -1089,8 +1089,12 @@ mod tests { // Verify that calling the arrow // cast kernel yields the same results // input array - let literal_array = literal.to_array_of_size(1); - let expected_array = expected_value.to_array_of_size(1); + let literal_array = literal + .to_array_of_size(1) + .expect("Failed to convert to array of size"); + let expected_array = expected_value + .to_array_of_size(1) + .expect("Failed to convert to array of size"); let cast_array = cast_with_options( &literal_array, &target_type, diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index a3e7e42875d7e..44f2404afade8 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -18,19 +18,13 @@ //! Collection of utility functions that are leveraged by the query optimizer rules use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::DataFusionError; -use datafusion_common::{plan_err, Column, DFSchemaRef}; +use datafusion_common::{Column, DFSchemaRef}; use datafusion_common::{DFSchema, Result}; -use datafusion_expr::expr::{Alias, BinaryExpr}; -use datafusion_expr::expr_rewriter::{replace_col, strip_outer_reference}; -use datafusion_expr::{ - and, - logical_plan::{Filter, LogicalPlan}, - Expr, Operator, -}; +use datafusion_expr::expr_rewriter::replace_col; +use datafusion_expr::utils as expr_utils; +use datafusion_expr::{logical_plan::LogicalPlan, Expr, Operator}; use log::{debug, trace}; use std::collections::{BTreeSet, HashMap}; -use std::sync::Arc; /// Convenience rule for writing optimizers: recursively invoke /// optimize on plan's children and then return a node of the same @@ -52,35 +46,61 @@ pub fn optimize_children( new_inputs.push(new_input.unwrap_or_else(|| input.clone())) } if plan_is_changed { - Ok(Some(plan.with_new_inputs(&new_inputs)?)) + Ok(Some(plan.with_new_exprs(plan.expressions(), &new_inputs)?)) } else { Ok(None) } } +pub(crate) fn collect_subquery_cols( + exprs: &[Expr], + subquery_schema: DFSchemaRef, +) -> Result> { + exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| { + let mut using_cols: Vec = vec![]; + for col in expr.to_columns()?.into_iter() { + if subquery_schema.has_column(&col) { + using_cols.push(col); + } + } + + cols.extend(using_cols); + Result::<_>::Ok(cols) + }) +} + +pub(crate) fn replace_qualified_name( + expr: Expr, + cols: &BTreeSet, + subquery_alias: &str, +) -> Result { + let alias_cols: Vec = cols + .iter() + .map(|col| { + Column::from_qualified_name(format!("{}.{}", subquery_alias, col.name)) + }) + .collect(); + let replace_map: HashMap<&Column, &Column> = + cols.iter().zip(alias_cols.iter()).collect(); + + replace_col(expr, &replace_map) +} + +/// Log the plan in debug/tracing mode after some part of the optimizer runs +pub fn log_plan(description: &str, plan: &LogicalPlan) { + debug!("{description}:\n{}\n", plan.display_indent()); + trace!("{description}::\n{}\n", plan.display_indent_schema()); +} + /// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` /// /// See [`split_conjunction_owned`] for more details and an example. +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::split_conjunction` instead" +)] pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> { - split_conjunction_impl(expr, vec![]) -} - -fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> { - match expr { - Expr::BinaryExpr(BinaryExpr { - right, - op: Operator::And, - left, - }) => { - let exprs = split_conjunction_impl(left, exprs); - split_conjunction_impl(right, exprs) - } - Expr::Alias(Alias { expr, .. }) => split_conjunction_impl(expr, exprs), - other => { - exprs.push(other); - exprs - } - } + expr_utils::split_conjunction(expr) } /// Splits an owned conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` @@ -104,8 +124,12 @@ fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<& /// // use split_conjunction_owned to split them /// assert_eq!(split_conjunction_owned(expr), split); /// ``` +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::split_conjunction_owned` instead" +)] pub fn split_conjunction_owned(expr: Expr) -> Vec { - split_binary_owned(expr, Operator::And) + expr_utils::split_conjunction_owned(expr) } /// Splits an owned binary operator tree [`Expr`] such as `A B C` => `[A, B, C]` @@ -130,53 +154,23 @@ pub fn split_conjunction_owned(expr: Expr) -> Vec { /// // use split_binary_owned to split them /// assert_eq!(split_binary_owned(expr, Operator::Plus), split); /// ``` +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::split_binary_owned` instead" +)] pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec { - split_binary_owned_impl(expr, op, vec![]) -} - -fn split_binary_owned_impl( - expr: Expr, - operator: Operator, - mut exprs: Vec, -) -> Vec { - match expr { - Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => { - let exprs = split_binary_owned_impl(*left, operator, exprs); - split_binary_owned_impl(*right, operator, exprs) - } - Expr::Alias(Alias { expr, .. }) => { - split_binary_owned_impl(*expr, operator, exprs) - } - other => { - exprs.push(other); - exprs - } - } + expr_utils::split_binary_owned(expr, op) } /// Splits an binary operator tree [`Expr`] such as `A B C` => `[A, B, C]` /// /// See [`split_binary_owned`] for more details and an example. +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::split_binary` instead" +)] pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> { - split_binary_impl(expr, op, vec![]) -} - -fn split_binary_impl<'a>( - expr: &'a Expr, - operator: Operator, - mut exprs: Vec<&'a Expr>, -) -> Vec<&'a Expr> { - match expr { - Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => { - let exprs = split_binary_impl(left, operator, exprs); - split_binary_impl(right, operator, exprs) - } - Expr::Alias(Alias { expr, .. }) => split_binary_impl(expr, operator, exprs), - other => { - exprs.push(other); - exprs - } - } + expr_utils::split_binary(expr, op) } /// Combines an array of filter expressions into a single filter @@ -201,8 +195,12 @@ fn split_binary_impl<'a>( /// // use conjunction to join them together with `AND` /// assert_eq!(conjunction(split), Some(expr)); /// ``` +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::conjunction` instead" +)] pub fn conjunction(filters: impl IntoIterator) -> Option { - filters.into_iter().reduce(|accum, expr| accum.and(expr)) + expr_utils::conjunction(filters) } /// Combines an array of filter expressions into a single filter @@ -210,25 +208,22 @@ pub fn conjunction(filters: impl IntoIterator) -> Option { /// logical OR. /// /// Returns None if the filters array is empty. +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::disjunction` instead" +)] pub fn disjunction(filters: impl IntoIterator) -> Option { - filters.into_iter().reduce(|accum, expr| accum.or(expr)) + expr_utils::disjunction(filters) } /// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with /// its predicate be all `predicates` ANDed. +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::add_filter` instead" +)] pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result { - // reduce filters to a single filter with an AND - let predicate = predicates - .iter() - .skip(1) - .fold(predicates[0].clone(), |acc, predicate| { - and(acc, (*predicate).to_owned()) - }); - - Ok(LogicalPlan::Filter(Filter::try_new( - predicate, - Arc::new(plan), - )?)) + expr_utils::add_filter(plan, predicates) } /// Looks for correlating expressions: for example, a binary expression with one field from the subquery, and @@ -241,22 +236,12 @@ pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result) -> Result<(Vec, Vec)> { - let mut joins = vec![]; - let mut others = vec![]; - for filter in exprs.into_iter() { - // If the expression contains correlated predicates, add it to join filters - if filter.contains_outer() { - if !matches!(filter, Expr::BinaryExpr(BinaryExpr{ left, op: Operator::Eq, right }) if left.eq(right)) - { - joins.push(strip_outer_reference((*filter).clone())); - } - } else { - others.push((*filter).clone()); - } - } - - Ok((joins, others)) + expr_utils::find_join_exprs(exprs) } /// Returns the first (and only) element in a slice, or an error @@ -268,215 +253,19 @@ pub fn find_join_exprs(exprs: Vec<&Expr>) -> Result<(Vec, Vec)> { /// # Return value /// /// The first element, or an error +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::only_or_err` instead" +)] pub fn only_or_err(slice: &[T]) -> Result<&T> { - match slice { - [it] => Ok(it), - [] => plan_err!("No items found!"), - _ => plan_err!("More than one item found!"), - } + expr_utils::only_or_err(slice) } /// merge inputs schema into a single schema. +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::merge_schema` instead" +)] pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema { - if inputs.len() == 1 { - inputs[0].schema().clone().as_ref().clone() - } else { - inputs.iter().map(|input| input.schema()).fold( - DFSchema::empty(), - |mut lhs, rhs| { - lhs.merge(rhs); - lhs - }, - ) - } -} - -pub(crate) fn collect_subquery_cols( - exprs: &[Expr], - subquery_schema: DFSchemaRef, -) -> Result> { - exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| { - let mut using_cols: Vec = vec![]; - for col in expr.to_columns()?.into_iter() { - if subquery_schema.has_column(&col) { - using_cols.push(col); - } - } - - cols.extend(using_cols); - Result::<_>::Ok(cols) - }) -} - -pub(crate) fn replace_qualified_name( - expr: Expr, - cols: &BTreeSet, - subquery_alias: &str, -) -> Result { - let alias_cols: Vec = cols - .iter() - .map(|col| { - Column::from_qualified_name(format!("{}.{}", subquery_alias, col.name)) - }) - .collect(); - let replace_map: HashMap<&Column, &Column> = - cols.iter().zip(alias_cols.iter()).collect(); - - replace_col(expr, &replace_map) -} - -/// Log the plan in debug/tracing mode after some part of the optimizer runs -pub fn log_plan(description: &str, plan: &LogicalPlan) { - debug!("{description}:\n{}\n", plan.display_indent()); - trace!("{description}::\n{}\n", plan.display_indent_schema()); -} - -#[cfg(test)] -mod tests { - use super::*; - use arrow::datatypes::DataType; - use datafusion_common::Column; - use datafusion_expr::expr::Cast; - use datafusion_expr::{col, lit, utils::expr_to_columns}; - use std::collections::HashSet; - - #[test] - fn test_split_conjunction() { - let expr = col("a"); - let result = split_conjunction(&expr); - assert_eq!(result, vec![&expr]); - } - - #[test] - fn test_split_conjunction_two() { - let expr = col("a").eq(lit(5)).and(col("b")); - let expr1 = col("a").eq(lit(5)); - let expr2 = col("b"); - - let result = split_conjunction(&expr); - assert_eq!(result, vec![&expr1, &expr2]); - } - - #[test] - fn test_split_conjunction_alias() { - let expr = col("a").eq(lit(5)).and(col("b").alias("the_alias")); - let expr1 = col("a").eq(lit(5)); - let expr2 = col("b"); // has no alias - - let result = split_conjunction(&expr); - assert_eq!(result, vec![&expr1, &expr2]); - } - - #[test] - fn test_split_conjunction_or() { - let expr = col("a").eq(lit(5)).or(col("b")); - let result = split_conjunction(&expr); - assert_eq!(result, vec![&expr]); - } - - #[test] - fn test_split_binary_owned() { - let expr = col("a"); - assert_eq!(split_binary_owned(expr.clone(), Operator::And), vec![expr]); - } - - #[test] - fn test_split_binary_owned_two() { - assert_eq!( - split_binary_owned(col("a").eq(lit(5)).and(col("b")), Operator::And), - vec![col("a").eq(lit(5)), col("b")] - ); - } - - #[test] - fn test_split_binary_owned_different_op() { - let expr = col("a").eq(lit(5)).or(col("b")); - assert_eq!( - // expr is connected by OR, but pass in AND - split_binary_owned(expr.clone(), Operator::And), - vec![expr] - ); - } - - #[test] - fn test_split_conjunction_owned() { - let expr = col("a"); - assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); - } - - #[test] - fn test_split_conjunction_owned_two() { - assert_eq!( - split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))), - vec![col("a").eq(lit(5)), col("b")] - ); - } - - #[test] - fn test_split_conjunction_owned_alias() { - assert_eq!( - split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))), - vec![ - col("a").eq(lit(5)), - // no alias on b - col("b"), - ] - ); - } - - #[test] - fn test_conjunction_empty() { - assert_eq!(conjunction(vec![]), None); - } - - #[test] - fn test_conjunction() { - // `[A, B, C]` - let expr = conjunction(vec![col("a"), col("b"), col("c")]); - - // --> `(A AND B) AND C` - assert_eq!(expr, Some(col("a").and(col("b")).and(col("c")))); - - // which is different than `A AND (B AND C)` - assert_ne!(expr, Some(col("a").and(col("b").and(col("c"))))); - } - - #[test] - fn test_disjunction_empty() { - assert_eq!(disjunction(vec![]), None); - } - - #[test] - fn test_disjunction() { - // `[A, B, C]` - let expr = disjunction(vec![col("a"), col("b"), col("c")]); - - // --> `(A OR B) OR C` - assert_eq!(expr, Some(col("a").or(col("b")).or(col("c")))); - - // which is different than `A OR (B OR C)` - assert_ne!(expr, Some(col("a").or(col("b").or(col("c"))))); - } - - #[test] - fn test_split_conjunction_owned_or() { - let expr = col("a").eq(lit(5)).or(col("b")); - assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); - } - - #[test] - fn test_collect_expr() -> Result<()> { - let mut accum: HashSet = HashSet::new(); - expr_to_columns( - &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), - &mut accum, - )?; - expr_to_columns( - &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), - &mut accum, - )?; - assert_eq!(1, accum.len()); - assert!(accum.contains(&Column::from_name("a"))); - Ok(()) - } + expr_utils::merge_schema(inputs) } diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 872071e52fa7a..d857c6154ea97 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -15,8 +15,11 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; +use std::collections::HashMap; +use std::sync::Arc; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; -use chrono::{DateTime, NaiveDateTime, Utc}; use datafusion_common::config::ConfigOptions; use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF}; @@ -28,9 +31,8 @@ use datafusion_sql::sqlparser::ast::Statement; use datafusion_sql::sqlparser::dialect::GenericDialect; use datafusion_sql::sqlparser::parser::Parser; use datafusion_sql::TableReference; -use std::any::Any; -use std::collections::HashMap; -use std::sync::Arc; + +use chrono::{DateTime, NaiveDateTime, Utc}; #[cfg(test)] #[ctor::ctor] @@ -185,8 +187,9 @@ fn between_date32_plus_interval() -> Result<()> { let plan = test_sql(sql)?; let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\ - \n Filter: test.col_date32 >= Date32(\"10303\") AND test.col_date32 <= Date32(\"10393\")\ - \n TableScan: test projection=[col_date32]"; + \n Projection: \ + \n Filter: test.col_date32 >= Date32(\"10303\") AND test.col_date32 <= Date32(\"10393\")\ + \n TableScan: test projection=[col_date32]"; assert_eq!(expected, format!("{plan:?}")); Ok(()) } @@ -198,8 +201,9 @@ fn between_date64_plus_interval() -> Result<()> { let plan = test_sql(sql)?; let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\ - \n Filter: test.col_date64 >= Date64(\"890179200000\") AND test.col_date64 <= Date64(\"897955200000\")\ - \n TableScan: test projection=[col_date64]"; + \n Projection: \ + \n Filter: test.col_date64 >= Date64(\"890179200000\") AND test.col_date64 <= Date64(\"897955200000\")\ + \n TableScan: test projection=[col_date64]"; assert_eq!(expected, format!("{plan:?}")); Ok(()) } @@ -322,11 +326,10 @@ fn push_down_filter_groupby_expr_contains_alias() { fn test_same_name_but_not_ambiguous() { let sql = "SELECT t1.col_int32 AS col_int32 FROM test t1 intersect SELECT col_int32 FROM test t2"; let plan = test_sql(sql).unwrap(); - let expected = "LeftSemi Join: col_int32 = t2.col_int32\ - \n Aggregate: groupBy=[[col_int32]], aggr=[[]]\ - \n Projection: t1.col_int32 AS col_int32\ - \n SubqueryAlias: t1\ - \n TableScan: test projection=[col_int32]\ + let expected = "LeftSemi Join: t1.col_int32 = t2.col_int32\ + \n Aggregate: groupBy=[[t1.col_int32]], aggr=[[]]\ + \n SubqueryAlias: t1\ + \n TableScan: test projection=[col_int32]\ \n SubqueryAlias: t2\ \n TableScan: test projection=[col_int32]"; assert_eq!(expected, format!("{plan:?}")); diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 4496e72152049..d237c68657a1f 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -34,13 +34,16 @@ path = "src/lib.rs" [features] crypto_expressions = ["md-5", "sha2", "blake2", "blake3"] -default = ["crypto_expressions", "regex_expressions", "unicode_expressions", "encoding_expressions"] +default = ["crypto_expressions", "regex_expressions", "unicode_expressions", "encoding_expressions", +] encoding_expressions = ["base64", "hex"] regex_expressions = ["regex"] unicode_expressions = ["unicode-segmentation"] [dependencies] -ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } +ahash = { version = "0.8", default-features = false, features = [ + "runtime-rng", +] } arrow = { workspace = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } @@ -56,8 +59,7 @@ half = { version = "2.1", default-features = false } hashbrown = { version = "0.14", features = ["raw"] } hex = { version = "0.4", optional = true } indexmap = { workspace = true } -itertools = { version = "0.11", features = ["use_std"] } -libc = "0.2.140" +itertools = { version = "0.12", features = ["use_std"] } log = { workspace = true } md-5 = { version = "^0.10.0", optional = true } paste = "^1.0" diff --git a/datafusion/physical-expr/benches/in_list.rs b/datafusion/physical-expr/benches/in_list.rs index db017326083ab..90bfc5efb61e8 100644 --- a/datafusion/physical-expr/benches/in_list.rs +++ b/datafusion/physical-expr/benches/in_list.rs @@ -57,7 +57,7 @@ fn do_benches( .collect(); let in_list: Vec<_> = (0..in_list_length) - .map(|_| ScalarValue::Utf8(Some(random_string(&mut rng, string_length)))) + .map(|_| ScalarValue::from(random_string(&mut rng, string_length))) .collect(); do_bench( diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs index aa4749f64ae9c..15c0fb3ace4d9 100644 --- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs +++ b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs @@ -18,7 +18,7 @@ use crate::aggregate::tdigest::TryIntoF64; use crate::aggregate::tdigest::{TDigest, DEFAULT_MAX_SIZE}; use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::{format_state_name, Literal}; +use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; use arrow::{ array::{ @@ -27,11 +27,13 @@ use arrow::{ }, datatypes::{DataType, Field}, }; +use arrow_array::RecordBatch; +use arrow_schema::Schema; use datafusion_common::{ downcast_value, exec_err, internal_err, not_impl_err, plan_err, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, ColumnarValue}; use std::{any::Any, iter, sync::Arc}; /// APPROX_PERCENTILE_CONT aggregate expression @@ -131,18 +133,22 @@ impl PartialEq for ApproxPercentileCont { } } +fn get_lit_value(expr: &Arc) -> Result { + let empty_schema = Schema::empty(); + let empty_batch = RecordBatch::new_empty(Arc::new(empty_schema)); + let result = expr.evaluate(&empty_batch)?; + match result { + ColumnarValue::Array(_) => Err(DataFusionError::Internal(format!( + "The expr {:?} can't be evaluated to scalar value", + expr + ))), + ColumnarValue::Scalar(scalar_value) => Ok(scalar_value), + } +} + fn validate_input_percentile_expr(expr: &Arc) -> Result { - // Extract the desired percentile literal - let lit = expr - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal( - "desired percentile argument must be float literal".to_string(), - ) - })? - .value(); - let percentile = match lit { + let lit = get_lit_value(expr)?; + let percentile = match &lit { ScalarValue::Float32(Some(q)) => *q as f64, ScalarValue::Float64(Some(q)) => *q, got => return not_impl_err!( @@ -161,17 +167,8 @@ fn validate_input_percentile_expr(expr: &Arc) -> Result { } fn validate_input_max_size_expr(expr: &Arc) -> Result { - // Extract the desired percentile literal - let lit = expr - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal( - "desired percentile argument must be float literal".to_string(), - ) - })? - .value(); - let max_size = match lit { + let lit = get_lit_value(expr)?; + let max_size = match &lit { ScalarValue::UInt8(Some(q)) => *q as usize, ScalarValue::UInt16(Some(q)) => *q as usize, ScalarValue::UInt32(Some(q)) => *q as usize, diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index 4dccbfef07f87..91d5c867d3125 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -34,9 +34,14 @@ use std::sync::Arc; /// ARRAY_AGG aggregate expression #[derive(Debug)] pub struct ArrayAgg { + /// Column name name: String, + /// The DataType for the input expression input_data_type: DataType, + /// The input expression expr: Arc, + /// If the input expression can have NULLs + nullable: bool, } impl ArrayAgg { @@ -45,11 +50,13 @@ impl ArrayAgg { expr: Arc, name: impl Into, data_type: DataType, + nullable: bool, ) -> Self { Self { name: name.into(), - expr, input_data_type: data_type, + expr, + nullable, } } } @@ -62,8 +69,9 @@ impl AggregateExpr for ArrayAgg { fn field(&self) -> Result { Ok(Field::new_list( &self.name, + // This should be the same as return type of AggregateFunction::ArrayAgg Field::new("item", self.input_data_type.clone(), true), - false, + self.nullable, )) } @@ -77,7 +85,7 @@ impl AggregateExpr for ArrayAgg { Ok(vec![Field::new_list( format_state_name(&self.name, "array_agg"), Field::new("item", self.input_data_type.clone(), true), - false, + self.nullable, )]) } @@ -184,7 +192,6 @@ mod tests { use super::*; use crate::expressions::col; use crate::expressions::tests::aggregate; - use crate::generic_test_op; use arrow::array::ArrayRef; use arrow::array::Int32Array; use arrow::datatypes::*; @@ -195,6 +202,30 @@ mod tests { use datafusion_common::DataFusionError; use datafusion_common::Result; + macro_rules! test_op { + ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr) => { + test_op!($ARRAY, $DATATYPE, $OP, $EXPECTED, $EXPECTED.data_type()) + }; + ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ + let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); + + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; + + let agg = Arc::new(<$OP>::new( + col("a", &schema)?, + "bla".to_string(), + $EXPECTED_DATATYPE, + true, + )); + let actual = aggregate(&batch, agg)?; + let expected = ScalarValue::from($EXPECTED); + + assert_eq!(expected, actual); + + Ok(()) as Result<(), DataFusionError> + }}; + } + #[test] fn array_agg_i32() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); @@ -208,7 +239,7 @@ mod tests { ])]); let list = ScalarValue::List(Arc::new(list)); - generic_test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32) + test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32) } #[test] @@ -264,7 +295,7 @@ mod tests { let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap(); - generic_test_op!( + test_op!( array, DataType::List(Arc::new(Field::new_list( "item", diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index 9b391b0c42cf0..1efae424cc699 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -40,6 +40,8 @@ pub struct DistinctArrayAgg { input_data_type: DataType, /// The input expression expr: Arc, + /// If the input expression can have NULLs + nullable: bool, } impl DistinctArrayAgg { @@ -48,12 +50,14 @@ impl DistinctArrayAgg { expr: Arc, name: impl Into, input_data_type: DataType, + nullable: bool, ) -> Self { let name = name.into(); Self { name, - expr, input_data_type, + expr, + nullable, } } } @@ -67,8 +71,9 @@ impl AggregateExpr for DistinctArrayAgg { fn field(&self) -> Result { Ok(Field::new_list( &self.name, + // This should be the same as return type of AggregateFunction::ArrayAgg Field::new("item", self.input_data_type.clone(), true), - false, + self.nullable, )) } @@ -82,7 +87,7 @@ impl AggregateExpr for DistinctArrayAgg { Ok(vec![Field::new_list( format_state_name(&self.name, "distinct_array_agg"), Field::new("item", self.input_data_type.clone(), true), - false, + self.nullable, )]) } @@ -238,6 +243,7 @@ mod tests { col("a", &schema)?, "bla".to_string(), datatype, + true, )); let actual = aggregate(&batch, agg)?; @@ -255,6 +261,7 @@ mod tests { col("a", &schema)?, "bla".to_string(), datatype, + true, )); let mut accum1 = agg.create_accumulator()?; diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs index a53d53107addb..eb5ae8b0b0c3f 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs @@ -30,9 +30,9 @@ use crate::{AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr}; use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field}; +use arrow_array::cast::AsArray; use arrow_array::Array; use arrow_schema::{Fields, SortOptions}; -use datafusion_common::cast::as_list_array; use datafusion_common::utils::{compare_rows, get_row_at_idx}; use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; @@ -48,10 +48,17 @@ use itertools::izip; /// and that can merge aggregations from multiple partitions. #[derive(Debug)] pub struct OrderSensitiveArrayAgg { + /// Column name name: String, + /// The DataType for the input expression input_data_type: DataType, - order_by_data_types: Vec, + /// The input expression expr: Arc, + /// If the input expression can have NULLs + nullable: bool, + /// Ordering data types + order_by_data_types: Vec, + /// Ordering requirement ordering_req: LexOrdering, } @@ -61,13 +68,15 @@ impl OrderSensitiveArrayAgg { expr: Arc, name: impl Into, input_data_type: DataType, + nullable: bool, order_by_data_types: Vec, ordering_req: LexOrdering, ) -> Self { Self { name: name.into(), - expr, input_data_type, + expr, + nullable, order_by_data_types, ordering_req, } @@ -82,8 +91,9 @@ impl AggregateExpr for OrderSensitiveArrayAgg { fn field(&self) -> Result { Ok(Field::new_list( &self.name, + // This should be the same as return type of AggregateFunction::ArrayAgg Field::new("item", self.input_data_type.clone(), true), - false, + self.nullable, )) } @@ -99,13 +109,13 @@ impl AggregateExpr for OrderSensitiveArrayAgg { let mut fields = vec![Field::new_list( format_state_name(&self.name, "array_agg"), Field::new("item", self.input_data_type.clone(), true), - false, + self.nullable, // This should be the same as field() )]; let orderings = ordering_fields(&self.ordering_req, &self.order_by_data_types); fields.push(Field::new_list( format_state_name(&self.name, "array_agg_orderings"), Field::new("item", DataType::Struct(Fields::from(orderings)), true), - false, + self.nullable, )); Ok(fields) } @@ -204,7 +214,7 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { // values received from its ordering requirement expression. (This information is necessary for during merging). let agg_orderings = &states[1]; - if as_list_array(agg_orderings).is_ok() { + if let Some(agg_orderings) = agg_orderings.as_list_opt::() { // Stores ARRAY_AGG results coming from each partition let mut partition_values = vec![]; // Stores ordering requirement expression results coming from each partition @@ -222,10 +232,21 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { } let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?; - // Ordering requirement expression values for each entry in the ARRAY_AGG list - let other_ordering_values = self.convert_array_agg_to_orderings(orderings)?; - for v in other_ordering_values.into_iter() { - partition_ordering_values.push(v); + + for partition_ordering_rows in orderings.into_iter() { + // Extract value from struct to ordering_rows for each group/partition + let ordering_value = partition_ordering_rows.into_iter().map(|ordering_row| { + if let ScalarValue::Struct(Some(ordering_columns_per_row), _) = ordering_row { + Ok(ordering_columns_per_row) + } else { + exec_err!( + "Expects to receive ScalarValue::Struct(Some(..), _) but got:{:?}", + ordering_row.data_type() + ) + } + }).collect::>>()?; + + partition_ordering_values.push(ordering_value); } let sort_options = self @@ -283,33 +304,10 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { } impl OrderSensitiveArrayAggAccumulator { - /// Inner Vec\ in the ordering_values can be thought as ordering information for the each ScalarValue in the values array. - /// See [`merge_ordered_arrays`] for more information. - fn convert_array_agg_to_orderings( - &self, - array_agg: Vec>, - ) -> Result>>> { - let mut orderings = vec![]; - // in_data is Vec where ScalarValue does not include ScalarValue::List - for in_data in array_agg.into_iter() { - let ordering = in_data.into_iter().map(|struct_vals| { - if let ScalarValue::Struct(Some(orderings), _) = struct_vals { - Ok(orderings) - } else { - exec_err!( - "Expects to receive ScalarValue::Struct(Some(..), _) but got:{:?}", - struct_vals.data_type() - ) - } - }).collect::>>()?; - orderings.push(ordering); - } - Ok(orderings) - } - fn evaluate_orderings(&self) -> Result { let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]); let struct_field = Fields::from(fields.clone()); + let orderings: Vec = self .ordering_values .iter() @@ -319,6 +317,7 @@ impl OrderSensitiveArrayAggAccumulator { .collect(); let struct_type = DataType::Struct(Fields::from(fields)); + // Wrap in List, so we have the same data structure ListArray(StructArray..) for group by cases let arr = ScalarValue::new_list(&orderings, &struct_type); Ok(ScalarValue::List(arr)) } diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 6568457bc234a..c40f0db194055 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -114,13 +114,16 @@ pub fn create_aggregate_expr( ), (AggregateFunction::ArrayAgg, false) => { let expr = input_phy_exprs[0].clone(); + let nullable = expr.nullable(input_schema)?; + if ordering_req.is_empty() { - Arc::new(expressions::ArrayAgg::new(expr, name, data_type)) + Arc::new(expressions::ArrayAgg::new(expr, name, data_type, nullable)) } else { Arc::new(expressions::OrderSensitiveArrayAgg::new( expr, name, data_type, + nullable, ordering_types, ordering_req.to_vec(), )) @@ -132,10 +135,13 @@ pub fn create_aggregate_expr( "ARRAY_AGG(DISTINCT ORDER BY a ASC) order-sensitive aggregations are not available" ); } + let expr = input_phy_exprs[0].clone(); + let is_expr_nullable = expr.nullable(input_schema)?; Arc::new(expressions::DistinctArrayAgg::new( - input_phy_exprs[0].clone(), + expr, name, data_type, + is_expr_nullable, )) } (AggregateFunction::Min, _) => Arc::new(expressions::Min::new( @@ -363,6 +369,22 @@ pub fn create_aggregate_expr( ordering_req.to_vec(), ordering_types, )), + (AggregateFunction::StringAgg, false) => { + if !ordering_req.is_empty() { + return not_impl_err!( + "STRING_AGG(ORDER BY a ASC) order-sensitive aggregations are not available" + ); + } + Arc::new(expressions::StringAgg::new( + input_phy_exprs[0].clone(), + input_phy_exprs[1].clone(), + name, + data_type, + )) + } + (AggregateFunction::StringAgg, true) => { + return not_impl_err!("STRING_AGG(DISTINCT) aggregations are not available"); + } }) } @@ -432,8 +454,8 @@ mod tests { assert_eq!( Field::new_list( "c1", - Field::new("item", data_type.clone(), true,), - false, + Field::new("item", data_type.clone(), true), + true, ), result_agg_phy_exprs.field().unwrap() ); @@ -471,8 +493,8 @@ mod tests { assert_eq!( Field::new_list( "c1", - Field::new("item", data_type.clone(), true,), - false, + Field::new("item", data_type.clone(), true), + true, ), result_agg_phy_exprs.field().unwrap() ); diff --git a/datafusion/physical-expr/src/aggregate/correlation.rs b/datafusion/physical-expr/src/aggregate/correlation.rs index 475bfa4ce0da2..61f2db5c8ef93 100644 --- a/datafusion/physical-expr/src/aggregate/correlation.rs +++ b/datafusion/physical-expr/src/aggregate/correlation.rs @@ -505,13 +505,17 @@ mod tests { let values1 = expr1 .iter() - .map(|e| e.evaluate(batch1)) - .map(|r| r.map(|v| v.into_array(batch1.num_rows()))) + .map(|e| { + e.evaluate(batch1) + .and_then(|v| v.into_array(batch1.num_rows())) + }) .collect::>>()?; let values2 = expr2 .iter() - .map(|e| e.evaluate(batch2)) - .map(|r| r.map(|v| v.into_array(batch2.num_rows()))) + .map(|e| { + e.evaluate(batch2) + .and_then(|v| v.into_array(batch2.num_rows())) + }) .collect::>>()?; accum1.update_batch(&values1)?; accum2.update_batch(&values2)?; diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs index 738ca4e915f7d..8e9ae5cea36b3 100644 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ b/datafusion/physical-expr/src/aggregate/count.rs @@ -123,7 +123,7 @@ impl GroupsAccumulator for CountGroupsAccumulator { self.counts.resize(total_num_groups, 0); accumulate_indices( group_indices, - values.nulls(), // ignore values + values.logical_nulls().as_ref(), opt_filter, |group_index| { self.counts[group_index] += 1; @@ -198,16 +198,18 @@ fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize { if values.len() > 1 { let result_bool_buf: Option = values .iter() - .map(|a| a.nulls()) + .map(|a| a.logical_nulls()) .fold(None, |acc, b| match (acc, b) { (Some(acc), Some(b)) => Some(acc.bitand(b.inner())), (Some(acc), None) => Some(acc), - (None, Some(b)) => Some(b.inner().clone()), + (None, Some(b)) => Some(b.into_inner()), _ => None, }); result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits()) } else { - values[0].null_count() + values[0] + .logical_nulls() + .map_or(0, |nulls| nulls.null_count()) } } diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index f5242d983d4cf..c2fd32a96c4fb 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -152,7 +152,12 @@ impl Accumulator for DistinctCountAccumulator { if values.is_empty() { return Ok(()); } + let arr = &values[0]; + if arr.data_type() == &DataType::Null { + return Ok(()); + } + (0..arr.len()).try_for_each(|index| { if !arr.is_null(index) { let scalar = ScalarValue::try_from_array(arr, index)?; diff --git a/datafusion/physical-expr/src/aggregate/covariance.rs b/datafusion/physical-expr/src/aggregate/covariance.rs index 5e589d4e39fd3..0f838eb6fa1cf 100644 --- a/datafusion/physical-expr/src/aggregate/covariance.rs +++ b/datafusion/physical-expr/src/aggregate/covariance.rs @@ -754,13 +754,17 @@ mod tests { let values1 = expr1 .iter() - .map(|e| e.evaluate(batch1)) - .map(|r| r.map(|v| v.into_array(batch1.num_rows()))) + .map(|e| { + e.evaluate(batch1) + .and_then(|v| v.into_array(batch1.num_rows())) + }) .collect::>>()?; let values2 = expr2 .iter() - .map(|e| e.evaluate(batch2)) - .map(|r| r.map(|v| v.into_array(batch2.num_rows()))) + .map(|e| { + e.evaluate(batch2) + .and_then(|v| v.into_array(batch2.num_rows())) + }) .collect::>>()?; accum1.update_batch(&values1)?; accum2.update_batch(&values2)?; diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index a4e0a6dc49a9a..4afa8d0dd5eca 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -20,7 +20,7 @@ use std::any::Any; use std::sync::Arc; -use crate::aggregate::utils::{down_cast_any_ref, ordering_fields}; +use crate::aggregate::utils::{down_cast_any_ref, get_sort_options, ordering_fields}; use crate::expressions::format_state_name; use crate::{ reverse_order_bys, AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr, @@ -29,19 +29,21 @@ use crate::{ use arrow::array::{Array, ArrayRef, AsArray, BooleanArray}; use arrow::compute::{self, lexsort_to_indices, SortColumn}; use arrow::datatypes::{DataType, Field}; -use arrow_schema::SortOptions; use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at_idx}; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::Accumulator; /// FIRST_VALUE aggregate expression -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct FirstValue { name: String, input_data_type: DataType, order_by_data_types: Vec, expr: Arc, ordering_req: LexOrdering, + requirement_satisfied: bool, } impl FirstValue { @@ -53,14 +55,68 @@ impl FirstValue { ordering_req: LexOrdering, order_by_data_types: Vec, ) -> Self { + let requirement_satisfied = ordering_req.is_empty(); Self { name: name.into(), input_data_type, order_by_data_types, expr, ordering_req, + requirement_satisfied, } } + + /// Returns the name of the aggregate expression. + pub fn name(&self) -> &str { + &self.name + } + + /// Returns the input data type of the aggregate expression. + pub fn input_data_type(&self) -> &DataType { + &self.input_data_type + } + + /// Returns the data types of the order-by columns. + pub fn order_by_data_types(&self) -> &Vec { + &self.order_by_data_types + } + + /// Returns the expression associated with the aggregate function. + pub fn expr(&self) -> &Arc { + &self.expr + } + + /// Returns the lexical ordering requirements of the aggregate expression. + pub fn ordering_req(&self) -> &LexOrdering { + &self.ordering_req + } + + pub fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { + self.requirement_satisfied = requirement_satisfied; + self + } + + pub fn convert_to_last(self) -> LastValue { + let name = if self.name.starts_with("FIRST") { + format!("LAST{}", &self.name[5..]) + } else { + format!("LAST_VALUE({})", self.expr) + }; + let FirstValue { + expr, + input_data_type, + ordering_req, + order_by_data_types, + .. + } = self; + LastValue::new( + expr, + name, + input_data_type, + reverse_order_bys(&ordering_req), + order_by_data_types, + ) + } } impl AggregateExpr for FirstValue { @@ -74,11 +130,14 @@ impl AggregateExpr for FirstValue { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(FirstValueAccumulator::try_new( + FirstValueAccumulator::try_new( &self.input_data_type, &self.order_by_data_types, self.ordering_req.clone(), - )?)) + ) + .map(|acc| { + Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ + }) } fn state_fields(&self) -> Result> { @@ -104,11 +163,7 @@ impl AggregateExpr for FirstValue { } fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - if self.ordering_req.is_empty() { - None - } else { - Some(&self.ordering_req) - } + (!self.ordering_req.is_empty()).then_some(&self.ordering_req) } fn name(&self) -> &str { @@ -116,26 +171,18 @@ impl AggregateExpr for FirstValue { } fn reverse_expr(&self) -> Option> { - let name = if self.name.starts_with("FIRST") { - format!("LAST{}", &self.name[5..]) - } else { - format!("LAST_VALUE({})", self.expr) - }; - Some(Arc::new(LastValue::new( - self.expr.clone(), - name, - self.input_data_type.clone(), - reverse_order_bys(&self.ordering_req), - self.order_by_data_types.clone(), - ))) + Some(Arc::new(self.clone().convert_to_last())) } fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(FirstValueAccumulator::try_new( + FirstValueAccumulator::try_new( &self.input_data_type, &self.order_by_data_types, self.ordering_req.clone(), - )?)) + ) + .map(|acc| { + Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ + }) } } @@ -164,6 +211,8 @@ struct FirstValueAccumulator { orderings: Vec, // Stores the applicable ordering requirement. ordering_req: LexOrdering, + // Stores whether incoming data already satisfies the ordering requirement. + requirement_satisfied: bool, } impl FirstValueAccumulator { @@ -177,11 +226,13 @@ impl FirstValueAccumulator { .iter() .map(ScalarValue::try_from) .collect::>>()?; - ScalarValue::try_from(data_type).map(|value| Self { - first: value, + let requirement_satisfied = ordering_req.is_empty(); + ScalarValue::try_from(data_type).map(|first| Self { + first, is_set: false, orderings, ordering_req, + requirement_satisfied, }) } @@ -191,6 +242,31 @@ impl FirstValueAccumulator { self.orderings = row[1..].to_vec(); self.is_set = true; } + + fn get_first_idx(&self, values: &[ArrayRef]) -> Result> { + let [value, ordering_values @ ..] = values else { + return internal_err!("Empty row in FIRST_VALUE"); + }; + if self.requirement_satisfied { + // Get first entry according to the pre-existing ordering (0th index): + return Ok((!value.is_empty()).then_some(0)); + } + let sort_columns = ordering_values + .iter() + .zip(self.ordering_req.iter()) + .map(|(values, req)| SortColumn { + values: values.clone(), + options: Some(req.options), + }) + .collect::>(); + let indices = lexsort_to_indices(&sort_columns, Some(1))?; + Ok((!indices.is_empty()).then_some(indices.value(0) as _)) + } + + fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { + self.requirement_satisfied = requirement_satisfied; + self + } } impl Accumulator for FirstValueAccumulator { @@ -202,11 +278,25 @@ impl Accumulator for FirstValueAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - // If we have seen first value, we shouldn't update it - if !values[0].is_empty() && !self.is_set { - let row = get_row_at_idx(values, 0)?; - // Update with first value in the array. - self.update_with_new_row(&row); + if !self.is_set { + if let Some(first_idx) = self.get_first_idx(values)? { + let row = get_row_at_idx(values, first_idx)?; + self.update_with_new_row(&row); + } + } else if !self.requirement_satisfied { + if let Some(first_idx) = self.get_first_idx(values)? { + let row = get_row_at_idx(values, first_idx)?; + let orderings = &row[1..]; + if compare_rows( + &self.orderings, + orderings, + &get_sort_options(&self.ordering_req), + )? + .is_gt() + { + self.update_with_new_row(&row); + } + } } Ok(()) } @@ -235,7 +325,7 @@ impl Accumulator for FirstValueAccumulator { let sort_options = get_sort_options(&self.ordering_req); // Either there is no existing value, or there is an earlier version in new data. if !self.is_set - || compare_rows(first_ordering, &self.orderings, &sort_options)?.is_lt() + || compare_rows(&self.orderings, first_ordering, &sort_options)?.is_gt() { // Update with first value in the state. Note that we should exclude the // is_set flag from the state. Otherwise, we will end up with a state @@ -259,13 +349,14 @@ impl Accumulator for FirstValueAccumulator { } /// LAST_VALUE aggregate expression -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct LastValue { name: String, input_data_type: DataType, order_by_data_types: Vec, expr: Arc, ordering_req: LexOrdering, + requirement_satisfied: bool, } impl LastValue { @@ -277,14 +368,68 @@ impl LastValue { ordering_req: LexOrdering, order_by_data_types: Vec, ) -> Self { + let requirement_satisfied = ordering_req.is_empty(); Self { name: name.into(), input_data_type, order_by_data_types, expr, ordering_req, + requirement_satisfied, } } + + /// Returns the name of the aggregate expression. + pub fn name(&self) -> &str { + &self.name + } + + /// Returns the input data type of the aggregate expression. + pub fn input_data_type(&self) -> &DataType { + &self.input_data_type + } + + /// Returns the data types of the order-by columns. + pub fn order_by_data_types(&self) -> &Vec { + &self.order_by_data_types + } + + /// Returns the expression associated with the aggregate function. + pub fn expr(&self) -> &Arc { + &self.expr + } + + /// Returns the lexical ordering requirements of the aggregate expression. + pub fn ordering_req(&self) -> &LexOrdering { + &self.ordering_req + } + + pub fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { + self.requirement_satisfied = requirement_satisfied; + self + } + + pub fn convert_to_first(self) -> FirstValue { + let name = if self.name.starts_with("LAST") { + format!("FIRST{}", &self.name[4..]) + } else { + format!("FIRST_VALUE({})", self.expr) + }; + let LastValue { + expr, + input_data_type, + ordering_req, + order_by_data_types, + .. + } = self; + FirstValue::new( + expr, + name, + input_data_type, + reverse_order_bys(&ordering_req), + order_by_data_types, + ) + } } impl AggregateExpr for LastValue { @@ -298,11 +443,14 @@ impl AggregateExpr for LastValue { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(LastValueAccumulator::try_new( + LastValueAccumulator::try_new( &self.input_data_type, &self.order_by_data_types, self.ordering_req.clone(), - )?)) + ) + .map(|acc| { + Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ + }) } fn state_fields(&self) -> Result> { @@ -328,11 +476,7 @@ impl AggregateExpr for LastValue { } fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - if self.ordering_req.is_empty() { - None - } else { - Some(&self.ordering_req) - } + (!self.ordering_req.is_empty()).then_some(&self.ordering_req) } fn name(&self) -> &str { @@ -340,26 +484,18 @@ impl AggregateExpr for LastValue { } fn reverse_expr(&self) -> Option> { - let name = if self.name.starts_with("LAST") { - format!("FIRST{}", &self.name[4..]) - } else { - format!("FIRST_VALUE({})", self.expr) - }; - Some(Arc::new(FirstValue::new( - self.expr.clone(), - name, - self.input_data_type.clone(), - reverse_order_bys(&self.ordering_req), - self.order_by_data_types.clone(), - ))) + Some(Arc::new(self.clone().convert_to_first())) } fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(LastValueAccumulator::try_new( + LastValueAccumulator::try_new( &self.input_data_type, &self.order_by_data_types, self.ordering_req.clone(), - )?)) + ) + .map(|acc| { + Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ + }) } } @@ -387,6 +523,8 @@ struct LastValueAccumulator { orderings: Vec, // Stores the applicable ordering requirement. ordering_req: LexOrdering, + // Stores whether incoming data already satisfies the ordering requirement. + requirement_satisfied: bool, } impl LastValueAccumulator { @@ -400,11 +538,13 @@ impl LastValueAccumulator { .iter() .map(ScalarValue::try_from) .collect::>>()?; - Ok(Self { - last: ScalarValue::try_from(data_type)?, + let requirement_satisfied = ordering_req.is_empty(); + ScalarValue::try_from(data_type).map(|last| Self { + last, is_set: false, orderings, ordering_req, + requirement_satisfied, }) } @@ -414,6 +554,35 @@ impl LastValueAccumulator { self.orderings = row[1..].to_vec(); self.is_set = true; } + + fn get_last_idx(&self, values: &[ArrayRef]) -> Result> { + let [value, ordering_values @ ..] = values else { + return internal_err!("Empty row in LAST_VALUE"); + }; + if self.requirement_satisfied { + // Get last entry according to the order of data: + return Ok((!value.is_empty()).then_some(value.len() - 1)); + } + let sort_columns = ordering_values + .iter() + .zip(self.ordering_req.iter()) + .map(|(values, req)| { + // Take the reverse ordering requirement. This enables us to + // use "fetch = 1" to get the last value. + SortColumn { + values: values.clone(), + options: Some(!req.options), + } + }) + .collect::>(); + let indices = lexsort_to_indices(&sort_columns, Some(1))?; + Ok((!indices.is_empty()).then_some(indices.value(0) as _)) + } + + fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { + self.requirement_satisfied = requirement_satisfied; + self + } } impl Accumulator for LastValueAccumulator { @@ -425,11 +594,26 @@ impl Accumulator for LastValueAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if !values[0].is_empty() { - let row = get_row_at_idx(values, values[0].len() - 1)?; - // Update with last value in the array. - self.update_with_new_row(&row); + if !self.is_set || self.requirement_satisfied { + if let Some(last_idx) = self.get_last_idx(values)? { + let row = get_row_at_idx(values, last_idx)?; + self.update_with_new_row(&row); + } + } else if let Some(last_idx) = self.get_last_idx(values)? { + let row = get_row_at_idx(values, last_idx)?; + let orderings = &row[1..]; + // Update when there is a more recent entry + if compare_rows( + &self.orderings, + orderings, + &get_sort_options(&self.ordering_req), + )? + .is_lt() + { + self.update_with_new_row(&row); + } } + Ok(()) } @@ -460,7 +644,7 @@ impl Accumulator for LastValueAccumulator { // Either there is no existing value, or there is a newer (latest) // version in the new data: if !self.is_set - || compare_rows(last_ordering, &self.orderings, &sort_options)?.is_gt() + || compare_rows(&self.orderings, last_ordering, &sort_options)?.is_lt() { // Update with last value in the state. Note that we should exclude the // is_set flag from the state. Otherwise, we will end up with a state @@ -491,7 +675,7 @@ fn filter_states_according_to_is_set( ) -> Result> { states .iter() - .map(|state| compute::filter(state, flags).map_err(DataFusionError::ArrowError)) + .map(|state| compute::filter(state, flags).map_err(|e| arrow_datafusion_err!(e))) .collect::>>() } @@ -509,26 +693,18 @@ fn convert_to_sort_cols( .collect::>() } -/// Selects the sort option attribute from all the given `PhysicalSortExpr`s. -fn get_sort_options(ordering_req: &[PhysicalSortExpr]) -> Vec { - ordering_req - .iter() - .map(|item| item.options) - .collect::>() -} - #[cfg(test)] mod tests { + use std::sync::Arc; + use crate::aggregate::first_last::{FirstValueAccumulator, LastValueAccumulator}; + use arrow::compute::concat; use arrow_array::{ArrayRef, Int64Array}; use arrow_schema::DataType; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Accumulator; - use arrow::compute::concat; - use std::sync::Arc; - #[test] fn test_first_last_value_value() -> Result<()> { let mut first_accumulator = @@ -587,7 +763,10 @@ mod tests { let mut states = vec![]; for idx in 0..state1.len() { - states.push(concat(&[&state1[idx].to_array(), &state2[idx].to_array()])?); + states.push(concat(&[ + &state1[idx].to_array()?, + &state2[idx].to_array()?, + ])?); } let mut first_accumulator = @@ -614,7 +793,10 @@ mod tests { let mut states = vec![]; for idx in 0..state1.len() { - states.push(concat(&[&state1[idx].to_array(), &state2[idx].to_array()])?); + states.push(concat(&[ + &state1[idx].to_array()?, + &state2[idx].to_array()?, + ])?); } let mut last_accumulator = diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs index dcc8c37e7484b..c6fd17a69b394 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs @@ -25,7 +25,8 @@ use arrow::{ }; use arrow_array::{ArrayRef, BooleanArray, PrimitiveArray}; use datafusion_common::{ - utils::get_arrayref_at_indices, DataFusionError, Result, ScalarValue, + arrow_datafusion_err, utils::get_arrayref_at_indices, DataFusionError, Result, + ScalarValue, }; use datafusion_expr::Accumulator; @@ -309,7 +310,7 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { // double check each array has the same length (aka the // accumulator was implemented correctly - if let Some(first_col) = arrays.get(0) { + if let Some(first_col) = arrays.first() { for arr in &arrays { assert_eq!(arr.len(), first_col.len()) } @@ -372,7 +373,7 @@ fn get_filter_at_indices( ) }) .transpose() - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) } // Copied from physical-plan @@ -394,7 +395,7 @@ pub(crate) fn slice_and_maybe_filter( sliced_arrays .iter() .map(|array| { - compute::filter(array, filter_array).map_err(DataFusionError::ArrowError) + compute::filter(array, filter_array).map_err(|e| arrow_datafusion_err!(e)) }) .collect() } else { diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs index f5b708e8894e7..7e3ef2a2ababb 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/physical-expr/src/aggregate/min_max.rs @@ -1297,12 +1297,7 @@ mod tests { #[test] fn max_utf8() -> Result<()> { let a: ArrayRef = Arc::new(StringArray::from(vec!["d", "a", "c", "b"])); - generic_test_op!( - a, - DataType::Utf8, - Max, - ScalarValue::Utf8(Some("d".to_string())) - ) + generic_test_op!(a, DataType::Utf8, Max, ScalarValue::from("d")) } #[test] @@ -1319,12 +1314,7 @@ mod tests { #[test] fn min_utf8() -> Result<()> { let a: ArrayRef = Arc::new(StringArray::from(vec!["d", "a", "c", "b"])); - generic_test_op!( - a, - DataType::Utf8, - Min, - ScalarValue::Utf8(Some("a".to_string())) - ) + generic_test_op!(a, DataType::Utf8, Min, ScalarValue::from("a")) } #[test] diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 442d018b87d55..5bd1fca385b11 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -15,16 +15,20 @@ // specific language governing permissions and limitations // under the License. -use crate::expressions::{FirstValue, LastValue, OrderSensitiveArrayAgg}; -use crate::{PhysicalExpr, PhysicalSortExpr}; -use arrow::datatypes::Field; -use datafusion_common::{not_impl_err, DataFusionError, Result}; -use datafusion_expr::Accumulator; use std::any::Any; use std::fmt::Debug; use std::sync::Arc; use self::groups_accumulator::GroupsAccumulator; +use crate::expressions::OrderSensitiveArrayAgg; +use crate::{PhysicalExpr, PhysicalSortExpr}; + +use arrow::datatypes::Field; +use datafusion_common::{not_impl_err, DataFusionError, Result}; +use datafusion_expr::Accumulator; + +mod hyperloglog; +mod tdigest; pub(crate) mod approx_distinct; pub(crate) mod approx_median; @@ -43,21 +47,21 @@ pub(crate) mod covariance; pub(crate) mod first_last; pub(crate) mod grouping; pub(crate) mod median; +pub(crate) mod string_agg; #[macro_use] pub(crate) mod min_max; -pub mod build_in; pub(crate) mod groups_accumulator; -mod hyperloglog; -pub mod moving_min_max; pub(crate) mod regr; pub(crate) mod stats; pub(crate) mod stddev; pub(crate) mod sum; pub(crate) mod sum_distinct; -mod tdigest; -pub mod utils; pub(crate) mod variance; +pub mod build_in; +pub mod moving_min_max; +pub mod utils; + /// An aggregate expression that: /// * knows its resulting field /// * knows how to create its accumulator @@ -133,10 +137,7 @@ pub trait AggregateExpr: Send + Sync + Debug + PartialEq { /// Checks whether the given aggregate expression is order-sensitive. /// For instance, a `SUM` aggregation doesn't depend on the order of its inputs. -/// However, a `FirstValue` depends on the input ordering (if the order changes, -/// the first value in the list would change). +/// However, an `ARRAY_AGG` with `ORDER BY` depends on the input ordering. pub fn is_order_sensitive(aggr_expr: &Arc) -> bool { - aggr_expr.as_any().is::() - || aggr_expr.as_any().is::() - || aggr_expr.as_any().is::() + aggr_expr.as_any().is::() } diff --git a/datafusion/physical-expr/src/aggregate/stddev.rs b/datafusion/physical-expr/src/aggregate/stddev.rs index 330507d6ffa63..64e19ef502c7b 100644 --- a/datafusion/physical-expr/src/aggregate/stddev.rs +++ b/datafusion/physical-expr/src/aggregate/stddev.rs @@ -445,13 +445,17 @@ mod tests { let values1 = expr1 .iter() - .map(|e| e.evaluate(batch1)) - .map(|r| r.map(|v| v.into_array(batch1.num_rows()))) + .map(|e| { + e.evaluate(batch1) + .and_then(|v| v.into_array(batch1.num_rows())) + }) .collect::>>()?; let values2 = expr2 .iter() - .map(|e| e.evaluate(batch2)) - .map(|r| r.map(|v| v.into_array(batch2.num_rows()))) + .map(|e| { + e.evaluate(batch2) + .and_then(|v| v.into_array(batch2.num_rows())) + }) .collect::>>()?; accum1.update_batch(&values1)?; accum2.update_batch(&values2)?; diff --git a/datafusion/physical-expr/src/aggregate/string_agg.rs b/datafusion/physical-expr/src/aggregate/string_agg.rs new file mode 100644 index 0000000000000..7adc736932ad7 --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/string_agg.rs @@ -0,0 +1,246 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`StringAgg`] and [`StringAggAccumulator`] accumulator for the `string_agg` function + +use crate::aggregate::utils::down_cast_any_ref; +use crate::expressions::{format_state_name, Literal}; +use crate::{AggregateExpr, PhysicalExpr}; +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field}; +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::Accumulator; +use std::any::Any; +use std::sync::Arc; + +/// STRING_AGG aggregate expression +#[derive(Debug)] +pub struct StringAgg { + name: String, + data_type: DataType, + expr: Arc, + delimiter: Arc, + nullable: bool, +} + +impl StringAgg { + /// Create a new StringAgg aggregate function + pub fn new( + expr: Arc, + delimiter: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + Self { + name: name.into(), + data_type, + delimiter, + expr, + nullable: true, + } + } +} + +impl AggregateExpr for StringAgg { + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new( + &self.name, + self.data_type.clone(), + self.nullable, + )) + } + + fn create_accumulator(&self) -> Result> { + if let Some(delimiter) = self.delimiter.as_any().downcast_ref::() { + match delimiter.value() { + ScalarValue::Utf8(Some(delimiter)) + | ScalarValue::LargeUtf8(Some(delimiter)) => { + return Ok(Box::new(StringAggAccumulator::new(delimiter))); + } + ScalarValue::Null => { + return Ok(Box::new(StringAggAccumulator::new(""))); + } + _ => return not_impl_err!("StringAgg not supported for {}", self.name), + } + } + not_impl_err!("StringAgg not supported for {}", self.name) + } + + fn state_fields(&self) -> Result> { + Ok(vec![Field::new( + format_state_name(&self.name, "string_agg"), + self.data_type.clone(), + self.nullable, + )]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone(), self.delimiter.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +impl PartialEq for StringAgg { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.name == x.name + && self.data_type == x.data_type + && self.expr.eq(&x.expr) + && self.delimiter.eq(&x.delimiter) + }) + .unwrap_or(false) + } +} + +#[derive(Debug)] +pub(crate) struct StringAggAccumulator { + values: Option, + delimiter: String, +} + +impl StringAggAccumulator { + pub fn new(delimiter: &str) -> Self { + Self { + values: None, + delimiter: delimiter.to_string(), + } + } +} + +impl Accumulator for StringAggAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let string_array: Vec<_> = as_generic_string_array::(&values[0])? + .iter() + .filter_map(|v| v.as_ref().map(ToString::to_string)) + .collect(); + if !string_array.is_empty() { + let s = string_array.join(self.delimiter.as_str()); + let v = self.values.get_or_insert("".to_string()); + if !v.is_empty() { + v.push_str(self.delimiter.as_str()); + } + v.push_str(s.as_str()); + } + Ok(()) + } + + fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.update_batch(values)?; + Ok(()) + } + + fn state(&self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn evaluate(&self) -> Result { + Ok(ScalarValue::LargeUtf8(self.values.clone())) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + + self.values.as_ref().map(|v| v.capacity()).unwrap_or(0) + + self.delimiter.capacity() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::expressions::tests::aggregate; + use crate::expressions::{col, create_aggregate_expr, try_cast}; + use arrow::array::ArrayRef; + use arrow::datatypes::*; + use arrow::record_batch::RecordBatch; + use arrow_array::LargeStringArray; + use arrow_array::StringArray; + use datafusion_expr::type_coercion::aggregates::coerce_types; + use datafusion_expr::AggregateFunction; + + fn assert_string_aggregate( + array: ArrayRef, + function: AggregateFunction, + distinct: bool, + expected: ScalarValue, + delimiter: String, + ) { + let data_type = array.data_type(); + let sig = function.signature(); + let coerced = + coerce_types(&function, &[data_type.clone(), DataType::Utf8], &sig).unwrap(); + + let input_schema = Schema::new(vec![Field::new("a", data_type.clone(), true)]); + let batch = + RecordBatch::try_new(Arc::new(input_schema.clone()), vec![array]).unwrap(); + + let input = try_cast( + col("a", &input_schema).unwrap(), + &input_schema, + coerced[0].clone(), + ) + .unwrap(); + + let delimiter = Arc::new(Literal::new(ScalarValue::from(delimiter))); + let schema = Schema::new(vec![Field::new("a", coerced[0].clone(), true)]); + let agg = create_aggregate_expr( + &function, + distinct, + &[input, delimiter], + &[], + &schema, + "agg", + ) + .unwrap(); + + let result = aggregate(&batch, agg).unwrap(); + assert_eq!(expected, result); + } + + #[test] + fn string_agg_utf8() { + let a: ArrayRef = Arc::new(StringArray::from(vec!["h", "e", "l", "l", "o"])); + assert_string_aggregate( + a, + AggregateFunction::StringAgg, + false, + ScalarValue::LargeUtf8(Some("h,e,l,l,o".to_owned())), + ",".to_owned(), + ); + } + + #[test] + fn string_agg_largeutf8() { + let a: ArrayRef = Arc::new(LargeStringArray::from(vec!["h", "e", "l", "l", "o"])); + assert_string_aggregate( + a, + AggregateFunction::StringAgg, + false, + ScalarValue::LargeUtf8(Some("h|e|l|l|o".to_owned())), + "|".to_owned(), + ); + } +} diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index d6c23d0dfafde..03f666cc4e5d5 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -41,7 +41,10 @@ use datafusion_expr::Accumulator; #[derive(Debug, Clone)] pub struct Sum { name: String, + // The DataType for the input expression data_type: DataType, + // The DataType for the final sum + return_type: DataType, expr: Arc, nullable: bool, } @@ -53,11 +56,12 @@ impl Sum { name: impl Into, data_type: DataType, ) -> Self { - let data_type = sum_return_type(&data_type).unwrap(); + let return_type = sum_return_type(&data_type).unwrap(); Self { name: name.into(), - expr, data_type, + return_type, + expr, nullable: true, } } @@ -70,13 +74,13 @@ impl Sum { /// `s` is a `Sum`, `helper` is a macro accepting (ArrowPrimitiveType, DataType) macro_rules! downcast_sum { ($s:ident, $helper:ident) => { - match $s.data_type { - DataType::UInt64 => $helper!(UInt64Type, $s.data_type), - DataType::Int64 => $helper!(Int64Type, $s.data_type), - DataType::Float64 => $helper!(Float64Type, $s.data_type), - DataType::Decimal128(_, _) => $helper!(Decimal128Type, $s.data_type), - DataType::Decimal256(_, _) => $helper!(Decimal256Type, $s.data_type), - _ => not_impl_err!("Sum not supported for {}: {}", $s.name, $s.data_type), + match $s.return_type { + DataType::UInt64 => $helper!(UInt64Type, $s.return_type), + DataType::Int64 => $helper!(Int64Type, $s.return_type), + DataType::Float64 => $helper!(Float64Type, $s.return_type), + DataType::Decimal128(_, _) => $helper!(Decimal128Type, $s.return_type), + DataType::Decimal256(_, _) => $helper!(Decimal256Type, $s.return_type), + _ => not_impl_err!("Sum not supported for {}: {}", $s.name, $s.return_type), } }; } @@ -91,7 +95,7 @@ impl AggregateExpr for Sum { fn field(&self) -> Result { Ok(Field::new( &self.name, - self.data_type.clone(), + self.return_type.clone(), self.nullable, )) } @@ -108,7 +112,7 @@ impl AggregateExpr for Sum { fn state_fields(&self) -> Result> { Ok(vec![Field::new( format_state_name(&self.name, "sum"), - self.data_type.clone(), + self.return_type.clone(), self.nullable, )]) } diff --git a/datafusion/physical-expr/src/aggregate/sum_distinct.rs b/datafusion/physical-expr/src/aggregate/sum_distinct.rs index ef1bd039a5ea3..0cf4a90ab8cc4 100644 --- a/datafusion/physical-expr/src/aggregate/sum_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/sum_distinct.rs @@ -40,8 +40,10 @@ use datafusion_expr::Accumulator; pub struct DistinctSum { /// Column name name: String, - /// The DataType for the final sum + // The DataType for the input expression data_type: DataType, + // The DataType for the final sum + return_type: DataType, /// The input arguments, only contains 1 item for sum exprs: Vec>, } @@ -53,10 +55,11 @@ impl DistinctSum { name: String, data_type: DataType, ) -> Self { - let data_type = sum_return_type(&data_type).unwrap(); + let return_type = sum_return_type(&data_type).unwrap(); Self { name, data_type, + return_type, exprs, } } @@ -68,14 +71,14 @@ impl AggregateExpr for DistinctSum { } fn field(&self) -> Result { - Ok(Field::new(&self.name, self.data_type.clone(), true)) + Ok(Field::new(&self.name, self.return_type.clone(), true)) } fn state_fields(&self) -> Result> { // State field is a List which stores items to rebuild hash set. Ok(vec![Field::new_list( format_state_name(&self.name, "sum distinct"), - Field::new("item", self.data_type.clone(), true), + Field::new("item", self.return_type.clone(), true), false, )]) } diff --git a/datafusion/physical-expr/src/aggregate/utils.rs b/datafusion/physical-expr/src/aggregate/utils.rs index da3a527132316..9777158da133a 100644 --- a/datafusion/physical-expr/src/aggregate/utils.rs +++ b/datafusion/physical-expr/src/aggregate/utils.rs @@ -17,30 +17,31 @@ //! Utilities used in aggregates +use std::any::Any; +use std::sync::Arc; + use crate::{AggregateExpr, PhysicalSortExpr}; -use arrow::array::ArrayRef; + +use arrow::array::{ArrayRef, ArrowNativeTypeOp}; use arrow_array::cast::AsArray; use arrow_array::types::{ Decimal128Type, DecimalType, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; -use arrow_array::ArrowNativeTypeOp; use arrow_buffer::ArrowNativeType; -use arrow_schema::{DataType, Field}; +use arrow_schema::{DataType, Field, SortOptions}; use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr::Accumulator; -use std::any::Any; -use std::sync::Arc; /// Convert scalar values from an accumulator into arrays. pub fn get_accum_scalar_values_as_arrays( accum: &dyn Accumulator, ) -> Result> { - Ok(accum + accum .state()? .iter() .map(|s| s.to_array_of_size(1)) - .collect::>()) + .collect() } /// Computes averages for `Decimal128`/`Decimal256` values, checking for overflow @@ -205,3 +206,8 @@ pub(crate) fn ordering_fields( }) .collect() } + +/// Selects the sort option attribute from all the given `PhysicalSortExpr`s. +pub fn get_sort_options(ordering_req: &[PhysicalSortExpr]) -> Vec { + ordering_req.iter().map(|item| item.options).collect() +} diff --git a/datafusion/physical-expr/src/aggregate/variance.rs b/datafusion/physical-expr/src/aggregate/variance.rs index a720dd833a87a..d82c5ad5626f4 100644 --- a/datafusion/physical-expr/src/aggregate/variance.rs +++ b/datafusion/physical-expr/src/aggregate/variance.rs @@ -519,13 +519,17 @@ mod tests { let values1 = expr1 .iter() - .map(|e| e.evaluate(batch1)) - .map(|r| r.map(|v| v.into_array(batch1.num_rows()))) + .map(|e| { + e.evaluate(batch1) + .and_then(|v| v.into_array(batch1.num_rows())) + }) .collect::>>()?; let values2 = expr2 .iter() - .map(|e| e.evaluate(batch2)) - .map(|r| r.map(|v| v.into_array(batch2.num_rows()))) + .map(|e| { + e.evaluate(batch2) + .and_then(|v| v.into_array(batch2.num_rows())) + }) .collect::>>()?; accum1.update_batch(&values1)?; accum2.update_batch(&values2)?; diff --git a/datafusion/physical-expr/src/analysis.rs b/datafusion/physical-expr/src/analysis.rs index 93c24014fd3e3..6d36e2233cdd9 100644 --- a/datafusion/physical-expr/src/analysis.rs +++ b/datafusion/physical-expr/src/analysis.rs @@ -21,8 +21,7 @@ use std::fmt::Debug; use std::sync::Arc; use crate::expressions::Column; -use crate::intervals::cp_solver::PropagationResult; -use crate::intervals::{cardinality_ratio, ExprIntervalGraph, Interval, IntervalBound}; +use crate::intervals::cp_solver::{ExprIntervalGraph, PropagationResult}; use crate::utils::collect_columns; use crate::PhysicalExpr; @@ -31,6 +30,7 @@ use datafusion_common::stats::Precision; use datafusion_common::{ internal_err, ColumnStatistics, DataFusionError, Result, ScalarValue, }; +use datafusion_expr::interval_arithmetic::{cardinality_ratio, Interval}; /// The shared context used during the analysis of an expression. Includes /// the boundaries for all known columns. @@ -72,8 +72,12 @@ impl AnalysisContext { } } -/// Represents the boundaries of the resulting value from a physical expression, -/// if it were to be an expression, if it were to be evaluated. +/// Represents the boundaries (e.g. min and max values) of a particular column +/// +/// This is used range analysis of expressions, to determine if the expression +/// limits the value of particular columns (e.g. analyzing an expression such as +/// `time < 50` would result in a boundary interval for `time` having a max +/// value of `50`). #[derive(Clone, Debug, PartialEq)] pub struct ExprBoundaries { pub column: Column, @@ -91,23 +95,20 @@ impl ExprBoundaries { col_index: usize, ) -> Result { let field = &schema.fields()[col_index]; - let empty_field = ScalarValue::try_from(field.data_type())?; - let interval = Interval::new( - IntervalBound::new_closed( - col_stats - .min_value - .get_value() - .cloned() - .unwrap_or(empty_field.clone()), - ), - IntervalBound::new_closed( - col_stats - .max_value - .get_value() - .cloned() - .unwrap_or(empty_field), - ), - ); + let empty_field = + ScalarValue::try_from(field.data_type()).unwrap_or(ScalarValue::Null); + let interval = Interval::try_new( + col_stats + .min_value + .get_value() + .cloned() + .unwrap_or(empty_field.clone()), + col_stats + .max_value + .get_value() + .cloned() + .unwrap_or(empty_field), + )?; let column = Column::new(field.name(), col_index); Ok(ExprBoundaries { column, @@ -115,6 +116,23 @@ impl ExprBoundaries { distinct_count: col_stats.distinct_count.clone(), }) } + + /// Create `ExprBoundaries` that represent no known bounds for all the + /// columns in `schema` + pub fn try_new_unbounded(schema: &Schema) -> Result> { + schema + .fields() + .iter() + .enumerate() + .map(|(i, field)| { + Ok(Self { + column: Column::new(field.name(), i), + interval: Interval::make_unbounded(field.data_type())?, + distinct_count: Precision::Absent, + }) + }) + .collect() + } } /// Attempts to refine column boundaries and compute a selectivity value. @@ -135,47 +153,44 @@ impl ExprBoundaries { pub fn analyze( expr: &Arc, context: AnalysisContext, + schema: &Schema, ) -> Result { let target_boundaries = context.boundaries; - let mut graph = ExprIntervalGraph::try_new(expr.clone())?; + let mut graph = ExprIntervalGraph::try_new(expr.clone(), schema)?; - let columns: Vec> = collect_columns(expr) + let columns = collect_columns(expr) .into_iter() - .map(|c| Arc::new(c) as Arc) - .collect(); - - let target_expr_and_indices: Vec<(Arc, usize)> = - graph.gather_node_indices(columns.as_slice()); - - let mut target_indices_and_boundaries: Vec<(usize, Interval)> = - target_expr_and_indices - .iter() - .filter_map(|(expr, i)| { - target_boundaries.iter().find_map(|bound| { - expr.as_any() - .downcast_ref::() - .filter(|expr_column| bound.column.eq(*expr_column)) - .map(|_| (*i, bound.interval.clone())) - }) + .map(|c| Arc::new(c) as _) + .collect::>(); + + let target_expr_and_indices = graph.gather_node_indices(columns.as_slice()); + + let mut target_indices_and_boundaries = target_expr_and_indices + .iter() + .filter_map(|(expr, i)| { + target_boundaries.iter().find_map(|bound| { + expr.as_any() + .downcast_ref::() + .filter(|expr_column| bound.column.eq(*expr_column)) + .map(|_| (*i, bound.interval.clone())) }) - .collect(); - Ok( - match graph.update_ranges(&mut target_indices_and_boundaries)? { - PropagationResult::Success => shrink_boundaries( - expr, - graph, - target_boundaries, - target_expr_and_indices, - )?, - PropagationResult::Infeasible => { - AnalysisContext::new(target_boundaries).with_selectivity(0.0) - } - PropagationResult::CannotPropagate => { - AnalysisContext::new(target_boundaries).with_selectivity(1.0) - } - }, - ) + }) + .collect::>(); + + match graph + .update_ranges(&mut target_indices_and_boundaries, Interval::CERTAINLY_TRUE)? + { + PropagationResult::Success => { + shrink_boundaries(graph, target_boundaries, target_expr_and_indices) + } + PropagationResult::Infeasible => { + Ok(AnalysisContext::new(target_boundaries).with_selectivity(0.0)) + } + PropagationResult::CannotPropagate => { + Ok(AnalysisContext::new(target_boundaries).with_selectivity(1.0)) + } + } } /// If the `PropagationResult` indicates success, this function calculates the @@ -183,8 +198,7 @@ pub fn analyze( /// Following this, it constructs and returns a new `AnalysisContext` with the /// updated parameters. fn shrink_boundaries( - expr: &Arc, - mut graph: ExprIntervalGraph, + graph: ExprIntervalGraph, mut target_boundaries: Vec, target_expr_and_indices: Vec<(Arc, usize)>, ) -> Result { @@ -199,20 +213,12 @@ fn shrink_boundaries( }; } }); - let graph_nodes = graph.gather_node_indices(&[expr.clone()]); - let Some((_, root_index)) = graph_nodes.get(0) else { - return internal_err!( - "The ExprIntervalGraph under investigation does not have any nodes." - ); - }; - let final_result = graph.get_interval(*root_index); - - let selectivity = calculate_selectivity( - &final_result.lower.value, - &final_result.upper.value, - &target_boundaries, - &initial_boundaries, - )?; + + let selectivity = calculate_selectivity(&target_boundaries, &initial_boundaries); + + if !(0.0..=1.0).contains(&selectivity) { + return internal_err!("Selectivity is out of limit: {}", selectivity); + } Ok(AnalysisContext::new(target_boundaries).with_selectivity(selectivity)) } @@ -220,33 +226,17 @@ fn shrink_boundaries( /// This function calculates the filter predicate's selectivity by comparing /// the initial and pruned column boundaries. Selectivity is defined as the /// ratio of rows in a table that satisfy the filter's predicate. -/// -/// An exact propagation result at the root, i.e. `[true, true]` or `[false, false]`, -/// leads to early exit (returning a selectivity value of either 1.0 or 0.0). In such -/// a case, `[true, true]` indicates that all data values satisfy the predicate (hence, -/// selectivity is 1.0), and `[false, false]` suggests that no data value meets the -/// predicate (therefore, selectivity is 0.0). fn calculate_selectivity( - lower_value: &ScalarValue, - upper_value: &ScalarValue, target_boundaries: &[ExprBoundaries], initial_boundaries: &[ExprBoundaries], -) -> Result { - match (lower_value, upper_value) { - (ScalarValue::Boolean(Some(true)), ScalarValue::Boolean(Some(true))) => Ok(1.0), - (ScalarValue::Boolean(Some(false)), ScalarValue::Boolean(Some(false))) => Ok(0.0), - _ => { - // Since the intervals are assumed uniform and the values - // are not correlated, we need to multiply the selectivities - // of multiple columns to get the overall selectivity. - target_boundaries.iter().enumerate().try_fold( - 1.0, - |acc, (i, ExprBoundaries { interval, .. })| { - let temp = - cardinality_ratio(&initial_boundaries[i].interval, interval)?; - Ok(acc * temp) - }, - ) - } - } +) -> f64 { + // Since the intervals are assumed uniform and the values + // are not correlated, we need to multiply the selectivities + // of multiple columns to get the overall selectivity. + initial_boundaries + .iter() + .zip(target_boundaries.iter()) + .fold(1.0, |acc, (initial, target)| { + acc * cardinality_ratio(&initial.interval, &target.interval) + }) } diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index e296e9c96fadc..15330af640aeb 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -18,21 +18,25 @@ //! Array expressions use std::any::type_name; +use std::collections::HashSet; +use std::fmt::{Display, Formatter}; use std::sync::Arc; use arrow::array::*; use arrow::buffer::OffsetBuffer; use arrow::compute; use arrow::datatypes::{DataType, Field, UInt64Type}; +use arrow::row::{RowConverter, SortField}; use arrow_buffer::NullBuffer; +use arrow_schema::{FieldRef, SortOptions}; use datafusion_common::cast::{ - as_generic_string_array, as_int64_array, as_list_array, as_string_array, + as_generic_list_array, as_generic_string_array, as_int64_array, as_large_list_array, + as_list_array, as_null_array, as_string_array, }; -use datafusion_common::utils::array_into_list_array; +use datafusion_common::utils::{array_into_list_array, list_ndims}; use datafusion_common::{ - exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, - DataFusionError, Result, + exec_err, internal_err, not_impl_err, plan_err, DataFusionError, Result, }; use itertools::Itertools; @@ -48,84 +52,105 @@ macro_rules! downcast_arg { }}; } -/// Downcasts multiple arguments into a single concrete type -/// $ARGS: &[ArrayRef] -/// $ARRAY_TYPE: type to downcast to +/// Computes a BooleanArray indicating equality or inequality between elements in a list array and a specified element array. /// -/// $returns a Vec<$ARRAY_TYPE> -macro_rules! downcast_vec { - ($ARGS:expr, $ARRAY_TYPE:ident) => {{ - $ARGS - .iter() - .map(|e| match e.as_any().downcast_ref::<$ARRAY_TYPE>() { - Some(array) => Ok(array), - _ => internal_err!("failed to downcast"), - }) - }}; -} - -macro_rules! new_builder { - (BooleanBuilder, $len:expr) => { - BooleanBuilder::with_capacity($len) - }; - (StringBuilder, $len:expr) => { - StringBuilder::new() - }; - (LargeStringBuilder, $len:expr) => { - LargeStringBuilder::new() - }; - ($el:ident, $len:expr) => {{ - <$el>::with_capacity($len) - }}; -} - -/// Combines multiple arrays into a single ListArray +/// # Arguments +/// +/// * `list_array_row` - A reference to a trait object implementing the Arrow `Array` trait. It represents the list array for which the equality or inequality will be compared. +/// +/// * `element_array` - A reference to a trait object implementing the Arrow `Array` trait. It represents the array with which each element in the `list_array_row` will be compared. +/// +/// * `row_index` - The index of the row in the `element_array` and `list_array` to use for the comparison. +/// +/// * `eq` - A boolean flag. If `true`, the function computes equality; if `false`, it computes inequality. +/// +/// # Returns /// -/// $ARGS: slice of arrays, each with $ARRAY_TYPE -/// $ARRAY_TYPE: the type of the list elements -/// $BUILDER_TYPE: the type of ArrayBuilder for the list elements +/// Returns a `Result` representing the comparison results. The result may contain an error if there are issues with the computation. /// -/// Returns: a ListArray where the elements each have the same type as -/// $ARRAY_TYPE and each element have a length of $ARGS.len() -macro_rules! array { - ($ARGS:expr, $ARRAY_TYPE:ident, $BUILDER_TYPE:ident) => {{ - let builder = new_builder!($BUILDER_TYPE, $ARGS[0].len()); - let mut builder = - ListBuilder::<$BUILDER_TYPE>::with_capacity(builder, $ARGS.len()); - - let num_rows = $ARGS[0].len(); - assert!( - $ARGS.iter().all(|a| a.len() == num_rows), - "all arguments must have the same number of rows" +/// # Example +/// +/// ```text +/// compare_element_to_list( +/// [1, 2, 3], [1, 2, 3], 0, true => [true, false, false] +/// [1, 2, 3, 3, 2, 1], [1, 2, 3], 1, true => [false, true, false, false, true, false] +/// +/// [[1, 2, 3], [2, 3, 4], [3, 4, 5]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 0, true => [true, false, false] +/// [[1, 2, 3], [2, 3, 4], [2, 3, 4]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 1, false => [true, false, false] +/// ) +/// ``` +fn compare_element_to_list( + list_array_row: &dyn Array, + element_array: &dyn Array, + row_index: usize, + eq: bool, +) -> Result { + if list_array_row.data_type() != element_array.data_type() { + return exec_err!( + "compare_element_to_list received incompatible types: '{:?}' and '{:?}'.", + list_array_row.data_type(), + element_array.data_type() ); + } + + let indices = UInt32Array::from(vec![row_index as u32]); + let element_array_row = arrow::compute::take(element_array, &indices, None)?; + + // Compute all positions in list_row_array (that is itself an + // array) that are equal to `from_array_row` + let res = match element_array_row.data_type() { + // arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop + DataType::List(_) => { + // compare each element of the from array + let element_array_row_inner = as_list_array(&element_array_row)?.value(0); + let list_array_row_inner = as_list_array(list_array_row)?; - // for each entry in the array - for index in 0..num_rows { - // for each column - for arg in $ARGS { - match arg.as_any().downcast_ref::<$ARRAY_TYPE>() { - // Copy the source array value into the target ListArray - Some(arr) => { - if arr.is_valid(index) { - builder.values().append_value(arr.value(index)); + list_array_row_inner + .iter() + // compare element by element the current row of list_array + .map(|row| { + row.map(|row| { + if eq { + row.eq(&element_array_row_inner) } else { - builder.values().append_null(); + row.ne(&element_array_row_inner) } - } - None => match arg.as_any().downcast_ref::() { - Some(arr) => { - for _ in 0..arr.len() { - builder.values().append_null(); - } + }) + }) + .collect::() + } + DataType::LargeList(_) => { + // compare each element of the from array + let element_array_row_inner = + as_large_list_array(&element_array_row)?.value(0); + let list_array_row_inner = as_large_list_array(list_array_row)?; + + list_array_row_inner + .iter() + // compare element by element the current row of list_array + .map(|row| { + row.map(|row| { + if eq { + row.eq(&element_array_row_inner) + } else { + row.ne(&element_array_row_inner) } - None => return internal_err!("failed to downcast"), - }, - } + }) + }) + .collect::() + } + _ => { + let element_arr = Scalar::new(element_array_row); + // use not_distinct so we can compare NULL + if eq { + arrow_ord::cmp::not_distinct(&list_array_row, &element_arr)? + } else { + arrow_ord::cmp::distinct(&list_array_row, &element_arr)? } - builder.append(true); } - Arc::new(builder.finish()) - }}; + }; + + Ok(res) } /// Returns the length of a concrete array dimension @@ -159,36 +184,11 @@ fn compute_array_length( value = downcast_arg!(value, ListArray).value(0); current_dimension += 1; } - _ => return Ok(None), - } - } -} - -/// Returns the dimension of the array -fn compute_array_ndims(arr: Option) -> Result> { - Ok(compute_array_ndims_with_datatype(arr)?.0) -} - -/// Returns the dimension and the datatype of elements of the array -fn compute_array_ndims_with_datatype( - arr: Option, -) -> Result<(Option, DataType)> { - let mut res: u64 = 1; - let mut value = match arr { - Some(arr) => arr, - None => return Ok((None, DataType::Null)), - }; - if value.is_empty() { - return Ok((None, DataType::Null)); - } - - loop { - match value.data_type() { - DataType::List(..) => { - value = downcast_arg!(value, ListArray).value(0); - res += 1; + DataType::LargeList(..) => { + value = downcast_arg!(value, LargeListArray).value(0); + current_dimension += 1; } - data_type => return Ok((Some(res), data_type.clone())), + _ => return Ok(None), } } } @@ -217,10 +217,10 @@ fn compute_array_dims(arr: Option) -> Result>>> fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> { let data_type = args[0].data_type(); - if !args - .iter() - .all(|arg| arg.data_type().equals_datatype(data_type)) - { + if !args.iter().all(|arg| { + arg.data_type().equals_datatype(data_type) + || arg.data_type().equals_datatype(&DataType::Null) + }) { let types = args.iter().map(|arg| arg.data_type()).collect::>(); return plan_err!("{name} received incompatible types: '{types:?}'."); } @@ -269,7 +269,7 @@ macro_rules! call_array_function { } /// Convert one or more [`ArrayRef`] of the same type into a -/// `ListArray` +/// `ListArray` or 'LargeListArray' depending on the offset size. /// /// # Example (non nested) /// @@ -308,94 +308,56 @@ macro_rules! call_array_function { /// └──────────────┘ └──────────────┘ └─────────────────────────────┘ /// col1 col2 output /// ``` -fn array_array(args: &[ArrayRef], data_type: DataType) -> Result { +fn array_array( + args: &[ArrayRef], + data_type: DataType, +) -> Result { // do not accept 0 arguments. if args.is_empty() { return plan_err!("Array requires at least one argument"); } - let res = match data_type { - DataType::List(..) => { - let row_count = args[0].len(); - let column_count = args.len(); - let mut list_arrays = vec![]; - let mut list_array_lengths = vec![]; - let mut list_valid = BooleanBufferBuilder::new(row_count); - // Construct ListArray per row - for index in 0..row_count { - let mut arrays = vec![]; - let mut array_lengths = vec![]; - let mut valid = BooleanBufferBuilder::new(column_count); - for arg in args { - if arg.as_any().downcast_ref::().is_some() { - array_lengths.push(0); - valid.append(false); - } else { - let list_arr = as_list_array(arg)?; - let arr = list_arr.value(index); - array_lengths.push(arr.len()); - arrays.push(arr); - valid.append(true); - } - } - if arrays.is_empty() { - list_valid.append(false); - list_array_lengths.push(0); - } else { - let buffer = valid.finish(); - // Assume all list arrays have the same data type - let data_type = arrays[0].data_type(); - let field = Arc::new(Field::new("item", data_type.to_owned(), true)); - let elements = arrays.iter().map(|x| x.as_ref()).collect::>(); - let values = arrow::compute::concat(elements.as_slice())?; - let list_arr = ListArray::new( - field, - OffsetBuffer::from_lengths(array_lengths), - values, - Some(NullBuffer::new(buffer)), - ); - list_valid.append(true); - list_array_lengths.push(list_arr.len()); - list_arrays.push(list_arr); - } + let mut data = vec![]; + let mut total_len = 0; + for arg in args { + let arg_data = if arg.as_any().is::() { + ArrayData::new_empty(&data_type) + } else { + arg.to_data() + }; + total_len += arg_data.len(); + data.push(arg_data); + } + + let mut offsets: Vec = Vec::with_capacity(total_len); + offsets.push(O::usize_as(0)); + + let capacity = Capacities::Array(total_len); + let data_ref = data.iter().collect::>(); + let mut mutable = MutableArrayData::with_capacities(data_ref, true, capacity); + + let num_rows = args[0].len(); + for row_idx in 0..num_rows { + for (arr_idx, arg) in args.iter().enumerate() { + if !arg.as_any().is::() + && !arg.is_null(row_idx) + && arg.is_valid(row_idx) + { + mutable.extend(arr_idx, row_idx, row_idx + 1); + } else { + mutable.extend_nulls(1); } - // Construct ListArray for all rows - let buffer = list_valid.finish(); - // Assume all list arrays have the same data type - let data_type = list_arrays[0].data_type(); - let field = Arc::new(Field::new("item", data_type.to_owned(), true)); - let elements = list_arrays - .iter() - .map(|x| x as &dyn Array) - .collect::>(); - let values = arrow::compute::concat(elements.as_slice())?; - let list_arr = ListArray::new( - field, - OffsetBuffer::from_lengths(list_array_lengths), - values, - Some(NullBuffer::new(buffer)), - ); - Arc::new(list_arr) - } - DataType::Utf8 => array!(args, StringArray, StringBuilder), - DataType::LargeUtf8 => array!(args, LargeStringArray, LargeStringBuilder), - DataType::Boolean => array!(args, BooleanArray, BooleanBuilder), - DataType::Float32 => array!(args, Float32Array, Float32Builder), - DataType::Float64 => array!(args, Float64Array, Float64Builder), - DataType::Int8 => array!(args, Int8Array, Int8Builder), - DataType::Int16 => array!(args, Int16Array, Int16Builder), - DataType::Int32 => array!(args, Int32Array, Int32Builder), - DataType::Int64 => array!(args, Int64Array, Int64Builder), - DataType::UInt8 => array!(args, UInt8Array, UInt8Builder), - DataType::UInt16 => array!(args, UInt16Array, UInt16Builder), - DataType::UInt32 => array!(args, UInt32Array, UInt32Builder), - DataType::UInt64 => array!(args, UInt64Array, UInt64Builder), - data_type => { - return not_impl_err!("Array is not implemented for type '{data_type:?}'.") } - }; + offsets.push(O::usize_as(mutable.len())); + } + let data = mutable.freeze(); - Ok(res) + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new("item", data_type, true)), + OffsetBuffer::new(offsets.into()), + arrow_array::make_array(data), + None, + )?)) } /// `make_array` SQL function @@ -412,326 +374,709 @@ pub fn make_array(arrays: &[ArrayRef]) -> Result { match data_type { // Either an empty array or all nulls: DataType::Null => { - let array = new_null_array(&DataType::Null, arrays.len()); + let array = + new_null_array(&DataType::Null, arrays.iter().map(|a| a.len()).sum()); Ok(Arc::new(array_into_list_array(array))) } - data_type => array_array(arrays, data_type), + DataType::LargeList(..) => array_array::(arrays, data_type), + _ => array_array::(arrays, data_type), } } -fn return_empty(return_null: bool, data_type: DataType) -> Arc { - if return_null { - new_null_array(&data_type, 1) - } else { - new_empty_array(&data_type) +fn general_array_element( + array: &GenericListArray, + indexes: &Int64Array, +) -> Result +where + i64: TryInto, +{ + let values = array.values(); + let original_data = values.to_data(); + let capacity = Capacities::Array(original_data.len()); + + // use_nulls: true, we don't construct List for array_element, so we need explicit nulls. + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], true, capacity); + + fn adjusted_array_index(index: i64, len: O) -> Result> + where + i64: TryInto, + { + let index: O = index.try_into().map_err(|_| { + DataFusionError::Execution(format!( + "array_element got invalid index: {}", + index + )) + })?; + // 0 ~ len - 1 + let adjusted_zero_index = if index < O::usize_as(0) { + index + len + } else { + index - O::usize_as(1) + }; + + if O::usize_as(0) <= adjusted_zero_index && adjusted_zero_index < len { + Ok(Some(adjusted_zero_index)) + } else { + // Out of bounds + Ok(None) + } + } + + for (row_index, offset_window) in array.offsets().windows(2).enumerate() { + let start = offset_window[0]; + let end = offset_window[1]; + let len = end - start; + + // array is null + if len == O::usize_as(0) { + mutable.extend_nulls(1); + continue; + } + + let index = adjusted_array_index::(indexes.value(row_index), len)?; + + if let Some(index) = index { + let start = start.as_usize() + index.as_usize(); + mutable.extend(0, start, start + 1_usize); + } else { + // Index out of bounds + mutable.extend_nulls(1); + } } + + let data = mutable.freeze(); + Ok(arrow_array::make_array(data)) } -macro_rules! list_slice { - ($ARRAY:expr, $I:expr, $J:expr, $RETURN_ELEMENT:expr, $ARRAY_TYPE:ident) => {{ - let array = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); - if $I == 0 && $J == 0 || $ARRAY.is_empty() { - return return_empty($RETURN_ELEMENT, $ARRAY.data_type().clone()); +/// array_element SQL function +/// +/// There are two arguments for array_element, the first one is the array, the second one is the 1-indexed index. +/// `array_element(array, index)` +/// +/// For example: +/// > array_element(\[1, 2, 3], 2) -> 2 +pub fn array_element(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_element needs two arguments"); + } + + match &args[0].data_type() { + DataType::List(_) => { + let array = as_list_array(&args[0])?; + let indexes = as_int64_array(&args[1])?; + general_array_element::(array, indexes) + } + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + let indexes = as_int64_array(&args[1])?; + general_array_element::(array, indexes) } + _ => exec_err!( + "array_element does not support type: {:?}", + args[0].data_type() + ), + } +} - let i = if $I < 0 { - if $I.abs() as usize > array.len() { - return return_empty(true, $ARRAY.data_type().clone()); +fn general_except( + l: &GenericListArray, + r: &GenericListArray, + field: &FieldRef, +) -> Result> { + let converter = RowConverter::new(vec![SortField::new(l.value_type())])?; + + let l_values = l.values().to_owned(); + let r_values = r.values().to_owned(); + let l_values = converter.convert_columns(&[l_values])?; + let r_values = converter.convert_columns(&[r_values])?; + + let mut offsets = Vec::::with_capacity(l.len() + 1); + offsets.push(OffsetSize::usize_as(0)); + + let mut rows = Vec::with_capacity(l_values.num_rows()); + let mut dedup = HashSet::new(); + + for (l_w, r_w) in l.offsets().windows(2).zip(r.offsets().windows(2)) { + let l_slice = l_w[0].as_usize()..l_w[1].as_usize(); + let r_slice = r_w[0].as_usize()..r_w[1].as_usize(); + for i in r_slice { + let right_row = r_values.row(i); + dedup.insert(right_row); + } + for i in l_slice { + let left_row = l_values.row(i); + if dedup.insert(left_row) { + rows.push(left_row); } + } + + offsets.push(OffsetSize::usize_as(rows.len())); + dedup.clear(); + } + + if let Some(values) = converter.convert_rows(rows)?.first() { + Ok(GenericListArray::::new( + field.to_owned(), + OffsetBuffer::new(offsets.into()), + values.to_owned(), + l.nulls().cloned(), + )) + } else { + internal_err!("array_except failed to convert rows") + } +} + +pub fn array_except(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return internal_err!("array_except needs two arguments"); + } + + let array1 = &args[0]; + let array2 = &args[1]; - (array.len() as i64 + $I + 1) as usize + match (array1.data_type(), array2.data_type()) { + (DataType::Null, _) | (_, DataType::Null) => Ok(array1.to_owned()), + (DataType::List(field), DataType::List(_)) => { + check_datatypes("array_except", &[array1, array2])?; + let list1 = array1.as_list::(); + let list2 = array2.as_list::(); + let result = general_except::(list1, list2, field)?; + Ok(Arc::new(result)) + } + (DataType::LargeList(field), DataType::LargeList(_)) => { + check_datatypes("array_except", &[array1, array2])?; + let list1 = array1.as_list::(); + let list2 = array2.as_list::(); + let result = general_except::(list1, list2, field)?; + Ok(Arc::new(result)) + } + (dt1, dt2) => { + internal_err!("array_except got unexpected types: {dt1:?} and {dt2:?}") + } + } +} + +/// array_slice SQL function +/// +/// We follow the behavior of array_slice in DuckDB +/// Note that array_slice is 1-indexed. And there are two additional arguments `from` and `to` in array_slice. +/// +/// > array_slice(array, from, to) +/// +/// Positive index is treated as the index from the start of the array. If the +/// `from` index is smaller than 1, it is treated as 1. If the `to` index is larger than the +/// length of the array, it is treated as the length of the array. +/// +/// Negative index is treated as the index from the end of the array. If the index +/// is larger than the length of the array, it is NOT VALID, either in `from` or `to`. +/// The `to` index is exclusive like python slice syntax. +/// +/// See test cases in `array.slt` for more details. +pub fn array_slice(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return exec_err!("array_slice needs three arguments"); + } + + let array_data_type = args[0].data_type(); + match array_data_type { + DataType::List(_) => { + let array = as_list_array(&args[0])?; + let from_array = as_int64_array(&args[1])?; + let to_array = as_int64_array(&args[2])?; + general_array_slice::(array, from_array, to_array) + } + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + let from_array = as_int64_array(&args[1])?; + let to_array = as_int64_array(&args[2])?; + general_array_slice::(array, from_array, to_array) + } + _ => exec_err!("array_slice does not support type: {:?}", array_data_type), + } +} + +fn general_array_slice( + array: &GenericListArray, + from_array: &Int64Array, + to_array: &Int64Array, +) -> Result +where + i64: TryInto, +{ + let values = array.values(); + let original_data = values.to_data(); + let capacity = Capacities::Array(original_data.len()); + + // use_nulls: false, we don't need nulls but empty array for array_slice, so we don't need explicit nulls but adjust offset to indicate nulls. + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], false, capacity); + + // We have the slice syntax compatible with DuckDB v0.8.1. + // The rule `adjusted_from_index` and `adjusted_to_index` follows the rule of array_slice in duckdb. + + fn adjusted_from_index(index: i64, len: O) -> Result> + where + i64: TryInto, + { + // 0 ~ len - 1 + let adjusted_zero_index = if index < 0 { + if let Ok(index) = index.try_into() { + index + len + } else { + return exec_err!("array_slice got invalid index: {}", index); + } } else { - if $I == 0 { - 1 + // array_slice(arr, 1, to) is the same as array_slice(arr, 0, to) + if let Ok(index) = index.try_into() { + std::cmp::max(index - O::usize_as(1), O::usize_as(0)) } else { - $I as usize + return exec_err!("array_slice got invalid index: {}", index); } }; - let j = if $J < 0 { - if $J.abs() as usize > array.len() { - return return_empty(true, $ARRAY.data_type().clone()); - } - if $RETURN_ELEMENT { - (array.len() as i64 + $J + 1) as usize + if O::usize_as(0) <= adjusted_zero_index && adjusted_zero_index < len { + Ok(Some(adjusted_zero_index)) + } else { + // Out of bounds + Ok(None) + } + } + + fn adjusted_to_index(index: i64, len: O) -> Result> + where + i64: TryInto, + { + // 0 ~ len - 1 + let adjusted_zero_index = if index < 0 { + // array_slice in duckdb with negative to_index is python-like, so index itself is exclusive + if let Ok(index) = index.try_into() { + index + len - O::usize_as(1) } else { - (array.len() as i64 + $J) as usize + return exec_err!("array_slice got invalid index: {}", index); } } else { - if $J == 0 { - 1 + // array_slice(arr, from, len + 1) is the same as array_slice(arr, from, len) + if let Ok(index) = index.try_into() { + std::cmp::min(index - O::usize_as(1), len - O::usize_as(1)) } else { - if $J as usize > array.len() { - array.len() - } else { - $J as usize - } + return exec_err!("array_slice got invalid index: {}", index); } }; - if i > j || i as usize > $ARRAY.len() { - return_empty($RETURN_ELEMENT, $ARRAY.data_type().clone()) + if O::usize_as(0) <= adjusted_zero_index && adjusted_zero_index < len { + Ok(Some(adjusted_zero_index)) } else { - Arc::new(array.slice((i - 1), (j + 1 - i))) + // Out of bounds + Ok(None) } - }}; -} + } -macro_rules! slice { - ($ARRAY:expr, $KEY:expr, $EXTRA_KEY:expr, $RETURN_ELEMENT:expr, $ARRAY_TYPE:ident) => {{ - let sliced_array: Vec> = $ARRAY - .iter() - .zip($KEY.iter()) - .zip($EXTRA_KEY.iter()) - .map(|((arr, i), j)| match (arr, i, j) { - (Some(arr), Some(i), Some(j)) => { - list_slice!(arr, i, j, $RETURN_ELEMENT, $ARRAY_TYPE) - } - (Some(arr), None, Some(j)) => { - list_slice!(arr, 1i64, j, $RETURN_ELEMENT, $ARRAY_TYPE) - } - (Some(arr), Some(i), None) => { - list_slice!(arr, i, arr.len() as i64, $RETURN_ELEMENT, $ARRAY_TYPE) - } - (Some(arr), None, None) if !$RETURN_ELEMENT => arr, - _ => return_empty($RETURN_ELEMENT, $ARRAY.value_type().clone()), - }) - .collect(); + let mut offsets = vec![O::usize_as(0)]; + + for (row_index, offset_window) in array.offsets().windows(2).enumerate() { + let start = offset_window[0]; + let end = offset_window[1]; + let len = end - start; - // concat requires input of at least one array - if sliced_array.is_empty() { - Ok(return_empty($RETURN_ELEMENT, $ARRAY.value_type())) + // len 0 indicate array is null, return empty array in this row. + if len == O::usize_as(0) { + offsets.push(offsets[row_index]); + continue; + } + + // If index is null, we consider it as the minimum / maximum index of the array. + let from_index = if from_array.is_null(row_index) { + Some(O::usize_as(0)) } else { - let vec = sliced_array - .iter() - .map(|a| a.as_ref()) - .collect::>(); - let mut i: i32 = 0; - let mut offsets = vec![i]; - offsets.extend( - vec.iter() - .map(|a| { - i += a.len() as i32; - i - }) - .collect::>(), - ); - let values = compute::concat(vec.as_slice()).unwrap(); + adjusted_from_index::(from_array.value(row_index), len)? + }; - if $RETURN_ELEMENT { - Ok(values) + let to_index = if to_array.is_null(row_index) { + Some(len - O::usize_as(1)) + } else { + adjusted_to_index::(to_array.value(row_index), len)? + }; + + if let (Some(from), Some(to)) = (from_index, to_index) { + if from <= to { + assert!(start + to <= end); + mutable.extend( + 0, + (start + from).to_usize().unwrap(), + (start + to + O::usize_as(1)).to_usize().unwrap(), + ); + offsets.push(offsets[row_index] + (to - from + O::usize_as(1))); } else { - let field = - Arc::new(Field::new("item", $ARRAY.value_type().clone(), true)); - Ok(Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - values, - None, - )?)) + // invalid range, return empty array + offsets.push(offsets[row_index]); } + } else { + // invalid range, return empty array + offsets.push(offsets[row_index]); } - }}; + } + + let data = mutable.freeze(); + + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new("item", array.value_type(), true)), + OffsetBuffer::::new(offsets.into()), + arrow_array::make_array(data), + None, + )?)) } -fn define_array_slice( - list_array: &ListArray, - key: &Int64Array, - extra_key: &Int64Array, - return_element: bool, -) -> Result { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - slice!(list_array, key, extra_key, return_element, $ARRAY_TYPE) - }; - } - call_array_function!(list_array.value_type(), true) +fn general_pop_front_list( + array: &GenericListArray, +) -> Result +where + i64: TryInto, +{ + let from_array = Int64Array::from(vec![2; array.len()]); + let to_array = Int64Array::from( + array + .iter() + .map(|arr| arr.map_or(0, |arr| arr.len() as i64)) + .collect::>(), + ); + general_array_slice::(array, &from_array, &to_array) } -pub fn array_element(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; - let key = as_int64_array(&args[1])?; - define_array_slice(list_array, key, key, true) +fn general_pop_back_list( + array: &GenericListArray, +) -> Result +where + i64: TryInto, +{ + let from_array = Int64Array::from(vec![1; array.len()]); + let to_array = Int64Array::from( + array + .iter() + .map(|arr| arr.map_or(0, |arr| arr.len() as i64 - 1)) + .collect::>(), + ); + general_array_slice::(array, &from_array, &to_array) } -pub fn array_slice(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; - let key = as_int64_array(&args[1])?; - let extra_key = as_int64_array(&args[2])?; - define_array_slice(list_array, key, extra_key, false) +/// array_pop_front SQL function +pub fn array_pop_front(args: &[ArrayRef]) -> Result { + let array_data_type = args[0].data_type(); + match array_data_type { + DataType::List(_) => { + let array = as_list_array(&args[0])?; + general_pop_front_list::(array) + } + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + general_pop_front_list::(array) + } + _ => exec_err!( + "array_pop_front does not support type: {:?}", + array_data_type + ), + } } +/// array_pop_back SQL function pub fn array_pop_back(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; - let key = vec![0; list_array.len()]; - let extra_key: Vec<_> = list_array - .iter() - .map(|x| x.map_or(0, |arr| arr.len() as i64 - 1)) - .collect(); + if args.len() != 1 { + return exec_err!("array_pop_back needs one argument"); + } + + let array_data_type = args[0].data_type(); + match array_data_type { + DataType::List(_) => { + let array = as_list_array(&args[0])?; + general_pop_back_list::(array) + } + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + general_pop_back_list::(array) + } + _ => exec_err!( + "array_pop_back does not support type: {:?}", + array_data_type + ), + } +} - define_array_slice( - list_array, - &Int64Array::from(key), - &Int64Array::from(extra_key), +/// Appends or prepends elements to a ListArray. +/// +/// This function takes a ListArray, an ArrayRef, a FieldRef, and a boolean flag +/// indicating whether to append or prepend the elements. It returns a `Result` +/// representing the resulting ListArray after the operation. +/// +/// # Arguments +/// +/// * `list_array` - A reference to the ListArray to which elements will be appended/prepended. +/// * `element_array` - A reference to the Array containing elements to be appended/prepended. +/// * `field` - A reference to the Field describing the data type of the arrays. +/// * `is_append` - A boolean flag indicating whether to append (`true`) or prepend (`false`) elements. +/// +/// # Examples +/// +/// generic_append_and_prepend( +/// [1, 2, 3], 4, append => [1, 2, 3, 4] +/// 5, [6, 7, 8], prepend => [5, 6, 7, 8] +/// ) +fn generic_append_and_prepend( + list_array: &GenericListArray, + element_array: &ArrayRef, + data_type: &DataType, + is_append: bool, +) -> Result +where + i64: TryInto, +{ + let mut offsets = vec![O::usize_as(0)]; + let values = list_array.values(); + let original_data = values.to_data(); + let element_data = element_array.to_data(); + let capacity = Capacities::Array(original_data.len() + element_data.len()); + + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data, &element_data], false, - ) -} - -macro_rules! append { - ($ARRAY:expr, $ELEMENT:expr, $ARRAY_TYPE:ident) => {{ - let mut offsets: Vec = vec![0]; - let mut values = - downcast_arg!(new_empty_array($ELEMENT.data_type()), $ARRAY_TYPE).clone(); - - let element = downcast_arg!($ELEMENT, $ARRAY_TYPE); - for (arr, el) in $ARRAY.iter().zip(element.iter()) { - let last_offset: i32 = offsets.last().copied().ok_or_else(|| { - DataFusionError::Internal(format!("offsets should not be empty")) - })?; - match arr { - Some(arr) => { - let child_array = downcast_arg!(arr, $ARRAY_TYPE); - values = downcast_arg!( - compute::concat(&[ - &values, - child_array, - &$ARRAY_TYPE::from(vec![el]) - ])? - .clone(), - $ARRAY_TYPE - ) - .clone(); - offsets.push(last_offset + child_array.len() as i32 + 1i32); - } - None => { - values = downcast_arg!( - compute::concat(&[ - &values, - &$ARRAY_TYPE::from(vec![el.clone()]) - ])? - .clone(), - $ARRAY_TYPE - ) - .clone(); - offsets.push(last_offset + 1i32); - } - } + capacity, + ); + + let values_index = 0; + let element_index = 1; + + for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { + let start = offset_window[0].to_usize().unwrap(); + let end = offset_window[1].to_usize().unwrap(); + if is_append { + mutable.extend(values_index, start, end); + mutable.extend(element_index, row_index, row_index + 1); + } else { + mutable.extend(element_index, row_index, row_index + 1); + mutable.extend(values_index, start, end); } + offsets.push(offsets[row_index] + O::usize_as(end - start + 1)); + } - let field = Arc::new(Field::new("item", $ELEMENT.data_type().clone(), true)); + let data = mutable.freeze(); - Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - Arc::new(values), - None, - )?) - }}; + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new("item", data_type.to_owned(), true)), + OffsetBuffer::new(offsets.into()), + arrow_array::make_array(data), + None, + )?)) } -/// Array_append SQL function -pub fn array_append(args: &[ArrayRef]) -> Result { - let arr = as_list_array(&args[0])?; - let element = &args[1]; +/// Generates an array of integers from start to stop with a given step. +/// +/// This function takes 1 to 3 ArrayRefs as arguments, representing start, stop, and step values. +/// It returns a `Result` representing the resulting ListArray after the operation. +/// +/// # Arguments +/// +/// * `args` - An array of 1 to 3 ArrayRefs representing start, stop, and step(step value can not be zero.) values. +/// +/// # Examples +/// +/// gen_range(3) => [0, 1, 2] +/// gen_range(1, 4) => [1, 2, 3] +/// gen_range(1, 7, 2) => [1, 3, 5] +pub fn gen_range(args: &[ArrayRef]) -> Result { + let (start_array, stop_array, step_array) = match args.len() { + 1 => (None, as_int64_array(&args[0])?, None), + 2 => ( + Some(as_int64_array(&args[0])?), + as_int64_array(&args[1])?, + None, + ), + 3 => ( + Some(as_int64_array(&args[0])?), + as_int64_array(&args[1])?, + Some(as_int64_array(&args[2])?), + ), + _ => return internal_err!("gen_range expects 1 to 3 arguments"), + }; - check_datatypes("array_append", &[arr.values(), element])?; - let res = match arr.value_type() { - DataType::List(_) => concat_internal(args)?, - DataType::Null => return make_array(&[element.to_owned()]), - data_type => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - append!(arr, element, $ARRAY_TYPE) - }; - } - call_array_function!(data_type, false) + let mut values = vec![]; + let mut offsets = vec![0]; + for (idx, stop) in stop_array.iter().enumerate() { + let stop = stop.unwrap_or(0); + let start = start_array.as_ref().map(|arr| arr.value(idx)).unwrap_or(0); + let step = step_array.as_ref().map(|arr| arr.value(idx)).unwrap_or(1); + if step == 0 { + return exec_err!("step can't be 0 for function range(start [, stop, step]"); + } + if step < 0 { + // Decreasing range + values.extend((stop + 1..start + 1).rev().step_by((-step) as usize)); + } else { + // Increasing range + values.extend((start..stop).step_by(step as usize)); } - }; - Ok(res) + offsets.push(values.len() as i32); + } + let arr = Arc::new(ListArray::try_new( + Arc::new(Field::new("item", DataType::Int64, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(Int64Array::from(values)), + None, + )?); + Ok(arr) } -macro_rules! prepend { - ($ARRAY:expr, $ELEMENT:expr, $ARRAY_TYPE:ident) => {{ - let mut offsets: Vec = vec![0]; - let mut values = - downcast_arg!(new_empty_array($ELEMENT.data_type()), $ARRAY_TYPE).clone(); - - let element = downcast_arg!($ELEMENT, $ARRAY_TYPE); - for (arr, el) in $ARRAY.iter().zip(element.iter()) { - let last_offset: i32 = offsets.last().copied().ok_or_else(|| { - DataFusionError::Internal(format!("offsets should not be empty")) - })?; - match arr { - Some(arr) => { - let child_array = downcast_arg!(arr, $ARRAY_TYPE); - values = downcast_arg!( - compute::concat(&[ - &values, - &$ARRAY_TYPE::from(vec![el]), - child_array - ])? - .clone(), - $ARRAY_TYPE - ) - .clone(); - offsets.push(last_offset + child_array.len() as i32 + 1i32); - } - None => { - values = downcast_arg!( - compute::concat(&[ - &values, - &$ARRAY_TYPE::from(vec![el.clone()]) - ])? - .clone(), - $ARRAY_TYPE - ) - .clone(); - offsets.push(last_offset + 1i32); - } - } +/// Array_sort SQL function +pub fn array_sort(args: &[ArrayRef]) -> Result { + if args.is_empty() || args.len() > 3 { + return exec_err!("array_sort expects one to three arguments"); + } + + let sort_option = match args.len() { + 1 => None, + 2 => { + let sort = as_string_array(&args[1])?.value(0); + Some(SortOptions { + descending: order_desc(sort)?, + nulls_first: true, + }) + } + 3 => { + let sort = as_string_array(&args[1])?.value(0); + let nulls_first = as_string_array(&args[2])?.value(0); + Some(SortOptions { + descending: order_desc(sort)?, + nulls_first: order_nulls_first(nulls_first)?, + }) } + _ => return internal_err!("array_sort expects 1 to 3 arguments"), + }; - let field = Arc::new(Field::new("item", $ELEMENT.data_type().clone(), true)); + let list_array = as_list_array(&args[0])?; + let row_count = list_array.len(); - Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - Arc::new(values), - None, - )?) - }}; + let mut array_lengths = vec![]; + let mut arrays = vec![]; + let mut valid = BooleanBufferBuilder::new(row_count); + for i in 0..row_count { + if list_array.is_null(i) { + array_lengths.push(0); + valid.append(false); + } else { + let arr_ref = list_array.value(i); + let arr_ref = arr_ref.as_ref(); + + let sorted_array = compute::sort(arr_ref, sort_option)?; + array_lengths.push(sorted_array.len()); + arrays.push(sorted_array); + valid.append(true); + } + } + + // Assume all arrays have the same data type + let data_type = list_array.value_type(); + let buffer = valid.finish(); + + let elements = arrays + .iter() + .map(|a| a.as_ref()) + .collect::>(); + + let list_arr = ListArray::new( + Arc::new(Field::new("item", data_type, true)), + OffsetBuffer::from_lengths(array_lengths), + Arc::new(compute::concat(elements.as_slice())?), + Some(NullBuffer::new(buffer)), + ); + Ok(Arc::new(list_arr)) } -/// Array_prepend SQL function -pub fn array_prepend(args: &[ArrayRef]) -> Result { - let element = &args[0]; - let arr = as_list_array(&args[1])?; +fn order_desc(modifier: &str) -> Result { + match modifier.to_uppercase().as_str() { + "DESC" => Ok(true), + "ASC" => Ok(false), + _ => internal_err!("the second parameter of array_sort expects DESC or ASC"), + } +} + +fn order_nulls_first(modifier: &str) -> Result { + match modifier.to_uppercase().as_str() { + "NULLS FIRST" => Ok(true), + "NULLS LAST" => Ok(false), + _ => internal_err!( + "the third parameter of array_sort expects NULLS FIRST or NULLS LAST" + ), + } +} - check_datatypes("array_prepend", &[element, arr.values()])?; - let res = match arr.value_type() { - DataType::List(_) => concat_internal(args)?, - DataType::Null => return make_array(&[element.to_owned()]), +fn general_append_and_prepend( + args: &[ArrayRef], + is_append: bool, +) -> Result +where + i64: TryInto, +{ + let (list_array, element_array) = if is_append { + let list_array = as_generic_list_array::(&args[0])?; + let element_array = &args[1]; + check_datatypes("array_append", &[element_array, list_array.values()])?; + (list_array, element_array) + } else { + let list_array = as_generic_list_array::(&args[1])?; + let element_array = &args[0]; + check_datatypes("array_prepend", &[list_array.values(), element_array])?; + (list_array, element_array) + }; + + let res = match list_array.value_type() { + DataType::List(_) => concat_internal::(args)?, + DataType::LargeList(_) => concat_internal::(args)?, + DataType::Null => { + return make_array(&[ + list_array.values().to_owned(), + element_array.to_owned(), + ]); + } data_type => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - prepend!(arr, element, $ARRAY_TYPE) - }; - } - call_array_function!(data_type, false) + return generic_append_and_prepend::( + list_array, + element_array, + &data_type, + is_append, + ); } }; Ok(res) } +/// Array_append SQL function +pub fn array_append(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_append expects two arguments"); + } + + match args[0].data_type() { + DataType::LargeList(_) => general_append_and_prepend::(args, true), + _ => general_append_and_prepend::(args, true), + } +} + +/// Array_prepend SQL function +pub fn array_prepend(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_prepend expects two arguments"); + } + + match args[1].data_type() { + DataType::LargeList(_) => general_append_and_prepend::(args, false), + _ => general_append_and_prepend::(args, false), + } +} + fn align_array_dimensions(args: Vec) -> Result> { let args_ndim = args .iter() - .map(|arg| compute_array_ndims(Some(arg.to_owned()))) - .collect::>>()? - .into_iter() - .map(|x| x.unwrap_or(0)) + .map(|arg| datafusion_common::utils::list_ndims(arg.data_type())) .collect::>(); let max_ndim = args_ndim.iter().max().unwrap_or(&0); @@ -765,11 +1110,13 @@ fn align_array_dimensions(args: Vec) -> Result> { } // Concatenate arrays on the same row. -fn concat_internal(args: &[ArrayRef]) -> Result { +fn concat_internal(args: &[ArrayRef]) -> Result { let args = align_array_dimensions(args.to_vec())?; - let list_arrays = - downcast_vec!(args, ListArray).collect::>>()?; + let list_arrays = args + .iter() + .map(|arg| as_generic_list_array::(arg)) + .collect::>>()?; // Assume number of rows is the same for all arrays let row_count = list_arrays[0].len(); @@ -801,7 +1148,7 @@ fn concat_internal(args: &[ArrayRef]) -> Result { .collect::>(); // Concatenated array on i-th row - let concated_array = arrow::compute::concat(elements.as_slice())?; + let concated_array = compute::concat(elements.as_slice())?; array_lengths.push(concated_array.len()); arrays.push(concated_array); valid.append(true); @@ -816,158 +1163,57 @@ fn concat_internal(args: &[ArrayRef]) -> Result { .map(|a| a.as_ref()) .collect::>(); - let list_arr = ListArray::new( + let list_arr = GenericListArray::::new( Arc::new(Field::new("item", data_type, true)), OffsetBuffer::from_lengths(array_lengths), - Arc::new(arrow::compute::concat(elements.as_slice())?), + Arc::new(compute::concat(elements.as_slice())?), Some(NullBuffer::new(buffer)), ); + Ok(Arc::new(list_arr)) } /// Array_concat/Array_cat SQL function pub fn array_concat(args: &[ArrayRef]) -> Result { + if args.is_empty() { + return exec_err!("array_concat expects at least one arguments"); + } + let mut new_args = vec![]; for arg in args { - let (ndim, lower_data_type) = - compute_array_ndims_with_datatype(Some(arg.clone()))?; - if ndim.is_none() || ndim == Some(1) { - return not_impl_err!("Array is not type '{lower_data_type:?}'."); - } else if !lower_data_type.equals_datatype(&DataType::Null) { + let ndim = list_ndims(arg.data_type()); + let base_type = datafusion_common::utils::base_type(arg.data_type()); + if ndim == 0 { + return not_impl_err!("Array is not type '{base_type:?}'."); + } else if !base_type.eq(&DataType::Null) { new_args.push(arg.clone()); } } - concat_internal(new_args.as_slice()) -} - -macro_rules! general_repeat { - ($ELEMENT:expr, $COUNT:expr, $ARRAY_TYPE:ident) => {{ - let mut offsets: Vec = vec![0]; - let mut values = - downcast_arg!(new_empty_array($ELEMENT.data_type()), $ARRAY_TYPE).clone(); - - let element_array = downcast_arg!($ELEMENT, $ARRAY_TYPE); - for (el, c) in element_array.iter().zip($COUNT.iter()) { - let last_offset: i32 = offsets.last().copied().ok_or_else(|| { - DataFusionError::Internal(format!("offsets should not be empty")) - })?; - match el { - Some(el) => { - let c = if c < Some(0) { 0 } else { c.unwrap() } as usize; - let repeated_array = - [Some(el.clone())].repeat(c).iter().collect::<$ARRAY_TYPE>(); - - values = downcast_arg!( - compute::concat(&[&values, &repeated_array])?.clone(), - $ARRAY_TYPE - ) - .clone(); - offsets.push(last_offset + repeated_array.len() as i32); - } - None => { - offsets.push(last_offset); - } - } - } - - let field = Arc::new(Field::new("item", $ELEMENT.data_type().clone(), true)); - - Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - Arc::new(values), - None, - )?) - }}; -} - -macro_rules! general_repeat_list { - ($ELEMENT:expr, $COUNT:expr, $ARRAY_TYPE:ident) => {{ - let mut offsets: Vec = vec![0]; - let mut values = - downcast_arg!(new_empty_array($ELEMENT.data_type()), ListArray).clone(); - - let element_array = downcast_arg!($ELEMENT, ListArray); - for (el, c) in element_array.iter().zip($COUNT.iter()) { - let last_offset: i32 = offsets.last().copied().ok_or_else(|| { - DataFusionError::Internal(format!("offsets should not be empty")) - })?; - match el { - Some(el) => { - let c = if c < Some(0) { 0 } else { c.unwrap() } as usize; - let repeated_vec = vec![el; c]; - - let mut i: i32 = 0; - let mut repeated_offsets = vec![i]; - repeated_offsets.extend( - repeated_vec - .clone() - .into_iter() - .map(|a| { - i += a.len() as i32; - i - }) - .collect::>(), - ); - - let mut repeated_values = downcast_arg!( - new_empty_array(&element_array.value_type()), - $ARRAY_TYPE - ) - .clone(); - for repeated_list in repeated_vec { - repeated_values = downcast_arg!( - compute::concat(&[&repeated_values, &repeated_list])?, - $ARRAY_TYPE - ) - .clone(); - } - - let field = Arc::new(Field::new( - "item", - element_array.value_type().clone(), - true, - )); - let repeated_array = ListArray::try_new( - field, - OffsetBuffer::new(repeated_offsets.clone().into()), - Arc::new(repeated_values), - None, - )?; - - values = downcast_arg!( - compute::concat(&[&values, &repeated_array,])?.clone(), - ListArray - ) - .clone(); - offsets.push(last_offset + repeated_array.len() as i32); - } - None => { - offsets.push(last_offset); - } - } - } - - let field = Arc::new(Field::new("item", $ELEMENT.data_type().clone(), true)); - - Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - Arc::new(values), - None, - )?) - }}; + concat_internal::(new_args.as_slice()) } /// Array_empty SQL function pub fn array_empty(args: &[ArrayRef]) -> Result { - if args[0].as_any().downcast_ref::().is_some() { + if args.len() != 1 { + return exec_err!("array_empty expects one argument"); + } + + if as_null_array(&args[0]).is_ok() { // Make sure to return Boolean type. return Ok(Arc::new(BooleanArray::new_null(args[0].len()))); } + let array_type = args[0].data_type(); + + match array_type { + DataType::List(_) => array_empty_dispatch::(&args[0]), + DataType::LargeList(_) => array_empty_dispatch::(&args[0]), + _ => internal_err!("array_empty does not support type '{array_type:?}'."), + } +} - let array = as_list_array(&args[0])?; +fn array_empty_dispatch(array: &ArrayRef) -> Result { + let array = as_generic_list_array::(array)?; let builder = array .iter() .map(|arr| arr.map(|arr| arr.len() == arr.null_count())) @@ -977,364 +1223,588 @@ pub fn array_empty(args: &[ArrayRef]) -> Result { /// Array_repeat SQL function pub fn array_repeat(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_repeat expects two arguments"); + } + let element = &args[0]; - let count = as_int64_array(&args[1])?; + let count_array = as_int64_array(&args[1])?; - let res = match element.data_type() { - DataType::List(field) => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - general_repeat_list!(element, count, $ARRAY_TYPE) - }; - } - call_array_function!(field.data_type(), true) + match element.data_type() { + DataType::List(_) => { + let list_array = as_list_array(element)?; + general_list_repeat::(list_array, count_array) } - data_type => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - general_repeat!(element, count, $ARRAY_TYPE) - }; - } - call_array_function!(data_type, false) + DataType::LargeList(_) => { + let list_array = as_large_list_array(element)?; + general_list_repeat::(list_array, count_array) } - }; - - Ok(res) + _ => general_repeat(element, count_array), + } } -macro_rules! position { - ($ARRAY:expr, $ELEMENT:expr, $INDEX:expr, $ARRAY_TYPE:ident) => {{ - let element = downcast_arg!($ELEMENT, $ARRAY_TYPE); - $ARRAY - .iter() - .zip(element.iter()) - .zip($INDEX.iter()) - .map(|((arr, el), i)| { - let index = match i { - Some(i) => { - if i <= 0 { - 0 - } else { - i - 1 - } - } - None => return exec_err!("initial position must not be null"), - }; +/// For each element of `array[i]` repeat `count_array[i]` times. +/// +/// Assumption for the input: +/// 1. `count[i] >= 0` +/// 2. `array.len() == count_array.len()` +/// +/// For example, +/// ```text +/// array_repeat( +/// [1, 2, 3], [2, 0, 1] => [[1, 1], [], [3]] +/// ) +/// ``` +fn general_repeat(array: &ArrayRef, count_array: &Int64Array) -> Result { + let data_type = array.data_type(); + let mut new_values = vec![]; - match arr { - Some(arr) => { - let child_array = downcast_arg!(arr, $ARRAY_TYPE); - - match child_array - .iter() - .skip(index as usize) - .position(|x| x == el) - { - Some(value) => Ok(Some(value as u64 + index as u64 + 1u64)), - None => Ok(None), - } - } - None => Ok(None), - } - }) - .collect::>()? - }}; -} + let count_vec = count_array + .values() + .to_vec() + .iter() + .map(|x| *x as usize) + .collect::>(); -/// Array_position SQL function -pub fn array_position(args: &[ArrayRef]) -> Result { - let arr = as_list_array(&args[0])?; - let element = &args[1]; + for (row_index, &count) in count_vec.iter().enumerate() { + let repeated_array = if array.is_null(row_index) { + new_null_array(data_type, count) + } else { + let original_data = array.to_data(); + let capacity = Capacities::Array(count); + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], false, capacity); - let index = if args.len() == 3 { - as_int64_array(&args[2])?.clone() - } else { - Int64Array::from_value(0, arr.len()) - }; + for _ in 0..count { + mutable.extend(0, row_index, row_index + 1); + } - check_datatypes("array_position", &[arr.values(), element])?; - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - position!(arr, element, index, $ARRAY_TYPE) + let data = mutable.freeze(); + arrow_array::make_array(data) }; + new_values.push(repeated_array); } - let res = call_array_function!(arr.value_type(), true); - Ok(Arc::new(res)) + let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); + let values = compute::concat(&new_values)?; + + Ok(Arc::new(ListArray::try_new( + Arc::new(Field::new("item", data_type.to_owned(), true)), + OffsetBuffer::from_lengths(count_vec), + values, + None, + )?)) } -macro_rules! positions { - ($ARRAY:expr, $ELEMENT:expr, $ARRAY_TYPE:ident) => {{ - let element = downcast_arg!($ELEMENT, $ARRAY_TYPE); - let mut offsets: Vec = vec![0]; - let mut values = - downcast_arg!(new_empty_array(&DataType::UInt64), UInt64Array).clone(); - for comp in $ARRAY - .iter() - .zip(element.iter()) - .map(|(arr, el)| match arr { - Some(arr) => { - let child_array = downcast_arg!(arr, $ARRAY_TYPE); - let res = child_array - .iter() - .enumerate() - .filter(|(_, x)| *x == el) - .flat_map(|(i, _)| Some((i + 1) as u64)) - .collect::(); +/// Handle List version of `general_repeat` +/// +/// For each element of `list_array[i]` repeat `count_array[i]` times. +/// +/// For example, +/// ```text +/// array_repeat( +/// [[1, 2, 3], [4, 5], [6]], [2, 0, 1] => [[[1, 2, 3], [1, 2, 3]], [], [[6]]] +/// ) +/// ``` +fn general_list_repeat( + list_array: &GenericListArray, + count_array: &Int64Array, +) -> Result { + let data_type = list_array.data_type(); + let value_type = list_array.value_type(); + let mut new_values = vec![]; - Ok(res) - } - None => Ok(downcast_arg!( - new_empty_array(&DataType::UInt64), - UInt64Array - ) - .clone()), - }) - .collect::>>()? - { - let last_offset: i32 = offsets.last().copied().ok_or_else(|| { - DataFusionError::Internal(format!("offsets should not be empty",)) - })?; - values = - downcast_arg!(compute::concat(&[&values, &comp,])?.clone(), UInt64Array) - .clone(); - offsets.push(last_offset + comp.len() as i32); - } + let count_vec = count_array + .values() + .to_vec() + .iter() + .map(|x| *x as usize) + .collect::>(); - let field = Arc::new(Field::new("item", DataType::UInt64, true)); + for (list_array_row, &count) in list_array.iter().zip(count_vec.iter()) { + let list_arr = match list_array_row { + Some(list_array_row) => { + let original_data = list_array_row.to_data(); + let capacity = Capacities::Array(original_data.len() * count); + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data], + false, + capacity, + ); - Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - Arc::new(values), - None, - )?) - }}; -} + for _ in 0..count { + mutable.extend(0, 0, original_data.len()); + } -/// Array_positions SQL function -pub fn array_positions(args: &[ArrayRef]) -> Result { - let arr = as_list_array(&args[0])?; - let element = &args[1]; + let data = mutable.freeze(); + let repeated_array = arrow_array::make_array(data); - check_datatypes("array_positions", &[arr.values(), element])?; - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - positions!(arr, element, $ARRAY_TYPE) + let list_arr = GenericListArray::::try_new( + Arc::new(Field::new("item", value_type.clone(), true)), + OffsetBuffer::::from_lengths(vec![original_data.len(); count]), + repeated_array, + None, + )?; + Arc::new(list_arr) as ArrayRef + } + None => new_null_array(data_type, count), }; + new_values.push(list_arr); } - let res = call_array_function!(arr.value_type(), true); - Ok(res) + let lengths = new_values.iter().map(|a| a.len()).collect::>(); + let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); + let values = compute::concat(&new_values)?; + + Ok(Arc::new(ListArray::try_new( + Arc::new(Field::new("item", data_type.to_owned(), true)), + OffsetBuffer::::from_lengths(lengths), + values, + None, + )?)) } -macro_rules! general_remove { - ($ARRAY:expr, $ELEMENT:expr, $MAX:expr, $ARRAY_TYPE:ident) => {{ - let mut offsets: Vec = vec![0]; - let mut values = - downcast_arg!(new_empty_array($ELEMENT.data_type()), $ARRAY_TYPE).clone(); - - let element = downcast_arg!($ELEMENT, $ARRAY_TYPE); - for ((arr, el), max) in $ARRAY.iter().zip(element.iter()).zip($MAX.iter()) { - let last_offset: i32 = offsets.last().copied().ok_or_else(|| { - DataFusionError::Internal(format!("offsets should not be empty")) - })?; - match arr { - Some(arr) => { - let child_array = downcast_arg!(arr, $ARRAY_TYPE); - let mut counter = 0; - let max = if max < Some(1) { 1 } else { max.unwrap() }; - - let filter_array = child_array - .iter() - .map(|element| { - if counter != max && element == el { - counter += 1; - Some(false) - } else { - Some(true) - } - }) - .collect::(); +/// Array_position SQL function +pub fn array_position(args: &[ArrayRef]) -> Result { + if args.len() < 2 || args.len() > 3 { + return exec_err!("array_position expects two or three arguments"); + } + match &args[0].data_type() { + DataType::List(_) => general_position_dispatch::(args), + DataType::LargeList(_) => general_position_dispatch::(args), + array_type => exec_err!("array_position does not support type '{array_type:?}'."), + } +} +fn general_position_dispatch(args: &[ArrayRef]) -> Result { + let list_array = as_generic_list_array::(&args[0])?; + let element_array = &args[1]; - let filtered_array = compute::filter(&child_array, &filter_array)?; - values = downcast_arg!( - compute::concat(&[&values, &filtered_array,])?.clone(), - $ARRAY_TYPE - ) - .clone(); - offsets.push(last_offset + filtered_array.len() as i32); - } - None => offsets.push(last_offset), + check_datatypes("array_position", &[list_array.values(), element_array])?; + + let arr_from = if args.len() == 3 { + as_int64_array(&args[2])? + .values() + .to_vec() + .iter() + .map(|&x| x - 1) + .collect::>() + } else { + vec![0; list_array.len()] + }; + + // if `start_from` index is out of bounds, return error + for (arr, &from) in list_array.iter().zip(arr_from.iter()) { + if let Some(arr) = arr { + if from < 0 || from as usize >= arr.len() { + return internal_err!("start_from index out of bounds"); } + } else { + // We will get null if we got null in the array, so we don't need to check } + } - let field = Arc::new(Field::new("item", $ELEMENT.data_type().clone(), true)); - - Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - Arc::new(values), - None, - )?) - }}; + generic_position::(list_array, element_array, arr_from) } -macro_rules! array_removement_function { - ($FUNC:ident, $MAX_FUNC:expr, $DOC:expr) => { - #[doc = $DOC] - pub fn $FUNC(args: &[ArrayRef]) -> Result { - let arr = as_list_array(&args[0])?; - let element = &args[1]; - let max = $MAX_FUNC(args)?; +fn generic_position( + list_array: &GenericListArray, + element_array: &ArrayRef, + arr_from: Vec, // 0-indexed +) -> Result { + let mut data = Vec::with_capacity(list_array.len()); - check_datatypes(stringify!($FUNC), &[arr.values(), element])?; - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - general_remove!(arr, element, max, $ARRAY_TYPE) - }; - } - let res = call_array_function!(arr.value_type(), true); + for (row_index, (list_array_row, &from)) in + list_array.iter().zip(arr_from.iter()).enumerate() + { + let from = from as usize; + + if let Some(list_array_row) = list_array_row { + let eq_array = + compare_element_to_list(&list_array_row, element_array, row_index, true)?; + + // Collect `true`s in 1-indexed positions + let index = eq_array + .iter() + .skip(from) + .position(|e| e == Some(true)) + .map(|index| (from + index + 1) as u64); - Ok(res) + data.push(index); + } else { + data.push(None); } - }; -} + } -fn remove_one(args: &[ArrayRef]) -> Result { - Ok(Int64Array::from_value(1, args[0].len())) + Ok(Arc::new(UInt64Array::from(data))) } -fn remove_n(args: &[ArrayRef]) -> Result { - as_int64_array(&args[2]).cloned() -} +/// Array_positions SQL function +pub fn array_positions(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_positions expects two arguments"); + } + + let element = &args[1]; -fn remove_all(args: &[ArrayRef]) -> Result { - Ok(Int64Array::from_value(i64::MAX, args[0].len())) + match &args[0].data_type() { + DataType::List(_) => { + let arr = as_list_array(&args[0])?; + check_datatypes("array_positions", &[arr.values(), element])?; + general_positions::(arr, element) + } + DataType::LargeList(_) => { + let arr = as_large_list_array(&args[0])?; + check_datatypes("array_positions", &[arr.values(), element])?; + general_positions::(arr, element) + } + array_type => { + exec_err!("array_positions does not support type '{array_type:?}'.") + } + } } -// array removement functions -array_removement_function!(array_remove, remove_one, "Array_remove SQL function"); -array_removement_function!(array_remove_n, remove_n, "Array_remove_n SQL function"); -array_removement_function!( - array_remove_all, - remove_all, - "Array_remove_all SQL function" -); +fn general_positions( + list_array: &GenericListArray, + element_array: &ArrayRef, +) -> Result { + let mut data = Vec::with_capacity(list_array.len()); -fn general_replace(args: &[ArrayRef], arr_n: Vec) -> Result { - let list_array = as_list_array(&args[0])?; - let from_array = &args[1]; - let to_array = &args[2]; + for (row_index, list_array_row) in list_array.iter().enumerate() { + if let Some(list_array_row) = list_array_row { + let eq_array = + compare_element_to_list(&list_array_row, element_array, row_index, true)?; - let mut offsets: Vec = vec![0]; - let data_type = list_array.value_type(); - let mut values = new_empty_array(&data_type); - - for (row_index, (arr, n)) in list_array.iter().zip(arr_n.iter()).enumerate() { - let last_offset: i32 = offsets - .last() - .copied() - .ok_or_else(|| internal_datafusion_err!("offsets should not be empty"))?; - match arr { - Some(arr) => { - let indices = UInt32Array::from(vec![row_index as u32]); - let from_arr = arrow::compute::take(from_array, &indices, None)?; - - let eq_array = match from_arr.data_type() { - // arrow_ord::cmp_eq does not support ListArray, so we need to compare it by loop - DataType::List(_) => { - let from_a = as_list_array(&from_arr)?.value(0); - let list_arr = as_list_array(&arr)?; - - let mut bool_values = vec![]; - for arr in list_arr.iter() { - if let Some(a) = arr { - bool_values.push(Some(a.eq(&from_a))); - } else { - return internal_err!( - "Null value is not supported in array_replace" - ); - } - } - BooleanArray::from(bool_values) - } - _ => { - let from_arr = Scalar::new(from_arr); - arrow_ord::cmp::eq(&arr, &from_arr)? - } - }; + // Collect `true`s in 1-indexed positions + let indexes = eq_array + .iter() + .positions(|e| e == Some(true)) + .map(|index| Some(index as u64 + 1)) + .collect::>(); - // Use MutableArrayData to build the replaced array - // First array is the original array, second array is the element to replace with. - let arrays = vec![arr, to_array.clone()]; - let arrays_data = arrays - .iter() - .map(|a| a.to_data()) - .collect::>(); - let arrays_data = arrays_data.iter().collect::>(); + data.push(Some(indexes)); + } else { + data.push(None); + } + } - let arrays = arrays - .iter() - .map(|arr| arr.as_ref()) - .collect::>(); - let capacity = Capacities::Array(arrays.iter().map(|a| a.len()).sum()); - - let mut mutable = - MutableArrayData::with_capacities(arrays_data, false, capacity); - - let mut counter = 0; - for (i, to_replace) in eq_array.iter().enumerate() { - if let Some(to_replace) = to_replace { - if to_replace { - mutable.extend(1, row_index, row_index + 1); - counter += 1; - if counter == *n { - // extend the rest of the array - mutable.extend(0, i + 1, eq_array.len()); - break; - } - } else { - mutable.extend(0, i, i + 1); - } - } else { - return internal_err!("eq_array should not contain None"); - } - } + Ok(Arc::new( + ListArray::from_iter_primitive::(data), + )) +} - let data = mutable.freeze(); - let replaced_array = arrow_array::make_array(data); +/// For each element of `list_array[i]`, removed up to `arr_n[i]` occurences +/// of `element_array[i]`. +/// +/// The type of each **element** in `list_array` must be the same as the type of +/// `element_array`. This function also handles nested arrays +/// ([`ListArray`] of [`ListArray`]s) +/// +/// For example, when called to remove a list array (where each element is a +/// list of int32s, the second argument are int32 arrays, and the +/// third argument is the number of occurrences to remove +/// +/// ```text +/// general_remove( +/// [1, 2, 3, 2], 2, 1 ==> [1, 3, 2] (only the first 2 is removed) +/// [4, 5, 6, 5], 5, 2 ==> [4, 6] (both 5s are removed) +/// ) +/// ``` +fn general_remove( + list_array: &GenericListArray, + element_array: &ArrayRef, + arr_n: Vec, +) -> Result { + let data_type = list_array.value_type(); + let mut new_values = vec![]; + // Build up the offsets for the final output array + let mut offsets = Vec::::with_capacity(arr_n.len() + 1); + offsets.push(OffsetSize::zero()); + + // n is the number of elements to remove in this row + for (row_index, (list_array_row, n)) in + list_array.iter().zip(arr_n.iter()).enumerate() + { + match list_array_row { + Some(list_array_row) => { + let eq_array = compare_element_to_list( + &list_array_row, + element_array, + row_index, + false, + )?; + + // We need to keep at most first n elements as `false`, which represent the elements to remove. + let eq_array = if eq_array.false_count() < *n as usize { + eq_array + } else { + let mut count = 0; + eq_array + .iter() + .map(|e| { + // Keep first n `false` elements, and reverse other elements to `true`. + if let Some(false) = e { + if count < *n { + count += 1; + e + } else { + Some(true) + } + } else { + e + } + }) + .collect::() + }; - let v = arrow::compute::concat(&[&values, &replaced_array])?; - values = v; - offsets.push(last_offset + replaced_array.len() as i32); + let filtered_array = arrow::compute::filter(&list_array_row, &eq_array)?; + offsets.push( + offsets[row_index] + OffsetSize::usize_as(filtered_array.len()), + ); + new_values.push(filtered_array); } None => { - offsets.push(last_offset); + // Null element results in a null row (no new offsets) + offsets.push(offsets[row_index]); } } } - Ok(Arc::new(ListArray::try_new( + let values = if new_values.is_empty() { + new_empty_array(&data_type) + } else { + let new_values = new_values.iter().map(|x| x.as_ref()).collect::>(); + arrow::compute::concat(&new_values)? + }; + + Ok(Arc::new(GenericListArray::::try_new( Arc::new(Field::new("item", data_type, true)), OffsetBuffer::new(offsets.into()), values, - None, + list_array.nulls().cloned(), + )?)) +} + +fn array_remove_internal( + array: &ArrayRef, + element_array: &ArrayRef, + arr_n: Vec, +) -> Result { + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + general_remove::(list_array, element_array, arr_n) + } + DataType::LargeList(_) => { + let list_array = array.as_list::(); + general_remove::(list_array, element_array, arr_n) + } + _ => internal_err!("array_remove_all expects a list array"), + } +} + +pub fn array_remove_all(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_remove_all expects two arguments"); + } + + let arr_n = vec![i64::MAX; args[0].len()]; + array_remove_internal(&args[0], &args[1], arr_n) +} + +pub fn array_remove(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_remove expects two arguments"); + } + + let arr_n = vec![1; args[0].len()]; + array_remove_internal(&args[0], &args[1], arr_n) +} + +pub fn array_remove_n(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return exec_err!("array_remove_n expects three arguments"); + } + + let arr_n = as_int64_array(&args[2])?.values().to_vec(); + array_remove_internal(&args[0], &args[1], arr_n) +} + +/// For each element of `list_array[i]`, replaces up to `arr_n[i]` occurences +/// of `from_array[i]`, `to_array[i]`. +/// +/// The type of each **element** in `list_array` must be the same as the type of +/// `from_array` and `to_array`. This function also handles nested arrays +/// ([`ListArray`] of [`ListArray`]s) +/// +/// For example, when called to replace a list array (where each element is a +/// list of int32s, the second and third argument are int32 arrays, and the +/// fourth argument is the number of occurrences to replace +/// +/// ```text +/// general_replace( +/// [1, 2, 3, 2], 2, 10, 1 ==> [1, 10, 3, 2] (only the first 2 is replaced) +/// [4, 5, 6, 5], 5, 20, 2 ==> [4, 20, 6, 20] (both 5s are replaced) +/// ) +/// ``` +fn general_replace( + list_array: &GenericListArray, + from_array: &ArrayRef, + to_array: &ArrayRef, + arr_n: Vec, +) -> Result { + // Build up the offsets for the final output array + let mut offsets: Vec = vec![O::usize_as(0)]; + let values = list_array.values(); + let original_data = values.to_data(); + let to_data = to_array.to_data(); + let capacity = Capacities::Array(original_data.len()); + + // First array is the original array, second array is the element to replace with. + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data, &to_data], + false, + capacity, + ); + + let mut valid = BooleanBufferBuilder::new(list_array.len()); + + for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { + if list_array.is_null(row_index) { + offsets.push(offsets[row_index]); + valid.append(false); + continue; + } + + let start = offset_window[0]; + let end = offset_window[1]; + + let list_array_row = list_array.value(row_index); + + // Compute all positions in list_row_array (that is itself an + // array) that are equal to `from_array_row` + let eq_array = + compare_element_to_list(&list_array_row, &from_array, row_index, true)?; + + let original_idx = O::usize_as(0); + let replace_idx = O::usize_as(1); + let n = arr_n[row_index]; + let mut counter = 0; + + // All elements are false, no need to replace, just copy original data + if eq_array.false_count() == eq_array.len() { + mutable.extend( + original_idx.to_usize().unwrap(), + start.to_usize().unwrap(), + end.to_usize().unwrap(), + ); + offsets.push(offsets[row_index] + (end - start)); + valid.append(true); + continue; + } + + for (i, to_replace) in eq_array.iter().enumerate() { + let i = O::usize_as(i); + if let Some(true) = to_replace { + mutable.extend(replace_idx.to_usize().unwrap(), row_index, row_index + 1); + counter += 1; + if counter == n { + // copy original data for any matches past n + mutable.extend( + original_idx.to_usize().unwrap(), + (start + i).to_usize().unwrap() + 1, + end.to_usize().unwrap(), + ); + break; + } + } else { + // copy original data for false / null matches + mutable.extend( + original_idx.to_usize().unwrap(), + (start + i).to_usize().unwrap(), + (start + i).to_usize().unwrap() + 1, + ); + } + } + + offsets.push(offsets[row_index] + (end - start)); + valid.append(true); + } + + let data = mutable.freeze(); + + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new("item", list_array.value_type(), true)), + OffsetBuffer::::new(offsets.into()), + arrow_array::make_array(data), + Some(NullBuffer::new(valid.finish())), )?)) } pub fn array_replace(args: &[ArrayRef]) -> Result { - general_replace(args, vec![1; args[0].len()]) + if args.len() != 3 { + return exec_err!("array_replace expects three arguments"); + } + + // replace at most one occurence for each element + let arr_n = vec![1; args[0].len()]; + let array = &args[0]; + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + DataType::LargeList(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + array_type => exec_err!("array_replace does not support type '{array_type:?}'."), + } } pub fn array_replace_n(args: &[ArrayRef]) -> Result { - let arr = as_int64_array(&args[3])?; - let arr_n = arr.values().to_vec(); - general_replace(args, arr_n) + if args.len() != 4 { + return exec_err!("array_replace_n expects four arguments"); + } + + // replace the specified number of occurences + let arr_n = as_int64_array(&args[3])?.values().to_vec(); + let array = &args[0]; + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + DataType::LargeList(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + array_type => { + exec_err!("array_replace_n does not support type '{array_type:?}'.") + } + } } pub fn array_replace_all(args: &[ArrayRef]) -> Result { - general_replace(args, vec![i64::MAX; args[0].len()]) + if args.len() != 3 { + return exec_err!("array_replace_all expects three arguments"); + } + + // replace all occurrences (up to "i64::MAX") + let arr_n = vec![i64::MAX; args[0].len()]; + let array = &args[0]; + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + DataType::LargeList(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + array_type => { + exec_err!("array_replace_all does not support type '{array_type:?}'.") + } + } } macro_rules! to_string { @@ -1358,8 +1828,179 @@ macro_rules! to_string { }}; } +#[derive(Debug, PartialEq)] +enum SetOp { + Union, + Intersect, +} + +impl Display for SetOp { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + SetOp::Union => write!(f, "array_union"), + SetOp::Intersect => write!(f, "array_intersect"), + } + } +} + +fn generic_set_lists( + l: &GenericListArray, + r: &GenericListArray, + field: Arc, + set_op: SetOp, +) -> Result { + if matches!(l.value_type(), DataType::Null) { + let field = Arc::new(Field::new("item", r.value_type(), true)); + return general_array_distinct::(r, &field); + } else if matches!(r.value_type(), DataType::Null) { + let field = Arc::new(Field::new("item", l.value_type(), true)); + return general_array_distinct::(l, &field); + } + + if l.value_type() != r.value_type() { + return internal_err!("{set_op:?} is not implemented for '{l:?}' and '{r:?}'"); + } + + let dt = l.value_type(); + + let mut offsets = vec![OffsetSize::usize_as(0)]; + let mut new_arrays = vec![]; + + let converter = RowConverter::new(vec![SortField::new(dt)])?; + for (first_arr, second_arr) in l.iter().zip(r.iter()) { + if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) { + let l_values = converter.convert_columns(&[first_arr])?; + let r_values = converter.convert_columns(&[second_arr])?; + + let l_iter = l_values.iter().sorted().dedup(); + let values_set: HashSet<_> = l_iter.clone().collect(); + let mut rows = if set_op == SetOp::Union { + l_iter.collect::>() + } else { + vec![] + }; + for r_val in r_values.iter().sorted().dedup() { + match set_op { + SetOp::Union => { + if !values_set.contains(&r_val) { + rows.push(r_val); + } + } + SetOp::Intersect => { + if values_set.contains(&r_val) { + rows.push(r_val); + } + } + } + } + + let last_offset = match offsets.last().copied() { + Some(offset) => offset, + None => return internal_err!("offsets should not be empty"), + }; + offsets.push(last_offset + OffsetSize::usize_as(rows.len())); + let arrays = converter.convert_rows(rows)?; + let array = match arrays.first() { + Some(array) => array.clone(), + None => { + return internal_err!("{set_op}: failed to get array from rows"); + } + }; + new_arrays.push(array); + } + } + + let offsets = OffsetBuffer::new(offsets.into()); + let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::>(); + let values = compute::concat(&new_arrays_ref)?; + let arr = GenericListArray::::try_new(field, offsets, values, None)?; + Ok(Arc::new(arr)) +} + +fn general_set_op( + array1: &ArrayRef, + array2: &ArrayRef, + set_op: SetOp, +) -> Result { + match (array1.data_type(), array2.data_type()) { + (DataType::Null, DataType::List(field)) => { + if set_op == SetOp::Intersect { + return Ok(new_empty_array(&DataType::Null)); + } + let array = as_list_array(&array2)?; + general_array_distinct::(array, field) + } + + (DataType::List(field), DataType::Null) => { + if set_op == SetOp::Intersect { + return make_array(&[]); + } + let array = as_list_array(&array1)?; + general_array_distinct::(array, field) + } + (DataType::Null, DataType::LargeList(field)) => { + if set_op == SetOp::Intersect { + return Ok(new_empty_array(&DataType::Null)); + } + let array = as_large_list_array(&array2)?; + general_array_distinct::(array, field) + } + (DataType::LargeList(field), DataType::Null) => { + if set_op == SetOp::Intersect { + return make_array(&[]); + } + let array = as_large_list_array(&array1)?; + general_array_distinct::(array, field) + } + (DataType::Null, DataType::Null) => Ok(new_empty_array(&DataType::Null)), + + (DataType::List(field), DataType::List(_)) => { + let array1 = as_list_array(&array1)?; + let array2 = as_list_array(&array2)?; + generic_set_lists::(array1, array2, field.clone(), set_op) + } + (DataType::LargeList(field), DataType::LargeList(_)) => { + let array1 = as_large_list_array(&array1)?; + let array2 = as_large_list_array(&array2)?; + generic_set_lists::(array1, array2, field.clone(), set_op) + } + (data_type1, data_type2) => { + internal_err!( + "{set_op} does not support types '{data_type1:?}' and '{data_type2:?}'" + ) + } + } +} + +/// Array_union SQL function +pub fn array_union(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_union needs two arguments"); + } + let array1 = &args[0]; + let array2 = &args[1]; + + general_set_op(array1, array2, SetOp::Union) +} + +/// array_intersect SQL function +pub fn array_intersect(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_intersect needs two arguments"); + } + + let array1 = &args[0]; + let array2 = &args[1]; + + general_set_op(array1, array2, SetOp::Intersect) +} + /// Array_to_string SQL function pub fn array_to_string(args: &[ArrayRef]) -> Result { + if args.len() < 2 || args.len() > 3 { + return exec_err!("array_to_string expects two or three arguments"); + } + let arr = &args[0]; let delimiters = as_string_array(&args[1])?; @@ -1469,6 +2110,10 @@ pub fn array_to_string(args: &[ArrayRef]) -> Result { /// Cardinality SQL function pub fn cardinality(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("cardinality expects one argument"); + } + let list_array = as_list_array(&args[0])?.clone(); let result = list_array @@ -1525,15 +2170,19 @@ fn flatten_internal( /// Flatten SQL function pub fn flatten(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("flatten expects one argument"); + } + let flattened_array = flatten_internal(&args[0], None)?; Ok(Arc::new(flattened_array) as ArrayRef) } -/// Array_length SQL function -pub fn array_length(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; - let dimension = if args.len() == 2 { - as_int64_array(&args[1])?.clone() +/// Dispatch array length computation based on the offset type. +fn array_length_dispatch(array: &[ArrayRef]) -> Result { + let list_array = as_generic_list_array::(&array[0])?; + let dimension = if array.len() == 2 { + as_int64_array(&array[1])?.clone() } else { Int64Array::from_value(1, list_array.len()) }; @@ -1547,14 +2196,51 @@ pub fn array_length(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } +/// Array_length SQL function +pub fn array_length(args: &[ArrayRef]) -> Result { + if args.len() != 1 && args.len() != 2 { + return exec_err!("array_length expects one or two arguments"); + } + + match &args[0].data_type() { + DataType::List(_) => array_length_dispatch::(args), + DataType::LargeList(_) => array_length_dispatch::(args), + _ => internal_err!( + "array_length does not support type '{:?}'", + args[0].data_type() + ), + } +} + /// Array_dims SQL function pub fn array_dims(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; + if args.len() != 1 { + return exec_err!("array_dims needs one argument"); + } + + let data = match args[0].data_type() { + DataType::List(_) => { + let array = as_list_array(&args[0])?; + array + .iter() + .map(compute_array_dims) + .collect::>>()? + } + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + array + .iter() + .map(compute_array_dims) + .collect::>>()? + } + _ => { + return exec_err!( + "array_dims does not support type '{:?}'", + args[0].data_type() + ); + } + }; - let data = list_array - .iter() - .map(compute_array_dims) - .collect::>>()?; let result = ListArray::from_iter_primitive::(data); Ok(Arc::new(result) as ArrayRef) @@ -1562,171 +2248,165 @@ pub fn array_dims(args: &[ArrayRef]) -> Result { /// Array_ndims SQL function pub fn array_ndims(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; - - let result = list_array - .iter() - .map(compute_array_ndims) - .collect::>()?; - - Ok(Arc::new(result) as ArrayRef) -} + if args.len() != 1 { + return exec_err!("array_ndims needs one argument"); + } -macro_rules! non_list_contains { - ($ARRAY:expr, $SUB_ARRAY:expr, $ARRAY_TYPE:ident) => {{ - let sub_array = downcast_arg!($SUB_ARRAY, $ARRAY_TYPE); - let mut boolean_builder = BooleanArray::builder($ARRAY.len()); + fn general_list_ndims( + array: &GenericListArray, + ) -> Result { + let mut data = Vec::new(); + let ndims = datafusion_common::utils::list_ndims(array.data_type()); - for (arr, elem) in $ARRAY.iter().zip(sub_array.iter()) { - if let (Some(arr), Some(elem)) = (arr, elem) { - let arr = downcast_arg!(arr, $ARRAY_TYPE); - let res = arr.iter().dedup().flatten().any(|x| x == elem); - boolean_builder.append_value(res); + for arr in array.iter() { + if arr.is_some() { + data.push(Some(ndims)) + } else { + data.push(None) } } - Ok(Arc::new(boolean_builder.finish())) - }}; -} -/// Array_has SQL function -pub fn array_has(args: &[ArrayRef]) -> Result { - let array = as_list_array(&args[0])?; - let element = &args[1]; + Ok(Arc::new(UInt64Array::from(data)) as ArrayRef) + } - check_datatypes("array_has", &[array.values(), element])?; - match element.data_type() { + match args[0].data_type() { DataType::List(_) => { - let sub_array = as_list_array(element)?; - let mut boolean_builder = BooleanArray::builder(array.len()); - - for (arr, elem) in array.iter().zip(sub_array.iter()) { - if let (Some(arr), Some(elem)) = (arr, elem) { - let list_arr = as_list_array(&arr)?; - let res = list_arr.iter().dedup().flatten().any(|x| *x == *elem); - boolean_builder.append_value(res); - } - } - Ok(Arc::new(boolean_builder.finish())) + let array = as_list_array(&args[0])?; + general_list_ndims::(array) } - data_type => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - non_list_contains!(array, element, $ARRAY_TYPE) - }; - } - call_array_function!(data_type, false) + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + general_list_ndims::(array) } + _ => Ok(Arc::new(UInt64Array::from(vec![0; args[0].len()])) as ArrayRef), } } -macro_rules! array_has_any_non_list_check { - ($ARRAY:expr, $SUB_ARRAY:expr, $ARRAY_TYPE:ident) => {{ - let arr = downcast_arg!($ARRAY, $ARRAY_TYPE); - let sub_arr = downcast_arg!($SUB_ARRAY, $ARRAY_TYPE); - - let mut res = false; - for elem in sub_arr.iter().dedup() { - if let Some(elem) = elem { - res |= arr.iter().dedup().flatten().any(|x| x == elem); - } else { - return internal_err!( - "array_has_any does not support Null type for element in sub_array" - ); - } - } - res - }}; +/// Represents the type of comparison for array_has. +#[derive(Debug, PartialEq)] +enum ComparisonType { + // array_has_all + All, + // array_has_any + Any, + // array_has + Single, } -/// Array_has_any SQL function -pub fn array_has_any(args: &[ArrayRef]) -> Result { - check_datatypes("array_has_any", &[&args[0], &args[1]])?; - - let array = as_list_array(&args[0])?; - let sub_array = as_list_array(&args[1])?; +fn general_array_has_dispatch( + array: &ArrayRef, + sub_array: &ArrayRef, + comparison_type: ComparisonType, +) -> Result { + let array = if comparison_type == ComparisonType::Single { + let arr = as_generic_list_array::(array)?; + check_datatypes("array_has", &[arr.values(), sub_array])?; + arr + } else { + check_datatypes("array_has", &[array, sub_array])?; + as_generic_list_array::(array)? + }; let mut boolean_builder = BooleanArray::builder(array.len()); - for (arr, sub_arr) in array.iter().zip(sub_array.iter()) { + + let converter = RowConverter::new(vec![SortField::new(array.value_type())])?; + + let element = sub_array.clone(); + let sub_array = if comparison_type != ComparisonType::Single { + as_generic_list_array::(sub_array)? + } else { + array + }; + + for (row_idx, (arr, sub_arr)) in array.iter().zip(sub_array.iter()).enumerate() { if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { - let res = match arr.data_type() { - DataType::List(_) => { - let arr = downcast_arg!(arr, ListArray); - let sub_arr = downcast_arg!(sub_arr, ListArray); - - let mut res = false; - for elem in sub_arr.iter().dedup().flatten() { - res |= arr.iter().dedup().flatten().any(|x| *x == *elem); - } - res - } - data_type => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - array_has_any_non_list_check!(arr, sub_arr, $ARRAY_TYPE) - }; - } - call_array_function!(data_type, false) - } + let arr_values = converter.convert_columns(&[arr])?; + let sub_arr_values = if comparison_type != ComparisonType::Single { + converter.convert_columns(&[sub_arr])? + } else { + converter.convert_columns(&[element.clone()])? + }; + + let mut res = match comparison_type { + ComparisonType::All => sub_arr_values + .iter() + .dedup() + .all(|elem| arr_values.iter().dedup().any(|x| x == elem)), + ComparisonType::Any => sub_arr_values + .iter() + .dedup() + .any(|elem| arr_values.iter().dedup().any(|x| x == elem)), + ComparisonType::Single => arr_values + .iter() + .dedup() + .any(|x| x == sub_arr_values.row(row_idx)), }; + + if comparison_type == ComparisonType::Any { + res |= res; + } + boolean_builder.append_value(res); } } Ok(Arc::new(boolean_builder.finish())) } -macro_rules! array_has_all_non_list_check { - ($ARRAY:expr, $SUB_ARRAY:expr, $ARRAY_TYPE:ident) => {{ - let arr = downcast_arg!($ARRAY, $ARRAY_TYPE); - let sub_arr = downcast_arg!($SUB_ARRAY, $ARRAY_TYPE); +/// Array_has SQL function +pub fn array_has(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_has needs two arguments"); + } - let mut res = true; - for elem in sub_arr.iter().dedup() { - if let Some(elem) = elem { - res &= arr.iter().dedup().flatten().any(|x| x == elem); - } else { - return internal_err!( - "array_has_all does not support Null type for element in sub_array" - ); - } + let array_type = args[0].data_type(); + + match array_type { + DataType::List(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Single) } - res - }}; + DataType::LargeList(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Single) + } + _ => exec_err!("array_has does not support type '{array_type:?}'."), + } +} + +/// Array_has_any SQL function +pub fn array_has_any(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_has_any needs two arguments"); + } + + let array_type = args[0].data_type(); + + match array_type { + DataType::List(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Any) + } + DataType::LargeList(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Any) + } + _ => internal_err!("array_has_any does not support type '{array_type:?}'."), + } } /// Array_has_all SQL function pub fn array_has_all(args: &[ArrayRef]) -> Result { - check_datatypes("array_has_all", &[&args[0], &args[1]])?; + if args.len() != 2 { + return exec_err!("array_has_all needs two arguments"); + } - let array = as_list_array(&args[0])?; - let sub_array = as_list_array(&args[1])?; + let array_type = args[0].data_type(); - let mut boolean_builder = BooleanArray::builder(array.len()); - for (arr, sub_arr) in array.iter().zip(sub_array.iter()) { - if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { - let res = match arr.data_type() { - DataType::List(_) => { - let arr = downcast_arg!(arr, ListArray); - let sub_arr = downcast_arg!(sub_arr, ListArray); - - let mut res = true; - for elem in sub_arr.iter().dedup().flatten() { - res &= arr.iter().dedup().flatten().any(|x| *x == *elem); - } - res - } - data_type => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - array_has_all_non_list_check!(arr, sub_arr, $ARRAY_TYPE) - }; - } - call_array_function!(data_type, false) - } - }; - boolean_builder.append_value(res); + match array_type { + DataType::List(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::All) + } + DataType::LargeList(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::All) } + _ => internal_err!("array_has_all does not support type '{array_type:?}'."), } - Ok(Arc::new(boolean_builder.finish())) } /// Splits string at occurrences of delimiter and returns an array of parts @@ -1818,12 +2498,74 @@ pub fn string_to_array(args: &[ArrayRef]) -> Result( + array: &GenericListArray, + field: &FieldRef, +) -> Result { + let dt = array.value_type(); + let mut offsets = Vec::with_capacity(array.len()); + offsets.push(OffsetSize::usize_as(0)); + let mut new_arrays = Vec::with_capacity(array.len()); + let converter = RowConverter::new(vec![SortField::new(dt)])?; + // distinct for each list in ListArray + for arr in array.iter().flatten() { + let values = converter.convert_columns(&[arr])?; + // sort elements in list and remove duplicates + let rows = values.iter().sorted().dedup().collect::>(); + let last_offset: OffsetSize = offsets.last().copied().unwrap(); + offsets.push(last_offset + OffsetSize::usize_as(rows.len())); + let arrays = converter.convert_rows(rows)?; + let array = match arrays.first() { + Some(array) => array.clone(), + None => { + return internal_err!("array_distinct: failed to get array from rows") + } + }; + new_arrays.push(array); + } + let offsets = OffsetBuffer::new(offsets.into()); + let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::>(); + let values = compute::concat(&new_arrays_ref)?; + Ok(Arc::new(GenericListArray::::try_new( + field.clone(), + offsets, + values, + None, + )?)) +} + +/// array_distinct SQL function +/// example: from list [1, 3, 2, 3, 1, 2, 4] to [1, 2, 3, 4] +pub fn array_distinct(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("array_distinct needs one argument"); + } + + // handle null + if args[0].data_type() == &DataType::Null { + return Ok(args[0].clone()); + } + + // handle for list & largelist + match args[0].data_type() { + DataType::List(field) => { + let array = as_list_array(&args[0])?; + general_array_distinct(array, field) + } + DataType::LargeList(field) => { + let array = as_large_list_array(&args[0])?; + general_array_distinct(array, field) + } + _ => internal_err!("array_distinct only support list array"), + } +} + #[cfg(test)] mod tests { use super::*; use arrow::datatypes::Int64Type; - use datafusion_common::cast::as_uint64_array; + /// Only test internal functions, array-related sql functions will be tested in sqllogictest `array.slt` #[test] fn test_align_array_dimensions() { let array1d_1 = @@ -1844,10 +2586,10 @@ mod tests { .unwrap(); let expected = as_list_array(&array2d_1).unwrap(); - let expected_dim = compute_array_ndims(Some(array2d_1.to_owned())).unwrap(); + let expected_dim = datafusion_common::utils::list_ndims(array2d_1.data_type()); assert_ne!(as_list_array(&res[0]).unwrap(), expected); assert_eq!( - compute_array_ndims(Some(res[0].clone())).unwrap(), + datafusion_common::utils::list_ndims(res[0].data_type()), expected_dim ); @@ -1857,1475 +2599,11 @@ mod tests { align_array_dimensions(vec![array1d_1, Arc::new(array3d_2.clone())]).unwrap(); let expected = as_list_array(&array3d_1).unwrap(); - let expected_dim = compute_array_ndims(Some(array3d_1.to_owned())).unwrap(); + let expected_dim = datafusion_common::utils::list_ndims(array3d_1.data_type()); assert_ne!(as_list_array(&res[0]).unwrap(), expected); assert_eq!( - compute_array_ndims(Some(res[0].clone())).unwrap(), + datafusion_common::utils::list_ndims(res[0].data_type()), expected_dim ); } - - #[test] - fn test_array() { - // make_array(1, 2, 3) = [1, 2, 3] - let args = [ - Arc::new(Int64Array::from(vec![1])) as ArrayRef, - Arc::new(Int64Array::from(vec![2])), - Arc::new(Int64Array::from(vec![3])), - ]; - let array = make_array(&args).expect("failed to initialize function array"); - let result = as_list_array(&array).expect("failed to initialize function array"); - assert_eq!(result.len(), 1); - assert_eq!( - &[1, 2, 3], - as_int64_array(&result.value(0)) - .expect("failed to cast to primitive array") - .values() - ) - } - - #[test] - fn test_nested_array() { - // make_array([1, 3, 5], [2, 4, 6]) = [[1, 3, 5], [2, 4, 6]] - let args = [ - Arc::new(Int64Array::from(vec![1, 2])) as ArrayRef, - Arc::new(Int64Array::from(vec![3, 4])), - Arc::new(Int64Array::from(vec![5, 6])), - ]; - let array = make_array(&args).expect("failed to initialize function array"); - let result = as_list_array(&array).expect("failed to initialize function array"); - assert_eq!(result.len(), 2); - assert_eq!( - &[1, 3, 5], - as_int64_array(&result.value(0)) - .expect("failed to cast to primitive array") - .values() - ); - assert_eq!( - &[2, 4, 6], - as_int64_array(&result.value(1)) - .expect("failed to cast to primitive array") - .values() - ); - } - - #[test] - fn test_array_element() { - // array_element([1, 2, 3, 4], 1) = 1 - let list_array = return_array(); - let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(1, 1))]) - .expect("failed to initialize function array_element"); - let result = - as_int64_array(&arr).expect("failed to initialize function array_element"); - - assert_eq!(result, &Int64Array::from_value(1, 1)); - - // array_element([1, 2, 3, 4], 3) = 3 - let list_array = return_array(); - let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(3, 1))]) - .expect("failed to initialize function array_element"); - let result = - as_int64_array(&arr).expect("failed to initialize function array_element"); - - assert_eq!(result, &Int64Array::from_value(3, 1)); - - // array_element([1, 2, 3, 4], 0) = NULL - let list_array = return_array(); - let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(0, 1))]) - .expect("failed to initialize function array_element"); - let result = - as_int64_array(&arr).expect("failed to initialize function array_element"); - - assert_eq!(result, &Int64Array::from(vec![None])); - - // array_element([1, 2, 3, 4], NULL) = NULL - let list_array = return_array(); - let arr = array_element(&[list_array, Arc::new(Int64Array::from(vec![None]))]) - .expect("failed to initialize function array_element"); - let result = - as_int64_array(&arr).expect("failed to initialize function array_element"); - - assert_eq!(result, &Int64Array::from(vec![None])); - - // array_element([1, 2, 3, 4], -1) = 4 - let list_array = return_array(); - let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(-1, 1))]) - .expect("failed to initialize function array_element"); - let result = - as_int64_array(&arr).expect("failed to initialize function array_element"); - - assert_eq!(result, &Int64Array::from_value(4, 1)); - - // array_element([1, 2, 3, 4], -3) = 2 - let list_array = return_array(); - let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(-3, 1))]) - .expect("failed to initialize function array_element"); - let result = - as_int64_array(&arr).expect("failed to initialize function array_element"); - - assert_eq!(result, &Int64Array::from_value(2, 1)); - - // array_element([1, 2, 3, 4], 10) = NULL - let list_array = return_array(); - let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(10, 1))]) - .expect("failed to initialize function array_element"); - let result = - as_int64_array(&arr).expect("failed to initialize function array_element"); - - assert_eq!(result, &Int64Array::from(vec![None])); - } - - #[test] - fn test_nested_array_element() { - // array_element([[1, 2, 3, 4], [5, 6, 7, 8]], 2) = [5, 6, 7, 8] - let list_array = return_nested_array(); - let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(2, 1))]) - .expect("failed to initialize function array_element"); - let result = - as_list_array(&arr).expect("failed to initialize function array_element"); - - assert_eq!( - &[5, 6, 7, 8], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_array_pop_back() { - // array_pop_back([1, 2, 3, 4]) = [1, 2, 3] - let list_array = return_array(); - let arr = array_pop_back(&[list_array]) - .expect("failed to initialize function array_pop_back"); - let result = - as_list_array(&arr).expect("failed to initialize function array_pop_back"); - assert_eq!( - &[1, 2, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_pop_back([1, 2, 3]) = [1, 2] - let list_array = Arc::new(result.clone()); - let arr = array_pop_back(&[list_array]) - .expect("failed to initialize function array_pop_back"); - let result = - as_list_array(&arr).expect("failed to initialize function array_pop_back"); - assert_eq!( - &[1, 2], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_pop_back([1, 2]) = [1] - let list_array = Arc::new(result.clone()); - let arr = array_pop_back(&[list_array]) - .expect("failed to initialize function array_pop_back"); - let result = - as_list_array(&arr).expect("failed to initialize function array_pop_back"); - assert_eq!( - &[1], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_pop_back([1]) = [] - let list_array = Arc::new(result.clone()); - let arr = array_pop_back(&[list_array]) - .expect("failed to initialize function array_pop_back"); - let result = - as_list_array(&arr).expect("failed to initialize function array_pop_back"); - assert_eq!( - &[], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - // array_pop_back([]) = [] - let list_array = Arc::new(result.clone()); - let arr = array_pop_back(&[list_array]) - .expect("failed to initialize function array_pop_back"); - let result = - as_list_array(&arr).expect("failed to initialize function array_pop_back"); - assert_eq!( - &[], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_pop_back([1, NULL, 3, NULL]) = [1, NULL, 3] - let list_array = return_array_with_nulls(); - let arr = array_pop_back(&[list_array]) - .expect("failed to initialize function array_pop_back"); - let result = - as_list_array(&arr).expect("failed to initialize function array_pop_back"); - assert_eq!(3, result.values().len()); - assert_eq!( - &[false, true, false], - &[ - result.values().is_null(0), - result.values().is_null(1), - result.values().is_null(2) - ] - ); - } - #[test] - fn test_nested_array_pop_back() { - // array_pop_back([[1, 2, 3, 4], [5, 6, 7, 8]]) = [[1, 2, 3, 4]] - let list_array = return_nested_array(); - let arr = array_pop_back(&[list_array]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - assert_eq!( - &[1, 2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_pop_back([[1, 2, 3, 4]]) = [] - let list_array = Arc::new(result.clone()); - let arr = array_pop_back(&[list_array]) - .expect("failed to initialize function array_pop_back"); - let result = - as_list_array(&arr).expect("failed to initialize function array_pop_back"); - assert!(result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .is_empty()); - // array_pop_back([]) = [] - let list_array = Arc::new(result.clone()); - let arr = array_pop_back(&[list_array]) - .expect("failed to initialize function array_pop_back"); - let result = - as_list_array(&arr).expect("failed to initialize function array_pop_back"); - assert!(result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .is_empty()); - } - - #[test] - fn test_array_slice() { - // array_slice([1, 2, 3, 4], 1, 3) = [1, 2, 3] - let list_array = return_array(); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(1, 1)), - Arc::new(Int64Array::from_value(3, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert_eq!( - &[1, 2, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_slice([1, 2, 3, 4], 2, 2) = [2] - let list_array = return_array(); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(2, 1)), - Arc::new(Int64Array::from_value(2, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert_eq!( - &[2], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_slice([1, 2, 3, 4], 0, 0) = [] - let list_array = return_array(); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(0, 1)), - Arc::new(Int64Array::from_value(0, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert!(result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .is_empty()); - - // array_slice([1, 2, 3, 4], 0, 6) = [1, 2, 3, 4] - let list_array = return_array(); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(0, 1)), - Arc::new(Int64Array::from_value(6, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert_eq!( - &[1, 2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_slice([1, 2, 3, 4], -2, -2) = [] - let list_array = return_array(); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(-2, 1)), - Arc::new(Int64Array::from_value(-2, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert!(result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .is_empty()); - - // array_slice([1, 2, 3, 4], -3, -1) = [2, 3] - let list_array = return_array(); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(-3, 1)), - Arc::new(Int64Array::from_value(-1, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert_eq!( - &[2, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_slice([1, 2, 3, 4], -3, 2) = [2] - let list_array = return_array(); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(-3, 1)), - Arc::new(Int64Array::from_value(2, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert_eq!( - &[2], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_slice([1, 2, 3, 4], 2, 11) = [2, 3, 4] - let list_array = return_array(); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(2, 1)), - Arc::new(Int64Array::from_value(11, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert_eq!( - &[2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_slice([1, 2, 3, 4], 3, 1) = [] - let list_array = return_array(); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(3, 1)), - Arc::new(Int64Array::from_value(1, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert!(result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .is_empty()); - - // array_slice([1, 2, 3, 4], -7, -2) = NULL - let list_array = return_array(); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(-7, 1)), - Arc::new(Int64Array::from_value(-2, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert!(result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .is_null(0)); - } - - #[test] - fn test_nested_array_slice() { - // array_slice([[1, 2, 3, 4], [5, 6, 7, 8]], 1, 1) = [[1, 2, 3, 4]] - let list_array = return_nested_array(); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(1, 1)), - Arc::new(Int64Array::from_value(1, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert_eq!( - &[1, 2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_slice([[1, 2, 3, 4], [5, 6, 7, 8]], -1, -1) = [] - let list_array = return_nested_array(); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(-1, 1)), - Arc::new(Int64Array::from_value(-1, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert!(result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .is_empty()); - - // array_slice([[1, 2, 3, 4], [5, 6, 7, 8]], -1, 2) = [[5, 6, 7, 8]] - let list_array = return_nested_array(); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(-1, 1)), - Arc::new(Int64Array::from_value(2, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert_eq!( - &[5, 6, 7, 8], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_array_append() { - // array_append([1, 2, 3], 4) = [1, 2, 3, 4] - let data = vec![Some(vec![Some(1), Some(2), Some(3)])]; - let list_array = - Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; - let int64_array = Arc::new(Int64Array::from(vec![Some(4)])) as ArrayRef; - - let args = [list_array, int64_array]; - - let array = - array_append(&args).expect("failed to initialize function array_append"); - let result = - as_list_array(&array).expect("failed to initialize function array_append"); - - assert_eq!( - &[1, 2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_array_prepend() { - // array_prepend(1, [2, 3, 4]) = [1, 2, 3, 4] - let data = vec![Some(vec![Some(2), Some(3), Some(4)])]; - let list_array = - Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; - let int64_array = Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef; - - let args = [int64_array, list_array]; - - let array = - array_prepend(&args).expect("failed to initialize function array_append"); - let result = - as_list_array(&array).expect("failed to initialize function array_append"); - - assert_eq!( - &[1, 2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_array_concat() { - // array_concat([1, 2, 3], [4, 5, 6], [7, 8, 9]) = [1, 2, 3, 4, 5, 6, 7, 8, 9] - let data = vec![Some(vec![Some(1), Some(2), Some(3)])]; - let list_array1 = - Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; - let data = vec![Some(vec![Some(4), Some(5), Some(6)])]; - let list_array2 = - Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; - let data = vec![Some(vec![Some(7), Some(8), Some(9)])]; - let list_array3 = - Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; - - let args = [list_array1, list_array2, list_array3]; - - let array = - array_concat(&args).expect("failed to initialize function array_concat"); - let result = - as_list_array(&array).expect("failed to initialize function array_concat"); - - assert_eq!( - &[1, 2, 3, 4, 5, 6, 7, 8, 9], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_array_concat() { - // array_concat([1, 2, 3, 4], [1, 2, 3, 4]) = [1, 2, 3, 4, 1, 2, 3, 4] - let list_array = return_array(); - let arr = array_concat(&[list_array.clone(), list_array.clone()]) - .expect("failed to initialize function array_concat"); - let result = - as_list_array(&arr).expect("failed to initialize function array_concat"); - - assert_eq!( - &[1, 2, 3, 4, 1, 2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_concat([[1, 2, 3, 4], [5, 6, 7, 8]], [1, 2, 3, 4]) = [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4]] - let list_nested_array = return_nested_array(); - let list_array = return_array(); - let arr = array_concat(&[list_nested_array, list_array]) - .expect("failed to initialize function array_concat"); - let result = - as_list_array(&arr).expect("failed to initialize function array_concat"); - - assert_eq!( - &[1, 2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .value(2) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_array_position() { - // array_position([1, 2, 3, 4], 3) = 3 - let list_array = return_array(); - let array = array_position(&[list_array, Arc::new(Int64Array::from_value(3, 1))]) - .expect("failed to initialize function array_position"); - let result = as_uint64_array(&array) - .expect("failed to initialize function array_position"); - - assert_eq!(result, &UInt64Array::from(vec![3])); - } - - #[test] - fn test_array_positions() { - // array_positions([1, 2, 3, 4], 3) = [3] - let list_array = return_array(); - let array = - array_positions(&[list_array, Arc::new(Int64Array::from_value(3, 1))]) - .expect("failed to initialize function array_position"); - let result = - as_list_array(&array).expect("failed to initialize function array_position"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_array_remove() { - // array_remove([3, 1, 2, 3, 2, 3], 3) = [1, 2, 3, 2, 3] - let list_array = return_array_with_repeating_elements(); - let array = array_remove(&[list_array, Arc::new(Int64Array::from_value(3, 1))]) - .expect("failed to initialize function array_remove"); - let result = - as_list_array(&array).expect("failed to initialize function array_remove"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[1, 2, 3, 2, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_array_remove() { - // array_remove( - // [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]], - // [1, 2, 3, 4], - // ) = [[5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]] - let list_array = return_nested_array_with_repeating_elements(); - let element_array = return_array(); - let array = array_remove(&[list_array, element_array]) - .expect("failed to initialize function array_remove"); - let result = - as_list_array(&array).expect("failed to initialize function array_remove"); - - assert_eq!(result.len(), 1); - let data = vec![ - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - Some(vec![Some(1), Some(2), Some(3), Some(4)]), - Some(vec![Some(9), Some(10), Some(11), Some(12)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - ]; - let expected = ListArray::from_iter_primitive::(data); - assert_eq!( - expected, - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .clone() - ); - } - - #[test] - fn test_array_remove_n() { - // array_remove_n([3, 1, 2, 3, 2, 3], 3, 2) = [1, 2, 2, 3] - let list_array = return_array_with_repeating_elements(); - let array = array_remove_n(&[ - list_array, - Arc::new(Int64Array::from_value(3, 1)), - Arc::new(Int64Array::from_value(2, 1)), - ]) - .expect("failed to initialize function array_remove_n"); - let result = - as_list_array(&array).expect("failed to initialize function array_remove_n"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[1, 2, 2, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_array_remove_n() { - // array_remove_n( - // [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]], - // [1, 2, 3, 4], - // 3, - // ) = [[5, 6, 7, 8], [9, 10, 11, 12], [5, 6, 7, 8]] - let list_array = return_nested_array_with_repeating_elements(); - let element_array = return_array(); - let array = array_remove_n(&[ - list_array, - element_array, - Arc::new(Int64Array::from_value(3, 1)), - ]) - .expect("failed to initialize function array_remove_n"); - let result = - as_list_array(&array).expect("failed to initialize function array_remove_n"); - - assert_eq!(result.len(), 1); - let data = vec![ - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - Some(vec![Some(9), Some(10), Some(11), Some(12)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - ]; - let expected = ListArray::from_iter_primitive::(data); - assert_eq!( - expected, - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .clone() - ); - } - - #[test] - fn test_array_remove_all() { - // array_remove_all([3, 1, 2, 3, 2, 3], 3) = [1, 2, 2] - let list_array = return_array_with_repeating_elements(); - let array = - array_remove_all(&[list_array, Arc::new(Int64Array::from_value(3, 1))]) - .expect("failed to initialize function array_remove_all"); - let result = as_list_array(&array) - .expect("failed to initialize function array_remove_all"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[1, 2, 2], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_array_remove_all() { - // array_remove_all( - // [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]], - // [1, 2, 3, 4], - // ) = [[5, 6, 7, 8], [9, 10, 11, 12], [5, 6, 7, 8]] - let list_array = return_nested_array_with_repeating_elements(); - let element_array = return_array(); - let array = array_remove_all(&[list_array, element_array]) - .expect("failed to initialize function array_remove_all"); - let result = as_list_array(&array) - .expect("failed to initialize function array_remove_all"); - - assert_eq!(result.len(), 1); - let data = vec![ - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - Some(vec![Some(9), Some(10), Some(11), Some(12)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - ]; - let expected = ListArray::from_iter_primitive::(data); - assert_eq!( - expected, - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .clone() - ); - } - - #[test] - fn test_array_replace() { - // array_replace([3, 1, 2, 3, 2, 3], 3, 4) = [4, 1, 2, 3, 2, 3] - let list_array = return_array_with_repeating_elements(); - let array = array_replace(&[ - list_array, - Arc::new(Int64Array::from_value(3, 1)), - Arc::new(Int64Array::from_value(4, 1)), - ]) - .expect("failed to initialize function array_replace"); - let result = - as_list_array(&array).expect("failed to initialize function array_replace"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[4, 1, 2, 3, 2, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_array_replace() { - // array_replace( - // [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]], - // [1, 2, 3, 4], - // [11, 12, 13, 14], - // ) = [[11, 12, 13, 14], [5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]] - let list_array = return_nested_array_with_repeating_elements(); - let from_array = return_array(); - let to_array = return_extra_array(); - let array = array_replace(&[list_array, from_array, to_array]) - .expect("failed to initialize function array_replace"); - let result = - as_list_array(&array).expect("failed to initialize function array_replace"); - - assert_eq!(result.len(), 1); - let data = vec![ - Some(vec![Some(11), Some(12), Some(13), Some(14)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - Some(vec![Some(1), Some(2), Some(3), Some(4)]), - Some(vec![Some(9), Some(10), Some(11), Some(12)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - ]; - let expected = ListArray::from_iter_primitive::(data); - assert_eq!( - expected, - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .clone() - ); - } - - #[test] - fn test_array_replace_n() { - // array_replace_n([3, 1, 2, 3, 2, 3], 3, 4, 2) = [4, 1, 2, 4, 2, 3] - let list_array = return_array_with_repeating_elements(); - let array = array_replace_n(&[ - list_array, - Arc::new(Int64Array::from_value(3, 1)), - Arc::new(Int64Array::from_value(4, 1)), - Arc::new(Int64Array::from_value(2, 1)), - ]) - .expect("failed to initialize function array_replace_n"); - let result = - as_list_array(&array).expect("failed to initialize function array_replace_n"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[4, 1, 2, 4, 2, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_array_replace_n() { - // array_replace_n( - // [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]], - // [1, 2, 3, 4], - // [11, 12, 13, 14], - // 2, - // ) = [[11, 12, 13, 14], [5, 6, 7, 8], [11, 12, 13, 14], [9, 10, 11, 12], [5, 6, 7, 8]] - let list_array = return_nested_array_with_repeating_elements(); - let from_array = return_array(); - let to_array = return_extra_array(); - let array = array_replace_n(&[ - list_array, - from_array, - to_array, - Arc::new(Int64Array::from_value(2, 1)), - ]) - .expect("failed to initialize function array_replace_n"); - let result = - as_list_array(&array).expect("failed to initialize function array_replace_n"); - - assert_eq!(result.len(), 1); - let data = vec![ - Some(vec![Some(11), Some(12), Some(13), Some(14)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - Some(vec![Some(11), Some(12), Some(13), Some(14)]), - Some(vec![Some(9), Some(10), Some(11), Some(12)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - ]; - let expected = ListArray::from_iter_primitive::(data); - assert_eq!( - expected, - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .clone() - ); - } - - #[test] - fn test_array_replace_all() { - // array_replace_all([3, 1, 2, 3, 2, 3], 3, 4) = [4, 1, 2, 4, 2, 4] - let list_array = return_array_with_repeating_elements(); - let array = array_replace_all(&[ - list_array, - Arc::new(Int64Array::from_value(3, 1)), - Arc::new(Int64Array::from_value(4, 1)), - ]) - .expect("failed to initialize function array_replace_all"); - let result = as_list_array(&array) - .expect("failed to initialize function array_replace_all"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[4, 1, 2, 4, 2, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_array_replace_all() { - // array_replace_all( - // [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]], - // [1, 2, 3, 4], - // [11, 12, 13, 14], - // ) = [[11, 12, 13, 14], [5, 6, 7, 8], [11, 12, 13, 14], [9, 10, 11, 12], [5, 6, 7, 8]] - let list_array = return_nested_array_with_repeating_elements(); - let from_array = return_array(); - let to_array = return_extra_array(); - let array = array_replace_all(&[list_array, from_array, to_array]) - .expect("failed to initialize function array_replace_all"); - let result = as_list_array(&array) - .expect("failed to initialize function array_replace_all"); - - assert_eq!(result.len(), 1); - let data = vec![ - Some(vec![Some(11), Some(12), Some(13), Some(14)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - Some(vec![Some(11), Some(12), Some(13), Some(14)]), - Some(vec![Some(9), Some(10), Some(11), Some(12)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - ]; - let expected = ListArray::from_iter_primitive::(data); - assert_eq!( - expected, - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .clone() - ); - } - - #[test] - fn test_array_repeat() { - // array_repeat(3, 5) = [3, 3, 3, 3, 3] - let array = array_repeat(&[ - Arc::new(Int64Array::from_value(3, 1)), - Arc::new(Int64Array::from_value(5, 1)), - ]) - .expect("failed to initialize function array_repeat"); - let result = - as_list_array(&array).expect("failed to initialize function array_repeat"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[3, 3, 3, 3, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_array_repeat() { - // array_repeat([1, 2, 3, 4], 3) = [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]] - let element = return_array(); - let array = array_repeat(&[element, Arc::new(Int64Array::from_value(3, 1))]) - .expect("failed to initialize function array_repeat"); - let result = - as_list_array(&array).expect("failed to initialize function array_repeat"); - - assert_eq!(result.len(), 1); - let data = vec![ - Some(vec![Some(1), Some(2), Some(3), Some(4)]), - Some(vec![Some(1), Some(2), Some(3), Some(4)]), - Some(vec![Some(1), Some(2), Some(3), Some(4)]), - ]; - let expected = ListArray::from_iter_primitive::(data); - assert_eq!( - expected, - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .clone() - ); - } - #[test] - fn test_array_to_string() { - // array_to_string([1, 2, 3, 4], ',') = 1,2,3,4 - let list_array = return_array(); - let array = - array_to_string(&[list_array, Arc::new(StringArray::from(vec![Some(",")]))]) - .expect("failed to initialize function array_to_string"); - let result = as_string_array(&array) - .expect("failed to initialize function array_to_string"); - - assert_eq!(result.len(), 1); - assert_eq!("1,2,3,4", result.value(0)); - - // array_to_string([1, NULL, 3, NULL], ',', '*') = 1,*,3,* - let list_array = return_array_with_nulls(); - let array = array_to_string(&[ - list_array, - Arc::new(StringArray::from(vec![Some(",")])), - Arc::new(StringArray::from(vec![Some("*")])), - ]) - .expect("failed to initialize function array_to_string"); - let result = as_string_array(&array) - .expect("failed to initialize function array_to_string"); - - assert_eq!(result.len(), 1); - assert_eq!("1,*,3,*", result.value(0)); - } - - #[test] - fn test_nested_array_to_string() { - // array_to_string([[1, 2, 3, 4], [5, 6, 7, 8]], '-') = 1-2-3-4-5-6-7-8 - let list_array = return_nested_array(); - let array = - array_to_string(&[list_array, Arc::new(StringArray::from(vec![Some("-")]))]) - .expect("failed to initialize function array_to_string"); - let result = as_string_array(&array) - .expect("failed to initialize function array_to_string"); - - assert_eq!(result.len(), 1); - assert_eq!("1-2-3-4-5-6-7-8", result.value(0)); - - // array_to_string([[1, NULL, 3, NULL], [NULL, 6, 7, NULL]], '-', '*') = 1-*-3-*-*-6-7-* - let list_array = return_nested_array_with_nulls(); - let array = array_to_string(&[ - list_array, - Arc::new(StringArray::from(vec![Some("-")])), - Arc::new(StringArray::from(vec![Some("*")])), - ]) - .expect("failed to initialize function array_to_string"); - let result = as_string_array(&array) - .expect("failed to initialize function array_to_string"); - - assert_eq!(result.len(), 1); - assert_eq!("1-*-3-*-*-6-7-*", result.value(0)); - } - - #[test] - fn test_cardinality() { - // cardinality([1, 2, 3, 4]) = 4 - let list_array = return_array(); - let arr = cardinality(&[list_array]) - .expect("failed to initialize function cardinality"); - let result = - as_uint64_array(&arr).expect("failed to initialize function cardinality"); - - assert_eq!(result, &UInt64Array::from(vec![4])); - } - - #[test] - fn test_nested_cardinality() { - // cardinality([[1, 2, 3, 4], [5, 6, 7, 8]]) = 8 - let list_array = return_nested_array(); - let arr = cardinality(&[list_array]) - .expect("failed to initialize function cardinality"); - let result = - as_uint64_array(&arr).expect("failed to initialize function cardinality"); - - assert_eq!(result, &UInt64Array::from(vec![8])); - } - - #[test] - fn test_array_length() { - // array_length([1, 2, 3, 4]) = 4 - let list_array = return_array(); - let arr = array_length(&[list_array.clone()]) - .expect("failed to initialize function array_ndims"); - let result = - as_uint64_array(&arr).expect("failed to initialize function array_ndims"); - - assert_eq!(result, &UInt64Array::from_value(4, 1)); - - // array_length([1, 2, 3, 4], 1) = 4 - let array = array_length(&[list_array, Arc::new(Int64Array::from_value(1, 1))]) - .expect("failed to initialize function array_ndims"); - let result = - as_uint64_array(&array).expect("failed to initialize function array_ndims"); - - assert_eq!(result, &UInt64Array::from_value(4, 1)); - } - - #[test] - fn test_nested_array_length() { - let list_array = return_nested_array(); - - // array_length([[1, 2, 3, 4], [5, 6, 7, 8]]) = 2 - let arr = array_length(&[list_array.clone()]) - .expect("failed to initialize function array_length"); - let result = - as_uint64_array(&arr).expect("failed to initialize function array_length"); - - assert_eq!(result, &UInt64Array::from_value(2, 1)); - - // array_length([[1, 2, 3, 4], [5, 6, 7, 8]], 1) = 2 - let arr = - array_length(&[list_array.clone(), Arc::new(Int64Array::from_value(1, 1))]) - .expect("failed to initialize function array_length"); - let result = - as_uint64_array(&arr).expect("failed to initialize function array_length"); - - assert_eq!(result, &UInt64Array::from_value(2, 1)); - - // array_length([[1, 2, 3, 4], [5, 6, 7, 8]], 2) = 4 - let arr = - array_length(&[list_array.clone(), Arc::new(Int64Array::from_value(2, 1))]) - .expect("failed to initialize function array_length"); - let result = - as_uint64_array(&arr).expect("failed to initialize function array_length"); - - assert_eq!(result, &UInt64Array::from_value(4, 1)); - - // array_length([[1, 2, 3, 4], [5, 6, 7, 8]], 3) = NULL - let arr = array_length(&[list_array, Arc::new(Int64Array::from_value(3, 1))]) - .expect("failed to initialize function array_length"); - let result = - as_uint64_array(&arr).expect("failed to initialize function array_length"); - - assert_eq!(result, &UInt64Array::from(vec![None])); - } - - #[test] - fn test_array_dims() { - // array_dims([1, 2, 3, 4]) = [4] - let list_array = return_array(); - - let array = - array_dims(&[list_array]).expect("failed to initialize function array_dims"); - let result = - as_list_array(&array).expect("failed to initialize function array_dims"); - - assert_eq!( - &[4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_array_dims() { - // array_dims([[1, 2, 3, 4], [5, 6, 7, 8]]) = [2, 4] - let list_array = return_nested_array(); - - let array = - array_dims(&[list_array]).expect("failed to initialize function array_dims"); - let result = - as_list_array(&array).expect("failed to initialize function array_dims"); - - assert_eq!( - &[2, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_array_ndims() { - // array_ndims([1, 2, 3, 4]) = 1 - let list_array = return_array(); - - let array = array_ndims(&[list_array]) - .expect("failed to initialize function array_ndims"); - let result = - as_uint64_array(&array).expect("failed to initialize function array_ndims"); - - assert_eq!(result, &UInt64Array::from_value(1, 1)); - } - - #[test] - fn test_nested_array_ndims() { - // array_ndims([[1, 2, 3, 4], [5, 6, 7, 8]]) = 2 - let list_array = return_nested_array(); - - let array = array_ndims(&[list_array]) - .expect("failed to initialize function array_ndims"); - let result = - as_uint64_array(&array).expect("failed to initialize function array_ndims"); - - assert_eq!(result, &UInt64Array::from_value(2, 1)); - } - - #[test] - fn test_check_invalid_datatypes() { - let data = vec![Some(vec![Some(1), Some(2), Some(3)])]; - let list_array = - Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; - let int64_array = Arc::new(StringArray::from(vec![Some("string")])) as ArrayRef; - - let args = [list_array.clone(), int64_array.clone()]; - - let array = array_append(&args); - - assert_eq!(array.unwrap_err().strip_backtrace(), "Error during planning: array_append received incompatible types: '[Int64, Utf8]'."); - } - - fn return_array() -> ArrayRef { - // Returns: [1, 2, 3, 4] - let args = [ - Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(2)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(3)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(4)])) as ArrayRef, - ]; - make_array(&args).expect("failed to initialize function array") - } - - fn return_extra_array() -> ArrayRef { - // Returns: [11, 12, 13, 14] - let args = [ - Arc::new(Int64Array::from(vec![Some(11)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(12)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(13)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(14)])) as ArrayRef, - ]; - make_array(&args).expect("failed to initialize function array") - } - - fn return_nested_array() -> ArrayRef { - // Returns: [[1, 2, 3, 4], [5, 6, 7, 8]] - let args = [ - Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(2)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(3)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(4)])) as ArrayRef, - ]; - let arr1 = make_array(&args).expect("failed to initialize function array"); - - let args = [ - Arc::new(Int64Array::from(vec![Some(5)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(6)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(7)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(8)])) as ArrayRef, - ]; - let arr2 = make_array(&args).expect("failed to initialize function array"); - - make_array(&[arr1, arr2]).expect("failed to initialize function array") - } - - fn return_array_with_nulls() -> ArrayRef { - // Returns: [1, NULL, 3, NULL] - let args = [ - Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef, - Arc::new(Int64Array::from(vec![None])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(3)])) as ArrayRef, - Arc::new(Int64Array::from(vec![None])) as ArrayRef, - ]; - make_array(&args).expect("failed to initialize function array") - } - - fn return_nested_array_with_nulls() -> ArrayRef { - // Returns: [[1, NULL, 3, NULL], [NULL, 6, 7, NULL]] - let args = [ - Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef, - Arc::new(Int64Array::from(vec![None])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(3)])) as ArrayRef, - Arc::new(Int64Array::from(vec![None])) as ArrayRef, - ]; - let arr1 = make_array(&args).expect("failed to initialize function array"); - - let args = [ - Arc::new(Int64Array::from(vec![None])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(6)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(7)])) as ArrayRef, - Arc::new(Int64Array::from(vec![None])) as ArrayRef, - ]; - let arr2 = make_array(&args).expect("failed to initialize function array"); - - make_array(&[arr1, arr2]).expect("failed to initialize function array") - } - - fn return_array_with_repeating_elements() -> ArrayRef { - // Returns: [3, 1, 2, 3, 2, 3] - let args = [ - Arc::new(Int64Array::from(vec![Some(3)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(2)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(3)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(2)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(3)])) as ArrayRef, - ]; - make_array(&args).expect("failed to initialize function array") - } - - fn return_nested_array_with_repeating_elements() -> ArrayRef { - // Returns: [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]] - let args = [ - Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(2)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(3)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(4)])) as ArrayRef, - ]; - let arr1 = make_array(&args).expect("failed to initialize function array"); - - let args = [ - Arc::new(Int64Array::from(vec![Some(5)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(6)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(7)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(8)])) as ArrayRef, - ]; - let arr2 = make_array(&args).expect("failed to initialize function array"); - - let args = [ - Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(2)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(3)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(4)])) as ArrayRef, - ]; - let arr3 = make_array(&args).expect("failed to initialize function array"); - - let args = [ - Arc::new(Int64Array::from(vec![Some(9)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(10)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(11)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(12)])) as ArrayRef, - ]; - let arr4 = make_array(&args).expect("failed to initialize function array"); - - let args = [ - Arc::new(Int64Array::from(vec![Some(5)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(6)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(7)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(8)])) as ArrayRef, - ]; - let arr5 = make_array(&args).expect("failed to initialize function array"); - - make_array(&[arr1, arr2, arr3, arr4, arr5]) - .expect("failed to initialize function array") - } } diff --git a/datafusion/physical-expr/src/conditional_expressions.rs b/datafusion/physical-expr/src/conditional_expressions.rs index 37adb2d71ce83..a9a25ffe2ec18 100644 --- a/datafusion/physical-expr/src/conditional_expressions.rs +++ b/datafusion/physical-expr/src/conditional_expressions.rs @@ -54,7 +54,7 @@ pub fn coalesce(args: &[ColumnarValue]) -> Result { if value.is_null() { continue; } else { - let last_value = value.to_array_of_size(size); + let last_value = value.to_array_of_size(size)?; current_value = zip(&remainder, &last_value, current_value.as_ref())?; break; diff --git a/datafusion/physical-expr/src/datetime_expressions.rs b/datafusion/physical-expr/src/datetime_expressions.rs index bb8720cb8d00a..589bbc8a952bd 100644 --- a/datafusion/physical-expr/src/datetime_expressions.rs +++ b/datafusion/physical-expr/src/datetime_expressions.rs @@ -17,14 +17,9 @@ //! DateTime expressions -use arrow::array::Float64Builder; +use crate::datetime_expressions; +use crate::expressions::cast_column; use arrow::compute::cast; -use arrow::{ - array::TimestampNanosecondArray, - compute::kernels::temporal, - datatypes::TimeUnit, - temporal_conversions::{as_datetime_with_timezone, timestamp_ns_to_datetime}, -}; use arrow::{ array::{Array, ArrayRef, Float64Array, OffsetSizeTrait, PrimitiveArray}, compute::kernels::cast_utils::string_to_timestamp_nanos, @@ -34,14 +29,18 @@ use arrow::{ TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }, }; -use arrow_array::{ - timezone::Tz, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampSecondArray, +use arrow::{ + compute::kernels::temporal, + datatypes::TimeUnit, + temporal_conversions::{as_datetime_with_timezone, timestamp_ns_to_datetime}, }; +use arrow_array::temporal_conversions::NANOSECONDS; +use arrow_array::timezone::Tz; +use arrow_array::types::ArrowTimestampType; use chrono::prelude::*; use chrono::{Duration, Months, NaiveDate}; use datafusion_common::cast::{ - as_date32_array, as_date64_array, as_generic_string_array, + as_date32_array, as_date64_array, as_generic_string_array, as_primitive_array, as_timestamp_microsecond_array, as_timestamp_millisecond_array, as_timestamp_nanosecond_array, as_timestamp_second_array, }; @@ -128,6 +127,10 @@ fn string_to_timestamp_nanos_shim(s: &str) -> Result { } /// to_timestamp SQL function +/// +/// Note: `to_timestamp` returns `Timestamp(Nanosecond)` though its arguments are interpreted as **seconds**. The supported range for integer input is between `-9223372037` and `9223372036`. +/// Supported range for string input is between `1677-09-21T00:12:44.0` and `2262-04-11T23:47:16.0`. +/// Please use `to_timestamp_seconds` for the input outside of supported bounds. pub fn to_timestamp(args: &[ColumnarValue]) -> Result { handle::( args, @@ -329,7 +332,7 @@ fn date_trunc_coarse(granularity: &str, value: i64, tz: Option) -> Result, tz: Option, @@ -397,123 +400,61 @@ pub fn date_trunc(args: &[ColumnarValue]) -> Result { return exec_err!("Granularity of `date_trunc` must be non-null scalar Utf8"); }; + fn process_array( + array: &dyn Array, + granularity: String, + tz_opt: &Option>, + ) -> Result { + let parsed_tz = parse_tz(tz_opt)?; + let array = as_primitive_array::(array)?; + let array = array + .iter() + .map(|x| general_date_trunc(T::UNIT, &x, parsed_tz, granularity.as_str())) + .collect::>>()? + .with_timezone_opt(tz_opt.clone()); + Ok(ColumnarValue::Array(Arc::new(array))) + } + + fn process_scalar( + v: &Option, + granularity: String, + tz_opt: &Option>, + ) -> Result { + let parsed_tz = parse_tz(tz_opt)?; + let value = general_date_trunc(T::UNIT, v, parsed_tz, granularity.as_str())?; + let value = ScalarValue::new_timestamp::(value, tz_opt.clone()); + Ok(ColumnarValue::Scalar(value)) + } + Ok(match array { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(v, tz_opt)) => { - let parsed_tz = parse_tz(tz_opt)?; - let value = - _date_trunc(TimeUnit::Nanosecond, v, parsed_tz, granularity.as_str())?; - let value = ScalarValue::TimestampNanosecond(value, tz_opt.clone()); - ColumnarValue::Scalar(value) + process_scalar::(v, granularity, tz_opt)? } ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(v, tz_opt)) => { - let parsed_tz = parse_tz(tz_opt)?; - let value = - _date_trunc(TimeUnit::Microsecond, v, parsed_tz, granularity.as_str())?; - let value = ScalarValue::TimestampMicrosecond(value, tz_opt.clone()); - ColumnarValue::Scalar(value) + process_scalar::(v, granularity, tz_opt)? } ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(v, tz_opt)) => { - let parsed_tz = parse_tz(tz_opt)?; - let value = - _date_trunc(TimeUnit::Millisecond, v, parsed_tz, granularity.as_str())?; - let value = ScalarValue::TimestampMillisecond(value, tz_opt.clone()); - ColumnarValue::Scalar(value) + process_scalar::(v, granularity, tz_opt)? } ColumnarValue::Scalar(ScalarValue::TimestampSecond(v, tz_opt)) => { - let parsed_tz = parse_tz(tz_opt)?; - let value = - _date_trunc(TimeUnit::Second, v, parsed_tz, granularity.as_str())?; - let value = ScalarValue::TimestampSecond(value, tz_opt.clone()); - ColumnarValue::Scalar(value) + process_scalar::(v, granularity, tz_opt)? } ColumnarValue::Array(array) => { let array_type = array.data_type(); match array_type { DataType::Timestamp(TimeUnit::Second, tz_opt) => { - let parsed_tz = parse_tz(tz_opt)?; - let array = as_timestamp_second_array(array)?; - let array = array - .iter() - .map(|x| { - _date_trunc( - TimeUnit::Second, - &x, - parsed_tz, - granularity.as_str(), - ) - }) - .collect::>()? - .with_timezone_opt(tz_opt.clone()); - ColumnarValue::Array(Arc::new(array)) + process_array::(array, granularity, tz_opt)? } DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { - let parsed_tz = parse_tz(tz_opt)?; - let array = as_timestamp_millisecond_array(array)?; - let array = array - .iter() - .map(|x| { - _date_trunc( - TimeUnit::Millisecond, - &x, - parsed_tz, - granularity.as_str(), - ) - }) - .collect::>()? - .with_timezone_opt(tz_opt.clone()); - ColumnarValue::Array(Arc::new(array)) + process_array::(array, granularity, tz_opt)? } DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { - let parsed_tz = parse_tz(tz_opt)?; - let array = as_timestamp_microsecond_array(array)?; - let array = array - .iter() - .map(|x| { - _date_trunc( - TimeUnit::Microsecond, - &x, - parsed_tz, - granularity.as_str(), - ) - }) - .collect::>()? - .with_timezone_opt(tz_opt.clone()); - ColumnarValue::Array(Arc::new(array)) + process_array::(array, granularity, tz_opt)? } DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { - let parsed_tz = parse_tz(tz_opt)?; - let array = as_timestamp_nanosecond_array(array)?; - let array = array - .iter() - .map(|x| { - _date_trunc( - TimeUnit::Nanosecond, - &x, - parsed_tz, - granularity.as_str(), - ) - }) - .collect::>()? - .with_timezone_opt(tz_opt.clone()); - ColumnarValue::Array(Arc::new(array)) - } - _ => { - let parsed_tz = None; - let array = as_timestamp_nanosecond_array(array)?; - let array = array - .iter() - .map(|x| { - _date_trunc( - TimeUnit::Nanosecond, - &x, - parsed_tz, - granularity.as_str(), - ) - }) - .collect::>()?; - - ColumnarValue::Array(Arc::new(array)) + process_array::(array, granularity, tz_opt)? } + _ => process_array::(array, granularity, &None)?, } } _ => { @@ -702,89 +643,104 @@ fn date_bin_impl( return exec_err!("DATE_BIN stride must be non-zero"); } - let f_nanos = |x: Option| x.map(|x| stride_fn(stride, x, origin)); - let f_micros = |x: Option| { - let scale = 1_000; - x.map(|x| stride_fn(stride, x * scale, origin) / scale) - }; - let f_millis = |x: Option| { - let scale = 1_000_000; - x.map(|x| stride_fn(stride, x * scale, origin) / scale) - }; - let f_secs = |x: Option| { - let scale = 1_000_000_000; - x.map(|x| stride_fn(stride, x * scale, origin) / scale) - }; + fn stride_map_fn( + origin: i64, + stride: i64, + stride_fn: fn(i64, i64, i64) -> i64, + ) -> impl Fn(Option) -> Option { + let scale = match T::UNIT { + TimeUnit::Nanosecond => 1, + TimeUnit::Microsecond => NANOSECONDS / 1_000_000, + TimeUnit::Millisecond => NANOSECONDS / 1_000, + TimeUnit::Second => NANOSECONDS, + }; + move |x: Option| x.map(|x| stride_fn(stride, x * scale, origin) / scale) + } Ok(match array { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(v, tz_opt)) => { + let apply_stride_fn = + stride_map_fn::(origin, stride, stride_fn); ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( - f_nanos(*v), + apply_stride_fn(*v), tz_opt.clone(), )) } ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(v, tz_opt)) => { + let apply_stride_fn = + stride_map_fn::(origin, stride, stride_fn); ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( - f_micros(*v), + apply_stride_fn(*v), tz_opt.clone(), )) } ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(v, tz_opt)) => { + let apply_stride_fn = + stride_map_fn::(origin, stride, stride_fn); ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( - f_millis(*v), + apply_stride_fn(*v), tz_opt.clone(), )) } ColumnarValue::Scalar(ScalarValue::TimestampSecond(v, tz_opt)) => { + let apply_stride_fn = + stride_map_fn::(origin, stride, stride_fn); ColumnarValue::Scalar(ScalarValue::TimestampSecond( - f_secs(*v), + apply_stride_fn(*v), tz_opt.clone(), )) } - ColumnarValue::Array(array) => match array.data_type() { - DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { - let array = as_timestamp_nanosecond_array(array)? - .iter() - .map(f_nanos) - .collect::() - .with_timezone_opt(tz_opt.clone()); - ColumnarValue::Array(Arc::new(array)) - } - DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { - let array = as_timestamp_microsecond_array(array)? - .iter() - .map(f_micros) - .collect::() - .with_timezone_opt(tz_opt.clone()); - - ColumnarValue::Array(Arc::new(array)) - } - DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { - let array = as_timestamp_millisecond_array(array)? - .iter() - .map(f_millis) - .collect::() - .with_timezone_opt(tz_opt.clone()); - - ColumnarValue::Array(Arc::new(array)) - } - DataType::Timestamp(TimeUnit::Second, tz_opt) => { - let array = as_timestamp_second_array(array)? + ColumnarValue::Array(array) => { + fn transform_array_with_stride( + origin: i64, + stride: i64, + stride_fn: fn(i64, i64, i64) -> i64, + array: &ArrayRef, + tz_opt: &Option>, + ) -> Result + where + T: ArrowTimestampType, + { + let array = as_primitive_array::(array)?; + let apply_stride_fn = stride_map_fn::(origin, stride, stride_fn); + let array = array .iter() - .map(f_secs) - .collect::() + .map(apply_stride_fn) + .collect::>() .with_timezone_opt(tz_opt.clone()); - ColumnarValue::Array(Arc::new(array)) + Ok(ColumnarValue::Array(Arc::new(array))) } - _ => { - return exec_err!( - "DATE_BIN expects source argument to be a TIMESTAMP but got {}", - array.data_type() - ) + match array.data_type() { + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { + transform_array_with_stride::( + origin, stride, stride_fn, array, tz_opt, + )? + } + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { + transform_array_with_stride::( + origin, stride, stride_fn, array, tz_opt, + )? + } + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { + transform_array_with_stride::( + origin, stride, stride_fn, array, tz_opt, + )? + } + DataType::Timestamp(TimeUnit::Second, tz_opt) => { + transform_array_with_stride::( + origin, stride, stride_fn, array, tz_opt, + )? + } + _ => { + return exec_err!( + "DATE_BIN expects source argument to be a TIMESTAMP but got {}", + array.data_type() + ) + } } - }, + } _ => { return exec_err!( "DATE_BIN expects source argument to be a TIMESTAMP scalar or array" @@ -850,7 +806,7 @@ pub fn date_part(args: &[ColumnarValue]) -> Result { let array = match array { ColumnarValue::Array(array) => array.clone(), - ColumnarValue::Scalar(scalar) => scalar.to_array(), + ColumnarValue::Scalar(scalar) => scalar.to_array()?, }; let arr = match date_part.to_lowercase().as_str() { @@ -930,28 +886,188 @@ where T: ArrowTemporalType + ArrowNumericType, i64: From, { - let mut b = Float64Builder::with_capacity(array.len()); - match array.data_type() { + let b = match array.data_type() { DataType::Timestamp(tu, _) => { - for i in 0..array.len() { - if array.is_null(i) { - b.append_null(); - } else { - let scale = match tu { - TimeUnit::Second => 1, - TimeUnit::Millisecond => 1_000, - TimeUnit::Microsecond => 1_000_000, - TimeUnit::Nanosecond => 1_000_000_000, - }; - - let n: i64 = array.value(i).into(); - b.append_value(n as f64 / scale as f64); - } - } + let scale = match tu { + TimeUnit::Second => 1, + TimeUnit::Millisecond => 1_000, + TimeUnit::Microsecond => 1_000_000, + TimeUnit::Nanosecond => 1_000_000_000, + } as f64; + array.unary(|n| { + let n: i64 = n.into(); + n as f64 / scale + }) + } + DataType::Date32 => { + let seconds_in_a_day = 86400_f64; + array.unary(|n| { + let n: i64 = n.into(); + n as f64 * seconds_in_a_day + }) } + DataType::Date64 => array.unary(|n| { + let n: i64 = n.into(); + n as f64 / 1_000_f64 + }), _ => return internal_err!("Can not convert {:?} to epoch", array.data_type()), + }; + Ok(b) +} + +/// to_timestammp() SQL function implementation +pub fn to_timestamp_invoke(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return internal_err!( + "to_timestamp function requires 1 arguments, got {}", + args.len() + ); + } + + match args[0].data_type() { + DataType::Int64 => cast_column( + &cast_column(&args[0], &DataType::Timestamp(TimeUnit::Second, None), None)?, + &DataType::Timestamp(TimeUnit::Nanosecond, None), + None, + ), + DataType::Float64 => cast_column( + &args[0], + &DataType::Timestamp(TimeUnit::Nanosecond, None), + None, + ), + DataType::Timestamp(_, None) => cast_column( + &args[0], + &DataType::Timestamp(TimeUnit::Nanosecond, None), + None, + ), + DataType::Utf8 => datetime_expressions::to_timestamp(args), + other => { + internal_err!( + "Unsupported data type {:?} for function to_timestamp", + other + ) + } + } +} + +/// to_timestamp_millis() SQL function implementation +pub fn to_timestamp_millis_invoke(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return internal_err!( + "to_timestamp_millis function requires 1 argument, got {}", + args.len() + ); + } + + match args[0].data_type() { + DataType::Int64 | DataType::Timestamp(_, None) => cast_column( + &args[0], + &DataType::Timestamp(TimeUnit::Millisecond, None), + None, + ), + DataType::Utf8 => datetime_expressions::to_timestamp_millis(args), + other => { + internal_err!( + "Unsupported data type {:?} for function to_timestamp_millis", + other + ) + } + } +} + +/// to_timestamp_micros() SQL function implementation +pub fn to_timestamp_micros_invoke(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return internal_err!( + "to_timestamp_micros function requires 1 argument, got {}", + args.len() + ); + } + + match args[0].data_type() { + DataType::Int64 | DataType::Timestamp(_, None) => cast_column( + &args[0], + &DataType::Timestamp(TimeUnit::Microsecond, None), + None, + ), + DataType::Utf8 => datetime_expressions::to_timestamp_micros(args), + other => { + internal_err!( + "Unsupported data type {:?} for function to_timestamp_micros", + other + ) + } + } +} + +/// to_timestamp_nanos() SQL function implementation +pub fn to_timestamp_nanos_invoke(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return internal_err!( + "to_timestamp_nanos function requires 1 argument, got {}", + args.len() + ); + } + + match args[0].data_type() { + DataType::Int64 | DataType::Timestamp(_, None) => cast_column( + &args[0], + &DataType::Timestamp(TimeUnit::Nanosecond, None), + None, + ), + DataType::Utf8 => datetime_expressions::to_timestamp_nanos(args), + other => { + internal_err!( + "Unsupported data type {:?} for function to_timestamp_nanos", + other + ) + } + } +} + +/// to_timestamp_seconds() SQL function implementation +pub fn to_timestamp_seconds_invoke(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return internal_err!( + "to_timestamp_seconds function requires 1 argument, got {}", + args.len() + ); + } + + match args[0].data_type() { + DataType::Int64 | DataType::Timestamp(_, None) => { + cast_column(&args[0], &DataType::Timestamp(TimeUnit::Second, None), None) + } + DataType::Utf8 => datetime_expressions::to_timestamp_seconds(args), + other => { + internal_err!( + "Unsupported data type {:?} for function to_timestamp_seconds", + other + ) + } + } +} + +/// from_unixtime() SQL function implementation +pub fn from_unixtime_invoke(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return internal_err!( + "from_unixtime function requires 1 argument, got {}", + args.len() + ); + } + + match args[0].data_type() { + DataType::Int64 => { + cast_column(&args[0], &DataType::Timestamp(TimeUnit::Second, None), None) + } + other => { + internal_err!( + "Unsupported data type {:?} for function from_unixtime", + other + ) + } } - Ok(b.finish()) } #[cfg(test)] @@ -961,6 +1077,7 @@ mod tests { use arrow::array::{ as_primitive_array, ArrayRef, Int64Array, IntervalDayTimeArray, StringBuilder, }; + use arrow_array::TimestampNanosecondArray; use super::*; @@ -1188,7 +1305,7 @@ mod tests { .collect::() .with_timezone_opt(tz_opt.clone()); let result = date_trunc(&[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some("day".to_string()))), + ColumnarValue::Scalar(ScalarValue::from("day")), ColumnarValue::Array(Arc::new(input)), ]) .unwrap(); diff --git a/datafusion/physical-expr/src/equivalence.rs b/datafusion/physical-expr/src/equivalence.rs deleted file mode 100644 index d8aa09b904605..0000000000000 --- a/datafusion/physical-expr/src/equivalence.rs +++ /dev/null @@ -1,2858 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::collections::HashSet; -use std::hash::Hash; -use std::sync::Arc; - -use crate::expressions::Column; -use crate::physical_expr::{deduplicate_physical_exprs, have_common_entries}; -use crate::sort_properties::{ExprOrdering, SortProperties}; -use crate::{ - physical_exprs_contains, LexOrdering, LexOrderingRef, LexRequirement, - LexRequirementRef, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, -}; - -use arrow::datatypes::SchemaRef; -use arrow_schema::SortOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{JoinSide, JoinType, Result}; - -use indexmap::map::Entry; -use indexmap::IndexMap; - -/// An `EquivalenceClass` is a set of [`Arc`]s that are known -/// to have the same value for all tuples in a relation. These are generated by -/// equality predicates, typically equi-join conditions and equality conditions -/// in filters. -pub type EquivalenceClass = Vec>; - -/// Stores the mapping between source expressions and target expressions for a -/// projection. -#[derive(Debug, Clone)] -pub struct ProjectionMapping { - /// `(source expression)` --> `(target expression)` - /// Indices in the vector corresponds to the indices after projection. - inner: Vec<(Arc, Arc)>, -} - -impl ProjectionMapping { - /// Constructs the mapping between a projection's input and output - /// expressions. - /// - /// For example, given the input projection expressions (`a+b`, `c+d`) - /// and an output schema with two columns `"c+d"` and `"a+b"` - /// the projection mapping would be - /// ```text - /// [0]: (c+d, col("c+d")) - /// [1]: (a+b, col("a+b")) - /// ``` - /// where `col("c+d")` means the column named "c+d". - pub fn try_new( - expr: &[(Arc, String)], - input_schema: &SchemaRef, - ) -> Result { - // Construct a map from the input expressions to the output expression of the projection: - let mut inner = vec![]; - for (expr_idx, (expression, name)) in expr.iter().enumerate() { - let target_expr = Arc::new(Column::new(name, expr_idx)) as _; - - let source_expr = expression.clone().transform_down(&|e| match e - .as_any() - .downcast_ref::( - ) { - Some(col) => { - // Sometimes, expression and its name in the input_schema doesn't match. - // This can cause problems. Hence in here we make sure that expression name - // matches with the name in the inout_schema. - // Conceptually, source_expr and expression should be same. - let idx = col.index(); - let matching_input_field = input_schema.field(idx); - let matching_input_column = - Column::new(matching_input_field.name(), idx); - Ok(Transformed::Yes(Arc::new(matching_input_column))) - } - None => Ok(Transformed::No(e)), - })?; - - inner.push((source_expr, target_expr)); - } - Ok(Self { inner }) - } - - /// Iterate over pairs of (source, target) expressions - pub fn iter( - &self, - ) -> impl Iterator, Arc)> + '_ { - self.inner.iter() - } -} - -/// An `EquivalenceGroup` is a collection of `EquivalenceClass`es where each -/// class represents a distinct equivalence class in a relation. -#[derive(Debug, Clone)] -pub struct EquivalenceGroup { - classes: Vec, -} - -impl EquivalenceGroup { - /// Creates an empty equivalence group. - fn empty() -> Self { - Self { classes: vec![] } - } - - /// Creates an equivalence group from the given equivalence classes. - fn new(classes: Vec) -> Self { - let mut result = EquivalenceGroup { classes }; - result.remove_redundant_entries(); - result - } - - /// Returns how many equivalence classes there are in this group. - fn len(&self) -> usize { - self.classes.len() - } - - /// Checks whether this equivalence group is empty. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Returns an iterator over the equivalence classes in this group. - fn iter(&self) -> impl Iterator { - self.classes.iter() - } - - /// Adds the equality `left` = `right` to this equivalence group. - /// New equality conditions often arise after steps like `Filter(a = b)`, - /// `Alias(a, a as b)` etc. - fn add_equal_conditions( - &mut self, - left: &Arc, - right: &Arc, - ) { - let mut first_class = None; - let mut second_class = None; - for (idx, cls) in self.classes.iter().enumerate() { - if physical_exprs_contains(cls, left) { - first_class = Some(idx); - } - if physical_exprs_contains(cls, right) { - second_class = Some(idx); - } - } - match (first_class, second_class) { - (Some(mut first_idx), Some(mut second_idx)) => { - // If the given left and right sides belong to different classes, - // we should unify/bridge these classes. - if first_idx != second_idx { - // By convention make sure second_idx is larger than first_idx. - if first_idx > second_idx { - (first_idx, second_idx) = (second_idx, first_idx); - } - // Remove second_idx from self.classes then merge its values with class at first_idx. - // Convention above makes sure that first_idx is still valid after second_idx removal. - let other_class = self.classes.swap_remove(second_idx); - self.classes[first_idx].extend(other_class); - } - } - (Some(group_idx), None) => { - // Right side is new, extend left side's class: - self.classes[group_idx].push(right.clone()); - } - (None, Some(group_idx)) => { - // Left side is new, extend right side's class: - self.classes[group_idx].push(left.clone()); - } - (None, None) => { - // None of the expressions is among existing classes. - // Create a new equivalence class and extend the group. - self.classes.push(vec![left.clone(), right.clone()]); - } - } - } - - /// Removes redundant entries from this group. - fn remove_redundant_entries(&mut self) { - // Remove duplicate entries from each equivalence class: - self.classes.retain_mut(|cls| { - // Keep groups that have at least two entries as singleton class is - // meaningless (i.e. it contains no non-trivial information): - deduplicate_physical_exprs(cls); - cls.len() > 1 - }); - // Unify/bridge groups that have common expressions: - self.bridge_classes() - } - - /// This utility function unifies/bridges classes that have common expressions. - /// For example, assume that we have [`EquivalenceClass`]es `[a, b]` and `[b, c]`. - /// Since both classes contain `b`, columns `a`, `b` and `c` are actually all - /// equal and belong to one class. This utility converts merges such classes. - fn bridge_classes(&mut self) { - let mut idx = 0; - while idx < self.classes.len() { - let mut next_idx = idx + 1; - let start_size = self.classes[idx].len(); - while next_idx < self.classes.len() { - if have_common_entries(&self.classes[idx], &self.classes[next_idx]) { - let extension = self.classes.swap_remove(next_idx); - self.classes[idx].extend(extension); - } else { - next_idx += 1; - } - } - if self.classes[idx].len() > start_size { - deduplicate_physical_exprs(&mut self.classes[idx]); - if self.classes[idx].len() > start_size { - continue; - } - } - idx += 1; - } - } - - /// Extends this equivalence group with the `other` equivalence group. - fn extend(&mut self, other: Self) { - self.classes.extend(other.classes); - self.remove_redundant_entries(); - } - - /// Normalizes the given physical expression according to this group. - /// The expression is replaced with the first expression in the equivalence - /// class it matches with (if any). - pub fn normalize_expr(&self, expr: Arc) -> Arc { - expr.clone() - .transform(&|expr| { - for cls in self.iter() { - if physical_exprs_contains(cls, &expr) { - return Ok(Transformed::Yes(cls[0].clone())); - } - } - Ok(Transformed::No(expr)) - }) - .unwrap_or(expr) - } - - /// Normalizes the given sort expression according to this group. - /// The underlying physical expression is replaced with the first expression - /// in the equivalence class it matches with (if any). If the underlying - /// expression does not belong to any equivalence class in this group, returns - /// the sort expression as is. - pub fn normalize_sort_expr( - &self, - mut sort_expr: PhysicalSortExpr, - ) -> PhysicalSortExpr { - sort_expr.expr = self.normalize_expr(sort_expr.expr); - sort_expr - } - - /// Normalizes the given sort requirement according to this group. - /// The underlying physical expression is replaced with the first expression - /// in the equivalence class it matches with (if any). If the underlying - /// expression does not belong to any equivalence class in this group, returns - /// the given sort requirement as is. - pub fn normalize_sort_requirement( - &self, - mut sort_requirement: PhysicalSortRequirement, - ) -> PhysicalSortRequirement { - sort_requirement.expr = self.normalize_expr(sort_requirement.expr); - sort_requirement - } - - /// This function applies the `normalize_expr` function for all expressions - /// in `exprs` and returns the corresponding normalized physical expressions. - pub fn normalize_exprs( - &self, - exprs: impl IntoIterator>, - ) -> Vec> { - exprs - .into_iter() - .map(|expr| self.normalize_expr(expr)) - .collect() - } - - /// This function applies the `normalize_sort_expr` function for all sort - /// expressions in `sort_exprs` and returns the corresponding normalized - /// sort expressions. - pub fn normalize_sort_exprs(&self, sort_exprs: LexOrderingRef) -> LexOrdering { - // Convert sort expressions to sort requirements: - let sort_reqs = PhysicalSortRequirement::from_sort_exprs(sort_exprs.iter()); - // Normalize the requirements: - let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs); - // Convert sort requirements back to sort expressions: - PhysicalSortRequirement::to_sort_exprs(normalized_sort_reqs) - } - - /// This function applies the `normalize_sort_requirement` function for all - /// requirements in `sort_reqs` and returns the corresponding normalized - /// sort requirements. - pub fn normalize_sort_requirements( - &self, - sort_reqs: LexRequirementRef, - ) -> LexRequirement { - collapse_lex_req( - sort_reqs - .iter() - .map(|sort_req| self.normalize_sort_requirement(sort_req.clone())) - .collect(), - ) - } - - /// Projects `expr` according to the given projection mapping. - /// If the resulting expression is invalid after projection, returns `None`. - fn project_expr( - &self, - mapping: &ProjectionMapping, - expr: &Arc, - ) -> Option> { - let children = expr.children(); - if children.is_empty() { - for (source, target) in mapping.iter() { - // If we match the source, or an equivalent expression to source, - // then we can project. For example, if we have the mapping - // (a as a1, a + c) and the equivalence class (a, b), expression - // b also projects to a1. - if source.eq(expr) - || self - .get_equivalence_class(source) - .map_or(false, |group| physical_exprs_contains(group, expr)) - { - return Some(target.clone()); - } - } - } - // Project a non-leaf expression by projecting its children. - else if let Some(children) = children - .into_iter() - .map(|child| self.project_expr(mapping, &child)) - .collect::>>() - { - return Some(expr.clone().with_new_children(children).unwrap()); - } - // Arriving here implies the expression was invalid after projection. - None - } - - /// Projects `ordering` according to the given projection mapping. - /// If the resulting ordering is invalid after projection, returns `None`. - fn project_ordering( - &self, - mapping: &ProjectionMapping, - ordering: LexOrderingRef, - ) -> Option { - // If any sort expression is invalid after projection, rest of the - // ordering shouldn't be projected either. For example, if input ordering - // is [a ASC, b ASC, c ASC], and column b is not valid after projection, - // the result should be [a ASC], not [a ASC, c ASC], even if column c is - // valid after projection. - let result = ordering - .iter() - .map_while(|sort_expr| { - self.project_expr(mapping, &sort_expr.expr) - .map(|expr| PhysicalSortExpr { - expr, - options: sort_expr.options, - }) - }) - .collect::>(); - (!result.is_empty()).then_some(result) - } - - /// Projects this equivalence group according to the given projection mapping. - pub fn project(&self, mapping: &ProjectionMapping) -> Self { - let projected_classes = self.iter().filter_map(|cls| { - let new_class = cls - .iter() - .filter_map(|expr| self.project_expr(mapping, expr)) - .collect::>(); - (new_class.len() > 1).then_some(new_class) - }); - // TODO: Convert the algorithm below to a version that uses `HashMap`. - // once `Arc` can be stored in `HashMap`. - // See issue: https://github.com/apache/arrow-datafusion/issues/8027 - let mut new_classes = vec![]; - for (source, target) in mapping.iter() { - if new_classes.is_empty() { - new_classes.push((source, vec![target.clone()])); - } - if let Some((_, values)) = - new_classes.iter_mut().find(|(key, _)| key.eq(source)) - { - if !physical_exprs_contains(values, target) { - values.push(target.clone()); - } - } - } - // Only add equivalence classes with at least two members as singleton - // equivalence classes are meaningless. - let new_classes = new_classes - .into_iter() - .filter_map(|(_, values)| (values.len() > 1).then_some(values)); - let classes = projected_classes.chain(new_classes).collect(); - Self::new(classes) - } - - /// Returns the equivalence class that contains `expr`. - /// If none of the equivalence classes contains `expr`, returns `None`. - fn get_equivalence_class( - &self, - expr: &Arc, - ) -> Option<&[Arc]> { - self.iter() - .map(|cls| cls.as_slice()) - .find(|cls| physical_exprs_contains(cls, expr)) - } - - /// Combine equivalence groups of the given join children. - pub fn join( - &self, - right_equivalences: &Self, - join_type: &JoinType, - left_size: usize, - on: &[(Column, Column)], - ) -> Self { - match join_type { - JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { - let mut result = Self::new( - self.iter() - .cloned() - .chain(right_equivalences.iter().map(|item| { - item.iter() - .cloned() - .map(|expr| add_offset_to_expr(expr, left_size)) - .collect() - })) - .collect(), - ); - // In we have an inner join, expressions in the "on" condition - // are equal in the resulting table. - if join_type == &JoinType::Inner { - for (lhs, rhs) in on.iter() { - let index = rhs.index() + left_size; - let new_lhs = Arc::new(lhs.clone()) as _; - let new_rhs = Arc::new(Column::new(rhs.name(), index)) as _; - result.add_equal_conditions(&new_lhs, &new_rhs); - } - } - result - } - JoinType::LeftSemi | JoinType::LeftAnti => self.clone(), - JoinType::RightSemi | JoinType::RightAnti => right_equivalences.clone(), - } - } -} - -/// This function constructs a duplicate-free `LexOrderingReq` by filtering out -/// duplicate entries that have same physical expression inside. For example, -/// `vec![a Some(Asc), a Some(Desc)]` collapses to `vec![a Some(Asc)]`. -pub fn collapse_lex_req(input: LexRequirement) -> LexRequirement { - let mut output = Vec::::new(); - for item in input { - if !output.iter().any(|req| req.expr.eq(&item.expr)) { - output.push(item); - } - } - output -} - -/// An `OrderingEquivalenceClass` object keeps track of different alternative -/// orderings than can describe a schema. For example, consider the following table: -/// -/// ```text -/// |a|b|c|d| -/// |1|4|3|1| -/// |2|3|3|2| -/// |3|1|2|2| -/// |3|2|1|3| -/// ``` -/// -/// Here, both `vec![a ASC, b ASC]` and `vec![c DESC, d ASC]` describe the table -/// ordering. In this case, we say that these orderings are equivalent. -#[derive(Debug, Clone, Eq, PartialEq, Hash)] -pub struct OrderingEquivalenceClass { - orderings: Vec, -} - -impl OrderingEquivalenceClass { - /// Creates new empty ordering equivalence class. - fn empty() -> Self { - Self { orderings: vec![] } - } - - /// Clears (empties) this ordering equivalence class. - pub fn clear(&mut self) { - self.orderings.clear(); - } - - /// Creates new ordering equivalence class from the given orderings. - pub fn new(orderings: Vec) -> Self { - let mut result = Self { orderings }; - result.remove_redundant_entries(); - result - } - - /// Checks whether `ordering` is a member of this equivalence class. - pub fn contains(&self, ordering: &LexOrdering) -> bool { - self.orderings.contains(ordering) - } - - /// Adds `ordering` to this equivalence class. - #[allow(dead_code)] - fn push(&mut self, ordering: LexOrdering) { - self.orderings.push(ordering); - // Make sure that there are no redundant orderings: - self.remove_redundant_entries(); - } - - /// Checks whether this ordering equivalence class is empty. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Returns an iterator over the equivalent orderings in this class. - pub fn iter(&self) -> impl Iterator { - self.orderings.iter() - } - - /// Returns how many equivalent orderings there are in this class. - pub fn len(&self) -> usize { - self.orderings.len() - } - - /// Extend this ordering equivalence class with the `other` class. - pub fn extend(&mut self, other: Self) { - self.orderings.extend(other.orderings); - // Make sure that there are no redundant orderings: - self.remove_redundant_entries(); - } - - /// Adds new orderings into this ordering equivalence class. - pub fn add_new_orderings( - &mut self, - orderings: impl IntoIterator, - ) { - self.orderings.extend(orderings); - // Make sure that there are no redundant orderings: - self.remove_redundant_entries(); - } - - /// Removes redundant orderings from this equivalence class. - /// For instance, If we already have the ordering [a ASC, b ASC, c DESC], - /// then there is no need to keep ordering [a ASC, b ASC] in the state. - fn remove_redundant_entries(&mut self) { - let mut idx = 0; - while idx < self.orderings.len() { - let mut removal = false; - for (ordering_idx, ordering) in self.orderings[0..idx].iter().enumerate() { - if let Some(right_finer) = finer_side(ordering, &self.orderings[idx]) { - if right_finer { - self.orderings.swap(ordering_idx, idx); - } - removal = true; - break; - } - } - if removal { - self.orderings.swap_remove(idx); - } else { - idx += 1; - } - } - } - - /// Gets the first ordering entry in this ordering equivalence class. - /// This is one of the many valid orderings (if there are multiple). - pub fn output_ordering(&self) -> Option { - self.orderings.first().cloned() - } - - // Append orderings in `other` to all existing orderings in this equivalence - // class. - pub fn join_suffix(mut self, other: &Self) -> Self { - for ordering in other.iter() { - for idx in 0..self.orderings.len() { - self.orderings[idx].extend(ordering.iter().cloned()); - } - } - self - } - - /// Adds `offset` value to the index of each expression inside this - /// ordering equivalence class. - pub fn add_offset(&mut self, offset: usize) { - for ordering in self.orderings.iter_mut() { - for sort_expr in ordering { - sort_expr.expr = add_offset_to_expr(sort_expr.expr.clone(), offset); - } - } - } - - /// Gets sort options associated with this expression if it is a leading - /// ordering expression. Otherwise, returns `None`. - fn get_options(&self, expr: &Arc) -> Option { - for ordering in self.iter() { - let leading_ordering = &ordering[0]; - if leading_ordering.expr.eq(expr) { - return Some(leading_ordering.options); - } - } - None - } -} - -/// Adds the `offset` value to `Column` indices inside `expr`. This function is -/// generally used during the update of the right table schema in join operations. -pub fn add_offset_to_expr( - expr: Arc, - offset: usize, -) -> Arc { - expr.transform_down(&|e| match e.as_any().downcast_ref::() { - Some(col) => Ok(Transformed::Yes(Arc::new(Column::new( - col.name(), - offset + col.index(), - )))), - None => Ok(Transformed::No(e)), - }) - .unwrap() - // Note that we can safely unwrap here since our transform always returns - // an `Ok` value. -} - -/// Returns `true` if the ordering `rhs` is strictly finer than the ordering `rhs`, -/// `false` if the ordering `lhs` is at least as fine as the ordering `lhs`, and -/// `None` otherwise (i.e. when given orderings are incomparable). -fn finer_side(lhs: LexOrderingRef, rhs: LexOrderingRef) -> Option { - let all_equal = lhs.iter().zip(rhs.iter()).all(|(lhs, rhs)| lhs.eq(rhs)); - all_equal.then_some(lhs.len() < rhs.len()) -} - -/// A `EquivalenceProperties` object stores useful information related to a schema. -/// Currently, it keeps track of: -/// - Equivalent expressions, e.g expressions that have same value. -/// - Valid sort expressions (orderings) for the schema. -/// - Constants expressions (e.g expressions that are known to have constant values). -/// -/// Consider table below: -/// -/// ```text -/// ┌-------┐ -/// | a | b | -/// |---|---| -/// | 1 | 9 | -/// | 2 | 8 | -/// | 3 | 7 | -/// | 5 | 5 | -/// └---┴---┘ -/// ``` -/// -/// where both `a ASC` and `b DESC` can describe the table ordering. With -/// `EquivalenceProperties`, we can keep track of these different valid sort -/// expressions and treat `a ASC` and `b DESC` on an equal footing. -/// -/// Similarly, consider the table below: -/// -/// ```text -/// ┌-------┐ -/// | a | b | -/// |---|---| -/// | 1 | 1 | -/// | 2 | 2 | -/// | 3 | 3 | -/// | 5 | 5 | -/// └---┴---┘ -/// ``` -/// -/// where columns `a` and `b` always have the same value. We keep track of such -/// equivalences inside this object. With this information, we can optimize -/// things like partitioning. For example, if the partition requirement is -/// `Hash(a)` and output partitioning is `Hash(b)`, then we can deduce that -/// the existing partitioning satisfies the requirement. -#[derive(Debug, Clone)] -pub struct EquivalenceProperties { - /// Collection of equivalence classes that store expressions with the same - /// value. - eq_group: EquivalenceGroup, - /// Equivalent sort expressions for this table. - oeq_class: OrderingEquivalenceClass, - /// Expressions whose values are constant throughout the table. - /// TODO: We do not need to track constants separately, they can be tracked - /// inside `eq_groups` as `Literal` expressions. - constants: Vec>, - /// Schema associated with this object. - schema: SchemaRef, -} - -impl EquivalenceProperties { - /// Creates an empty `EquivalenceProperties` object. - pub fn new(schema: SchemaRef) -> Self { - Self { - eq_group: EquivalenceGroup::empty(), - oeq_class: OrderingEquivalenceClass::empty(), - constants: vec![], - schema, - } - } - - /// Creates a new `EquivalenceProperties` object with the given orderings. - pub fn new_with_orderings(schema: SchemaRef, orderings: &[LexOrdering]) -> Self { - Self { - eq_group: EquivalenceGroup::empty(), - oeq_class: OrderingEquivalenceClass::new(orderings.to_vec()), - constants: vec![], - schema, - } - } - - /// Returns the associated schema. - pub fn schema(&self) -> &SchemaRef { - &self.schema - } - - /// Returns a reference to the ordering equivalence class within. - pub fn oeq_class(&self) -> &OrderingEquivalenceClass { - &self.oeq_class - } - - /// Returns a reference to the equivalence group within. - pub fn eq_group(&self) -> &EquivalenceGroup { - &self.eq_group - } - - /// Returns the normalized version of the ordering equivalence class within. - /// Normalization removes constants and duplicates as well as standardizing - /// expressions according to the equivalence group within. - pub fn normalized_oeq_class(&self) -> OrderingEquivalenceClass { - OrderingEquivalenceClass::new( - self.oeq_class - .iter() - .map(|ordering| self.normalize_sort_exprs(ordering)) - .collect(), - ) - } - - /// Extends this `EquivalenceProperties` with the `other` object. - pub fn extend(mut self, other: Self) -> Self { - self.eq_group.extend(other.eq_group); - self.oeq_class.extend(other.oeq_class); - self.add_constants(other.constants) - } - - /// Clears (empties) the ordering equivalence class within this object. - /// Call this method when existing orderings are invalidated. - pub fn clear_orderings(&mut self) { - self.oeq_class.clear(); - } - - /// Extends this `EquivalenceProperties` by adding the orderings inside the - /// ordering equivalence class `other`. - pub fn add_ordering_equivalence_class(&mut self, other: OrderingEquivalenceClass) { - self.oeq_class.extend(other); - } - - /// Adds new orderings into the existing ordering equivalence class. - pub fn add_new_orderings( - &mut self, - orderings: impl IntoIterator, - ) { - self.oeq_class.add_new_orderings(orderings); - } - - /// Incorporates the given equivalence group to into the existing - /// equivalence group within. - pub fn add_equivalence_group(&mut self, other_eq_group: EquivalenceGroup) { - self.eq_group.extend(other_eq_group); - } - - /// Adds a new equality condition into the existing equivalence group. - /// If the given equality defines a new equivalence class, adds this new - /// equivalence class to the equivalence group. - pub fn add_equal_conditions( - &mut self, - left: &Arc, - right: &Arc, - ) { - self.eq_group.add_equal_conditions(left, right); - } - - /// Track/register physical expressions with constant values. - pub fn add_constants( - mut self, - constants: impl IntoIterator>, - ) -> Self { - for expr in self.eq_group.normalize_exprs(constants) { - if !physical_exprs_contains(&self.constants, &expr) { - self.constants.push(expr); - } - } - self - } - - /// Updates the ordering equivalence group within assuming that the table - /// is re-sorted according to the argument `sort_exprs`. Note that constants - /// and equivalence classes are unchanged as they are unaffected by a re-sort. - pub fn with_reorder(mut self, sort_exprs: Vec) -> Self { - // TODO: In some cases, existing ordering equivalences may still be valid add this analysis. - self.oeq_class = OrderingEquivalenceClass::new(vec![sort_exprs]); - self - } - - /// Normalizes the given sort expressions (i.e. `sort_exprs`) using the - /// equivalence group and the ordering equivalence class within. - /// - /// Assume that `self.eq_group` states column `a` and `b` are aliases. - /// Also assume that `self.oeq_class` states orderings `d ASC` and `a ASC, c ASC` - /// are equivalent (in the sense that both describe the ordering of the table). - /// If the `sort_exprs` argument were `vec![b ASC, c ASC, a ASC]`, then this - /// function would return `vec![a ASC, c ASC]`. Internally, it would first - /// normalize to `vec![a ASC, c ASC, a ASC]` and end up with the final result - /// after deduplication. - fn normalize_sort_exprs(&self, sort_exprs: LexOrderingRef) -> LexOrdering { - // Convert sort expressions to sort requirements: - let sort_reqs = PhysicalSortRequirement::from_sort_exprs(sort_exprs.iter()); - // Normalize the requirements: - let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs); - // Convert sort requirements back to sort expressions: - PhysicalSortRequirement::to_sort_exprs(normalized_sort_reqs) - } - - /// Normalizes the given sort requirements (i.e. `sort_reqs`) using the - /// equivalence group and the ordering equivalence class within. It works by: - /// - Removing expressions that have a constant value from the given requirement. - /// - Replacing sections that belong to some equivalence class in the equivalence - /// group with the first entry in the matching equivalence class. - /// - /// Assume that `self.eq_group` states column `a` and `b` are aliases. - /// Also assume that `self.oeq_class` states orderings `d ASC` and `a ASC, c ASC` - /// are equivalent (in the sense that both describe the ordering of the table). - /// If the `sort_reqs` argument were `vec![b ASC, c ASC, a ASC]`, then this - /// function would return `vec![a ASC, c ASC]`. Internally, it would first - /// normalize to `vec![a ASC, c ASC, a ASC]` and end up with the final result - /// after deduplication. - fn normalize_sort_requirements( - &self, - sort_reqs: LexRequirementRef, - ) -> LexRequirement { - let normalized_sort_reqs = self.eq_group.normalize_sort_requirements(sort_reqs); - let constants_normalized = self.eq_group.normalize_exprs(self.constants.clone()); - // Prune redundant sections in the requirement: - collapse_lex_req( - normalized_sort_reqs - .iter() - .filter(|&order| { - !physical_exprs_contains(&constants_normalized, &order.expr) - }) - .cloned() - .collect(), - ) - } - - /// Checks whether the given ordering is satisfied by any of the existing - /// orderings. - pub fn ordering_satisfy(&self, given: LexOrderingRef) -> bool { - // Convert the given sort expressions to sort requirements: - let sort_requirements = PhysicalSortRequirement::from_sort_exprs(given.iter()); - self.ordering_satisfy_requirement(&sort_requirements) - } - - /// Checks whether the given sort requirements are satisfied by any of the - /// existing orderings. - pub fn ordering_satisfy_requirement(&self, reqs: LexRequirementRef) -> bool { - // First, standardize the given requirement: - let normalized_reqs = self.normalize_sort_requirements(reqs); - if normalized_reqs.is_empty() { - // Requirements are tautologically satisfied if empty. - return true; - } - let mut indices = HashSet::new(); - for ordering in self.normalized_oeq_class().iter() { - let match_indices = ordering - .iter() - .map(|sort_expr| { - normalized_reqs - .iter() - .position(|sort_req| sort_expr.satisfy(sort_req, &self.schema)) - }) - .collect::>(); - // Find the largest contiguous increasing sequence starting from the first index: - if let Some(&Some(first)) = match_indices.first() { - indices.insert(first); - let mut iter = match_indices.windows(2); - while let Some([Some(current), Some(next)]) = iter.next() { - if next > current { - indices.insert(*next); - } else { - break; - } - } - } - } - indices.len() == normalized_reqs.len() - } - - /// Checks whether the `given`` sort requirements are equal or more specific - /// than the `reference` sort requirements. - pub fn requirements_compatible( - &self, - given: LexRequirementRef, - reference: LexRequirementRef, - ) -> bool { - let normalized_given = self.normalize_sort_requirements(given); - let normalized_reference = self.normalize_sort_requirements(reference); - - (normalized_reference.len() <= normalized_given.len()) - && normalized_reference - .into_iter() - .zip(normalized_given) - .all(|(reference, given)| given.compatible(&reference)) - } - - /// Returns the finer ordering among the orderings `lhs` and `rhs`, breaking - /// any ties by choosing `lhs`. - /// - /// The finer ordering is the ordering that satisfies both of the orderings. - /// If the orderings are incomparable, returns `None`. - /// - /// For example, the finer ordering among `[a ASC]` and `[a ASC, b ASC]` is - /// the latter. - pub fn get_finer_ordering( - &self, - lhs: LexOrderingRef, - rhs: LexOrderingRef, - ) -> Option { - // Convert the given sort expressions to sort requirements: - let lhs = PhysicalSortRequirement::from_sort_exprs(lhs); - let rhs = PhysicalSortRequirement::from_sort_exprs(rhs); - let finer = self.get_finer_requirement(&lhs, &rhs); - // Convert the chosen sort requirements back to sort expressions: - finer.map(PhysicalSortRequirement::to_sort_exprs) - } - - /// Returns the finer ordering among the requirements `lhs` and `rhs`, - /// breaking any ties by choosing `lhs`. - /// - /// The finer requirements are the ones that satisfy both of the given - /// requirements. If the requirements are incomparable, returns `None`. - /// - /// For example, the finer requirements among `[a ASC]` and `[a ASC, b ASC]` - /// is the latter. - pub fn get_finer_requirement( - &self, - req1: LexRequirementRef, - req2: LexRequirementRef, - ) -> Option { - let mut lhs = self.normalize_sort_requirements(req1); - let mut rhs = self.normalize_sort_requirements(req2); - lhs.iter_mut() - .zip(rhs.iter_mut()) - .all(|(lhs, rhs)| { - lhs.expr.eq(&rhs.expr) - && match (lhs.options, rhs.options) { - (Some(lhs_opt), Some(rhs_opt)) => lhs_opt == rhs_opt, - (Some(options), None) => { - rhs.options = Some(options); - true - } - (None, Some(options)) => { - lhs.options = Some(options); - true - } - (None, None) => true, - } - }) - .then_some(if lhs.len() >= rhs.len() { lhs } else { rhs }) - } - - /// Calculates the "meet" of the given orderings (`lhs` and `rhs`). - /// The meet of a set of orderings is the finest ordering that is satisfied - /// by all the orderings in that set. For details, see: - /// - /// - /// - /// If there is no ordering that satisfies both `lhs` and `rhs`, returns - /// `None`. As an example, the meet of orderings `[a ASC]` and `[a ASC, b ASC]` - /// is `[a ASC]`. - pub fn get_meet_ordering( - &self, - lhs: LexOrderingRef, - rhs: LexOrderingRef, - ) -> Option { - let lhs = self.normalize_sort_exprs(lhs); - let rhs = self.normalize_sort_exprs(rhs); - let mut meet = vec![]; - for (lhs, rhs) in lhs.into_iter().zip(rhs.into_iter()) { - if lhs.eq(&rhs) { - meet.push(lhs); - } else { - break; - } - } - (!meet.is_empty()).then_some(meet) - } - - /// Projects argument `expr` according to `projection_mapping`, taking - /// equivalences into account. - /// - /// For example, assume that columns `a` and `c` are always equal, and that - /// `projection_mapping` encodes following mapping: - /// - /// ```text - /// a -> a1 - /// b -> b1 - /// ``` - /// - /// Then, this function projects `a + b` to `Some(a1 + b1)`, `c + b` to - /// `Some(a1 + b1)` and `d` to `None`, meaning that it cannot be projected. - pub fn project_expr( - &self, - expr: &Arc, - projection_mapping: &ProjectionMapping, - ) -> Option> { - self.eq_group.project_expr(projection_mapping, expr) - } - - /// Projects the equivalences within according to `projection_mapping` - /// and `output_schema`. - pub fn project( - &self, - projection_mapping: &ProjectionMapping, - output_schema: SchemaRef, - ) -> Self { - let mut projected_orderings = self - .oeq_class - .iter() - .filter_map(|order| self.eq_group.project_ordering(projection_mapping, order)) - .collect::>(); - for (source, target) in projection_mapping.iter() { - let expr_ordering = ExprOrdering::new(source.clone()) - .transform_up(&|expr| update_ordering(expr, self)) - .unwrap(); - if let SortProperties::Ordered(options) = expr_ordering.state { - // Push new ordering to the state. - projected_orderings.push(vec![PhysicalSortExpr { - expr: target.clone(), - options, - }]); - } - } - Self { - eq_group: self.eq_group.project(projection_mapping), - oeq_class: OrderingEquivalenceClass::new(projected_orderings), - constants: vec![], - schema: output_schema, - } - } - - /// Returns the longest (potentially partial) permutation satisfying the - /// existing ordering. For example, if we have the equivalent orderings - /// `[a ASC, b ASC]` and `[c DESC]`, with `exprs` containing `[c, b, a, d]`, - /// then this function returns `([a ASC, b ASC, c DESC], [2, 1, 0])`. - /// This means that the specification `[a ASC, b ASC, c DESC]` is satisfied - /// by the existing ordering, and `[a, b, c]` resides at indices: `2, 1, 0` - /// inside the argument `exprs` (respectively). For the mathematical - /// definition of "partial permutation", see: - /// - /// - pub fn find_longest_permutation( - &self, - exprs: &[Arc], - ) -> (LexOrdering, Vec) { - let normalized_exprs = self.eq_group.normalize_exprs(exprs.to_vec()); - // Use a map to associate expression indices with sort options: - let mut ordered_exprs = IndexMap::::new(); - for ordering in self.normalized_oeq_class().iter() { - for sort_expr in ordering { - if let Some(idx) = normalized_exprs - .iter() - .position(|expr| sort_expr.expr.eq(expr)) - { - if let Entry::Vacant(e) = ordered_exprs.entry(idx) { - e.insert(sort_expr.options); - } - } else { - // We only consider expressions that correspond to a prefix - // of one of the equivalent orderings we have. - break; - } - } - } - // Construct the lexicographical ordering according to the permutation: - ordered_exprs - .into_iter() - .map(|(idx, options)| { - ( - PhysicalSortExpr { - expr: exprs[idx].clone(), - options, - }, - idx, - ) - }) - .unzip() - } -} - -/// Calculate ordering equivalence properties for the given join operation. -pub fn join_equivalence_properties( - left: EquivalenceProperties, - right: EquivalenceProperties, - join_type: &JoinType, - join_schema: SchemaRef, - maintains_input_order: &[bool], - probe_side: Option, - on: &[(Column, Column)], -) -> EquivalenceProperties { - let left_size = left.schema.fields.len(); - let mut result = EquivalenceProperties::new(join_schema); - result.add_equivalence_group(left.eq_group().join( - right.eq_group(), - join_type, - left_size, - on, - )); - - let left_oeq_class = left.oeq_class; - let mut right_oeq_class = right.oeq_class; - match maintains_input_order { - [true, false] => { - // In this special case, right side ordering can be prefixed with - // the left side ordering. - if let (Some(JoinSide::Left), JoinType::Inner) = (probe_side, join_type) { - updated_right_ordering_equivalence_class( - &mut right_oeq_class, - join_type, - left_size, - ); - - // Right side ordering equivalence properties should be prepended - // with those of the left side while constructing output ordering - // equivalence properties since stream side is the left side. - // - // For example, if the right side ordering equivalences contain - // `b ASC`, and the left side ordering equivalences contain `a ASC`, - // then we should add `a ASC, b ASC` to the ordering equivalences - // of the join output. - let out_oeq_class = left_oeq_class.join_suffix(&right_oeq_class); - result.add_ordering_equivalence_class(out_oeq_class); - } else { - result.add_ordering_equivalence_class(left_oeq_class); - } - } - [false, true] => { - updated_right_ordering_equivalence_class( - &mut right_oeq_class, - join_type, - left_size, - ); - // In this special case, left side ordering can be prefixed with - // the right side ordering. - if let (Some(JoinSide::Right), JoinType::Inner) = (probe_side, join_type) { - // Left side ordering equivalence properties should be prepended - // with those of the right side while constructing output ordering - // equivalence properties since stream side is the right side. - // - // For example, if the left side ordering equivalences contain - // `a ASC`, and the right side ordering equivalences contain `b ASC`, - // then we should add `b ASC, a ASC` to the ordering equivalences - // of the join output. - let out_oeq_class = right_oeq_class.join_suffix(&left_oeq_class); - result.add_ordering_equivalence_class(out_oeq_class); - } else { - result.add_ordering_equivalence_class(right_oeq_class); - } - } - [false, false] => {} - [true, true] => unreachable!("Cannot maintain ordering of both sides"), - _ => unreachable!("Join operators can not have more than two children"), - } - result -} - -/// In the context of a join, update the right side `OrderingEquivalenceClass` -/// so that they point to valid indices in the join output schema. -/// -/// To do so, we increment column indices by the size of the left table when -/// join schema consists of a combination of the left and right schemas. This -/// is the case for `Inner`, `Left`, `Full` and `Right` joins. For other cases, -/// indices do not change. -fn updated_right_ordering_equivalence_class( - right_oeq_class: &mut OrderingEquivalenceClass, - join_type: &JoinType, - left_size: usize, -) { - if matches!( - join_type, - JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right - ) { - right_oeq_class.add_offset(left_size); - } -} - -/// Calculates the [`SortProperties`] of a given [`ExprOrdering`] node. -/// The node can either be a leaf node, or an intermediate node: -/// - If it is a leaf node, we directly find the order of the node by looking -/// at the given sort expression and equivalence properties if it is a `Column` -/// leaf, or we mark it as unordered. In the case of a `Literal` leaf, we mark -/// it as singleton so that it can cooperate with all ordered columns. -/// - If it is an intermediate node, the children states matter. Each `PhysicalExpr` -/// and operator has its own rules on how to propagate the children orderings. -/// However, before we engage in recursion, we check whether this intermediate -/// node directly matches with the sort expression. If there is a match, the -/// sort expression emerges at that node immediately, discarding the recursive -/// result coming from its children. -fn update_ordering( - mut node: ExprOrdering, - eq_properties: &EquivalenceProperties, -) -> Result> { - if !node.expr.children().is_empty() { - // We have an intermediate (non-leaf) node, account for its children: - node.state = node.expr.get_ordering(&node.children_states); - Ok(Transformed::Yes(node)) - } else if node.expr.as_any().is::() { - // We have a Column, which is one of the two possible leaf node types: - let eq_group = &eq_properties.eq_group; - let normalized_expr = eq_group.normalize_expr(node.expr.clone()); - let oeq_class = &eq_properties.oeq_class; - if let Some(options) = oeq_class.get_options(&normalized_expr) { - node.state = SortProperties::Ordered(options); - Ok(Transformed::Yes(node)) - } else { - Ok(Transformed::No(node)) - } - } else { - // We have a Literal, which is the other possible leaf node type: - node.state = node.expr.get_ordering(&[]); - Ok(Transformed::Yes(node)) - } -} - -#[cfg(test)] -mod tests { - use std::ops::Not; - use std::sync::Arc; - - use super::*; - use crate::expressions::{col, lit, BinaryExpr, Column}; - use crate::physical_expr::{physical_exprs_bag_equal, physical_exprs_equal}; - - use arrow::compute::{lexsort_to_indices, SortColumn}; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow_array::{ArrayRef, RecordBatch, UInt32Array, UInt64Array}; - use arrow_schema::{Fields, SortOptions}; - use datafusion_common::Result; - use datafusion_expr::Operator; - - use itertools::{izip, Itertools}; - use rand::rngs::StdRng; - use rand::seq::SliceRandom; - use rand::{Rng, SeedableRng}; - - // Generate a schema which consists of 8 columns (a, b, c, d, e, f, g, h) - fn create_test_schema() -> Result { - let a = Field::new("a", DataType::Int32, true); - let b = Field::new("b", DataType::Int32, true); - let c = Field::new("c", DataType::Int32, true); - let d = Field::new("d", DataType::Int32, true); - let e = Field::new("e", DataType::Int32, true); - let f = Field::new("f", DataType::Int32, true); - let g = Field::new("g", DataType::Int32, true); - let h = Field::new("h", DataType::Int32, true); - let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f, g, h])); - - Ok(schema) - } - - /// Construct a schema with following properties - /// Schema satisfies following orderings: - /// [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] - /// and - /// Column [a=c] (e.g they are aliases). - fn create_test_params() -> Result<(SchemaRef, EquivalenceProperties)> { - let test_schema = create_test_schema()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_f = &col("f", &test_schema)?; - let col_g = &col("g", &test_schema)?; - let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); - eq_properties.add_equal_conditions(col_a, col_c); - - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - let orderings = vec![ - // [a ASC] - vec![(col_a, option_asc)], - // [d ASC, b ASC] - vec![(col_d, option_asc), (col_b, option_asc)], - // [e DESC, f ASC, g ASC] - vec![ - (col_e, option_desc), - (col_f, option_asc), - (col_g, option_asc), - ], - ]; - let orderings = convert_to_orderings(&orderings); - eq_properties.add_new_orderings(orderings); - Ok((test_schema, eq_properties)) - } - - // Generate a schema which consists of 6 columns (a, b, c, d, e, f) - fn create_test_schema_2() -> Result { - let a = Field::new("a", DataType::Int32, true); - let b = Field::new("b", DataType::Int32, true); - let c = Field::new("c", DataType::Int32, true); - let d = Field::new("d", DataType::Int32, true); - let e = Field::new("e", DataType::Int32, true); - let f = Field::new("f", DataType::Int32, true); - let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f])); - - Ok(schema) - } - - /// Construct a schema with random ordering - /// among column a, b, c, d - /// where - /// Column [a=f] (e.g they are aliases). - /// Column e is constant. - fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperties)> { - let test_schema = create_test_schema_2()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_f = &col("f", &test_schema)?; - let col_exprs = [col_a, col_b, col_c, col_d, col_e, col_f]; - - let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); - // Define a and f are aliases - eq_properties.add_equal_conditions(col_a, col_f); - // Column e has constant value. - eq_properties = eq_properties.add_constants([col_e.clone()]); - - // Randomly order columns for sorting - let mut rng = StdRng::seed_from_u64(seed); - let mut remaining_exprs = col_exprs[0..4].to_vec(); // only a, b, c, d are sorted - - let options_asc = SortOptions { - descending: false, - nulls_first: false, - }; - - while !remaining_exprs.is_empty() { - let n_sort_expr = rng.gen_range(0..remaining_exprs.len() + 1); - remaining_exprs.shuffle(&mut rng); - - let ordering = remaining_exprs - .drain(0..n_sort_expr) - .map(|expr| PhysicalSortExpr { - expr: expr.clone(), - options: options_asc, - }) - .collect(); - - eq_properties.add_new_orderings([ordering]); - } - - Ok((test_schema, eq_properties)) - } - - // Convert each tuple to PhysicalSortRequirement - fn convert_to_sort_reqs( - in_data: &[(&Arc, Option)], - ) -> Vec { - in_data - .iter() - .map(|(expr, options)| { - PhysicalSortRequirement::new((*expr).clone(), *options) - }) - .collect::>() - } - - // Convert each tuple to PhysicalSortExpr - fn convert_to_sort_exprs( - in_data: &[(&Arc, SortOptions)], - ) -> Vec { - in_data - .iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: (*expr).clone(), - options: *options, - }) - .collect::>() - } - - // Convert each inner tuple to PhysicalSortExpr - fn convert_to_orderings( - orderings: &[Vec<(&Arc, SortOptions)>], - ) -> Vec> { - orderings - .iter() - .map(|sort_exprs| convert_to_sort_exprs(sort_exprs)) - .collect() - } - - #[test] - fn add_equal_conditions_test() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int64, true), - Field::new("b", DataType::Int64, true), - Field::new("c", DataType::Int64, true), - Field::new("x", DataType::Int64, true), - Field::new("y", DataType::Int64, true), - ])); - - let mut eq_properties = EquivalenceProperties::new(schema); - let col_a_expr = Arc::new(Column::new("a", 0)) as Arc; - let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; - let col_c_expr = Arc::new(Column::new("c", 2)) as Arc; - let col_x_expr = Arc::new(Column::new("x", 3)) as Arc; - let col_y_expr = Arc::new(Column::new("y", 4)) as Arc; - - // a and b are aliases - eq_properties.add_equal_conditions(&col_a_expr, &col_b_expr); - assert_eq!(eq_properties.eq_group().len(), 1); - - // This new entry is redundant, size shouldn't increase - eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr); - assert_eq!(eq_properties.eq_group().len(), 1); - let eq_groups = &eq_properties.eq_group().classes[0]; - assert_eq!(eq_groups.len(), 2); - assert!(physical_exprs_contains(eq_groups, &col_a_expr)); - assert!(physical_exprs_contains(eq_groups, &col_b_expr)); - - // b and c are aliases. Exising equivalence class should expand, - // however there shouldn't be any new equivalence class - eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr); - assert_eq!(eq_properties.eq_group().len(), 1); - let eq_groups = &eq_properties.eq_group().classes[0]; - assert_eq!(eq_groups.len(), 3); - assert!(physical_exprs_contains(eq_groups, &col_a_expr)); - assert!(physical_exprs_contains(eq_groups, &col_b_expr)); - assert!(physical_exprs_contains(eq_groups, &col_c_expr)); - - // This is a new set of equality. Hence equivalent class count should be 2. - eq_properties.add_equal_conditions(&col_x_expr, &col_y_expr); - assert_eq!(eq_properties.eq_group().len(), 2); - - // This equality bridges distinct equality sets. - // Hence equivalent class count should decrease from 2 to 1. - eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr); - assert_eq!(eq_properties.eq_group().len(), 1); - let eq_groups = &eq_properties.eq_group().classes[0]; - assert_eq!(eq_groups.len(), 5); - assert!(physical_exprs_contains(eq_groups, &col_a_expr)); - assert!(physical_exprs_contains(eq_groups, &col_b_expr)); - assert!(physical_exprs_contains(eq_groups, &col_c_expr)); - assert!(physical_exprs_contains(eq_groups, &col_x_expr)); - assert!(physical_exprs_contains(eq_groups, &col_y_expr)); - - Ok(()) - } - - #[test] - fn project_equivalence_properties_test() -> Result<()> { - let input_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int64, true), - Field::new("b", DataType::Int64, true), - Field::new("c", DataType::Int64, true), - ])); - - let input_properties = EquivalenceProperties::new(input_schema.clone()); - let col_a = col("a", &input_schema)?; - - let out_schema = Arc::new(Schema::new(vec![ - Field::new("a1", DataType::Int64, true), - Field::new("a2", DataType::Int64, true), - Field::new("a3", DataType::Int64, true), - Field::new("a4", DataType::Int64, true), - ])); - - // a as a1, a as a2, a as a3, a as a3 - let col_a1 = &col("a1", &out_schema)?; - let col_a2 = &col("a2", &out_schema)?; - let col_a3 = &col("a3", &out_schema)?; - let col_a4 = &col("a4", &out_schema)?; - let projection_mapping = ProjectionMapping { - inner: vec![ - (col_a.clone(), col_a1.clone()), - (col_a.clone(), col_a2.clone()), - (col_a.clone(), col_a3.clone()), - (col_a.clone(), col_a4.clone()), - ], - }; - let out_properties = input_properties.project(&projection_mapping, out_schema); - - // At the output a1=a2=a3=a4 - assert_eq!(out_properties.eq_group().len(), 1); - let eq_class = &out_properties.eq_group().classes[0]; - assert_eq!(eq_class.len(), 4); - assert!(physical_exprs_contains(eq_class, col_a1)); - assert!(physical_exprs_contains(eq_class, col_a2)); - assert!(physical_exprs_contains(eq_class, col_a3)); - assert!(physical_exprs_contains(eq_class, col_a4)); - - Ok(()) - } - - #[test] - fn test_ordering_satisfy() -> Result<()> { - let crude = vec![PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: SortOptions::default(), - }]; - let finer = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }, - ]; - // finer ordering satisfies, crude ordering should return true - let empty_schema = &Arc::new(Schema::empty()); - let mut eq_properties_finer = EquivalenceProperties::new(empty_schema.clone()); - eq_properties_finer.oeq_class.push(finer.clone()); - assert!(eq_properties_finer.ordering_satisfy(&crude)); - - // Crude ordering doesn't satisfy finer ordering. should return false - let mut eq_properties_crude = EquivalenceProperties::new(empty_schema.clone()); - eq_properties_crude.oeq_class.push(crude.clone()); - assert!(!eq_properties_crude.ordering_satisfy(&finer)); - Ok(()) - } - - #[test] - fn test_ordering_satisfy_with_equivalence() -> Result<()> { - // Schema satisfies following orderings: - // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] - // and - // Column [a=c] (e.g they are aliases). - let (test_schema, eq_properties) = create_test_params()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_f = &col("f", &test_schema)?; - let col_g = &col("g", &test_schema)?; - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, 625, 5)?; - - // First element in the tuple stores vector of requirement, second element is the expected return value for ordering_satisfy function - let requirements = vec![ - // `a ASC NULLS LAST`, expects `ordering_satisfy` to be `true`, since existing ordering `a ASC NULLS LAST, b ASC NULLS LAST` satisfies it - (vec![(col_a, option_asc)], true), - (vec![(col_a, option_desc)], false), - // Test whether equivalence works as expected - (vec![(col_c, option_asc)], true), - (vec![(col_c, option_desc)], false), - // Test whether ordering equivalence works as expected - (vec![(col_d, option_asc)], true), - (vec![(col_d, option_asc), (col_b, option_asc)], true), - (vec![(col_d, option_desc), (col_b, option_asc)], false), - ( - vec![ - (col_e, option_desc), - (col_f, option_asc), - (col_g, option_asc), - ], - true, - ), - (vec![(col_e, option_desc), (col_f, option_asc)], true), - (vec![(col_e, option_asc), (col_f, option_asc)], false), - (vec![(col_e, option_desc), (col_b, option_asc)], false), - (vec![(col_e, option_asc), (col_b, option_asc)], false), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_d, option_asc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_desc), - (col_f, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_desc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_d, option_desc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_asc), - (col_f, option_asc), - ], - false, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_asc), - (col_b, option_asc), - ], - false, - ), - (vec![(col_d, option_asc), (col_e, option_desc)], true), - ( - vec![ - (col_d, option_asc), - (col_c, option_asc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_e, option_desc), - (col_f, option_asc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_e, option_desc), - (col_c, option_asc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_e, option_desc), - (col_b, option_asc), - (col_f, option_asc), - ], - true, - ), - ]; - - for (cols, expected) in requirements { - let err_msg = format!("Error in test case:{cols:?}"); - let required = cols - .into_iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: expr.clone(), - options, - }) - .collect::>(); - - // Check expected result with experimental result. - assert_eq!( - is_table_same_after_sort( - required.clone(), - table_data_with_properties.clone() - )?, - expected - ); - assert_eq!( - eq_properties.ordering_satisfy(&required), - expected, - "{err_msg}" - ); - } - Ok(()) - } - - #[test] - fn test_ordering_satisfy_with_equivalence_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 5; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - const SORT_OPTIONS: SortOptions = SortOptions { - descending: false, - nulls_first: false, - }; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - let col_exprs = vec![ - col("a", &test_schema)?, - col("b", &test_schema)?, - col("c", &test_schema)?, - col("d", &test_schema)?, - col("e", &test_schema)?, - col("f", &test_schema)?, - ]; - - for n_req in 0..=col_exprs.len() { - for exprs in col_exprs.iter().combinations(n_req) { - let requirement = exprs - .into_iter() - .map(|expr| PhysicalSortExpr { - expr: expr.clone(), - options: SORT_OPTIONS, - }) - .collect::>(); - let expected = is_table_same_after_sort( - requirement.clone(), - table_data_with_properties.clone(), - )?; - let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}", - requirement, expected - ); - // Check whether ordering_satisfy API result and - // experimental result matches. - assert_eq!( - eq_properties.ordering_satisfy(&requirement), - expected, - "{}", - err_msg - ); - } - } - } - - Ok(()) - } - - #[test] - fn test_ordering_satisfy_different_lengths() -> Result<()> { - let test_schema = create_test_schema()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_f = &col("f", &test_schema)?; - let options = SortOptions { - descending: false, - nulls_first: false, - }; - // a=c (e.g they are aliases). - let mut eq_properties = EquivalenceProperties::new(test_schema); - eq_properties.add_equal_conditions(col_a, col_c); - - let orderings = vec![ - vec![(col_a, options)], - vec![(col_e, options)], - vec![(col_d, options), (col_f, options)], - ]; - let orderings = convert_to_orderings(&orderings); - - // Column [a ASC], [e ASC], [d ASC, f ASC] are all valid orderings for the schema. - eq_properties.add_new_orderings(orderings); - - // First entry in the tuple is required ordering, second entry is the expected flag - // that indicates whether this required ordering is satisfied. - // ([a ASC], true) indicate a ASC requirement is already satisfied by existing orderings. - let test_cases = vec![ - // [c ASC, a ASC, e ASC], expected represents this requirement is satisfied - ( - vec![(col_c, options), (col_a, options), (col_e, options)], - true, - ), - (vec![(col_c, options), (col_b, options)], false), - (vec![(col_c, options), (col_d, options)], true), - ( - vec![(col_d, options), (col_f, options), (col_b, options)], - false, - ), - (vec![(col_d, options), (col_f, options)], true), - ]; - - for (reqs, expected) in test_cases { - let err_msg = - format!("error in test reqs: {:?}, expected: {:?}", reqs, expected,); - let reqs = convert_to_sort_exprs(&reqs); - assert_eq!( - eq_properties.ordering_satisfy(&reqs), - expected, - "{}", - err_msg - ); - } - - Ok(()) - } - - #[test] - fn test_bridge_groups() -> Result<()> { - // First entry in the tuple is argument, second entry is the bridged result - let test_cases = vec![ - // ------- TEST CASE 1 -----------// - ( - vec![vec![1, 2, 3], vec![2, 4, 5], vec![11, 12, 9], vec![7, 6, 5]], - // Expected is compared with set equality. Order of the specific results may change. - vec![vec![1, 2, 3, 4, 5, 6, 7], vec![9, 11, 12]], - ), - // ------- TEST CASE 2 -----------// - ( - vec![vec![1, 2, 3], vec![3, 4, 5], vec![9, 8, 7], vec![7, 6, 5]], - // Expected - vec![vec![1, 2, 3, 4, 5, 6, 7, 8, 9]], - ), - ]; - for (entries, expected) in test_cases { - let entries = entries - .into_iter() - .map(|entry| entry.into_iter().map(lit).collect::>()) - .collect::>(); - let expected = expected - .into_iter() - .map(|entry| entry.into_iter().map(lit).collect::>()) - .collect::>(); - let mut eq_groups = EquivalenceGroup::new(entries.clone()); - eq_groups.bridge_classes(); - let eq_groups = eq_groups.classes; - let err_msg = format!( - "error in test entries: {:?}, expected: {:?}, actual:{:?}", - entries, expected, eq_groups - ); - assert_eq!(eq_groups.len(), expected.len(), "{}", err_msg); - for idx in 0..eq_groups.len() { - assert!( - physical_exprs_bag_equal(&eq_groups[idx], &expected[idx]), - "{}", - err_msg - ); - } - } - Ok(()) - } - - #[test] - fn test_remove_redundant_entries_eq_group() -> Result<()> { - let entries = vec![ - vec![lit(1), lit(1), lit(2)], - // This group is meaningless should be removed - vec![lit(3), lit(3)], - vec![lit(4), lit(5), lit(6)], - ]; - // Given equivalences classes are not in succinct form. - // Expected form is the most plain representation that is functionally same. - let expected = vec![vec![lit(1), lit(2)], vec![lit(4), lit(5), lit(6)]]; - let mut eq_groups = EquivalenceGroup::new(entries); - eq_groups.remove_redundant_entries(); - - let eq_groups = eq_groups.classes; - assert_eq!(eq_groups.len(), expected.len()); - assert_eq!(eq_groups.len(), 2); - - assert!(physical_exprs_equal(&eq_groups[0], &expected[0])); - assert!(physical_exprs_equal(&eq_groups[1], &expected[1])); - Ok(()) - } - - #[test] - fn test_remove_redundant_entries_oeq_class() -> Result<()> { - let schema = create_test_schema()?; - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let col_c = &col("c", &schema)?; - - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - - // First entry in the tuple is the given orderings for the table - // Second entry is the simplest version of the given orderings that is functionally equivalent. - let test_cases = vec![ - // ------- TEST CASE 1 --------- - ( - // ORDERINGS GIVEN - vec![ - // [a ASC, b ASC] - vec![(col_a, option_asc), (col_b, option_asc)], - ], - // EXPECTED orderings that is succinct. - vec![ - // [a ASC, b ASC] - vec![(col_a, option_asc), (col_b, option_asc)], - ], - ), - // ------- TEST CASE 2 --------- - ( - // ORDERINGS GIVEN - vec![ - // [a ASC, b ASC] - vec![(col_a, option_asc), (col_b, option_asc)], - // [a ASC, b ASC, c ASC] - vec![ - (col_a, option_asc), - (col_b, option_asc), - (col_c, option_asc), - ], - ], - // EXPECTED orderings that is succinct. - vec![ - // [a ASC, b ASC, c ASC] - vec![ - (col_a, option_asc), - (col_b, option_asc), - (col_c, option_asc), - ], - ], - ), - // ------- TEST CASE 3 --------- - ( - // ORDERINGS GIVEN - vec![ - // [a ASC, b DESC] - vec![(col_a, option_asc), (col_b, option_desc)], - // [a ASC] - vec![(col_a, option_asc)], - // [a ASC, c ASC] - vec![(col_a, option_asc), (col_c, option_asc)], - ], - // EXPECTED orderings that is succinct. - vec![ - // [a ASC, b DESC] - vec![(col_a, option_asc), (col_b, option_desc)], - // [a ASC, c ASC] - vec![(col_a, option_asc), (col_c, option_asc)], - ], - ), - // ------- TEST CASE 4 --------- - ( - // ORDERINGS GIVEN - vec![ - // [a ASC, b ASC] - vec![(col_a, option_asc), (col_b, option_asc)], - // [a ASC, b ASC, c ASC] - vec![ - (col_a, option_asc), - (col_b, option_asc), - (col_c, option_asc), - ], - // [a ASC] - vec![(col_a, option_asc)], - ], - // EXPECTED orderings that is succinct. - vec![ - // [a ASC, b ASC, c ASC] - vec![ - (col_a, option_asc), - (col_b, option_asc), - (col_c, option_asc), - ], - ], - ), - ]; - for (orderings, expected) in test_cases { - let orderings = convert_to_orderings(&orderings); - let expected = convert_to_orderings(&expected); - let actual = OrderingEquivalenceClass::new(orderings.clone()); - let actual = actual.orderings; - let err_msg = format!( - "orderings: {:?}, expected: {:?}, actual :{:?}", - orderings, expected, actual - ); - assert_eq!(actual.len(), expected.len(), "{}", err_msg); - for elem in actual { - assert!(expected.contains(&elem), "{}", err_msg); - } - } - - Ok(()) - } - - #[test] - fn test_get_updated_right_ordering_equivalence_properties() -> Result<()> { - let join_type = JoinType::Inner; - // Join right child schema - let child_fields: Fields = ["x", "y", "z", "w"] - .into_iter() - .map(|name| Field::new(name, DataType::Int32, true)) - .collect(); - let child_schema = Schema::new(child_fields); - let col_x = &col("x", &child_schema)?; - let col_y = &col("y", &child_schema)?; - let col_z = &col("z", &child_schema)?; - let col_w = &col("w", &child_schema)?; - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - // [x ASC, y ASC], [z ASC, w ASC] - let orderings = vec![ - vec![(col_x, option_asc), (col_y, option_asc)], - vec![(col_z, option_asc), (col_w, option_asc)], - ]; - let orderings = convert_to_orderings(&orderings); - // Right child ordering equivalences - let mut right_oeq_class = OrderingEquivalenceClass::new(orderings); - - let left_columns_len = 4; - - let fields: Fields = ["a", "b", "c", "d", "x", "y", "z", "w"] - .into_iter() - .map(|name| Field::new(name, DataType::Int32, true)) - .collect(); - - // Join Schema - let schema = Schema::new(fields); - let col_a = &col("a", &schema)?; - let col_d = &col("d", &schema)?; - let col_x = &col("x", &schema)?; - let col_y = &col("y", &schema)?; - let col_z = &col("z", &schema)?; - let col_w = &col("w", &schema)?; - - let mut join_eq_properties = EquivalenceProperties::new(Arc::new(schema)); - // a=x and d=w - join_eq_properties.add_equal_conditions(col_a, col_x); - join_eq_properties.add_equal_conditions(col_d, col_w); - - updated_right_ordering_equivalence_class( - &mut right_oeq_class, - &join_type, - left_columns_len, - ); - join_eq_properties.add_ordering_equivalence_class(right_oeq_class); - let result = join_eq_properties.oeq_class().clone(); - - // [x ASC, y ASC], [z ASC, w ASC] - let orderings = vec![ - vec![(col_x, option_asc), (col_y, option_asc)], - vec![(col_z, option_asc), (col_w, option_asc)], - ]; - let orderings = convert_to_orderings(&orderings); - let expected = OrderingEquivalenceClass::new(orderings); - - assert_eq!(result, expected); - - Ok(()) - } - - /// Checks if the table (RecordBatch) remains unchanged when sorted according to the provided `required_ordering`. - /// - /// The function works by adding a unique column of ascending integers to the original table. This column ensures - /// that rows that are otherwise indistinguishable (e.g., if they have the same values in all other columns) can - /// still be differentiated. When sorting the extended table, the unique column acts as a tie-breaker to produce - /// deterministic sorting results. - /// - /// If the table remains the same after sorting with the added unique column, it indicates that the table was - /// already sorted according to `required_ordering` to begin with. - fn is_table_same_after_sort( - mut required_ordering: Vec, - batch: RecordBatch, - ) -> Result { - // Clone the original schema and columns - let original_schema = batch.schema(); - let mut columns = batch.columns().to_vec(); - - // Create a new unique column - let n_row = batch.num_rows() as u64; - let unique_col = Arc::new(UInt64Array::from_iter_values(0..n_row)) as ArrayRef; - columns.push(unique_col.clone()); - - // Create a new schema with the added unique column - let unique_col_name = "unique"; - let unique_field = Arc::new(Field::new(unique_col_name, DataType::UInt64, false)); - let fields: Vec<_> = original_schema - .fields() - .iter() - .cloned() - .chain(std::iter::once(unique_field)) - .collect(); - let schema = Arc::new(Schema::new(fields)); - - // Create a new batch with the added column - let new_batch = RecordBatch::try_new(schema.clone(), columns)?; - - // Add the unique column to the required ordering to ensure deterministic results - required_ordering.push(PhysicalSortExpr { - expr: Arc::new(Column::new(unique_col_name, original_schema.fields().len())), - options: Default::default(), - }); - - // Convert the required ordering to a list of SortColumn - let sort_columns: Vec<_> = required_ordering - .iter() - .filter_map(|order_expr| { - let col = order_expr.expr.as_any().downcast_ref::()?; - let col_index = schema.column_with_name(col.name())?.0; - Some(SortColumn { - values: new_batch.column(col_index).clone(), - options: Some(order_expr.options), - }) - }) - .collect(); - - // Check if the indices after sorting match the initial ordering - let sorted_indices = lexsort_to_indices(&sort_columns, None)?; - let original_indices = UInt32Array::from_iter_values(0..n_row as u32); - - Ok(sorted_indices == original_indices) - } - - // If we already generated a random result for one of the - // expressions in the equivalence classes. For other expressions in the same - // equivalence class use same result. This util gets already calculated result, when available. - fn get_representative_arr( - eq_group: &[Arc], - existing_vec: &[Option], - schema: SchemaRef, - ) -> Option { - for expr in eq_group.iter() { - let col = expr.as_any().downcast_ref::().unwrap(); - let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - if let Some(res) = &existing_vec[idx] { - return Some(res.clone()); - } - } - None - } - - // Generate a table that satisfies the given equivalence properties; i.e. - // equivalences, ordering equivalences, and constants. - fn generate_table_for_eq_properties( - eq_properties: &EquivalenceProperties, - n_elem: usize, - n_distinct: usize, - ) -> Result { - let mut rng = StdRng::seed_from_u64(23); - - let schema = eq_properties.schema(); - let mut schema_vec = vec![None; schema.fields.len()]; - - // Utility closure to generate random array - let mut generate_random_array = |num_elems: usize, max_val: usize| -> ArrayRef { - let values: Vec = (0..num_elems) - .map(|_| rng.gen_range(0..max_val) as u64) - .collect(); - Arc::new(UInt64Array::from_iter_values(values)) - }; - - // Fill constant columns - for constant in &eq_properties.constants { - let col = constant.as_any().downcast_ref::().unwrap(); - let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - let arr = - Arc::new(UInt64Array::from_iter_values(vec![0; n_elem])) as ArrayRef; - schema_vec[idx] = Some(arr); - } - - // Fill columns based on ordering equivalences - for ordering in eq_properties.oeq_class.iter() { - let (sort_columns, indices): (Vec<_>, Vec<_>) = ordering - .iter() - .map(|PhysicalSortExpr { expr, options }| { - let col = expr.as_any().downcast_ref::().unwrap(); - let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - let arr = generate_random_array(n_elem, n_distinct); - ( - SortColumn { - values: arr, - options: Some(*options), - }, - idx, - ) - }) - .unzip(); - - let sort_arrs = arrow::compute::lexsort(&sort_columns, None)?; - for (idx, arr) in izip!(indices, sort_arrs) { - schema_vec[idx] = Some(arr); - } - } - - // Fill columns based on equivalence groups - for eq_group in eq_properties.eq_group.iter() { - let representative_array = - get_representative_arr(eq_group, &schema_vec, schema.clone()) - .unwrap_or_else(|| generate_random_array(n_elem, n_distinct)); - - for expr in eq_group { - let col = expr.as_any().downcast_ref::().unwrap(); - let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - schema_vec[idx] = Some(representative_array.clone()); - } - } - - let res: Vec<_> = schema_vec - .into_iter() - .zip(schema.fields.iter()) - .map(|(elem, field)| { - ( - field.name(), - // Generate random values for columns that do not occur in any of the groups (equivalence, ordering equivalence, constants) - elem.unwrap_or_else(|| generate_random_array(n_elem, n_distinct)), - ) - }) - .collect(); - - Ok(RecordBatch::try_from_iter(res)?) - } - - #[test] - fn test_schema_normalize_expr_with_equivalence() -> Result<()> { - let col_a = &Column::new("a", 0); - let col_b = &Column::new("b", 1); - let col_c = &Column::new("c", 2); - // Assume that column a and c are aliases. - let (_test_schema, eq_properties) = create_test_params()?; - - let col_a_expr = Arc::new(col_a.clone()) as Arc; - let col_b_expr = Arc::new(col_b.clone()) as Arc; - let col_c_expr = Arc::new(col_c.clone()) as Arc; - // Test cases for equivalence normalization, - // First entry in the tuple is argument, second entry is expected result after normalization. - let expressions = vec![ - // Normalized version of the column a and c should go to a - // (by convention all the expressions inside equivalence class are mapped to the first entry - // in this case a is the first entry in the equivalence class.) - (&col_a_expr, &col_a_expr), - (&col_c_expr, &col_a_expr), - // Cannot normalize column b - (&col_b_expr, &col_b_expr), - ]; - let eq_group = eq_properties.eq_group(); - for (expr, expected_eq) in expressions { - assert!( - expected_eq.eq(&eq_group.normalize_expr(expr.clone())), - "error in test: expr: {expr:?}" - ); - } - - Ok(()) - } - - #[test] - fn test_schema_normalize_sort_requirement_with_equivalence() -> Result<()> { - let option1 = SortOptions { - descending: false, - nulls_first: false, - }; - // Assume that column a and c are aliases. - let (test_schema, eq_properties) = create_test_params()?; - let col_a = &col("a", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - - // Test cases for equivalence normalization - // First entry in the tuple is PhysicalSortRequirement, second entry in the tuple is - // expected PhysicalSortRequirement after normalization. - let test_cases = vec![ - (vec![(col_a, Some(option1))], vec![(col_a, Some(option1))]), - // In the normalized version column c should be replace with column a - (vec![(col_c, Some(option1))], vec![(col_a, Some(option1))]), - (vec![(col_c, None)], vec![(col_a, None)]), - (vec![(col_d, Some(option1))], vec![(col_d, Some(option1))]), - ]; - for (reqs, expected) in test_cases.into_iter() { - let reqs = convert_to_sort_reqs(&reqs); - let expected = convert_to_sort_reqs(&expected); - - let normalized = eq_properties.normalize_sort_requirements(&reqs); - assert!( - expected.eq(&normalized), - "error in test: reqs: {reqs:?}, expected: {expected:?}, normalized: {normalized:?}" - ); - } - - Ok(()) - } - - #[test] - fn test_normalize_sort_reqs() -> Result<()> { - // Schema satisfies following properties - // a=c - // and following orderings are valid - // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] - let (test_schema, eq_properties) = create_test_params()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_f = &col("f", &test_schema)?; - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - // First element in the tuple stores vector of requirement, second element is the expected return value for ordering_satisfy function - let requirements = vec![ - ( - vec![(col_a, Some(option_asc))], - vec![(col_a, Some(option_asc))], - ), - ( - vec![(col_a, Some(option_desc))], - vec![(col_a, Some(option_desc))], - ), - (vec![(col_a, None)], vec![(col_a, None)]), - // Test whether equivalence works as expected - ( - vec![(col_c, Some(option_asc))], - vec![(col_a, Some(option_asc))], - ), - (vec![(col_c, None)], vec![(col_a, None)]), - // Test whether ordering equivalence works as expected - ( - vec![(col_d, Some(option_asc)), (col_b, Some(option_asc))], - vec![(col_d, Some(option_asc)), (col_b, Some(option_asc))], - ), - ( - vec![(col_d, None), (col_b, None)], - vec![(col_d, None), (col_b, None)], - ), - ( - vec![(col_e, Some(option_desc)), (col_f, Some(option_asc))], - vec![(col_e, Some(option_desc)), (col_f, Some(option_asc))], - ), - // We should be able to normalize in compatible requirements also (not exactly equal) - ( - vec![(col_e, Some(option_desc)), (col_f, None)], - vec![(col_e, Some(option_desc)), (col_f, None)], - ), - ( - vec![(col_e, None), (col_f, None)], - vec![(col_e, None), (col_f, None)], - ), - ]; - - for (reqs, expected_normalized) in requirements.into_iter() { - let req = convert_to_sort_reqs(&reqs); - let expected_normalized = convert_to_sort_reqs(&expected_normalized); - - assert_eq!( - eq_properties.normalize_sort_requirements(&req), - expected_normalized - ); - } - - Ok(()) - } - - #[test] - fn test_get_finer() -> Result<()> { - let schema = create_test_schema()?; - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let col_c = &col("c", &schema)?; - let eq_properties = EquivalenceProperties::new(schema); - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - // First entry, and second entry are the physical sort requirement that are argument for get_finer_requirement. - // Third entry is the expected result. - let tests_cases = vec![ - // Get finer requirement between [a Some(ASC)] and [a None, b Some(ASC)] - // result should be [a Some(ASC), b Some(ASC)] - ( - vec![(col_a, Some(option_asc))], - vec![(col_a, None), (col_b, Some(option_asc))], - Some(vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))]), - ), - // Get finer requirement between [a Some(ASC), b Some(ASC), c Some(ASC)] and [a Some(ASC), b Some(ASC)] - // result should be [a Some(ASC), b Some(ASC), c Some(ASC)] - ( - vec![ - (col_a, Some(option_asc)), - (col_b, Some(option_asc)), - (col_c, Some(option_asc)), - ], - vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))], - Some(vec![ - (col_a, Some(option_asc)), - (col_b, Some(option_asc)), - (col_c, Some(option_asc)), - ]), - ), - // Get finer requirement between [a Some(ASC), b Some(ASC)] and [a Some(ASC), b Some(DESC)] - // result should be None - ( - vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))], - vec![(col_a, Some(option_asc)), (col_b, Some(option_desc))], - None, - ), - ]; - for (lhs, rhs, expected) in tests_cases { - let lhs = convert_to_sort_reqs(&lhs); - let rhs = convert_to_sort_reqs(&rhs); - let expected = expected.map(|expected| convert_to_sort_reqs(&expected)); - let finer = eq_properties.get_finer_requirement(&lhs, &rhs); - assert_eq!(finer, expected) - } - - Ok(()) - } - - #[test] - fn test_get_meet_ordering() -> Result<()> { - let schema = create_test_schema()?; - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let eq_properties = EquivalenceProperties::new(schema); - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - let tests_cases = vec![ - // Get meet ordering between [a ASC] and [a ASC, b ASC] - // result should be [a ASC] - ( - vec![(col_a, option_asc)], - vec![(col_a, option_asc), (col_b, option_asc)], - Some(vec![(col_a, option_asc)]), - ), - // Get meet ordering between [a ASC] and [a DESC] - // result should be None. - (vec![(col_a, option_asc)], vec![(col_a, option_desc)], None), - // Get meet ordering between [a ASC, b ASC] and [a ASC, b DESC] - // result should be [a ASC]. - ( - vec![(col_a, option_asc), (col_b, option_asc)], - vec![(col_a, option_asc), (col_b, option_desc)], - Some(vec![(col_a, option_asc)]), - ), - ]; - for (lhs, rhs, expected) in tests_cases { - let lhs = convert_to_sort_exprs(&lhs); - let rhs = convert_to_sort_exprs(&rhs); - let expected = expected.map(|expected| convert_to_sort_exprs(&expected)); - let finer = eq_properties.get_meet_ordering(&lhs, &rhs); - assert_eq!(finer, expected) - } - - Ok(()) - } - - #[test] - fn test_find_longest_permutation() -> Result<()> { - // Schema satisfies following orderings: - // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] - // and - // Column [a=c] (e.g they are aliases). - // At below we add [d ASC, h DESC] also, for test purposes - let (test_schema, mut eq_properties) = create_test_params()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_h = &col("h", &test_schema)?; - - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - // [d ASC, h ASC] also satisfies schema. - eq_properties.add_new_orderings([vec![ - PhysicalSortExpr { - expr: col_d.clone(), - options: option_asc, - }, - PhysicalSortExpr { - expr: col_h.clone(), - options: option_desc, - }, - ]]); - let test_cases = vec![ - // TEST CASE 1 - (vec![col_a], vec![(col_a, option_asc)]), - // TEST CASE 2 - (vec![col_c], vec![(col_c, option_asc)]), - // TEST CASE 3 - ( - vec![col_d, col_e, col_b], - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_desc), - ], - ), - // TEST CASE 4 - (vec![col_b], vec![]), - // TEST CASE 5 - (vec![col_d], vec![(col_d, option_asc)]), - ]; - for (exprs, expected) in test_cases { - let exprs = exprs.into_iter().cloned().collect::>(); - let expected = convert_to_sort_exprs(&expected); - let (actual, _) = eq_properties.find_longest_permutation(&exprs); - assert_eq!(actual, expected); - } - - Ok(()) - } - - #[test] - fn test_update_ordering() -> Result<()> { - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - Field::new("d", DataType::Int32, true), - ]); - - let mut eq_properties = EquivalenceProperties::new(Arc::new(schema.clone())); - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let col_c = &col("c", &schema)?; - let col_d = &col("d", &schema)?; - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - // b=a (e.g they are aliases) - eq_properties.add_equal_conditions(col_b, col_a); - // [b ASC], [d ASC] - eq_properties.add_new_orderings(vec![ - vec![PhysicalSortExpr { - expr: col_b.clone(), - options: option_asc, - }], - vec![PhysicalSortExpr { - expr: col_d.clone(), - options: option_asc, - }], - ]); - - let test_cases = vec![ - // d + b - ( - Arc::new(BinaryExpr::new( - col_d.clone(), - Operator::Plus, - col_b.clone(), - )) as Arc, - SortProperties::Ordered(option_asc), - ), - // b - (col_b.clone(), SortProperties::Ordered(option_asc)), - // a - (col_a.clone(), SortProperties::Ordered(option_asc)), - // a + c - ( - Arc::new(BinaryExpr::new( - col_a.clone(), - Operator::Plus, - col_c.clone(), - )), - SortProperties::Unordered, - ), - ]; - for (expr, expected) in test_cases { - let expr_ordering = ExprOrdering::new(expr.clone()); - let expr_ordering = expr_ordering - .transform_up(&|expr| update_ordering(expr, &eq_properties))?; - let err_msg = format!( - "expr:{:?}, expected: {:?}, actual: {:?}", - expr, expected, expr_ordering.state - ); - assert_eq!(expr_ordering.state, expected, "{}", err_msg); - } - - Ok(()) - } - - #[test] - fn test_get_indices_of_matching_sort_exprs_with_order_eq() -> Result<()> { - let sort_options = SortOptions::default(); - let sort_options_not = SortOptions::default().not(); - - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - ]); - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let required_columns = [col_b.clone(), col_a.clone()]; - let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); - eq_properties.add_new_orderings([vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: sort_options_not, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: sort_options, - }, - ]]); - let (result, idxs) = eq_properties.find_longest_permutation(&required_columns); - assert_eq!(idxs, vec![0, 1]); - assert_eq!( - result, - vec![ - PhysicalSortExpr { - expr: col_b.clone(), - options: sort_options_not - }, - PhysicalSortExpr { - expr: col_a.clone(), - options: sort_options - } - ] - ); - - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - ]); - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let required_columns = [col_b.clone(), col_a.clone()]; - let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); - eq_properties.add_new_orderings([ - vec![PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options: sort_options, - }], - vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: sort_options_not, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: sort_options, - }, - ], - ]); - let (result, idxs) = eq_properties.find_longest_permutation(&required_columns); - assert_eq!(idxs, vec![0, 1]); - assert_eq!( - result, - vec![ - PhysicalSortExpr { - expr: col_b.clone(), - options: sort_options_not - }, - PhysicalSortExpr { - expr: col_a.clone(), - options: sort_options - } - ] - ); - - let required_columns = [ - Arc::new(Column::new("b", 1)) as _, - Arc::new(Column::new("a", 0)) as _, - ]; - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - ]); - let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); - - // not satisfied orders - eq_properties.add_new_orderings([vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: sort_options_not, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options: sort_options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: sort_options, - }, - ]]); - let (_, idxs) = eq_properties.find_longest_permutation(&required_columns); - assert_eq!(idxs, vec![0]); - - Ok(()) - } - - #[test] - fn test_normalize_ordering_equivalence_classes() -> Result<()> { - let sort_options = SortOptions::default(); - - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - ]); - let col_a_expr = col("a", &schema)?; - let col_b_expr = col("b", &schema)?; - let col_c_expr = col("c", &schema)?; - let mut eq_properties = EquivalenceProperties::new(Arc::new(schema.clone())); - - eq_properties.add_equal_conditions(&col_a_expr, &col_c_expr); - let others = vec![ - vec![PhysicalSortExpr { - expr: col_b_expr.clone(), - options: sort_options, - }], - vec![PhysicalSortExpr { - expr: col_c_expr.clone(), - options: sort_options, - }], - ]; - eq_properties.add_new_orderings(others); - - let mut expected_eqs = EquivalenceProperties::new(Arc::new(schema)); - expected_eqs.add_new_orderings([ - vec![PhysicalSortExpr { - expr: col_b_expr.clone(), - options: sort_options, - }], - vec![PhysicalSortExpr { - expr: col_c_expr.clone(), - options: sort_options, - }], - ]); - - let oeq_class = eq_properties.oeq_class().clone(); - let expected = expected_eqs.oeq_class(); - assert!(oeq_class.eq(expected)); - - Ok(()) - } - - #[test] - fn project_empty_output_ordering() -> Result<()> { - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - ]); - let mut eq_properties = EquivalenceProperties::new(Arc::new(schema.clone())); - let ordering = vec![PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }]; - eq_properties.add_new_orderings([ordering]); - let projection_mapping = ProjectionMapping { - inner: vec![ - ( - Arc::new(Column::new("b", 1)) as _, - Arc::new(Column::new("b_new", 0)) as _, - ), - ( - Arc::new(Column::new("a", 0)) as _, - Arc::new(Column::new("a_new", 1)) as _, - ), - ], - }; - let projection_schema = Arc::new(Schema::new(vec![ - Field::new("b_new", DataType::Int32, true), - Field::new("a_new", DataType::Int32, true), - ])); - let orderings = eq_properties - .project(&projection_mapping, projection_schema) - .oeq_class() - .output_ordering() - .unwrap_or_default(); - - assert_eq!( - vec![PhysicalSortExpr { - expr: Arc::new(Column::new("b_new", 0)), - options: SortOptions::default(), - }], - orderings - ); - - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - ]); - let eq_properties = EquivalenceProperties::new(Arc::new(schema)); - let projection_mapping = ProjectionMapping { - inner: vec![ - ( - Arc::new(Column::new("c", 2)) as _, - Arc::new(Column::new("c_new", 0)) as _, - ), - ( - Arc::new(Column::new("b", 1)) as _, - Arc::new(Column::new("b_new", 1)) as _, - ), - ], - }; - let projection_schema = Arc::new(Schema::new(vec![ - Field::new("c_new", DataType::Int32, true), - Field::new("b_new", DataType::Int32, true), - ])); - let projected = eq_properties.project(&projection_mapping, projection_schema); - // After projection there is no ordering. - assert!(projected.oeq_class().output_ordering().is_none()); - - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs new file mode 100644 index 0000000000000..f0bd1740d5d2d --- /dev/null +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -0,0 +1,598 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::{add_offset_to_expr, collapse_lex_req, ProjectionMapping}; +use crate::{ + expressions::Column, physical_expr::deduplicate_physical_exprs, + physical_exprs_bag_equal, physical_exprs_contains, LexOrdering, LexOrderingRef, + LexRequirement, LexRequirementRef, PhysicalExpr, PhysicalSortExpr, + PhysicalSortRequirement, +}; +use datafusion_common::tree_node::TreeNode; +use datafusion_common::{tree_node::Transformed, JoinType}; +use std::sync::Arc; + +/// An `EquivalenceClass` is a set of [`Arc`]s that are known +/// to have the same value for all tuples in a relation. These are generated by +/// equality predicates (e.g. `a = b`), typically equi-join conditions and +/// equality conditions in filters. +/// +/// Two `EquivalenceClass`es are equal if they contains the same expressions in +/// without any ordering. +#[derive(Debug, Clone)] +pub struct EquivalenceClass { + /// The expressions in this equivalence class. The order doesn't + /// matter for equivalence purposes + /// + /// TODO: use a HashSet for this instead of a Vec + exprs: Vec>, +} + +impl PartialEq for EquivalenceClass { + /// Returns true if other is equal in the sense + /// of bags (multi-sets), disregarding their orderings. + fn eq(&self, other: &Self) -> bool { + physical_exprs_bag_equal(&self.exprs, &other.exprs) + } +} + +impl EquivalenceClass { + /// Create a new empty equivalence class + pub fn new_empty() -> Self { + Self { exprs: vec![] } + } + + // Create a new equivalence class from a pre-existing `Vec` + pub fn new(mut exprs: Vec>) -> Self { + deduplicate_physical_exprs(&mut exprs); + Self { exprs } + } + + /// Return the inner vector of expressions + pub fn into_vec(self) -> Vec> { + self.exprs + } + + /// Return the "canonical" expression for this class (the first element) + /// if any + fn canonical_expr(&self) -> Option> { + self.exprs.first().cloned() + } + + /// Insert the expression into this class, meaning it is known to be equal to + /// all other expressions in this class + pub fn push(&mut self, expr: Arc) { + if !self.contains(&expr) { + self.exprs.push(expr); + } + } + + /// Inserts all the expressions from other into this class + pub fn extend(&mut self, other: Self) { + for expr in other.exprs { + // use push so entries are deduplicated + self.push(expr); + } + } + + /// Returns true if this equivalence class contains t expression + pub fn contains(&self, expr: &Arc) -> bool { + physical_exprs_contains(&self.exprs, expr) + } + + /// Returns true if this equivalence class has any entries in common with `other` + pub fn contains_any(&self, other: &Self) -> bool { + self.exprs.iter().any(|e| other.contains(e)) + } + + /// return the number of items in this class + pub fn len(&self) -> usize { + self.exprs.len() + } + + /// return true if this class is empty + pub fn is_empty(&self) -> bool { + self.exprs.is_empty() + } + + /// Iterate over all elements in this class, in some arbitrary order + pub fn iter(&self) -> impl Iterator> { + self.exprs.iter() + } + + /// Return a new equivalence class that have the specified offset added to + /// each expression (used when schemas are appended such as in joins) + pub fn with_offset(&self, offset: usize) -> Self { + let new_exprs = self + .exprs + .iter() + .cloned() + .map(|e| add_offset_to_expr(e, offset)) + .collect(); + Self::new(new_exprs) + } +} + +/// An `EquivalenceGroup` is a collection of `EquivalenceClass`es where each +/// class represents a distinct equivalence class in a relation. +#[derive(Debug, Clone)] +pub struct EquivalenceGroup { + pub classes: Vec, +} + +impl EquivalenceGroup { + /// Creates an empty equivalence group. + pub fn empty() -> Self { + Self { classes: vec![] } + } + + /// Creates an equivalence group from the given equivalence classes. + pub fn new(classes: Vec) -> Self { + let mut result = Self { classes }; + result.remove_redundant_entries(); + result + } + + /// Returns how many equivalence classes there are in this group. + pub fn len(&self) -> usize { + self.classes.len() + } + + /// Checks whether this equivalence group is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns an iterator over the equivalence classes in this group. + pub fn iter(&self) -> impl Iterator { + self.classes.iter() + } + + /// Adds the equality `left` = `right` to this equivalence group. + /// New equality conditions often arise after steps like `Filter(a = b)`, + /// `Alias(a, a as b)` etc. + pub fn add_equal_conditions( + &mut self, + left: &Arc, + right: &Arc, + ) { + let mut first_class = None; + let mut second_class = None; + for (idx, cls) in self.classes.iter().enumerate() { + if cls.contains(left) { + first_class = Some(idx); + } + if cls.contains(right) { + second_class = Some(idx); + } + } + match (first_class, second_class) { + (Some(mut first_idx), Some(mut second_idx)) => { + // If the given left and right sides belong to different classes, + // we should unify/bridge these classes. + if first_idx != second_idx { + // By convention, make sure `second_idx` is larger than `first_idx`. + if first_idx > second_idx { + (first_idx, second_idx) = (second_idx, first_idx); + } + // Remove the class at `second_idx` and merge its values with + // the class at `first_idx`. The convention above makes sure + // that `first_idx` is still valid after removing `second_idx`. + let other_class = self.classes.swap_remove(second_idx); + self.classes[first_idx].extend(other_class); + } + } + (Some(group_idx), None) => { + // Right side is new, extend left side's class: + self.classes[group_idx].push(right.clone()); + } + (None, Some(group_idx)) => { + // Left side is new, extend right side's class: + self.classes[group_idx].push(left.clone()); + } + (None, None) => { + // None of the expressions is among existing classes. + // Create a new equivalence class and extend the group. + self.classes + .push(EquivalenceClass::new(vec![left.clone(), right.clone()])); + } + } + } + + /// Removes redundant entries from this group. + fn remove_redundant_entries(&mut self) { + // Remove duplicate entries from each equivalence class: + self.classes.retain_mut(|cls| { + // Keep groups that have at least two entries as singleton class is + // meaningless (i.e. it contains no non-trivial information): + cls.len() > 1 + }); + // Unify/bridge groups that have common expressions: + self.bridge_classes() + } + + /// This utility function unifies/bridges classes that have common expressions. + /// For example, assume that we have [`EquivalenceClass`]es `[a, b]` and `[b, c]`. + /// Since both classes contain `b`, columns `a`, `b` and `c` are actually all + /// equal and belong to one class. This utility converts merges such classes. + fn bridge_classes(&mut self) { + let mut idx = 0; + while idx < self.classes.len() { + let mut next_idx = idx + 1; + let start_size = self.classes[idx].len(); + while next_idx < self.classes.len() { + if self.classes[idx].contains_any(&self.classes[next_idx]) { + let extension = self.classes.swap_remove(next_idx); + self.classes[idx].extend(extension); + } else { + next_idx += 1; + } + } + if self.classes[idx].len() > start_size { + continue; + } + idx += 1; + } + } + + /// Extends this equivalence group with the `other` equivalence group. + pub fn extend(&mut self, other: Self) { + self.classes.extend(other.classes); + self.remove_redundant_entries(); + } + + /// Normalizes the given physical expression according to this group. + /// The expression is replaced with the first expression in the equivalence + /// class it matches with (if any). + pub fn normalize_expr(&self, expr: Arc) -> Arc { + expr.clone() + .transform(&|expr| { + for cls in self.iter() { + if cls.contains(&expr) { + return Ok(Transformed::Yes(cls.canonical_expr().unwrap())); + } + } + Ok(Transformed::No(expr)) + }) + .unwrap_or(expr) + } + + /// Normalizes the given sort expression according to this group. + /// The underlying physical expression is replaced with the first expression + /// in the equivalence class it matches with (if any). If the underlying + /// expression does not belong to any equivalence class in this group, returns + /// the sort expression as is. + pub fn normalize_sort_expr( + &self, + mut sort_expr: PhysicalSortExpr, + ) -> PhysicalSortExpr { + sort_expr.expr = self.normalize_expr(sort_expr.expr); + sort_expr + } + + /// Normalizes the given sort requirement according to this group. + /// The underlying physical expression is replaced with the first expression + /// in the equivalence class it matches with (if any). If the underlying + /// expression does not belong to any equivalence class in this group, returns + /// the given sort requirement as is. + pub fn normalize_sort_requirement( + &self, + mut sort_requirement: PhysicalSortRequirement, + ) -> PhysicalSortRequirement { + sort_requirement.expr = self.normalize_expr(sort_requirement.expr); + sort_requirement + } + + /// This function applies the `normalize_expr` function for all expressions + /// in `exprs` and returns the corresponding normalized physical expressions. + pub fn normalize_exprs( + &self, + exprs: impl IntoIterator>, + ) -> Vec> { + exprs + .into_iter() + .map(|expr| self.normalize_expr(expr)) + .collect() + } + + /// This function applies the `normalize_sort_expr` function for all sort + /// expressions in `sort_exprs` and returns the corresponding normalized + /// sort expressions. + pub fn normalize_sort_exprs(&self, sort_exprs: LexOrderingRef) -> LexOrdering { + // Convert sort expressions to sort requirements: + let sort_reqs = PhysicalSortRequirement::from_sort_exprs(sort_exprs.iter()); + // Normalize the requirements: + let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs); + // Convert sort requirements back to sort expressions: + PhysicalSortRequirement::to_sort_exprs(normalized_sort_reqs) + } + + /// This function applies the `normalize_sort_requirement` function for all + /// requirements in `sort_reqs` and returns the corresponding normalized + /// sort requirements. + pub fn normalize_sort_requirements( + &self, + sort_reqs: LexRequirementRef, + ) -> LexRequirement { + collapse_lex_req( + sort_reqs + .iter() + .map(|sort_req| self.normalize_sort_requirement(sort_req.clone())) + .collect(), + ) + } + + /// Projects `expr` according to the given projection mapping. + /// If the resulting expression is invalid after projection, returns `None`. + pub fn project_expr( + &self, + mapping: &ProjectionMapping, + expr: &Arc, + ) -> Option> { + // First, we try to project expressions with an exact match. If we are + // unable to do this, we consult equivalence classes. + if let Some(target) = mapping.target_expr(expr) { + // If we match the source, we can project directly: + return Some(target); + } else { + // If the given expression is not inside the mapping, try to project + // expressions considering the equivalence classes. + for (source, target) in mapping.iter() { + // If we match an equivalent expression to `source`, then we can + // project. For example, if we have the mapping `(a as a1, a + c)` + // and the equivalence class `(a, b)`, expression `b` projects to `a1`. + if self + .get_equivalence_class(source) + .map_or(false, |group| group.contains(expr)) + { + return Some(target.clone()); + } + } + } + // Project a non-leaf expression by projecting its children. + let children = expr.children(); + if children.is_empty() { + // Leaf expression should be inside mapping. + return None; + } + children + .into_iter() + .map(|child| self.project_expr(mapping, &child)) + .collect::>>() + .map(|children| expr.clone().with_new_children(children).unwrap()) + } + + /// Projects this equivalence group according to the given projection mapping. + pub fn project(&self, mapping: &ProjectionMapping) -> Self { + let projected_classes = self.iter().filter_map(|cls| { + let new_class = cls + .iter() + .filter_map(|expr| self.project_expr(mapping, expr)) + .collect::>(); + (new_class.len() > 1).then_some(EquivalenceClass::new(new_class)) + }); + // TODO: Convert the algorithm below to a version that uses `HashMap`. + // once `Arc` can be stored in `HashMap`. + // See issue: https://github.com/apache/arrow-datafusion/issues/8027 + let mut new_classes = vec![]; + for (source, target) in mapping.iter() { + if new_classes.is_empty() { + new_classes.push((source, vec![target.clone()])); + } + if let Some((_, values)) = + new_classes.iter_mut().find(|(key, _)| key.eq(source)) + { + if !physical_exprs_contains(values, target) { + values.push(target.clone()); + } + } + } + // Only add equivalence classes with at least two members as singleton + // equivalence classes are meaningless. + let new_classes = new_classes + .into_iter() + .filter_map(|(_, values)| (values.len() > 1).then_some(values)) + .map(EquivalenceClass::new); + + let classes = projected_classes.chain(new_classes).collect(); + Self::new(classes) + } + + /// Returns the equivalence class containing `expr`. If no equivalence class + /// contains `expr`, returns `None`. + fn get_equivalence_class( + &self, + expr: &Arc, + ) -> Option<&EquivalenceClass> { + self.iter().find(|cls| cls.contains(expr)) + } + + /// Combine equivalence groups of the given join children. + pub fn join( + &self, + right_equivalences: &Self, + join_type: &JoinType, + left_size: usize, + on: &[(Column, Column)], + ) -> Self { + match join_type { + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { + let mut result = Self::new( + self.iter() + .cloned() + .chain( + right_equivalences + .iter() + .map(|cls| cls.with_offset(left_size)), + ) + .collect(), + ); + // In we have an inner join, expressions in the "on" condition + // are equal in the resulting table. + if join_type == &JoinType::Inner { + for (lhs, rhs) in on.iter() { + let index = rhs.index() + left_size; + let new_lhs = Arc::new(lhs.clone()) as _; + let new_rhs = Arc::new(Column::new(rhs.name(), index)) as _; + result.add_equal_conditions(&new_lhs, &new_rhs); + } + } + result + } + JoinType::LeftSemi | JoinType::LeftAnti => self.clone(), + JoinType::RightSemi | JoinType::RightAnti => right_equivalences.clone(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::equivalence::tests::create_test_params; + use crate::equivalence::{EquivalenceClass, EquivalenceGroup}; + use crate::expressions::lit; + use crate::expressions::Column; + use crate::expressions::Literal; + use datafusion_common::Result; + use datafusion_common::ScalarValue; + use std::sync::Arc; + + #[test] + fn test_bridge_groups() -> Result<()> { + // First entry in the tuple is argument, second entry is the bridged result + let test_cases = vec![ + // ------- TEST CASE 1 -----------// + ( + vec![vec![1, 2, 3], vec![2, 4, 5], vec![11, 12, 9], vec![7, 6, 5]], + // Expected is compared with set equality. Order of the specific results may change. + vec![vec![1, 2, 3, 4, 5, 6, 7], vec![9, 11, 12]], + ), + // ------- TEST CASE 2 -----------// + ( + vec![vec![1, 2, 3], vec![3, 4, 5], vec![9, 8, 7], vec![7, 6, 5]], + // Expected + vec![vec![1, 2, 3, 4, 5, 6, 7, 8, 9]], + ), + ]; + for (entries, expected) in test_cases { + let entries = entries + .into_iter() + .map(|entry| entry.into_iter().map(lit).collect::>()) + .map(EquivalenceClass::new) + .collect::>(); + let expected = expected + .into_iter() + .map(|entry| entry.into_iter().map(lit).collect::>()) + .map(EquivalenceClass::new) + .collect::>(); + let mut eq_groups = EquivalenceGroup::new(entries.clone()); + eq_groups.bridge_classes(); + let eq_groups = eq_groups.classes; + let err_msg = format!( + "error in test entries: {:?}, expected: {:?}, actual:{:?}", + entries, expected, eq_groups + ); + assert_eq!(eq_groups.len(), expected.len(), "{}", err_msg); + for idx in 0..eq_groups.len() { + assert_eq!(&eq_groups[idx], &expected[idx], "{}", err_msg); + } + } + Ok(()) + } + + #[test] + fn test_remove_redundant_entries_eq_group() -> Result<()> { + let entries = vec![ + EquivalenceClass::new(vec![lit(1), lit(1), lit(2)]), + // This group is meaningless should be removed + EquivalenceClass::new(vec![lit(3), lit(3)]), + EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), + ]; + // Given equivalences classes are not in succinct form. + // Expected form is the most plain representation that is functionally same. + let expected = vec![ + EquivalenceClass::new(vec![lit(1), lit(2)]), + EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), + ]; + let mut eq_groups = EquivalenceGroup::new(entries); + eq_groups.remove_redundant_entries(); + + let eq_groups = eq_groups.classes; + assert_eq!(eq_groups.len(), expected.len()); + assert_eq!(eq_groups.len(), 2); + + assert_eq!(eq_groups[0], expected[0]); + assert_eq!(eq_groups[1], expected[1]); + Ok(()) + } + + #[test] + fn test_schema_normalize_expr_with_equivalence() -> Result<()> { + let col_a = &Column::new("a", 0); + let col_b = &Column::new("b", 1); + let col_c = &Column::new("c", 2); + // Assume that column a and c are aliases. + let (_test_schema, eq_properties) = create_test_params()?; + + let col_a_expr = Arc::new(col_a.clone()) as Arc; + let col_b_expr = Arc::new(col_b.clone()) as Arc; + let col_c_expr = Arc::new(col_c.clone()) as Arc; + // Test cases for equivalence normalization, + // First entry in the tuple is argument, second entry is expected result after normalization. + let expressions = vec![ + // Normalized version of the column a and c should go to a + // (by convention all the expressions inside equivalence class are mapped to the first entry + // in this case a is the first entry in the equivalence class.) + (&col_a_expr, &col_a_expr), + (&col_c_expr, &col_a_expr), + // Cannot normalize column b + (&col_b_expr, &col_b_expr), + ]; + let eq_group = eq_properties.eq_group(); + for (expr, expected_eq) in expressions { + assert!( + expected_eq.eq(&eq_group.normalize_expr(expr.clone())), + "error in test: expr: {expr:?}" + ); + } + + Ok(()) + } + + #[test] + fn test_contains_any() { + let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) + as Arc; + let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) + as Arc; + let lit2 = + Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; + let lit1 = + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; + let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; + + let cls1 = EquivalenceClass::new(vec![lit_true.clone(), lit_false.clone()]); + let cls2 = EquivalenceClass::new(vec![lit_true.clone(), col_b_expr.clone()]); + let cls3 = EquivalenceClass::new(vec![lit2.clone(), lit1.clone()]); + + // lit_true is common + assert!(cls1.contains_any(&cls2)); + // there is no common entry + assert!(!cls1.contains_any(&cls3)); + assert!(!cls2.contains_any(&cls3)); + } +} diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs new file mode 100644 index 0000000000000..387dce2cdc8b2 --- /dev/null +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -0,0 +1,533 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod class; +mod ordering; +mod projection; +mod properties; +use crate::expressions::Column; +use crate::{LexRequirement, PhysicalExpr, PhysicalSortRequirement}; +pub use class::{EquivalenceClass, EquivalenceGroup}; +use datafusion_common::tree_node::{Transformed, TreeNode}; +pub use ordering::OrderingEquivalenceClass; +pub use projection::ProjectionMapping; +pub use properties::{join_equivalence_properties, EquivalenceProperties}; +use std::sync::Arc; + +/// This function constructs a duplicate-free `LexOrderingReq` by filtering out +/// duplicate entries that have same physical expression inside. For example, +/// `vec![a Some(ASC), a Some(DESC)]` collapses to `vec![a Some(ASC)]`. +pub fn collapse_lex_req(input: LexRequirement) -> LexRequirement { + let mut output = Vec::::new(); + for item in input { + if !output.iter().any(|req| req.expr.eq(&item.expr)) { + output.push(item); + } + } + output +} + +/// Adds the `offset` value to `Column` indices inside `expr`. This function is +/// generally used during the update of the right table schema in join operations. +pub fn add_offset_to_expr( + expr: Arc, + offset: usize, +) -> Arc { + expr.transform_down(&|e| match e.as_any().downcast_ref::() { + Some(col) => Ok(Transformed::Yes(Arc::new(Column::new( + col.name(), + offset + col.index(), + )))), + None => Ok(Transformed::No(e)), + }) + .unwrap() + // Note that we can safely unwrap here since our transform always returns + // an `Ok` value. +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::expressions::{col, Column}; + use crate::PhysicalSortExpr; + use arrow::compute::{lexsort_to_indices, SortColumn}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_array::{ArrayRef, Float64Array, RecordBatch, UInt32Array}; + use arrow_schema::{SchemaRef, SortOptions}; + use datafusion_common::{plan_datafusion_err, DataFusionError, Result}; + use itertools::izip; + use rand::rngs::StdRng; + use rand::seq::SliceRandom; + use rand::{Rng, SeedableRng}; + use std::sync::Arc; + + pub fn output_schema( + mapping: &ProjectionMapping, + input_schema: &Arc, + ) -> Result { + // Calculate output schema + let fields: Result> = mapping + .iter() + .map(|(source, target)| { + let name = target + .as_any() + .downcast_ref::() + .ok_or_else(|| plan_datafusion_err!("Expects to have column"))? + .name(); + let field = Field::new( + name, + source.data_type(input_schema)?, + source.nullable(input_schema)?, + ); + + Ok(field) + }) + .collect(); + + let output_schema = Arc::new(Schema::new_with_metadata( + fields?, + input_schema.metadata().clone(), + )); + + Ok(output_schema) + } + + // Generate a schema which consists of 8 columns (a, b, c, d, e, f, g, h) + pub fn create_test_schema() -> Result { + let a = Field::new("a", DataType::Int32, true); + let b = Field::new("b", DataType::Int32, true); + let c = Field::new("c", DataType::Int32, true); + let d = Field::new("d", DataType::Int32, true); + let e = Field::new("e", DataType::Int32, true); + let f = Field::new("f", DataType::Int32, true); + let g = Field::new("g", DataType::Int32, true); + let h = Field::new("h", DataType::Int32, true); + let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f, g, h])); + + Ok(schema) + } + + /// Construct a schema with following properties + /// Schema satisfies following orderings: + /// [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] + /// and + /// Column [a=c] (e.g they are aliases). + pub fn create_test_params() -> Result<(SchemaRef, EquivalenceProperties)> { + let test_schema = create_test_schema()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_g = &col("g", &test_schema)?; + let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); + eq_properties.add_equal_conditions(col_a, col_c); + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let orderings = vec![ + // [a ASC] + vec![(col_a, option_asc)], + // [d ASC, b ASC] + vec![(col_d, option_asc), (col_b, option_asc)], + // [e DESC, f ASC, g ASC] + vec![ + (col_e, option_desc), + (col_f, option_asc), + (col_g, option_asc), + ], + ]; + let orderings = convert_to_orderings(&orderings); + eq_properties.add_new_orderings(orderings); + Ok((test_schema, eq_properties)) + } + + // Generate a schema which consists of 6 columns (a, b, c, d, e, f) + fn create_test_schema_2() -> Result { + let a = Field::new("a", DataType::Float64, true); + let b = Field::new("b", DataType::Float64, true); + let c = Field::new("c", DataType::Float64, true); + let d = Field::new("d", DataType::Float64, true); + let e = Field::new("e", DataType::Float64, true); + let f = Field::new("f", DataType::Float64, true); + let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f])); + + Ok(schema) + } + + /// Construct a schema with random ordering + /// among column a, b, c, d + /// where + /// Column [a=f] (e.g they are aliases). + /// Column e is constant. + pub fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperties)> { + let test_schema = create_test_schema_2()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_exprs = [col_a, col_b, col_c, col_d, col_e, col_f]; + + let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); + // Define a and f are aliases + eq_properties.add_equal_conditions(col_a, col_f); + // Column e has constant value. + eq_properties = eq_properties.add_constants([col_e.clone()]); + + // Randomly order columns for sorting + let mut rng = StdRng::seed_from_u64(seed); + let mut remaining_exprs = col_exprs[0..4].to_vec(); // only a, b, c, d are sorted + + let options_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + while !remaining_exprs.is_empty() { + let n_sort_expr = rng.gen_range(0..remaining_exprs.len() + 1); + remaining_exprs.shuffle(&mut rng); + + let ordering = remaining_exprs + .drain(0..n_sort_expr) + .map(|expr| PhysicalSortExpr { + expr: expr.clone(), + options: options_asc, + }) + .collect(); + + eq_properties.add_new_orderings([ordering]); + } + + Ok((test_schema, eq_properties)) + } + + // Convert each tuple to PhysicalSortRequirement + pub fn convert_to_sort_reqs( + in_data: &[(&Arc, Option)], + ) -> Vec { + in_data + .iter() + .map(|(expr, options)| { + PhysicalSortRequirement::new((*expr).clone(), *options) + }) + .collect() + } + + // Convert each tuple to PhysicalSortExpr + pub fn convert_to_sort_exprs( + in_data: &[(&Arc, SortOptions)], + ) -> Vec { + in_data + .iter() + .map(|(expr, options)| PhysicalSortExpr { + expr: (*expr).clone(), + options: *options, + }) + .collect() + } + + // Convert each inner tuple to PhysicalSortExpr + pub fn convert_to_orderings( + orderings: &[Vec<(&Arc, SortOptions)>], + ) -> Vec> { + orderings + .iter() + .map(|sort_exprs| convert_to_sort_exprs(sort_exprs)) + .collect() + } + + // Convert each tuple to PhysicalSortExpr + pub fn convert_to_sort_exprs_owned( + in_data: &[(Arc, SortOptions)], + ) -> Vec { + in_data + .iter() + .map(|(expr, options)| PhysicalSortExpr { + expr: (*expr).clone(), + options: *options, + }) + .collect() + } + + // Convert each inner tuple to PhysicalSortExpr + pub fn convert_to_orderings_owned( + orderings: &[Vec<(Arc, SortOptions)>], + ) -> Vec> { + orderings + .iter() + .map(|sort_exprs| convert_to_sort_exprs_owned(sort_exprs)) + .collect() + } + + // Apply projection to the input_data, return projected equivalence properties and record batch + pub fn apply_projection( + proj_exprs: Vec<(Arc, String)>, + input_data: &RecordBatch, + input_eq_properties: &EquivalenceProperties, + ) -> Result<(RecordBatch, EquivalenceProperties)> { + let input_schema = input_data.schema(); + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + + let output_schema = output_schema(&projection_mapping, &input_schema)?; + let num_rows = input_data.num_rows(); + // Apply projection to the input record batch. + let projected_values = projection_mapping + .iter() + .map(|(source, _target)| source.evaluate(input_data)?.into_array(num_rows)) + .collect::>>()?; + let projected_batch = if projected_values.is_empty() { + RecordBatch::new_empty(output_schema.clone()) + } else { + RecordBatch::try_new(output_schema.clone(), projected_values)? + }; + + let projected_eq = + input_eq_properties.project(&projection_mapping, output_schema); + Ok((projected_batch, projected_eq)) + } + + #[test] + fn add_equal_conditions_test() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + Field::new("c", DataType::Int64, true), + Field::new("x", DataType::Int64, true), + Field::new("y", DataType::Int64, true), + ])); + + let mut eq_properties = EquivalenceProperties::new(schema); + let col_a_expr = Arc::new(Column::new("a", 0)) as Arc; + let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; + let col_c_expr = Arc::new(Column::new("c", 2)) as Arc; + let col_x_expr = Arc::new(Column::new("x", 3)) as Arc; + let col_y_expr = Arc::new(Column::new("y", 4)) as Arc; + + // a and b are aliases + eq_properties.add_equal_conditions(&col_a_expr, &col_b_expr); + assert_eq!(eq_properties.eq_group().len(), 1); + + // This new entry is redundant, size shouldn't increase + eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr); + assert_eq!(eq_properties.eq_group().len(), 1); + let eq_groups = &eq_properties.eq_group().classes[0]; + assert_eq!(eq_groups.len(), 2); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + + // b and c are aliases. Exising equivalence class should expand, + // however there shouldn't be any new equivalence class + eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr); + assert_eq!(eq_properties.eq_group().len(), 1); + let eq_groups = &eq_properties.eq_group().classes[0]; + assert_eq!(eq_groups.len(), 3); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + assert!(eq_groups.contains(&col_c_expr)); + + // This is a new set of equality. Hence equivalent class count should be 2. + eq_properties.add_equal_conditions(&col_x_expr, &col_y_expr); + assert_eq!(eq_properties.eq_group().len(), 2); + + // This equality bridges distinct equality sets. + // Hence equivalent class count should decrease from 2 to 1. + eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr); + assert_eq!(eq_properties.eq_group().len(), 1); + let eq_groups = &eq_properties.eq_group().classes[0]; + assert_eq!(eq_groups.len(), 5); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + assert!(eq_groups.contains(&col_c_expr)); + assert!(eq_groups.contains(&col_x_expr)); + assert!(eq_groups.contains(&col_y_expr)); + + Ok(()) + } + + /// Checks if the table (RecordBatch) remains unchanged when sorted according to the provided `required_ordering`. + /// + /// The function works by adding a unique column of ascending integers to the original table. This column ensures + /// that rows that are otherwise indistinguishable (e.g., if they have the same values in all other columns) can + /// still be differentiated. When sorting the extended table, the unique column acts as a tie-breaker to produce + /// deterministic sorting results. + /// + /// If the table remains the same after sorting with the added unique column, it indicates that the table was + /// already sorted according to `required_ordering` to begin with. + pub fn is_table_same_after_sort( + mut required_ordering: Vec, + batch: RecordBatch, + ) -> Result { + // Clone the original schema and columns + let original_schema = batch.schema(); + let mut columns = batch.columns().to_vec(); + + // Create a new unique column + let n_row = batch.num_rows(); + let vals: Vec = (0..n_row).collect::>(); + let vals: Vec = vals.into_iter().map(|val| val as f64).collect(); + let unique_col = Arc::new(Float64Array::from_iter_values(vals)) as ArrayRef; + columns.push(unique_col.clone()); + + // Create a new schema with the added unique column + let unique_col_name = "unique"; + let unique_field = + Arc::new(Field::new(unique_col_name, DataType::Float64, false)); + let fields: Vec<_> = original_schema + .fields() + .iter() + .cloned() + .chain(std::iter::once(unique_field)) + .collect(); + let schema = Arc::new(Schema::new(fields)); + + // Create a new batch with the added column + let new_batch = RecordBatch::try_new(schema.clone(), columns)?; + + // Add the unique column to the required ordering to ensure deterministic results + required_ordering.push(PhysicalSortExpr { + expr: Arc::new(Column::new(unique_col_name, original_schema.fields().len())), + options: Default::default(), + }); + + // Convert the required ordering to a list of SortColumn + let sort_columns = required_ordering + .iter() + .map(|order_expr| { + let expr_result = order_expr.expr.evaluate(&new_batch)?; + let values = expr_result.into_array(new_batch.num_rows())?; + Ok(SortColumn { + values, + options: Some(order_expr.options), + }) + }) + .collect::>>()?; + + // Check if the indices after sorting match the initial ordering + let sorted_indices = lexsort_to_indices(&sort_columns, None)?; + let original_indices = UInt32Array::from_iter_values(0..n_row as u32); + + Ok(sorted_indices == original_indices) + } + + // If we already generated a random result for one of the + // expressions in the equivalence classes. For other expressions in the same + // equivalence class use same result. This util gets already calculated result, when available. + fn get_representative_arr( + eq_group: &EquivalenceClass, + existing_vec: &[Option], + schema: SchemaRef, + ) -> Option { + for expr in eq_group.iter() { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + if let Some(res) = &existing_vec[idx] { + return Some(res.clone()); + } + } + None + } + + // Generate a table that satisfies the given equivalence properties; i.e. + // equivalences, ordering equivalences, and constants. + pub fn generate_table_for_eq_properties( + eq_properties: &EquivalenceProperties, + n_elem: usize, + n_distinct: usize, + ) -> Result { + let mut rng = StdRng::seed_from_u64(23); + + let schema = eq_properties.schema(); + let mut schema_vec = vec![None; schema.fields.len()]; + + // Utility closure to generate random array + let mut generate_random_array = |num_elems: usize, max_val: usize| -> ArrayRef { + let values: Vec = (0..num_elems) + .map(|_| rng.gen_range(0..max_val) as f64 / 2.0) + .collect(); + Arc::new(Float64Array::from_iter_values(values)) + }; + + // Fill constant columns + for constant in &eq_properties.constants { + let col = constant.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + let arr = Arc::new(Float64Array::from_iter_values(vec![0 as f64; n_elem])) + as ArrayRef; + schema_vec[idx] = Some(arr); + } + + // Fill columns based on ordering equivalences + for ordering in eq_properties.oeq_class.iter() { + let (sort_columns, indices): (Vec<_>, Vec<_>) = ordering + .iter() + .map(|PhysicalSortExpr { expr, options }| { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + let arr = generate_random_array(n_elem, n_distinct); + ( + SortColumn { + values: arr, + options: Some(*options), + }, + idx, + ) + }) + .unzip(); + + let sort_arrs = arrow::compute::lexsort(&sort_columns, None)?; + for (idx, arr) in izip!(indices, sort_arrs) { + schema_vec[idx] = Some(arr); + } + } + + // Fill columns based on equivalence groups + for eq_group in eq_properties.eq_group.iter() { + let representative_array = + get_representative_arr(eq_group, &schema_vec, schema.clone()) + .unwrap_or_else(|| generate_random_array(n_elem, n_distinct)); + + for expr in eq_group.iter() { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + schema_vec[idx] = Some(representative_array.clone()); + } + } + + let res: Vec<_> = schema_vec + .into_iter() + .zip(schema.fields.iter()) + .map(|(elem, field)| { + ( + field.name(), + // Generate random values for columns that do not occur in any of the groups (equivalence, ordering equivalence, constants) + elem.unwrap_or_else(|| generate_random_array(n_elem, n_distinct)), + ) + }) + .collect(); + + Ok(RecordBatch::try_from_iter(res)?) + } +} diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs new file mode 100644 index 0000000000000..1a414592ce4c8 --- /dev/null +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -0,0 +1,1159 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::SortOptions; +use std::hash::Hash; +use std::sync::Arc; + +use crate::equivalence::add_offset_to_expr; +use crate::{LexOrdering, PhysicalExpr, PhysicalSortExpr}; + +/// An `OrderingEquivalenceClass` object keeps track of different alternative +/// orderings than can describe a schema. For example, consider the following table: +/// +/// ```text +/// |a|b|c|d| +/// |1|4|3|1| +/// |2|3|3|2| +/// |3|1|2|2| +/// |3|2|1|3| +/// ``` +/// +/// Here, both `vec![a ASC, b ASC]` and `vec![c DESC, d ASC]` describe the table +/// ordering. In this case, we say that these orderings are equivalent. +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub struct OrderingEquivalenceClass { + pub orderings: Vec, +} + +impl OrderingEquivalenceClass { + /// Creates new empty ordering equivalence class. + pub fn empty() -> Self { + Self { orderings: vec![] } + } + + /// Clears (empties) this ordering equivalence class. + pub fn clear(&mut self) { + self.orderings.clear(); + } + + /// Creates new ordering equivalence class from the given orderings. + pub fn new(orderings: Vec) -> Self { + let mut result = Self { orderings }; + result.remove_redundant_entries(); + result + } + + /// Checks whether `ordering` is a member of this equivalence class. + pub fn contains(&self, ordering: &LexOrdering) -> bool { + self.orderings.contains(ordering) + } + + /// Adds `ordering` to this equivalence class. + #[allow(dead_code)] + fn push(&mut self, ordering: LexOrdering) { + self.orderings.push(ordering); + // Make sure that there are no redundant orderings: + self.remove_redundant_entries(); + } + + /// Checks whether this ordering equivalence class is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns an iterator over the equivalent orderings in this class. + pub fn iter(&self) -> impl Iterator { + self.orderings.iter() + } + + /// Returns how many equivalent orderings there are in this class. + pub fn len(&self) -> usize { + self.orderings.len() + } + + /// Extend this ordering equivalence class with the `other` class. + pub fn extend(&mut self, other: Self) { + self.orderings.extend(other.orderings); + // Make sure that there are no redundant orderings: + self.remove_redundant_entries(); + } + + /// Adds new orderings into this ordering equivalence class. + pub fn add_new_orderings( + &mut self, + orderings: impl IntoIterator, + ) { + self.orderings.extend(orderings); + // Make sure that there are no redundant orderings: + self.remove_redundant_entries(); + } + + /// Removes redundant orderings from this equivalence class. For instance, + /// if we already have the ordering `[a ASC, b ASC, c DESC]`, then there is + /// no need to keep ordering `[a ASC, b ASC]` in the state. + fn remove_redundant_entries(&mut self) { + let mut work = true; + while work { + work = false; + let mut idx = 0; + while idx < self.orderings.len() { + let mut ordering_idx = idx + 1; + let mut removal = self.orderings[idx].is_empty(); + while ordering_idx < self.orderings.len() { + work |= resolve_overlap(&mut self.orderings, idx, ordering_idx); + if self.orderings[idx].is_empty() { + removal = true; + break; + } + work |= resolve_overlap(&mut self.orderings, ordering_idx, idx); + if self.orderings[ordering_idx].is_empty() { + self.orderings.swap_remove(ordering_idx); + } else { + ordering_idx += 1; + } + } + if removal { + self.orderings.swap_remove(idx); + } else { + idx += 1; + } + } + } + } + + /// Returns the concatenation of all the orderings. This enables merge + /// operations to preserve all equivalent orderings simultaneously. + pub fn output_ordering(&self) -> Option { + let output_ordering = self.orderings.iter().flatten().cloned().collect(); + let output_ordering = collapse_lex_ordering(output_ordering); + (!output_ordering.is_empty()).then_some(output_ordering) + } + + // Append orderings in `other` to all existing orderings in this equivalence + // class. + pub fn join_suffix(mut self, other: &Self) -> Self { + let n_ordering = self.orderings.len(); + // Replicate entries before cross product + let n_cross = std::cmp::max(n_ordering, other.len() * n_ordering); + self.orderings = self + .orderings + .iter() + .cloned() + .cycle() + .take(n_cross) + .collect(); + // Suffix orderings of other to the current orderings. + for (outer_idx, ordering) in other.iter().enumerate() { + for idx in 0..n_ordering { + // Calculate cross product index + let idx = outer_idx * n_ordering + idx; + self.orderings[idx].extend(ordering.iter().cloned()); + } + } + self + } + + /// Adds `offset` value to the index of each expression inside this + /// ordering equivalence class. + pub fn add_offset(&mut self, offset: usize) { + for ordering in self.orderings.iter_mut() { + for sort_expr in ordering { + sort_expr.expr = add_offset_to_expr(sort_expr.expr.clone(), offset); + } + } + } + + /// Gets sort options associated with this expression if it is a leading + /// ordering expression. Otherwise, returns `None`. + pub fn get_options(&self, expr: &Arc) -> Option { + for ordering in self.iter() { + let leading_ordering = &ordering[0]; + if leading_ordering.expr.eq(expr) { + return Some(leading_ordering.options); + } + } + None + } +} + +/// This function constructs a duplicate-free `LexOrdering` by filtering out +/// duplicate entries that have same physical expression inside. For example, +/// `vec![a ASC, a DESC]` collapses to `vec![a ASC]`. +pub fn collapse_lex_ordering(input: LexOrdering) -> LexOrdering { + let mut output = Vec::::new(); + for item in input { + if !output.iter().any(|req| req.expr.eq(&item.expr)) { + output.push(item); + } + } + output +} + +/// Trims `orderings[idx]` if some suffix of it overlaps with a prefix of +/// `orderings[pre_idx]`. Returns `true` if there is any overlap, `false` otherwise. +fn resolve_overlap(orderings: &mut [LexOrdering], idx: usize, pre_idx: usize) -> bool { + let length = orderings[idx].len(); + let other_length = orderings[pre_idx].len(); + for overlap in 1..=length.min(other_length) { + if orderings[idx][length - overlap..] == orderings[pre_idx][..overlap] { + orderings[idx].truncate(length - overlap); + return true; + } + } + false +} + +#[cfg(test)] +mod tests { + use crate::equivalence::tests::{ + convert_to_orderings, convert_to_sort_exprs, create_random_schema, + create_test_params, generate_table_for_eq_properties, is_table_same_after_sort, + }; + use crate::equivalence::{tests::create_test_schema, EquivalenceProperties}; + use crate::equivalence::{ + EquivalenceClass, EquivalenceGroup, OrderingEquivalenceClass, + }; + use crate::execution_props::ExecutionProps; + use crate::expressions::Column; + use crate::expressions::{col, BinaryExpr}; + use crate::functions::create_physical_expr; + use crate::{PhysicalExpr, PhysicalSortExpr}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::SortOptions; + use datafusion_common::Result; + use datafusion_expr::{BuiltinScalarFunction, Operator}; + use itertools::Itertools; + use std::sync::Arc; + + #[test] + fn test_ordering_satisfy() -> Result<()> { + let input_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + ])); + let crude = vec![PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions::default(), + }]; + let finer = vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: SortOptions::default(), + }, + ]; + // finer ordering satisfies, crude ordering should return true + let mut eq_properties_finer = EquivalenceProperties::new(input_schema.clone()); + eq_properties_finer.oeq_class.push(finer.clone()); + assert!(eq_properties_finer.ordering_satisfy(&crude)); + + // Crude ordering doesn't satisfy finer ordering. should return false + let mut eq_properties_crude = EquivalenceProperties::new(input_schema.clone()); + eq_properties_crude.oeq_class.push(crude.clone()); + assert!(!eq_properties_crude.ordering_satisfy(&finer)); + Ok(()) + } + + #[test] + fn test_ordering_satisfy_with_equivalence2() -> Result<()> { + let test_schema = create_test_schema()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let floor_a = &create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let floor_f = &create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("f", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let exp_a = &create_physical_expr( + &BuiltinScalarFunction::Exp, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_b.clone(), + )) as Arc; + let options = SortOptions { + descending: false, + nulls_first: false, + }; + + let test_cases = vec![ + // ------------ TEST CASE 1 ------------ + ( + // orderings + vec![ + // [a ASC, d ASC, b ASC] + vec![(col_a, options), (col_d, options), (col_b, options)], + // [c ASC] + vec![(col_c, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, b ASC], requirement is not satisfied. + vec![(col_a, options), (col_b, options)], + // expected: requirement is not satisfied. + false, + ), + // ------------ TEST CASE 2 ------------ + ( + // orderings + vec![ + // [a ASC, c ASC, b ASC] + vec![(col_a, options), (col_c, options), (col_b, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [floor(a) ASC], + vec![(floor_a, options)], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 2.1 ------------ + ( + // orderings + vec![ + // [a ASC, c ASC, b ASC] + vec![(col_a, options), (col_c, options), (col_b, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [floor(f) ASC], (Please note that a=f) + vec![(floor_f, options)], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 3 ------------ + ( + // orderings + vec![ + // [a ASC, c ASC, b ASC] + vec![(col_a, options), (col_c, options), (col_b, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, c ASC, a+b ASC], + vec![(col_a, options), (col_c, options), (&a_plus_b, options)], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 4 ------------ + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC, d ASC] + vec![ + (col_a, options), + (col_b, options), + (col_c, options), + (col_d, options), + ], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [floor(a) ASC, a+b ASC], + vec![(floor_a, options), (&a_plus_b, options)], + // expected: requirement is satisfied. + false, + ), + // ------------ TEST CASE 5 ------------ + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC, d ASC] + vec![ + (col_a, options), + (col_b, options), + (col_c, options), + (col_d, options), + ], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [exp(a) ASC, a+b ASC], + vec![(exp_a, options), (&a_plus_b, options)], + // expected: requirement is not satisfied. + // TODO: If we know that exp function is 1-to-1 function. + // we could have deduced that above requirement is satisfied. + false, + ), + // ------------ TEST CASE 6 ------------ + ( + // orderings + vec![ + // [a ASC, d ASC, b ASC] + vec![(col_a, options), (col_d, options), (col_b, options)], + // [c ASC] + vec![(col_c, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, d ASC, floor(a) ASC], + vec![(col_a, options), (col_d, options), (floor_a, options)], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 7 ------------ + ( + // orderings + vec![ + // [a ASC, c ASC, b ASC] + vec![(col_a, options), (col_c, options), (col_b, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, floor(a) ASC, a + b ASC], + vec![(col_a, options), (floor_a, options), (&a_plus_b, options)], + // expected: requirement is not satisfied. + false, + ), + // ------------ TEST CASE 8 ------------ + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC] + vec![(col_a, options), (col_b, options), (col_c, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, c ASC, floor(a) ASC, a + b ASC], + vec![ + (col_a, options), + (col_c, options), + (&floor_a, options), + (&a_plus_b, options), + ], + // expected: requirement is not satisfied. + false, + ), + // ------------ TEST CASE 9 ------------ + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC, d ASC] + vec![ + (col_a, options), + (col_b, options), + (col_c, options), + (col_d, options), + ], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, b ASC, c ASC, floor(a) ASC], + vec![ + (col_a, options), + (col_b, options), + (&col_c, options), + (&floor_a, options), + ], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 10 ------------ + ( + // orderings + vec![ + // [d ASC, b ASC] + vec![(col_d, options), (col_b, options)], + // [c ASC, a ASC] + vec![(col_c, options), (col_a, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [c ASC, d ASC, a + b ASC], + vec![(col_c, options), (col_d, options), (&a_plus_b, options)], + // expected: requirement is satisfied. + true, + ), + ]; + + for (orderings, eq_group, constants, reqs, expected) in test_cases { + let err_msg = + format!("error in test orderings: {orderings:?}, eq_group: {eq_group:?}, constants: {constants:?}, reqs: {reqs:?}, expected: {expected:?}"); + let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); + let orderings = convert_to_orderings(&orderings); + eq_properties.add_new_orderings(orderings); + let eq_group = eq_group + .into_iter() + .map(|eq_class| { + let eq_classes = eq_class.into_iter().cloned().collect::>(); + EquivalenceClass::new(eq_classes) + }) + .collect::>(); + let eq_group = EquivalenceGroup::new(eq_group); + eq_properties.add_equivalence_group(eq_group); + + let constants = constants.into_iter().cloned(); + eq_properties = eq_properties.add_constants(constants); + + let reqs = convert_to_sort_exprs(&reqs); + assert_eq!( + eq_properties.ordering_satisfy(&reqs), + expected, + "{}", + err_msg + ); + } + + Ok(()) + } + + #[test] + fn test_ordering_satisfy_with_equivalence() -> Result<()> { + // Schema satisfies following orderings: + // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] + // and + // Column [a=c] (e.g they are aliases). + let (test_schema, eq_properties) = create_test_params()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_g = &col("g", &test_schema)?; + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, 625, 5)?; + + // First element in the tuple stores vector of requirement, second element is the expected return value for ordering_satisfy function + let requirements = vec![ + // `a ASC NULLS LAST`, expects `ordering_satisfy` to be `true`, since existing ordering `a ASC NULLS LAST, b ASC NULLS LAST` satisfies it + (vec![(col_a, option_asc)], true), + (vec![(col_a, option_desc)], false), + // Test whether equivalence works as expected + (vec![(col_c, option_asc)], true), + (vec![(col_c, option_desc)], false), + // Test whether ordering equivalence works as expected + (vec![(col_d, option_asc)], true), + (vec![(col_d, option_asc), (col_b, option_asc)], true), + (vec![(col_d, option_desc), (col_b, option_asc)], false), + ( + vec![ + (col_e, option_desc), + (col_f, option_asc), + (col_g, option_asc), + ], + true, + ), + (vec![(col_e, option_desc), (col_f, option_asc)], true), + (vec![(col_e, option_asc), (col_f, option_asc)], false), + (vec![(col_e, option_desc), (col_b, option_asc)], false), + (vec![(col_e, option_asc), (col_b, option_asc)], false), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_d, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_desc), + (col_f, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_desc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_d, option_desc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_asc), + (col_f, option_asc), + ], + false, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_asc), + (col_b, option_asc), + ], + false, + ), + (vec![(col_d, option_asc), (col_e, option_desc)], true), + ( + vec![ + (col_d, option_asc), + (col_c, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_f, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_c, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_b, option_asc), + (col_f, option_asc), + ], + true, + ), + ]; + + for (cols, expected) in requirements { + let err_msg = format!("Error in test case:{cols:?}"); + let required = cols + .into_iter() + .map(|(expr, options)| PhysicalSortExpr { + expr: expr.clone(), + options, + }) + .collect::>(); + + // Check expected result with experimental result. + assert_eq!( + is_table_same_after_sort( + required.clone(), + table_data_with_properties.clone() + )?, + expected + ); + assert_eq!( + eq_properties.ordering_satisfy(&required), + expected, + "{err_msg}" + ); + } + Ok(()) + } + + #[test] + fn test_ordering_satisfy_with_equivalence_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 5; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + const SORT_OPTIONS: SortOptions = SortOptions { + descending: false, + nulls_first: false, + }; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + let col_exprs = vec![ + col("a", &test_schema)?, + col("b", &test_schema)?, + col("c", &test_schema)?, + col("d", &test_schema)?, + col("e", &test_schema)?, + col("f", &test_schema)?, + ]; + + for n_req in 0..=col_exprs.len() { + for exprs in col_exprs.iter().combinations(n_req) { + let requirement = exprs + .into_iter() + .map(|expr| PhysicalSortExpr { + expr: expr.clone(), + options: SORT_OPTIONS, + }) + .collect::>(); + let expected = is_table_same_after_sort( + requirement.clone(), + table_data_with_properties.clone(), + )?; + let err_msg = format!( + "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", + requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants + ); + // Check whether ordering_satisfy API result and + // experimental result matches. + assert_eq!( + eq_properties.ordering_satisfy(&requirement), + expected, + "{}", + err_msg + ); + } + } + } + + Ok(()) + } + + #[test] + fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 100; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + const SORT_OPTIONS: SortOptions = SortOptions { + descending: false, + nulls_first: false, + }; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + + let floor_a = create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let exprs = vec![ + col("a", &test_schema)?, + col("b", &test_schema)?, + col("c", &test_schema)?, + col("d", &test_schema)?, + col("e", &test_schema)?, + col("f", &test_schema)?, + floor_a, + a_plus_b, + ]; + + for n_req in 0..=exprs.len() { + for exprs in exprs.iter().combinations(n_req) { + let requirement = exprs + .into_iter() + .map(|expr| PhysicalSortExpr { + expr: expr.clone(), + options: SORT_OPTIONS, + }) + .collect::>(); + let expected = is_table_same_after_sort( + requirement.clone(), + table_data_with_properties.clone(), + )?; + let err_msg = format!( + "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", + requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants + ); + // Check whether ordering_satisfy API result and + // experimental result matches. + + assert_eq!( + eq_properties.ordering_satisfy(&requirement), + (expected | false), + "{}", + err_msg + ); + } + } + } + + Ok(()) + } + + #[test] + fn test_ordering_satisfy_different_lengths() -> Result<()> { + let test_schema = create_test_schema()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let options = SortOptions { + descending: false, + nulls_first: false, + }; + // a=c (e.g they are aliases). + let mut eq_properties = EquivalenceProperties::new(test_schema); + eq_properties.add_equal_conditions(col_a, col_c); + + let orderings = vec![ + vec![(col_a, options)], + vec![(col_e, options)], + vec![(col_d, options), (col_f, options)], + ]; + let orderings = convert_to_orderings(&orderings); + + // Column [a ASC], [e ASC], [d ASC, f ASC] are all valid orderings for the schema. + eq_properties.add_new_orderings(orderings); + + // First entry in the tuple is required ordering, second entry is the expected flag + // that indicates whether this required ordering is satisfied. + // ([a ASC], true) indicate a ASC requirement is already satisfied by existing orderings. + let test_cases = vec![ + // [c ASC, a ASC, e ASC], expected represents this requirement is satisfied + ( + vec![(col_c, options), (col_a, options), (col_e, options)], + true, + ), + (vec![(col_c, options), (col_b, options)], false), + (vec![(col_c, options), (col_d, options)], true), + ( + vec![(col_d, options), (col_f, options), (col_b, options)], + false, + ), + (vec![(col_d, options), (col_f, options)], true), + ]; + + for (reqs, expected) in test_cases { + let err_msg = + format!("error in test reqs: {:?}, expected: {:?}", reqs, expected,); + let reqs = convert_to_sort_exprs(&reqs); + assert_eq!( + eq_properties.ordering_satisfy(&reqs), + expected, + "{}", + err_msg + ); + } + + Ok(()) + } + + #[test] + fn test_remove_redundant_entries_oeq_class() -> Result<()> { + let schema = create_test_schema()?; + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_d = &col("d", &schema)?; + let col_e = &col("e", &schema)?; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + + // First entry in the tuple is the given orderings for the table + // Second entry is the simplest version of the given orderings that is functionally equivalent. + let test_cases = vec![ + // ------- TEST CASE 1 --------- + ( + // ORDERINGS GIVEN + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + ], + // EXPECTED orderings that is succinct. + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + ], + ), + // ------- TEST CASE 2 --------- + ( + // ORDERINGS GIVEN + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + ], + // EXPECTED orderings that is succinct. + vec![ + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + ], + ), + // ------- TEST CASE 3 --------- + ( + // ORDERINGS GIVEN + vec![ + // [a ASC, b DESC] + vec![(col_a, option_asc), (col_b, option_desc)], + // [a ASC] + vec![(col_a, option_asc)], + // [a ASC, c ASC] + vec![(col_a, option_asc), (col_c, option_asc)], + ], + // EXPECTED orderings that is succinct. + vec![ + // [a ASC, b DESC] + vec![(col_a, option_asc), (col_b, option_desc)], + // [a ASC, c ASC] + vec![(col_a, option_asc), (col_c, option_asc)], + ], + ), + // ------- TEST CASE 4 --------- + ( + // ORDERINGS GIVEN + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + // [a ASC] + vec![(col_a, option_asc)], + ], + // EXPECTED orderings that is succinct. + vec![ + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + ], + ), + // ------- TEST CASE 5 --------- + // Empty ordering + ( + vec![vec![]], + // No ordering in the state (empty ordering is ignored). + vec![], + ), + // ------- TEST CASE 6 --------- + ( + // ORDERINGS GIVEN + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [b ASC] + vec![(col_b, option_asc)], + ], + // EXPECTED orderings that is succinct. + vec![ + // [a ASC] + vec![(col_a, option_asc)], + // [b ASC] + vec![(col_b, option_asc)], + ], + ), + // ------- TEST CASE 7 --------- + // b, a + // c, a + // d, b, c + ( + // ORDERINGS GIVEN + vec![ + // [b ASC, a ASC] + vec![(col_b, option_asc), (col_a, option_asc)], + // [c ASC, a ASC] + vec![(col_c, option_asc), (col_a, option_asc)], + // [d ASC, b ASC, c ASC] + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + ], + // EXPECTED orderings that is succinct. + vec![ + // [b ASC, a ASC] + vec![(col_b, option_asc), (col_a, option_asc)], + // [c ASC, a ASC] + vec![(col_c, option_asc), (col_a, option_asc)], + // [d ASC] + vec![(col_d, option_asc)], + ], + ), + // ------- TEST CASE 8 --------- + // b, e + // c, a + // d, b, e, c, a + ( + // ORDERINGS GIVEN + vec![ + // [b ASC, e ASC] + vec![(col_b, option_asc), (col_e, option_asc)], + // [c ASC, a ASC] + vec![(col_c, option_asc), (col_a, option_asc)], + // [d ASC, b ASC, e ASC, c ASC, a ASC] + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_asc), + (col_c, option_asc), + (col_a, option_asc), + ], + ], + // EXPECTED orderings that is succinct. + vec![ + // [b ASC, e ASC] + vec![(col_b, option_asc), (col_e, option_asc)], + // [c ASC, a ASC] + vec![(col_c, option_asc), (col_a, option_asc)], + // [d ASC] + vec![(col_d, option_asc)], + ], + ), + // ------- TEST CASE 9 --------- + // b + // a, b, c + // d, a, b + ( + // ORDERINGS GIVEN + vec![ + // [b ASC] + vec![(col_b, option_asc)], + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + // [d ASC, a ASC, b ASC] + vec![ + (col_d, option_asc), + (col_a, option_asc), + (col_b, option_asc), + ], + ], + // EXPECTED orderings that is succinct. + vec![ + // [b ASC] + vec![(col_b, option_asc)], + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + // [d ASC] + vec![(col_d, option_asc)], + ], + ), + ]; + for (orderings, expected) in test_cases { + let orderings = convert_to_orderings(&orderings); + let expected = convert_to_orderings(&expected); + let actual = OrderingEquivalenceClass::new(orderings.clone()); + let actual = actual.orderings; + let err_msg = format!( + "orderings: {:?}, expected: {:?}, actual :{:?}", + orderings, expected, actual + ); + assert_eq!(actual.len(), expected.len(), "{}", err_msg); + for elem in actual { + assert!(expected.contains(&elem), "{}", err_msg); + } + } + + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs new file mode 100644 index 0000000000000..0f92b2c2f431d --- /dev/null +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -0,0 +1,1153 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use crate::expressions::Column; +use crate::PhysicalExpr; + +use arrow::datatypes::SchemaRef; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::Result; + +/// Stores the mapping between source expressions and target expressions for a +/// projection. +#[derive(Debug, Clone)] +pub struct ProjectionMapping { + /// Mapping between source expressions and target expressions. + /// Vector indices correspond to the indices after projection. + pub map: Vec<(Arc, Arc)>, +} + +impl ProjectionMapping { + /// Constructs the mapping between a projection's input and output + /// expressions. + /// + /// For example, given the input projection expressions (`a + b`, `c + d`) + /// and an output schema with two columns `"c + d"` and `"a + b"`, the + /// projection mapping would be: + /// + /// ```text + /// [0]: (c + d, col("c + d")) + /// [1]: (a + b, col("a + b")) + /// ``` + /// + /// where `col("c + d")` means the column named `"c + d"`. + pub fn try_new( + expr: &[(Arc, String)], + input_schema: &SchemaRef, + ) -> Result { + // Construct a map from the input expressions to the output expression of the projection: + expr.iter() + .enumerate() + .map(|(expr_idx, (expression, name))| { + let target_expr = Arc::new(Column::new(name, expr_idx)) as _; + expression + .clone() + .transform_down(&|e| match e.as_any().downcast_ref::() { + Some(col) => { + // Sometimes, an expression and its name in the input_schema + // doesn't match. This can cause problems, so we make sure + // that the expression name matches with the name in `input_schema`. + // Conceptually, `source_expr` and `expression` should be the same. + let idx = col.index(); + let matching_input_field = input_schema.field(idx); + let matching_input_column = + Column::new(matching_input_field.name(), idx); + Ok(Transformed::Yes(Arc::new(matching_input_column))) + } + None => Ok(Transformed::No(e)), + }) + .map(|source_expr| (source_expr, target_expr)) + }) + .collect::>>() + .map(|map| Self { map }) + } + + /// Iterate over pairs of (source, target) expressions + pub fn iter( + &self, + ) -> impl Iterator, Arc)> + '_ { + self.map.iter() + } + + /// This function returns the target expression for a given source expression. + /// + /// # Arguments + /// + /// * `expr` - Source physical expression. + /// + /// # Returns + /// + /// An `Option` containing the target for the given source expression, + /// where a `None` value means that `expr` is not inside the mapping. + pub fn target_expr( + &self, + expr: &Arc, + ) -> Option> { + self.map + .iter() + .find(|(source, _)| source.eq(expr)) + .map(|(_, target)| target.clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::equivalence::tests::{ + apply_projection, convert_to_orderings, convert_to_orderings_owned, + create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort, + output_schema, + }; + use crate::equivalence::EquivalenceProperties; + use crate::execution_props::ExecutionProps; + use crate::expressions::{col, BinaryExpr, Literal}; + use crate::functions::create_physical_expr; + use crate::PhysicalSortExpr; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::{SortOptions, TimeUnit}; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{BuiltinScalarFunction, Operator}; + use itertools::Itertools; + use std::sync::Arc; + + #[test] + fn project_orderings() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true), + ])); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_d = &col("d", &schema)?; + let col_e = &col("e", &schema)?; + let col_ts = &col("ts", &schema)?; + let interval = Arc::new(Literal::new(ScalarValue::IntervalDayTime(Some(2)))) + as Arc; + let date_bin_func = &create_physical_expr( + &BuiltinScalarFunction::DateBin, + &[interval, col_ts.clone()], + &schema, + &ExecutionProps::default(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_b.clone(), + )) as Arc; + let b_plus_d = Arc::new(BinaryExpr::new( + col_b.clone(), + Operator::Plus, + col_d.clone(), + )) as Arc; + let b_plus_e = Arc::new(BinaryExpr::new( + col_b.clone(), + Operator::Plus, + col_e.clone(), + )) as Arc; + let c_plus_d = Arc::new(BinaryExpr::new( + col_c.clone(), + Operator::Plus, + col_d.clone(), + )) as Arc; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + + let test_cases = vec![ + // ---------- TEST CASE 1 ------------ + ( + // orderings + vec![ + // [b ASC] + vec![(col_b, option_asc)], + ], + // projection exprs + vec![(col_b, "b_new".to_string()), (col_a, "a_new".to_string())], + // expected + vec![ + // [b_new ASC] + vec![("b_new", option_asc)], + ], + ), + // ---------- TEST CASE 2 ------------ + ( + // orderings + vec![ + // empty ordering + ], + // projection exprs + vec![(col_c, "c_new".to_string()), (col_b, "b_new".to_string())], + // expected + vec![ + // no ordering at the output + ], + ), + // ---------- TEST CASE 3 ------------ + ( + // orderings + vec![ + // [ts ASC] + vec![(col_ts, option_asc)], + ], + // projection exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_ts, "ts_new".to_string()), + (date_bin_func, "date_bin_res".to_string()), + ], + // expected + vec![ + // [date_bin_res ASC] + vec![("date_bin_res", option_asc)], + // [ts_new ASC] + vec![("ts_new", option_asc)], + ], + ), + // ---------- TEST CASE 4 ------------ + ( + // orderings + vec![ + // [a ASC, ts ASC] + vec![(col_a, option_asc), (col_ts, option_asc)], + // [b ASC, ts ASC] + vec![(col_b, option_asc), (col_ts, option_asc)], + ], + // projection exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_ts, "ts_new".to_string()), + (date_bin_func, "date_bin_res".to_string()), + ], + // expected + vec![ + // [a_new ASC, ts_new ASC] + vec![("a_new", option_asc), ("ts_new", option_asc)], + // [a_new ASC, date_bin_res ASC] + vec![("a_new", option_asc), ("date_bin_res", option_asc)], + // [b_new ASC, ts_new ASC] + vec![("b_new", option_asc), ("ts_new", option_asc)], + // [b_new ASC, date_bin_res ASC] + vec![("b_new", option_asc), ("date_bin_res", option_asc)], + ], + ), + // ---------- TEST CASE 5 ------------ + ( + // orderings + vec![ + // [a + b ASC] + vec![(&a_plus_b, option_asc)], + ], + // projection exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (&a_plus_b, "a+b".to_string()), + ], + // expected + vec![ + // [a + b ASC] + vec![("a+b", option_asc)], + ], + ), + // ---------- TEST CASE 6 ------------ + ( + // orderings + vec![ + // [a + b ASC, c ASC] + vec![(&a_plus_b, option_asc), (&col_c, option_asc)], + ], + // projection exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_c, "c_new".to_string()), + (&a_plus_b, "a+b".to_string()), + ], + // expected + vec![ + // [a + b ASC, c_new ASC] + vec![("a+b", option_asc), ("c_new", option_asc)], + ], + ), + // ------- TEST CASE 7 ---------- + ( + vec![ + // [a ASC, b ASC, c ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [a ASC, d ASC] + vec![(col_a, option_asc), (col_d, option_asc)], + ], + // b as b_new, a as a_new, d as d_new b+d + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_d, "d_new".to_string()), + (&b_plus_d, "b+d".to_string()), + ], + // expected + vec![ + // [a_new ASC, b_new ASC] + vec![("a_new", option_asc), ("b_new", option_asc)], + // [a_new ASC, d_new ASC] + vec![("a_new", option_asc), ("d_new", option_asc)], + // [a_new ASC, b+d ASC] + vec![("a_new", option_asc), ("b+d", option_asc)], + ], + ), + // ------- TEST CASE 8 ---------- + ( + // orderings + vec![ + // [b+d ASC] + vec![(&b_plus_d, option_asc)], + ], + // proj exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_d, "d_new".to_string()), + (&b_plus_d, "b+d".to_string()), + ], + // expected + vec![ + // [b+d ASC] + vec![("b+d", option_asc)], + ], + ), + // ------- TEST CASE 9 ---------- + ( + // orderings + vec![ + // [a ASC, d ASC, b ASC] + vec![ + (col_a, option_asc), + (col_d, option_asc), + (col_b, option_asc), + ], + // [c ASC] + vec![(col_c, option_asc)], + ], + // proj exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_d, "d_new".to_string()), + (col_c, "c_new".to_string()), + ], + // expected + vec![ + // [a_new ASC, d_new ASC, b_new ASC] + vec![ + ("a_new", option_asc), + ("d_new", option_asc), + ("b_new", option_asc), + ], + // [c_new ASC], + vec![("c_new", option_asc)], + ], + ), + // ------- TEST CASE 10 ---------- + ( + vec![ + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + // [a ASC, d ASC] + vec![(col_a, option_asc), (col_d, option_asc)], + ], + // proj exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_c, "c_new".to_string()), + (&c_plus_d, "c+d".to_string()), + ], + // expected + vec![ + // [a_new ASC, b_new ASC, c_new ASC] + vec![ + ("a_new", option_asc), + ("b_new", option_asc), + ("c_new", option_asc), + ], + // [a_new ASC, b_new ASC, c+d ASC] + vec![ + ("a_new", option_asc), + ("b_new", option_asc), + ("c+d", option_asc), + ], + ], + ), + // ------- TEST CASE 11 ---------- + ( + // orderings + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [a ASC, d ASC] + vec![(col_a, option_asc), (col_d, option_asc)], + ], + // proj exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (&b_plus_d, "b+d".to_string()), + ], + // expected + vec![ + // [a_new ASC, b_new ASC] + vec![("a_new", option_asc), ("b_new", option_asc)], + // [a_new ASC, b + d ASC] + vec![("a_new", option_asc), ("b+d", option_asc)], + ], + ), + // ------- TEST CASE 12 ---------- + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + ], + // proj exprs + vec![(col_c, "c_new".to_string()), (col_a, "a_new".to_string())], + // expected + vec![ + // [a_new ASC] + vec![("a_new", option_asc)], + ], + ), + // ------- TEST CASE 13 ---------- + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + // [a ASC, a + b ASC, c ASC] + vec![ + (col_a, option_asc), + (&a_plus_b, option_asc), + (col_c, option_asc), + ], + ], + // proj exprs + vec![ + (col_c, "c_new".to_string()), + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (&a_plus_b, "a+b".to_string()), + ], + // expected + vec![ + // [a_new ASC, b_new ASC, c_new ASC] + vec![ + ("a_new", option_asc), + ("b_new", option_asc), + ("c_new", option_asc), + ], + // [a_new ASC, a+b ASC, c_new ASC] + vec![ + ("a_new", option_asc), + ("a+b", option_asc), + ("c_new", option_asc), + ], + ], + ), + // ------- TEST CASE 14 ---------- + ( + // orderings + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [c ASC, b ASC] + vec![(col_c, option_asc), (col_b, option_asc)], + // [d ASC, e ASC] + vec![(col_d, option_asc), (col_e, option_asc)], + ], + // proj exprs + vec![ + (col_c, "c_new".to_string()), + (col_d, "d_new".to_string()), + (col_a, "a_new".to_string()), + (&b_plus_e, "b+e".to_string()), + ], + // expected + vec![ + // [a_new ASC, d_new ASC, b+e ASC] + vec![ + ("a_new", option_asc), + ("d_new", option_asc), + ("b+e", option_asc), + ], + // [d_new ASC, a_new ASC, b+e ASC] + vec![ + ("d_new", option_asc), + ("a_new", option_asc), + ("b+e", option_asc), + ], + // [c_new ASC, d_new ASC, b+e ASC] + vec![ + ("c_new", option_asc), + ("d_new", option_asc), + ("b+e", option_asc), + ], + // [d_new ASC, c_new ASC, b+e ASC] + vec![ + ("d_new", option_asc), + ("c_new", option_asc), + ("b+e", option_asc), + ], + ], + ), + // ------- TEST CASE 15 ---------- + ( + // orderings + vec![ + // [a ASC, c ASC, b ASC] + vec![ + (col_a, option_asc), + (col_c, option_asc), + (&col_b, option_asc), + ], + ], + // proj exprs + vec![ + (col_c, "c_new".to_string()), + (col_a, "a_new".to_string()), + (&a_plus_b, "a+b".to_string()), + ], + // expected + vec![ + // [a_new ASC, d_new ASC, b+e ASC] + vec![ + ("a_new", option_asc), + ("c_new", option_asc), + ("a+b", option_asc), + ], + ], + ), + // ------- TEST CASE 16 ---------- + ( + // orderings + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [c ASC, b DESC] + vec![(col_c, option_asc), (col_b, option_desc)], + // [e ASC] + vec![(col_e, option_asc)], + ], + // proj exprs + vec![ + (col_c, "c_new".to_string()), + (col_a, "a_new".to_string()), + (col_b, "b_new".to_string()), + (&b_plus_e, "b+e".to_string()), + ], + // expected + vec![ + // [a_new ASC, b_new ASC] + vec![("a_new", option_asc), ("b_new", option_asc)], + // [a_new ASC, b_new ASC] + vec![("a_new", option_asc), ("b+e", option_asc)], + // [c_new ASC, b_new DESC] + vec![("c_new", option_asc), ("b_new", option_desc)], + ], + ), + ]; + + for (idx, (orderings, proj_exprs, expected)) in test_cases.into_iter().enumerate() + { + let mut eq_properties = EquivalenceProperties::new(schema.clone()); + + let orderings = convert_to_orderings(&orderings); + eq_properties.add_new_orderings(orderings); + + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (expr.clone(), name)) + .collect::>(); + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; + let output_schema = output_schema(&projection_mapping, &schema)?; + + let expected = expected + .into_iter() + .map(|ordering| { + ordering + .into_iter() + .map(|(name, options)| { + (col(name, &output_schema).unwrap(), options) + }) + .collect::>() + }) + .collect::>(); + let expected = convert_to_orderings_owned(&expected); + + let projected_eq = eq_properties.project(&projection_mapping, output_schema); + let orderings = projected_eq.oeq_class(); + + let err_msg = format!( + "test_idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}", + idx, orderings.orderings, expected, projection_mapping + ); + + assert_eq!(orderings.len(), expected.len(), "{}", err_msg); + for expected_ordering in &expected { + assert!(orderings.contains(expected_ordering), "{}", err_msg) + } + } + + Ok(()) + } + + #[test] + fn project_orderings2() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true), + ])); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_ts = &col("ts", &schema)?; + let a_plus_b = Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_b.clone(), + )) as Arc; + let interval = Arc::new(Literal::new(ScalarValue::IntervalDayTime(Some(2)))) + as Arc; + let date_bin_ts = &create_physical_expr( + &BuiltinScalarFunction::DateBin, + &[interval, col_ts.clone()], + &schema, + &ExecutionProps::default(), + )?; + + let round_c = &create_physical_expr( + &BuiltinScalarFunction::Round, + &[col_c.clone()], + &schema, + &ExecutionProps::default(), + )?; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + let proj_exprs = vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_c, "c_new".to_string()), + (date_bin_ts, "date_bin_res".to_string()), + (round_c, "round_c_res".to_string()), + ]; + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (expr.clone(), name)) + .collect::>(); + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; + let output_schema = output_schema(&projection_mapping, &schema)?; + + let col_a_new = &col("a_new", &output_schema)?; + let col_b_new = &col("b_new", &output_schema)?; + let col_c_new = &col("c_new", &output_schema)?; + let col_date_bin_res = &col("date_bin_res", &output_schema)?; + let col_round_c_res = &col("round_c_res", &output_schema)?; + let a_new_plus_b_new = Arc::new(BinaryExpr::new( + col_a_new.clone(), + Operator::Plus, + col_b_new.clone(), + )) as Arc; + + let test_cases = vec![ + // ---------- TEST CASE 1 ------------ + ( + // orderings + vec![ + // [a ASC] + vec![(col_a, option_asc)], + ], + // expected + vec![ + // [b_new ASC] + vec![(col_a_new, option_asc)], + ], + ), + // ---------- TEST CASE 2 ------------ + ( + // orderings + vec![ + // [a+b ASC] + vec![(&a_plus_b, option_asc)], + ], + // expected + vec![ + // [b_new ASC] + vec![(&a_new_plus_b_new, option_asc)], + ], + ), + // ---------- TEST CASE 3 ------------ + ( + // orderings + vec![ + // [a ASC, ts ASC] + vec![(col_a, option_asc), (col_ts, option_asc)], + ], + // expected + vec![ + // [a_new ASC, date_bin_res ASC] + vec![(col_a_new, option_asc), (col_date_bin_res, option_asc)], + ], + ), + // ---------- TEST CASE 4 ------------ + ( + // orderings + vec![ + // [a ASC, ts ASC, b ASC] + vec![ + (col_a, option_asc), + (col_ts, option_asc), + (col_b, option_asc), + ], + ], + // expected + vec![ + // [a_new ASC, date_bin_res ASC] + // Please note that result is not [a_new ASC, date_bin_res ASC, b_new ASC] + // because, datebin_res may not be 1-1 function. Hence without introducing ts + // dependency we cannot guarantee any ordering after date_bin_res column. + vec![(col_a_new, option_asc), (col_date_bin_res, option_asc)], + ], + ), + // ---------- TEST CASE 5 ------------ + ( + // orderings + vec![ + // [a ASC, c ASC] + vec![(col_a, option_asc), (col_c, option_asc)], + ], + // expected + vec![ + // [a_new ASC, round_c_res ASC, c_new ASC] + vec![(col_a_new, option_asc), (col_round_c_res, option_asc)], + // [a_new ASC, c_new ASC] + vec![(col_a_new, option_asc), (col_c_new, option_asc)], + ], + ), + // ---------- TEST CASE 6 ------------ + ( + // orderings + vec![ + // [c ASC, b ASC] + vec![(col_c, option_asc), (col_b, option_asc)], + ], + // expected + vec![ + // [round_c_res ASC] + vec![(col_round_c_res, option_asc)], + // [c_new ASC, b_new ASC] + vec![(col_c_new, option_asc), (col_b_new, option_asc)], + ], + ), + // ---------- TEST CASE 7 ------------ + ( + // orderings + vec![ + // [a+b ASC, c ASC] + vec![(&a_plus_b, option_asc), (col_c, option_asc)], + ], + // expected + vec![ + // [a+b ASC, round(c) ASC, c_new ASC] + vec![ + (&a_new_plus_b_new, option_asc), + (&col_round_c_res, option_asc), + ], + // [a+b ASC, c_new ASC] + vec![(&a_new_plus_b_new, option_asc), (col_c_new, option_asc)], + ], + ), + ]; + + for (idx, (orderings, expected)) in test_cases.iter().enumerate() { + let mut eq_properties = EquivalenceProperties::new(schema.clone()); + + let orderings = convert_to_orderings(orderings); + eq_properties.add_new_orderings(orderings); + + let expected = convert_to_orderings(expected); + + let projected_eq = + eq_properties.project(&projection_mapping, output_schema.clone()); + let orderings = projected_eq.oeq_class(); + + let err_msg = format!( + "test idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}", + idx, orderings.orderings, expected, projection_mapping + ); + + assert_eq!(orderings.len(), expected.len(), "{}", err_msg); + for expected_ordering in &expected { + assert!(orderings.contains(expected_ordering), "{}", err_msg) + } + } + Ok(()) + } + + #[test] + fn project_orderings3() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + Field::new("f", DataType::Int32, true), + ])); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_d = &col("d", &schema)?; + let col_e = &col("e", &schema)?; + let col_f = &col("f", &schema)?; + let a_plus_b = Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_b.clone(), + )) as Arc; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + let proj_exprs = vec![ + (col_c, "c_new".to_string()), + (col_d, "d_new".to_string()), + (&a_plus_b, "a+b".to_string()), + ]; + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (expr.clone(), name)) + .collect::>(); + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; + let output_schema = output_schema(&projection_mapping, &schema)?; + + let col_a_plus_b_new = &col("a+b", &output_schema)?; + let col_c_new = &col("c_new", &output_schema)?; + let col_d_new = &col("d_new", &output_schema)?; + + let test_cases = vec![ + // ---------- TEST CASE 1 ------------ + ( + // orderings + vec![ + // [d ASC, b ASC] + vec![(col_d, option_asc), (col_b, option_asc)], + // [c ASC, a ASC] + vec![(col_c, option_asc), (col_a, option_asc)], + ], + // equal conditions + vec![], + // expected + vec![ + // [d_new ASC, c_new ASC, a+b ASC] + vec![ + (col_d_new, option_asc), + (col_c_new, option_asc), + (col_a_plus_b_new, option_asc), + ], + // [c_new ASC, d_new ASC, a+b ASC] + vec![ + (col_c_new, option_asc), + (col_d_new, option_asc), + (col_a_plus_b_new, option_asc), + ], + ], + ), + // ---------- TEST CASE 2 ------------ + ( + // orderings + vec![ + // [d ASC, b ASC] + vec![(col_d, option_asc), (col_b, option_asc)], + // [c ASC, e ASC], Please note that a=e + vec![(col_c, option_asc), (col_e, option_asc)], + ], + // equal conditions + vec![(col_e, col_a)], + // expected + vec![ + // [d_new ASC, c_new ASC, a+b ASC] + vec![ + (col_d_new, option_asc), + (col_c_new, option_asc), + (col_a_plus_b_new, option_asc), + ], + // [c_new ASC, d_new ASC, a+b ASC] + vec![ + (col_c_new, option_asc), + (col_d_new, option_asc), + (col_a_plus_b_new, option_asc), + ], + ], + ), + // ---------- TEST CASE 3 ------------ + ( + // orderings + vec![ + // [d ASC, b ASC] + vec![(col_d, option_asc), (col_b, option_asc)], + // [c ASC, e ASC], Please note that a=f + vec![(col_c, option_asc), (col_e, option_asc)], + ], + // equal conditions + vec![(col_a, col_f)], + // expected + vec![ + // [d_new ASC] + vec![(col_d_new, option_asc)], + // [c_new ASC] + vec![(col_c_new, option_asc)], + ], + ), + ]; + for (orderings, equal_columns, expected) in test_cases { + let mut eq_properties = EquivalenceProperties::new(schema.clone()); + for (lhs, rhs) in equal_columns { + eq_properties.add_equal_conditions(lhs, rhs); + } + + let orderings = convert_to_orderings(&orderings); + eq_properties.add_new_orderings(orderings); + + let expected = convert_to_orderings(&expected); + + let projected_eq = + eq_properties.project(&projection_mapping, output_schema.clone()); + let orderings = projected_eq.oeq_class(); + + let err_msg = format!( + "actual: {:?}, expected: {:?}, projection_mapping: {:?}", + orderings.orderings, expected, projection_mapping + ); + + assert_eq!(orderings.len(), expected.len(), "{}", err_msg); + for expected_ordering in &expected { + assert!(orderings.contains(expected_ordering), "{}", err_msg) + } + } + + Ok(()) + } + + #[test] + fn project_orderings_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 20; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + // Floor(a) + let floor_a = create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + // a + b + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let proj_exprs = vec![ + (col("a", &test_schema)?, "a_new"), + (col("b", &test_schema)?, "b_new"), + (col("c", &test_schema)?, "c_new"), + (col("d", &test_schema)?, "d_new"), + (col("e", &test_schema)?, "e_new"), + (col("f", &test_schema)?, "f_new"), + (floor_a, "floor(a)"), + (a_plus_b, "a+b"), + ]; + + for n_req in 0..=proj_exprs.len() { + for proj_exprs in proj_exprs.iter().combinations(n_req) { + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (expr.clone(), name.to_string())) + .collect::>(); + let (projected_batch, projected_eq) = apply_projection( + proj_exprs.clone(), + &table_data_with_properties, + &eq_properties, + )?; + + // Make sure each ordering after projection is valid. + for ordering in projected_eq.oeq_class().iter() { + let err_msg = format!( + "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, proj_exprs: {:?}", + ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, proj_exprs + ); + // Since ordered section satisfies schema, we expect + // that result will be same after sort (e.g sort was unnecessary). + assert!( + is_table_same_after_sort( + ordering.clone(), + projected_batch.clone(), + )?, + "{}", + err_msg + ); + } + } + } + } + + Ok(()) + } + + #[test] + fn ordering_satisfy_after_projection_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 20; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + const SORT_OPTIONS: SortOptions = SortOptions { + descending: false, + nulls_first: false, + }; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + // Floor(a) + let floor_a = create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + // a + b + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let proj_exprs = vec![ + (col("a", &test_schema)?, "a_new"), + (col("b", &test_schema)?, "b_new"), + (col("c", &test_schema)?, "c_new"), + (col("d", &test_schema)?, "d_new"), + (col("e", &test_schema)?, "e_new"), + (col("f", &test_schema)?, "f_new"), + (floor_a, "floor(a)"), + (a_plus_b, "a+b"), + ]; + + for n_req in 0..=proj_exprs.len() { + for proj_exprs in proj_exprs.iter().combinations(n_req) { + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (expr.clone(), name.to_string())) + .collect::>(); + let (projected_batch, projected_eq) = apply_projection( + proj_exprs.clone(), + &table_data_with_properties, + &eq_properties, + )?; + + let projection_mapping = + ProjectionMapping::try_new(&proj_exprs, &test_schema)?; + + let projected_exprs = projection_mapping + .iter() + .map(|(_source, target)| target.clone()) + .collect::>(); + + for n_req in 0..=projected_exprs.len() { + for exprs in projected_exprs.iter().combinations(n_req) { + let requirement = exprs + .into_iter() + .map(|expr| PhysicalSortExpr { + expr: expr.clone(), + options: SORT_OPTIONS, + }) + .collect::>(); + let expected = is_table_same_after_sort( + requirement.clone(), + projected_batch.clone(), + )?; + let err_msg = format!( + "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, projected_eq.oeq_class: {:?}, projected_eq.eq_group: {:?}, projected_eq.constants: {:?}, projection_mapping: {:?}", + requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, projected_eq.oeq_class, projected_eq.eq_group, projected_eq.constants, projection_mapping + ); + // Check whether ordering_satisfy API result and + // experimental result matches. + assert_eq!( + projected_eq.ordering_satisfy(&requirement), + expected, + "{}", + err_msg + ); + } + } + } + } + } + + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs new file mode 100644 index 0000000000000..31c1cf61193a2 --- /dev/null +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -0,0 +1,2062 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::expressions::Column; +use arrow_schema::SchemaRef; +use datafusion_common::{JoinSide, JoinType}; +use indexmap::IndexSet; +use itertools::Itertools; +use std::collections::{HashMap, HashSet}; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use crate::equivalence::{ + collapse_lex_req, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping, +}; + +use crate::expressions::Literal; +use crate::sort_properties::{ExprOrdering, SortProperties}; +use crate::{ + physical_exprs_contains, LexOrdering, LexOrderingRef, LexRequirement, + LexRequirementRef, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, +}; +use datafusion_common::tree_node::{Transformed, TreeNode}; + +use super::ordering::collapse_lex_ordering; + +/// A `EquivalenceProperties` object stores useful information related to a schema. +/// Currently, it keeps track of: +/// - Equivalent expressions, e.g expressions that have same value. +/// - Valid sort expressions (orderings) for the schema. +/// - Constants expressions (e.g expressions that are known to have constant values). +/// +/// Consider table below: +/// +/// ```text +/// ┌-------┐ +/// | a | b | +/// |---|---| +/// | 1 | 9 | +/// | 2 | 8 | +/// | 3 | 7 | +/// | 5 | 5 | +/// └---┴---┘ +/// ``` +/// +/// where both `a ASC` and `b DESC` can describe the table ordering. With +/// `EquivalenceProperties`, we can keep track of these different valid sort +/// expressions and treat `a ASC` and `b DESC` on an equal footing. +/// +/// Similarly, consider the table below: +/// +/// ```text +/// ┌-------┐ +/// | a | b | +/// |---|---| +/// | 1 | 1 | +/// | 2 | 2 | +/// | 3 | 3 | +/// | 5 | 5 | +/// └---┴---┘ +/// ``` +/// +/// where columns `a` and `b` always have the same value. We keep track of such +/// equivalences inside this object. With this information, we can optimize +/// things like partitioning. For example, if the partition requirement is +/// `Hash(a)` and output partitioning is `Hash(b)`, then we can deduce that +/// the existing partitioning satisfies the requirement. +#[derive(Debug, Clone)] +pub struct EquivalenceProperties { + /// Collection of equivalence classes that store expressions with the same + /// value. + pub eq_group: EquivalenceGroup, + /// Equivalent sort expressions for this table. + pub oeq_class: OrderingEquivalenceClass, + /// Expressions whose values are constant throughout the table. + /// TODO: We do not need to track constants separately, they can be tracked + /// inside `eq_groups` as `Literal` expressions. + pub constants: Vec>, + /// Schema associated with this object. + schema: SchemaRef, +} + +impl EquivalenceProperties { + /// Creates an empty `EquivalenceProperties` object. + pub fn new(schema: SchemaRef) -> Self { + Self { + eq_group: EquivalenceGroup::empty(), + oeq_class: OrderingEquivalenceClass::empty(), + constants: vec![], + schema, + } + } + + /// Creates a new `EquivalenceProperties` object with the given orderings. + pub fn new_with_orderings(schema: SchemaRef, orderings: &[LexOrdering]) -> Self { + Self { + eq_group: EquivalenceGroup::empty(), + oeq_class: OrderingEquivalenceClass::new(orderings.to_vec()), + constants: vec![], + schema, + } + } + + /// Returns the associated schema. + pub fn schema(&self) -> &SchemaRef { + &self.schema + } + + /// Returns a reference to the ordering equivalence class within. + pub fn oeq_class(&self) -> &OrderingEquivalenceClass { + &self.oeq_class + } + + /// Returns a reference to the equivalence group within. + pub fn eq_group(&self) -> &EquivalenceGroup { + &self.eq_group + } + + /// Returns a reference to the constant expressions + pub fn constants(&self) -> &[Arc] { + &self.constants + } + + /// Returns the normalized version of the ordering equivalence class within. + /// Normalization removes constants and duplicates as well as standardizing + /// expressions according to the equivalence group within. + pub fn normalized_oeq_class(&self) -> OrderingEquivalenceClass { + OrderingEquivalenceClass::new( + self.oeq_class + .iter() + .map(|ordering| self.normalize_sort_exprs(ordering)) + .collect(), + ) + } + + /// Extends this `EquivalenceProperties` with the `other` object. + pub fn extend(mut self, other: Self) -> Self { + self.eq_group.extend(other.eq_group); + self.oeq_class.extend(other.oeq_class); + self.add_constants(other.constants) + } + + /// Clears (empties) the ordering equivalence class within this object. + /// Call this method when existing orderings are invalidated. + pub fn clear_orderings(&mut self) { + self.oeq_class.clear(); + } + + /// Extends this `EquivalenceProperties` by adding the orderings inside the + /// ordering equivalence class `other`. + pub fn add_ordering_equivalence_class(&mut self, other: OrderingEquivalenceClass) { + self.oeq_class.extend(other); + } + + /// Adds new orderings into the existing ordering equivalence class. + pub fn add_new_orderings( + &mut self, + orderings: impl IntoIterator, + ) { + self.oeq_class.add_new_orderings(orderings); + } + + /// Incorporates the given equivalence group to into the existing + /// equivalence group within. + pub fn add_equivalence_group(&mut self, other_eq_group: EquivalenceGroup) { + self.eq_group.extend(other_eq_group); + } + + /// Adds a new equality condition into the existing equivalence group. + /// If the given equality defines a new equivalence class, adds this new + /// equivalence class to the equivalence group. + pub fn add_equal_conditions( + &mut self, + left: &Arc, + right: &Arc, + ) { + self.eq_group.add_equal_conditions(left, right); + } + + /// Track/register physical expressions with constant values. + pub fn add_constants( + mut self, + constants: impl IntoIterator>, + ) -> Self { + for expr in self.eq_group.normalize_exprs(constants) { + if !physical_exprs_contains(&self.constants, &expr) { + self.constants.push(expr); + } + } + self + } + + /// Updates the ordering equivalence group within assuming that the table + /// is re-sorted according to the argument `sort_exprs`. Note that constants + /// and equivalence classes are unchanged as they are unaffected by a re-sort. + pub fn with_reorder(mut self, sort_exprs: Vec) -> Self { + // TODO: In some cases, existing ordering equivalences may still be valid add this analysis. + self.oeq_class = OrderingEquivalenceClass::new(vec![sort_exprs]); + self + } + + /// Normalizes the given sort expressions (i.e. `sort_exprs`) using the + /// equivalence group and the ordering equivalence class within. + /// + /// Assume that `self.eq_group` states column `a` and `b` are aliases. + /// Also assume that `self.oeq_class` states orderings `d ASC` and `a ASC, c ASC` + /// are equivalent (in the sense that both describe the ordering of the table). + /// If the `sort_exprs` argument were `vec![b ASC, c ASC, a ASC]`, then this + /// function would return `vec![a ASC, c ASC]`. Internally, it would first + /// normalize to `vec![a ASC, c ASC, a ASC]` and end up with the final result + /// after deduplication. + fn normalize_sort_exprs(&self, sort_exprs: LexOrderingRef) -> LexOrdering { + // Convert sort expressions to sort requirements: + let sort_reqs = PhysicalSortRequirement::from_sort_exprs(sort_exprs.iter()); + // Normalize the requirements: + let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs); + // Convert sort requirements back to sort expressions: + PhysicalSortRequirement::to_sort_exprs(normalized_sort_reqs) + } + + /// Normalizes the given sort requirements (i.e. `sort_reqs`) using the + /// equivalence group and the ordering equivalence class within. It works by: + /// - Removing expressions that have a constant value from the given requirement. + /// - Replacing sections that belong to some equivalence class in the equivalence + /// group with the first entry in the matching equivalence class. + /// + /// Assume that `self.eq_group` states column `a` and `b` are aliases. + /// Also assume that `self.oeq_class` states orderings `d ASC` and `a ASC, c ASC` + /// are equivalent (in the sense that both describe the ordering of the table). + /// If the `sort_reqs` argument were `vec![b ASC, c ASC, a ASC]`, then this + /// function would return `vec![a ASC, c ASC]`. Internally, it would first + /// normalize to `vec![a ASC, c ASC, a ASC]` and end up with the final result + /// after deduplication. + fn normalize_sort_requirements( + &self, + sort_reqs: LexRequirementRef, + ) -> LexRequirement { + let normalized_sort_reqs = self.eq_group.normalize_sort_requirements(sort_reqs); + let constants_normalized = self.eq_group.normalize_exprs(self.constants.clone()); + // Prune redundant sections in the requirement: + collapse_lex_req( + normalized_sort_reqs + .iter() + .filter(|&order| { + !physical_exprs_contains(&constants_normalized, &order.expr) + }) + .cloned() + .collect(), + ) + } + + /// Checks whether the given ordering is satisfied by any of the existing + /// orderings. + pub fn ordering_satisfy(&self, given: LexOrderingRef) -> bool { + // Convert the given sort expressions to sort requirements: + let sort_requirements = PhysicalSortRequirement::from_sort_exprs(given.iter()); + self.ordering_satisfy_requirement(&sort_requirements) + } + + /// Checks whether the given sort requirements are satisfied by any of the + /// existing orderings. + pub fn ordering_satisfy_requirement(&self, reqs: LexRequirementRef) -> bool { + let mut eq_properties = self.clone(); + // First, standardize the given requirement: + let normalized_reqs = eq_properties.normalize_sort_requirements(reqs); + for normalized_req in normalized_reqs { + // Check whether given ordering is satisfied + if !eq_properties.ordering_satisfy_single(&normalized_req) { + return false; + } + // Treat satisfied keys as constants in subsequent iterations. We + // can do this because the "next" key only matters in a lexicographical + // ordering when the keys to its left have the same values. + // + // Note that these expressions are not properly "constants". This is just + // an implementation strategy confined to this function. + // + // For example, assume that the requirement is `[a ASC, (b + c) ASC]`, + // and existing equivalent orderings are `[a ASC, b ASC]` and `[c ASC]`. + // From the analysis above, we know that `[a ASC]` is satisfied. Then, + // we add column `a` as constant to the algorithm state. This enables us + // to deduce that `(b + c) ASC` is satisfied, given `a` is constant. + eq_properties = + eq_properties.add_constants(std::iter::once(normalized_req.expr)); + } + true + } + + /// Determines whether the ordering specified by the given sort requirement + /// is satisfied based on the orderings within, equivalence classes, and + /// constant expressions. + /// + /// # Arguments + /// + /// - `req`: A reference to a `PhysicalSortRequirement` for which the ordering + /// satisfaction check will be done. + /// + /// # Returns + /// + /// Returns `true` if the specified ordering is satisfied, `false` otherwise. + fn ordering_satisfy_single(&self, req: &PhysicalSortRequirement) -> bool { + let expr_ordering = self.get_expr_ordering(req.expr.clone()); + let ExprOrdering { expr, state, .. } = expr_ordering; + match state { + SortProperties::Ordered(options) => { + let sort_expr = PhysicalSortExpr { expr, options }; + sort_expr.satisfy(req, self.schema()) + } + // Singleton expressions satisfies any ordering. + SortProperties::Singleton => true, + SortProperties::Unordered => false, + } + } + + /// Checks whether the `given`` sort requirements are equal or more specific + /// than the `reference` sort requirements. + pub fn requirements_compatible( + &self, + given: LexRequirementRef, + reference: LexRequirementRef, + ) -> bool { + let normalized_given = self.normalize_sort_requirements(given); + let normalized_reference = self.normalize_sort_requirements(reference); + + (normalized_reference.len() <= normalized_given.len()) + && normalized_reference + .into_iter() + .zip(normalized_given) + .all(|(reference, given)| given.compatible(&reference)) + } + + /// Returns the finer ordering among the orderings `lhs` and `rhs`, breaking + /// any ties by choosing `lhs`. + /// + /// The finer ordering is the ordering that satisfies both of the orderings. + /// If the orderings are incomparable, returns `None`. + /// + /// For example, the finer ordering among `[a ASC]` and `[a ASC, b ASC]` is + /// the latter. + pub fn get_finer_ordering( + &self, + lhs: LexOrderingRef, + rhs: LexOrderingRef, + ) -> Option { + // Convert the given sort expressions to sort requirements: + let lhs = PhysicalSortRequirement::from_sort_exprs(lhs); + let rhs = PhysicalSortRequirement::from_sort_exprs(rhs); + let finer = self.get_finer_requirement(&lhs, &rhs); + // Convert the chosen sort requirements back to sort expressions: + finer.map(PhysicalSortRequirement::to_sort_exprs) + } + + /// Returns the finer ordering among the requirements `lhs` and `rhs`, + /// breaking any ties by choosing `lhs`. + /// + /// The finer requirements are the ones that satisfy both of the given + /// requirements. If the requirements are incomparable, returns `None`. + /// + /// For example, the finer requirements among `[a ASC]` and `[a ASC, b ASC]` + /// is the latter. + pub fn get_finer_requirement( + &self, + req1: LexRequirementRef, + req2: LexRequirementRef, + ) -> Option { + let mut lhs = self.normalize_sort_requirements(req1); + let mut rhs = self.normalize_sort_requirements(req2); + lhs.iter_mut() + .zip(rhs.iter_mut()) + .all(|(lhs, rhs)| { + lhs.expr.eq(&rhs.expr) + && match (lhs.options, rhs.options) { + (Some(lhs_opt), Some(rhs_opt)) => lhs_opt == rhs_opt, + (Some(options), None) => { + rhs.options = Some(options); + true + } + (None, Some(options)) => { + lhs.options = Some(options); + true + } + (None, None) => true, + } + }) + .then_some(if lhs.len() >= rhs.len() { lhs } else { rhs }) + } + + /// Calculates the "meet" of the given orderings (`lhs` and `rhs`). + /// The meet of a set of orderings is the finest ordering that is satisfied + /// by all the orderings in that set. For details, see: + /// + /// + /// + /// If there is no ordering that satisfies both `lhs` and `rhs`, returns + /// `None`. As an example, the meet of orderings `[a ASC]` and `[a ASC, b ASC]` + /// is `[a ASC]`. + pub fn get_meet_ordering( + &self, + lhs: LexOrderingRef, + rhs: LexOrderingRef, + ) -> Option { + let lhs = self.normalize_sort_exprs(lhs); + let rhs = self.normalize_sort_exprs(rhs); + let mut meet = vec![]; + for (lhs, rhs) in lhs.into_iter().zip(rhs.into_iter()) { + if lhs.eq(&rhs) { + meet.push(lhs); + } else { + break; + } + } + (!meet.is_empty()).then_some(meet) + } + + /// Projects argument `expr` according to `projection_mapping`, taking + /// equivalences into account. + /// + /// For example, assume that columns `a` and `c` are always equal, and that + /// `projection_mapping` encodes following mapping: + /// + /// ```text + /// a -> a1 + /// b -> b1 + /// ``` + /// + /// Then, this function projects `a + b` to `Some(a1 + b1)`, `c + b` to + /// `Some(a1 + b1)` and `d` to `None`, meaning that it cannot be projected. + pub fn project_expr( + &self, + expr: &Arc, + projection_mapping: &ProjectionMapping, + ) -> Option> { + self.eq_group.project_expr(projection_mapping, expr) + } + + /// Constructs a dependency map based on existing orderings referred to in + /// the projection. + /// + /// This function analyzes the orderings in the normalized order-equivalence + /// class and builds a dependency map. The dependency map captures relationships + /// between expressions within the orderings, helping to identify dependencies + /// and construct valid projected orderings during projection operations. + /// + /// # Parameters + /// + /// - `mapping`: A reference to the `ProjectionMapping` that defines the + /// relationship between source and target expressions. + /// + /// # Returns + /// + /// A [`DependencyMap`] representing the dependency map, where each + /// [`DependencyNode`] contains dependencies for the key [`PhysicalSortExpr`]. + /// + /// # Example + /// + /// Assume we have two equivalent orderings: `[a ASC, b ASC]` and `[a ASC, c ASC]`, + /// and the projection mapping is `[a -> a_new, b -> b_new, b + c -> b + c]`. + /// Then, the dependency map will be: + /// + /// ```text + /// a ASC: Node {Some(a_new ASC), HashSet{}} + /// b ASC: Node {Some(b_new ASC), HashSet{a ASC}} + /// c ASC: Node {None, HashSet{a ASC}} + /// ``` + fn construct_dependency_map(&self, mapping: &ProjectionMapping) -> DependencyMap { + let mut dependency_map = HashMap::new(); + for ordering in self.normalized_oeq_class().iter() { + for (idx, sort_expr) in ordering.iter().enumerate() { + let target_sort_expr = + self.project_expr(&sort_expr.expr, mapping).map(|expr| { + PhysicalSortExpr { + expr, + options: sort_expr.options, + } + }); + let is_projected = target_sort_expr.is_some(); + if is_projected + || mapping + .iter() + .any(|(source, _)| expr_refers(source, &sort_expr.expr)) + { + // Previous ordering is a dependency. Note that there is no, + // dependency for a leading ordering (i.e. the first sort + // expression). + let dependency = idx.checked_sub(1).map(|a| &ordering[a]); + // Add sort expressions that can be projected or referred to + // by any of the projection expressions to the dependency map: + dependency_map + .entry(sort_expr.clone()) + .or_insert_with(|| DependencyNode { + target_sort_expr: target_sort_expr.clone(), + dependencies: HashSet::new(), + }) + .insert_dependency(dependency); + } + if !is_projected { + // If we can not project, stop constructing the dependency + // map as remaining dependencies will be invalid after projection. + break; + } + } + } + dependency_map + } + + /// Returns a new `ProjectionMapping` where source expressions are normalized. + /// + /// This normalization ensures that source expressions are transformed into a + /// consistent representation. This is beneficial for algorithms that rely on + /// exact equalities, as it allows for more precise and reliable comparisons. + /// + /// # Parameters + /// + /// - `mapping`: A reference to the original `ProjectionMapping` to be normalized. + /// + /// # Returns + /// + /// A new `ProjectionMapping` with normalized source expressions. + fn normalized_mapping(&self, mapping: &ProjectionMapping) -> ProjectionMapping { + // Construct the mapping where source expressions are normalized. In this way + // In the algorithms below we can work on exact equalities + ProjectionMapping { + map: mapping + .iter() + .map(|(source, target)| { + let normalized_source = self.eq_group.normalize_expr(source.clone()); + (normalized_source, target.clone()) + }) + .collect(), + } + } + + /// Computes projected orderings based on a given projection mapping. + /// + /// This function takes a `ProjectionMapping` and computes the possible + /// orderings for the projected expressions. It considers dependencies + /// between expressions and generates valid orderings according to the + /// specified sort properties. + /// + /// # Parameters + /// + /// - `mapping`: A reference to the `ProjectionMapping` that defines the + /// relationship between source and target expressions. + /// + /// # Returns + /// + /// A vector of `LexOrdering` containing all valid orderings after projection. + fn projected_orderings(&self, mapping: &ProjectionMapping) -> Vec { + let mapping = self.normalized_mapping(mapping); + + // Get dependency map for existing orderings: + let dependency_map = self.construct_dependency_map(&mapping); + + let orderings = mapping.iter().flat_map(|(source, target)| { + referred_dependencies(&dependency_map, source) + .into_iter() + .filter_map(|relevant_deps| { + if let SortProperties::Ordered(options) = + get_expr_ordering(source, &relevant_deps) + { + Some((options, relevant_deps)) + } else { + // Do not consider unordered cases + None + } + }) + .flat_map(|(options, relevant_deps)| { + let sort_expr = PhysicalSortExpr { + expr: target.clone(), + options, + }; + // Generate dependent orderings (i.e. prefixes for `sort_expr`): + let mut dependency_orderings = + generate_dependency_orderings(&relevant_deps, &dependency_map); + // Append `sort_expr` to the dependent orderings: + for ordering in dependency_orderings.iter_mut() { + ordering.push(sort_expr.clone()); + } + dependency_orderings + }) + }); + + // Add valid projected orderings. For example, if existing ordering is + // `a + b` and projection is `[a -> a_new, b -> b_new]`, we need to + // preserve `a_new + b_new` as ordered. Please note that `a_new` and + // `b_new` themselves need not be ordered. Such dependencies cannot be + // deduced via the pass above. + let projected_orderings = dependency_map.iter().flat_map(|(sort_expr, node)| { + let mut prefixes = construct_prefix_orderings(sort_expr, &dependency_map); + if prefixes.is_empty() { + // If prefix is empty, there is no dependency. Insert + // empty ordering: + prefixes = vec![vec![]]; + } + // Append current ordering on top its dependencies: + for ordering in prefixes.iter_mut() { + if let Some(target) = &node.target_sort_expr { + ordering.push(target.clone()) + } + } + prefixes + }); + + // Simplify each ordering by removing redundant sections: + orderings + .chain(projected_orderings) + .map(collapse_lex_ordering) + .collect() + } + + /// Projects constants based on the provided `ProjectionMapping`. + /// + /// This function takes a `ProjectionMapping` and identifies/projects + /// constants based on the existing constants and the mapping. It ensures + /// that constants are appropriately propagated through the projection. + /// + /// # Arguments + /// + /// - `mapping`: A reference to a `ProjectionMapping` representing the + /// mapping of source expressions to target expressions in the projection. + /// + /// # Returns + /// + /// Returns a `Vec>` containing the projected constants. + fn projected_constants( + &self, + mapping: &ProjectionMapping, + ) -> Vec> { + // First, project existing constants. For example, assume that `a + b` + // is known to be constant. If the projection were `a as a_new`, `b as b_new`, + // then we would project constant `a + b` as `a_new + b_new`. + let mut projected_constants = self + .constants + .iter() + .flat_map(|expr| self.eq_group.project_expr(mapping, expr)) + .collect::>(); + // Add projection expressions that are known to be constant: + for (source, target) in mapping.iter() { + if self.is_expr_constant(source) + && !physical_exprs_contains(&projected_constants, target) + { + projected_constants.push(target.clone()); + } + } + projected_constants + } + + /// Projects the equivalences within according to `projection_mapping` + /// and `output_schema`. + pub fn project( + &self, + projection_mapping: &ProjectionMapping, + output_schema: SchemaRef, + ) -> Self { + let projected_constants = self.projected_constants(projection_mapping); + let projected_eq_group = self.eq_group.project(projection_mapping); + let projected_orderings = self.projected_orderings(projection_mapping); + Self { + eq_group: projected_eq_group, + oeq_class: OrderingEquivalenceClass::new(projected_orderings), + constants: projected_constants, + schema: output_schema, + } + } + + /// Returns the longest (potentially partial) permutation satisfying the + /// existing ordering. For example, if we have the equivalent orderings + /// `[a ASC, b ASC]` and `[c DESC]`, with `exprs` containing `[c, b, a, d]`, + /// then this function returns `([a ASC, b ASC, c DESC], [2, 1, 0])`. + /// This means that the specification `[a ASC, b ASC, c DESC]` is satisfied + /// by the existing ordering, and `[a, b, c]` resides at indices: `2, 1, 0` + /// inside the argument `exprs` (respectively). For the mathematical + /// definition of "partial permutation", see: + /// + /// + pub fn find_longest_permutation( + &self, + exprs: &[Arc], + ) -> (LexOrdering, Vec) { + let mut eq_properties = self.clone(); + let mut result = vec![]; + // The algorithm is as follows: + // - Iterate over all the expressions and insert ordered expressions + // into the result. + // - Treat inserted expressions as constants (i.e. add them as constants + // to the state). + // - Continue the above procedure until no expression is inserted; i.e. + // the algorithm reaches a fixed point. + // This algorithm should reach a fixed point in at most `exprs.len()` + // iterations. + let mut search_indices = (0..exprs.len()).collect::>(); + for _idx in 0..exprs.len() { + // Get ordered expressions with their indices. + let ordered_exprs = search_indices + .iter() + .flat_map(|&idx| { + let ExprOrdering { expr, state, .. } = + eq_properties.get_expr_ordering(exprs[idx].clone()); + if let SortProperties::Ordered(options) = state { + Some((PhysicalSortExpr { expr, options }, idx)) + } else { + None + } + }) + .collect::>(); + // We reached a fixed point, exit. + if ordered_exprs.is_empty() { + break; + } + // Remove indices that have an ordering from `search_indices`, and + // treat ordered expressions as constants in subsequent iterations. + // We can do this because the "next" key only matters in a lexicographical + // ordering when the keys to its left have the same values. + // + // Note that these expressions are not properly "constants". This is just + // an implementation strategy confined to this function. + for (PhysicalSortExpr { expr, .. }, idx) in &ordered_exprs { + eq_properties = + eq_properties.add_constants(std::iter::once(expr.clone())); + search_indices.remove(idx); + } + // Add new ordered section to the state. + result.extend(ordered_exprs); + } + result.into_iter().unzip() + } + + /// This function determines whether the provided expression is constant + /// based on the known constants. + /// + /// # Arguments + /// + /// - `expr`: A reference to a `Arc` representing the + /// expression to be checked. + /// + /// # Returns + /// + /// Returns `true` if the expression is constant according to equivalence + /// group, `false` otherwise. + fn is_expr_constant(&self, expr: &Arc) -> bool { + // As an example, assume that we know columns `a` and `b` are constant. + // Then, `a`, `b` and `a + b` will all return `true` whereas `c` will + // return `false`. + let normalized_constants = self.eq_group.normalize_exprs(self.constants.to_vec()); + let normalized_expr = self.eq_group.normalize_expr(expr.clone()); + is_constant_recurse(&normalized_constants, &normalized_expr) + } + + /// Retrieves the ordering information for a given physical expression. + /// + /// This function constructs an `ExprOrdering` object for the provided + /// expression, which encapsulates information about the expression's + /// ordering, including its [`SortProperties`]. + /// + /// # Arguments + /// + /// - `expr`: An `Arc` representing the physical expression + /// for which ordering information is sought. + /// + /// # Returns + /// + /// Returns an `ExprOrdering` object containing the ordering information for + /// the given expression. + pub fn get_expr_ordering(&self, expr: Arc) -> ExprOrdering { + ExprOrdering::new(expr.clone()) + .transform_up(&|expr| Ok(update_ordering(expr, self))) + // Guaranteed to always return `Ok`. + .unwrap() + } +} + +/// Calculates the [`SortProperties`] of a given [`ExprOrdering`] node. +/// The node can either be a leaf node, or an intermediate node: +/// - If it is a leaf node, we directly find the order of the node by looking +/// at the given sort expression and equivalence properties if it is a `Column` +/// leaf, or we mark it as unordered. In the case of a `Literal` leaf, we mark +/// it as singleton so that it can cooperate with all ordered columns. +/// - If it is an intermediate node, the children states matter. Each `PhysicalExpr` +/// and operator has its own rules on how to propagate the children orderings. +/// However, before we engage in recursion, we check whether this intermediate +/// node directly matches with the sort expression. If there is a match, the +/// sort expression emerges at that node immediately, discarding the recursive +/// result coming from its children. +fn update_ordering( + mut node: ExprOrdering, + eq_properties: &EquivalenceProperties, +) -> Transformed { + // We have a Column, which is one of the two possible leaf node types: + let normalized_expr = eq_properties.eq_group.normalize_expr(node.expr.clone()); + if eq_properties.is_expr_constant(&normalized_expr) { + node.state = SortProperties::Singleton; + } else if let Some(options) = eq_properties + .normalized_oeq_class() + .get_options(&normalized_expr) + { + node.state = SortProperties::Ordered(options); + } else if !node.expr.children().is_empty() { + // We have an intermediate (non-leaf) node, account for its children: + node.state = node.expr.get_ordering(&node.children_state()); + } else if node.expr.as_any().is::() { + // We have a Literal, which is the other possible leaf node type: + node.state = node.expr.get_ordering(&[]); + } else { + return Transformed::No(node); + } + Transformed::Yes(node) +} + +/// This function determines whether the provided expression is constant +/// based on the known constants. +/// +/// # Arguments +/// +/// - `constants`: A `&[Arc]` containing expressions known to +/// be a constant. +/// - `expr`: A reference to a `Arc` representing the expression +/// to check. +/// +/// # Returns +/// +/// Returns `true` if the expression is constant according to equivalence +/// group, `false` otherwise. +fn is_constant_recurse( + constants: &[Arc], + expr: &Arc, +) -> bool { + if physical_exprs_contains(constants, expr) { + return true; + } + let children = expr.children(); + !children.is_empty() && children.iter().all(|c| is_constant_recurse(constants, c)) +} + +/// This function examines whether a referring expression directly refers to a +/// given referred expression or if any of its children in the expression tree +/// refer to the specified expression. +/// +/// # Parameters +/// +/// - `referring_expr`: A reference to the referring expression (`Arc`). +/// - `referred_expr`: A reference to the referred expression (`Arc`) +/// +/// # Returns +/// +/// A boolean value indicating whether `referring_expr` refers (needs it to evaluate its result) +/// `referred_expr` or not. +fn expr_refers( + referring_expr: &Arc, + referred_expr: &Arc, +) -> bool { + referring_expr.eq(referred_expr) + || referring_expr + .children() + .iter() + .any(|child| expr_refers(child, referred_expr)) +} + +/// This function analyzes the dependency map to collect referred dependencies for +/// a given source expression. +/// +/// # Parameters +/// +/// - `dependency_map`: A reference to the `DependencyMap` where each +/// `PhysicalSortExpr` is associated with a `DependencyNode`. +/// - `source`: A reference to the source expression (`Arc`) +/// for which relevant dependencies need to be identified. +/// +/// # Returns +/// +/// A `Vec` containing the dependencies for the given source +/// expression. These dependencies are expressions that are referred to by +/// the source expression based on the provided dependency map. +fn referred_dependencies( + dependency_map: &DependencyMap, + source: &Arc, +) -> Vec { + // Associate `PhysicalExpr`s with `PhysicalSortExpr`s that contain them: + let mut expr_to_sort_exprs = HashMap::::new(); + for sort_expr in dependency_map + .keys() + .filter(|sort_expr| expr_refers(source, &sort_expr.expr)) + { + let key = ExprWrapper(sort_expr.expr.clone()); + expr_to_sort_exprs + .entry(key) + .or_default() + .insert(sort_expr.clone()); + } + + // Generate all valid dependencies for the source. For example, if the source + // is `a + b` and the map is `[a -> (a ASC, a DESC), b -> (b ASC)]`, we get + // `vec![HashSet(a ASC, b ASC), HashSet(a DESC, b ASC)]`. + expr_to_sort_exprs + .values() + .multi_cartesian_product() + .map(|referred_deps| referred_deps.into_iter().cloned().collect()) + .collect() +} + +/// This function retrieves the dependencies of the given relevant sort expression +/// from the given dependency map. It then constructs prefix orderings by recursively +/// analyzing the dependencies and include them in the orderings. +/// +/// # Parameters +/// +/// - `relevant_sort_expr`: A reference to the relevant sort expression +/// (`PhysicalSortExpr`) for which prefix orderings are to be constructed. +/// - `dependency_map`: A reference to the `DependencyMap` containing dependencies. +/// +/// # Returns +/// +/// A vector of prefix orderings (`Vec`) based on the given relevant +/// sort expression and its dependencies. +fn construct_prefix_orderings( + relevant_sort_expr: &PhysicalSortExpr, + dependency_map: &DependencyMap, +) -> Vec { + dependency_map[relevant_sort_expr] + .dependencies + .iter() + .flat_map(|dep| construct_orderings(dep, dependency_map)) + .collect() +} + +/// Given a set of relevant dependencies (`relevant_deps`) and a map of dependencies +/// (`dependency_map`), this function generates all possible prefix orderings +/// based on the given dependencies. +/// +/// # Parameters +/// +/// * `dependencies` - A reference to the dependencies. +/// * `dependency_map` - A reference to the map of dependencies for expressions. +/// +/// # Returns +/// +/// A vector of lexical orderings (`Vec`) representing all valid orderings +/// based on the given dependencies. +fn generate_dependency_orderings( + dependencies: &Dependencies, + dependency_map: &DependencyMap, +) -> Vec { + // Construct all the valid prefix orderings for each expression appearing + // in the projection: + let relevant_prefixes = dependencies + .iter() + .flat_map(|dep| { + let prefixes = construct_prefix_orderings(dep, dependency_map); + (!prefixes.is_empty()).then_some(prefixes) + }) + .collect::>(); + + // No dependency, dependent is a leading ordering. + if relevant_prefixes.is_empty() { + // Return an empty ordering: + return vec![vec![]]; + } + + // Generate all possible orderings where dependencies are satisfied for the + // current projection expression. For example, if expression is `a + b ASC`, + // and the dependency for `a ASC` is `[c ASC]`, the dependency for `b ASC` + // is `[d DESC]`, then we generate `[c ASC, d DESC, a + b ASC]` and + // `[d DESC, c ASC, a + b ASC]`. + relevant_prefixes + .into_iter() + .multi_cartesian_product() + .flat_map(|prefix_orderings| { + prefix_orderings + .iter() + .permutations(prefix_orderings.len()) + .map(|prefixes| prefixes.into_iter().flatten().cloned().collect()) + .collect::>() + }) + .collect() +} + +/// This function examines the given expression and the sort expressions it +/// refers to determine the ordering properties of the expression. +/// +/// # Parameters +/// +/// - `expr`: A reference to the source expression (`Arc`) for +/// which ordering properties need to be determined. +/// - `dependencies`: A reference to `Dependencies`, containing sort expressions +/// referred to by `expr`. +/// +/// # Returns +/// +/// A `SortProperties` indicating the ordering information of the given expression. +fn get_expr_ordering( + expr: &Arc, + dependencies: &Dependencies, +) -> SortProperties { + if let Some(column_order) = dependencies.iter().find(|&order| expr.eq(&order.expr)) { + // If exact match is found, return its ordering. + SortProperties::Ordered(column_order.options) + } else { + // Find orderings of its children + let child_states = expr + .children() + .iter() + .map(|child| get_expr_ordering(child, dependencies)) + .collect::>(); + // Calculate expression ordering using ordering of its children. + expr.get_ordering(&child_states) + } +} + +/// Represents a node in the dependency map used to construct projected orderings. +/// +/// A `DependencyNode` contains information about a particular sort expression, +/// including its target sort expression and a set of dependencies on other sort +/// expressions. +/// +/// # Fields +/// +/// - `target_sort_expr`: An optional `PhysicalSortExpr` representing the target +/// sort expression associated with the node. It is `None` if the sort expression +/// cannot be projected. +/// - `dependencies`: A [`Dependencies`] containing dependencies on other sort +/// expressions that are referred to by the target sort expression. +#[derive(Debug, Clone, PartialEq, Eq)] +struct DependencyNode { + target_sort_expr: Option, + dependencies: Dependencies, +} + +impl DependencyNode { + // Insert dependency to the state (if exists). + fn insert_dependency(&mut self, dependency: Option<&PhysicalSortExpr>) { + if let Some(dep) = dependency { + self.dependencies.insert(dep.clone()); + } + } +} + +type DependencyMap = HashMap; +type Dependencies = HashSet; + +/// This function recursively analyzes the dependencies of the given sort +/// expression within the given dependency map to construct lexicographical +/// orderings that include the sort expression and its dependencies. +/// +/// # Parameters +/// +/// - `referred_sort_expr`: A reference to the sort expression (`PhysicalSortExpr`) +/// for which lexicographical orderings satisfying its dependencies are to be +/// constructed. +/// - `dependency_map`: A reference to the `DependencyMap` that contains +/// dependencies for different `PhysicalSortExpr`s. +/// +/// # Returns +/// +/// A vector of lexicographical orderings (`Vec`) based on the given +/// sort expression and its dependencies. +fn construct_orderings( + referred_sort_expr: &PhysicalSortExpr, + dependency_map: &DependencyMap, +) -> Vec { + // We are sure that `referred_sort_expr` is inside `dependency_map`. + let node = &dependency_map[referred_sort_expr]; + // Since we work on intermediate nodes, we are sure `val.target_sort_expr` + // exists. + let target_sort_expr = node.target_sort_expr.clone().unwrap(); + if node.dependencies.is_empty() { + vec![vec![target_sort_expr]] + } else { + node.dependencies + .iter() + .flat_map(|dep| { + let mut orderings = construct_orderings(dep, dependency_map); + for ordering in orderings.iter_mut() { + ordering.push(target_sort_expr.clone()) + } + orderings + }) + .collect() + } +} + +/// Calculate ordering equivalence properties for the given join operation. +pub fn join_equivalence_properties( + left: EquivalenceProperties, + right: EquivalenceProperties, + join_type: &JoinType, + join_schema: SchemaRef, + maintains_input_order: &[bool], + probe_side: Option, + on: &[(Column, Column)], +) -> EquivalenceProperties { + let left_size = left.schema.fields.len(); + let mut result = EquivalenceProperties::new(join_schema); + result.add_equivalence_group(left.eq_group().join( + right.eq_group(), + join_type, + left_size, + on, + )); + + let left_oeq_class = left.oeq_class; + let mut right_oeq_class = right.oeq_class; + match maintains_input_order { + [true, false] => { + // In this special case, right side ordering can be prefixed with + // the left side ordering. + if let (Some(JoinSide::Left), JoinType::Inner) = (probe_side, join_type) { + updated_right_ordering_equivalence_class( + &mut right_oeq_class, + join_type, + left_size, + ); + + // Right side ordering equivalence properties should be prepended + // with those of the left side while constructing output ordering + // equivalence properties since stream side is the left side. + // + // For example, if the right side ordering equivalences contain + // `b ASC`, and the left side ordering equivalences contain `a ASC`, + // then we should add `a ASC, b ASC` to the ordering equivalences + // of the join output. + let out_oeq_class = left_oeq_class.join_suffix(&right_oeq_class); + result.add_ordering_equivalence_class(out_oeq_class); + } else { + result.add_ordering_equivalence_class(left_oeq_class); + } + } + [false, true] => { + updated_right_ordering_equivalence_class( + &mut right_oeq_class, + join_type, + left_size, + ); + // In this special case, left side ordering can be prefixed with + // the right side ordering. + if let (Some(JoinSide::Right), JoinType::Inner) = (probe_side, join_type) { + // Left side ordering equivalence properties should be prepended + // with those of the right side while constructing output ordering + // equivalence properties since stream side is the right side. + // + // For example, if the left side ordering equivalences contain + // `a ASC`, and the right side ordering equivalences contain `b ASC`, + // then we should add `b ASC, a ASC` to the ordering equivalences + // of the join output. + let out_oeq_class = right_oeq_class.join_suffix(&left_oeq_class); + result.add_ordering_equivalence_class(out_oeq_class); + } else { + result.add_ordering_equivalence_class(right_oeq_class); + } + } + [false, false] => {} + [true, true] => unreachable!("Cannot maintain ordering of both sides"), + _ => unreachable!("Join operators can not have more than two children"), + } + result +} + +/// In the context of a join, update the right side `OrderingEquivalenceClass` +/// so that they point to valid indices in the join output schema. +/// +/// To do so, we increment column indices by the size of the left table when +/// join schema consists of a combination of the left and right schemas. This +/// is the case for `Inner`, `Left`, `Full` and `Right` joins. For other cases, +/// indices do not change. +fn updated_right_ordering_equivalence_class( + right_oeq_class: &mut OrderingEquivalenceClass, + join_type: &JoinType, + left_size: usize, +) { + if matches!( + join_type, + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right + ) { + right_oeq_class.add_offset(left_size); + } +} + +/// Wrapper struct for `Arc` to use them as keys in a hash map. +#[derive(Debug, Clone)] +struct ExprWrapper(Arc); + +impl PartialEq for ExprWrapper { + fn eq(&self, other: &Self) -> bool { + self.0.eq(&other.0) + } +} + +impl Eq for ExprWrapper {} + +impl Hash for ExprWrapper { + fn hash(&self, state: &mut H) { + self.0.hash(state); + } +} + +#[cfg(test)] +mod tests { + use std::ops::Not; + use std::sync::Arc; + + use super::*; + use crate::equivalence::add_offset_to_expr; + use crate::equivalence::tests::{ + convert_to_orderings, convert_to_sort_exprs, convert_to_sort_reqs, + create_random_schema, create_test_params, create_test_schema, + generate_table_for_eq_properties, is_table_same_after_sort, output_schema, + }; + use crate::execution_props::ExecutionProps; + use crate::expressions::{col, BinaryExpr, Column}; + use crate::functions::create_physical_expr; + use crate::PhysicalSortExpr; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::{Fields, SortOptions, TimeUnit}; + use datafusion_common::Result; + use datafusion_expr::{BuiltinScalarFunction, Operator}; + use itertools::Itertools; + + #[test] + fn project_equivalence_properties_test() -> Result<()> { + let input_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + Field::new("c", DataType::Int64, true), + ])); + + let input_properties = EquivalenceProperties::new(input_schema.clone()); + let col_a = col("a", &input_schema)?; + + // a as a1, a as a2, a as a3, a as a3 + let proj_exprs = vec![ + (col_a.clone(), "a1".to_string()), + (col_a.clone(), "a2".to_string()), + (col_a.clone(), "a3".to_string()), + (col_a.clone(), "a4".to_string()), + ]; + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + + let out_schema = output_schema(&projection_mapping, &input_schema)?; + // a as a1, a as a2, a as a3, a as a3 + let proj_exprs = vec![ + (col_a.clone(), "a1".to_string()), + (col_a.clone(), "a2".to_string()), + (col_a.clone(), "a3".to_string()), + (col_a.clone(), "a4".to_string()), + ]; + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + + // a as a1, a as a2, a as a3, a as a3 + let col_a1 = &col("a1", &out_schema)?; + let col_a2 = &col("a2", &out_schema)?; + let col_a3 = &col("a3", &out_schema)?; + let col_a4 = &col("a4", &out_schema)?; + let out_properties = input_properties.project(&projection_mapping, out_schema); + + // At the output a1=a2=a3=a4 + assert_eq!(out_properties.eq_group().len(), 1); + let eq_class = &out_properties.eq_group().classes[0]; + assert_eq!(eq_class.len(), 4); + assert!(eq_class.contains(col_a1)); + assert!(eq_class.contains(col_a2)); + assert!(eq_class.contains(col_a3)); + assert!(eq_class.contains(col_a4)); + + Ok(()) + } + + #[test] + fn test_join_equivalence_properties() -> Result<()> { + let schema = create_test_schema()?; + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let offset = schema.fields.len(); + let col_a2 = &add_offset_to_expr(col_a.clone(), offset); + let col_b2 = &add_offset_to_expr(col_b.clone(), offset); + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let test_cases = vec![ + // ------- TEST CASE 1 -------- + // [a ASC], [b ASC] + ( + // [a ASC], [b ASC] + vec![vec![(col_a, option_asc)], vec![(col_b, option_asc)]], + // [a ASC], [b ASC] + vec![vec![(col_a, option_asc)], vec![(col_b, option_asc)]], + // expected [a ASC, a2 ASC], [a ASC, b2 ASC], [b ASC, a2 ASC], [b ASC, b2 ASC] + vec![ + vec![(col_a, option_asc), (col_a2, option_asc)], + vec![(col_a, option_asc), (col_b2, option_asc)], + vec![(col_b, option_asc), (col_a2, option_asc)], + vec![(col_b, option_asc), (col_b2, option_asc)], + ], + ), + // ------- TEST CASE 2 -------- + // [a ASC], [b ASC] + ( + // [a ASC], [b ASC], [c ASC] + vec![ + vec![(col_a, option_asc)], + vec![(col_b, option_asc)], + vec![(col_c, option_asc)], + ], + // [a ASC], [b ASC] + vec![vec![(col_a, option_asc)], vec![(col_b, option_asc)]], + // expected [a ASC, a2 ASC], [a ASC, b2 ASC], [b ASC, a2 ASC], [b ASC, b2 ASC], [c ASC, a2 ASC], [c ASC, b2 ASC] + vec![ + vec![(col_a, option_asc), (col_a2, option_asc)], + vec![(col_a, option_asc), (col_b2, option_asc)], + vec![(col_b, option_asc), (col_a2, option_asc)], + vec![(col_b, option_asc), (col_b2, option_asc)], + vec![(col_c, option_asc), (col_a2, option_asc)], + vec![(col_c, option_asc), (col_b2, option_asc)], + ], + ), + ]; + for (left_orderings, right_orderings, expected) in test_cases { + let mut left_eq_properties = EquivalenceProperties::new(schema.clone()); + let mut right_eq_properties = EquivalenceProperties::new(schema.clone()); + let left_orderings = convert_to_orderings(&left_orderings); + let right_orderings = convert_to_orderings(&right_orderings); + let expected = convert_to_orderings(&expected); + left_eq_properties.add_new_orderings(left_orderings); + right_eq_properties.add_new_orderings(right_orderings); + let join_eq = join_equivalence_properties( + left_eq_properties, + right_eq_properties, + &JoinType::Inner, + Arc::new(Schema::empty()), + &[true, false], + Some(JoinSide::Left), + &[], + ); + let orderings = &join_eq.oeq_class.orderings; + let err_msg = format!("expected: {:?}, actual:{:?}", expected, orderings); + assert_eq!( + join_eq.oeq_class.orderings.len(), + expected.len(), + "{}", + err_msg + ); + for ordering in orderings { + assert!( + expected.contains(ordering), + "{}, ordering: {:?}", + err_msg, + ordering + ); + } + } + Ok(()) + } + + #[test] + fn test_expr_consists_of_constants() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true), + ])); + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let col_d = col("d", &schema)?; + let b_plus_d = Arc::new(BinaryExpr::new( + col_b.clone(), + Operator::Plus, + col_d.clone(), + )) as Arc; + + let constants = vec![col_a.clone(), col_b.clone()]; + let expr = b_plus_d.clone(); + assert!(!is_constant_recurse(&constants, &expr)); + + let constants = vec![col_a.clone(), col_b.clone(), col_d.clone()]; + let expr = b_plus_d.clone(); + assert!(is_constant_recurse(&constants, &expr)); + Ok(()) + } + + #[test] + fn test_get_updated_right_ordering_equivalence_properties() -> Result<()> { + let join_type = JoinType::Inner; + // Join right child schema + let child_fields: Fields = ["x", "y", "z", "w"] + .into_iter() + .map(|name| Field::new(name, DataType::Int32, true)) + .collect(); + let child_schema = Schema::new(child_fields); + let col_x = &col("x", &child_schema)?; + let col_y = &col("y", &child_schema)?; + let col_z = &col("z", &child_schema)?; + let col_w = &col("w", &child_schema)?; + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + // [x ASC, y ASC], [z ASC, w ASC] + let orderings = vec![ + vec![(col_x, option_asc), (col_y, option_asc)], + vec![(col_z, option_asc), (col_w, option_asc)], + ]; + let orderings = convert_to_orderings(&orderings); + // Right child ordering equivalences + let mut right_oeq_class = OrderingEquivalenceClass::new(orderings); + + let left_columns_len = 4; + + let fields: Fields = ["a", "b", "c", "d", "x", "y", "z", "w"] + .into_iter() + .map(|name| Field::new(name, DataType::Int32, true)) + .collect(); + + // Join Schema + let schema = Schema::new(fields); + let col_a = &col("a", &schema)?; + let col_d = &col("d", &schema)?; + let col_x = &col("x", &schema)?; + let col_y = &col("y", &schema)?; + let col_z = &col("z", &schema)?; + let col_w = &col("w", &schema)?; + + let mut join_eq_properties = EquivalenceProperties::new(Arc::new(schema)); + // a=x and d=w + join_eq_properties.add_equal_conditions(col_a, col_x); + join_eq_properties.add_equal_conditions(col_d, col_w); + + updated_right_ordering_equivalence_class( + &mut right_oeq_class, + &join_type, + left_columns_len, + ); + join_eq_properties.add_ordering_equivalence_class(right_oeq_class); + let result = join_eq_properties.oeq_class().clone(); + + // [x ASC, y ASC], [z ASC, w ASC] + let orderings = vec![ + vec![(col_x, option_asc), (col_y, option_asc)], + vec![(col_z, option_asc), (col_w, option_asc)], + ]; + let orderings = convert_to_orderings(&orderings); + let expected = OrderingEquivalenceClass::new(orderings); + + assert_eq!(result, expected); + + Ok(()) + } + + #[test] + fn test_normalize_ordering_equivalence_classes() -> Result<()> { + let sort_options = SortOptions::default(); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ]); + let col_a_expr = col("a", &schema)?; + let col_b_expr = col("b", &schema)?; + let col_c_expr = col("c", &schema)?; + let mut eq_properties = EquivalenceProperties::new(Arc::new(schema.clone())); + + eq_properties.add_equal_conditions(&col_a_expr, &col_c_expr); + let others = vec![ + vec![PhysicalSortExpr { + expr: col_b_expr.clone(), + options: sort_options, + }], + vec![PhysicalSortExpr { + expr: col_c_expr.clone(), + options: sort_options, + }], + ]; + eq_properties.add_new_orderings(others); + + let mut expected_eqs = EquivalenceProperties::new(Arc::new(schema)); + expected_eqs.add_new_orderings([ + vec![PhysicalSortExpr { + expr: col_b_expr.clone(), + options: sort_options, + }], + vec![PhysicalSortExpr { + expr: col_c_expr.clone(), + options: sort_options, + }], + ]); + + let oeq_class = eq_properties.oeq_class().clone(); + let expected = expected_eqs.oeq_class(); + assert!(oeq_class.eq(expected)); + + Ok(()) + } + + #[test] + fn test_get_indices_of_matching_sort_exprs_with_order_eq() -> Result<()> { + let sort_options = SortOptions::default(); + let sort_options_not = SortOptions::default().not(); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ]); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let required_columns = [col_b.clone(), col_a.clone()]; + let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); + eq_properties.add_new_orderings([vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: sort_options_not, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: sort_options, + }, + ]]); + let (result, idxs) = eq_properties.find_longest_permutation(&required_columns); + assert_eq!(idxs, vec![0, 1]); + assert_eq!( + result, + vec![ + PhysicalSortExpr { + expr: col_b.clone(), + options: sort_options_not + }, + PhysicalSortExpr { + expr: col_a.clone(), + options: sort_options + } + ] + ); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ]); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let required_columns = [col_b.clone(), col_a.clone()]; + let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); + eq_properties.add_new_orderings([ + vec![PhysicalSortExpr { + expr: Arc::new(Column::new("c", 2)), + options: sort_options, + }], + vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: sort_options_not, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: sort_options, + }, + ], + ]); + let (result, idxs) = eq_properties.find_longest_permutation(&required_columns); + assert_eq!(idxs, vec![0, 1]); + assert_eq!( + result, + vec![ + PhysicalSortExpr { + expr: col_b.clone(), + options: sort_options_not + }, + PhysicalSortExpr { + expr: col_a.clone(), + options: sort_options + } + ] + ); + + let required_columns = [ + Arc::new(Column::new("b", 1)) as _, + Arc::new(Column::new("a", 0)) as _, + ]; + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ]); + let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); + + // not satisfied orders + eq_properties.add_new_orderings([vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: sort_options_not, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("c", 2)), + options: sort_options, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: sort_options, + }, + ]]); + let (_, idxs) = eq_properties.find_longest_permutation(&required_columns); + assert_eq!(idxs, vec![0]); + + Ok(()) + } + + #[test] + fn test_update_ordering() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + ]); + + let mut eq_properties = EquivalenceProperties::new(Arc::new(schema.clone())); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_d = &col("d", &schema)?; + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + // b=a (e.g they are aliases) + eq_properties.add_equal_conditions(col_b, col_a); + // [b ASC], [d ASC] + eq_properties.add_new_orderings(vec![ + vec![PhysicalSortExpr { + expr: col_b.clone(), + options: option_asc, + }], + vec![PhysicalSortExpr { + expr: col_d.clone(), + options: option_asc, + }], + ]); + + let test_cases = vec![ + // d + b + ( + Arc::new(BinaryExpr::new( + col_d.clone(), + Operator::Plus, + col_b.clone(), + )) as Arc, + SortProperties::Ordered(option_asc), + ), + // b + (col_b.clone(), SortProperties::Ordered(option_asc)), + // a + (col_a.clone(), SortProperties::Ordered(option_asc)), + // a + c + ( + Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_c.clone(), + )), + SortProperties::Unordered, + ), + ]; + for (expr, expected) in test_cases { + let leading_orderings = eq_properties + .oeq_class() + .iter() + .flat_map(|ordering| ordering.first().cloned()) + .collect::>(); + let expr_ordering = eq_properties.get_expr_ordering(expr.clone()); + let err_msg = format!( + "expr:{:?}, expected: {:?}, actual: {:?}, leading_orderings: {leading_orderings:?}", + expr, expected, expr_ordering.state + ); + assert_eq!(expr_ordering.state, expected, "{}", err_msg); + } + + Ok(()) + } + + #[test] + fn test_find_longest_permutation_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 100; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + + let floor_a = create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let exprs = vec![ + col("a", &test_schema)?, + col("b", &test_schema)?, + col("c", &test_schema)?, + col("d", &test_schema)?, + col("e", &test_schema)?, + col("f", &test_schema)?, + floor_a, + a_plus_b, + ]; + + for n_req in 0..=exprs.len() { + for exprs in exprs.iter().combinations(n_req) { + let exprs = exprs.into_iter().cloned().collect::>(); + let (ordering, indices) = + eq_properties.find_longest_permutation(&exprs); + // Make sure that find_longest_permutation return values are consistent + let ordering2 = indices + .iter() + .zip(ordering.iter()) + .map(|(&idx, sort_expr)| PhysicalSortExpr { + expr: exprs[idx].clone(), + options: sort_expr.options, + }) + .collect::>(); + assert_eq!( + ordering, ordering2, + "indices and lexicographical ordering do not match" + ); + + let err_msg = format!( + "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", + ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants + ); + assert_eq!(ordering.len(), indices.len(), "{}", err_msg); + // Since ordered section satisfies schema, we expect + // that result will be same after sort (e.g sort was unnecessary). + assert!( + is_table_same_after_sort( + ordering.clone(), + table_data_with_properties.clone(), + )?, + "{}", + err_msg + ); + } + } + } + + Ok(()) + } + #[test] + fn test_find_longest_permutation() -> Result<()> { + // Schema satisfies following orderings: + // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] + // and + // Column [a=c] (e.g they are aliases). + // At below we add [d ASC, h DESC] also, for test purposes + let (test_schema, mut eq_properties) = create_test_params()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_h = &col("h", &test_schema)?; + // a + d + let a_plus_d = Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_d.clone(), + )) as Arc; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + // [d ASC, h ASC] also satisfies schema. + eq_properties.add_new_orderings([vec![ + PhysicalSortExpr { + expr: col_d.clone(), + options: option_asc, + }, + PhysicalSortExpr { + expr: col_h.clone(), + options: option_desc, + }, + ]]); + let test_cases = vec![ + // TEST CASE 1 + (vec![col_a], vec![(col_a, option_asc)]), + // TEST CASE 2 + (vec![col_c], vec![(col_c, option_asc)]), + // TEST CASE 3 + ( + vec![col_d, col_e, col_b], + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_b, option_asc), + ], + ), + // TEST CASE 4 + (vec![col_b], vec![]), + // TEST CASE 5 + (vec![col_d], vec![(col_d, option_asc)]), + // TEST CASE 5 + (vec![&a_plus_d], vec![(&a_plus_d, option_asc)]), + // TEST CASE 6 + ( + vec![col_b, col_d], + vec![(col_d, option_asc), (col_b, option_asc)], + ), + // TEST CASE 6 + ( + vec![col_c, col_e], + vec![(col_c, option_asc), (col_e, option_desc)], + ), + ]; + for (exprs, expected) in test_cases { + let exprs = exprs.into_iter().cloned().collect::>(); + let expected = convert_to_sort_exprs(&expected); + let (actual, _) = eq_properties.find_longest_permutation(&exprs); + assert_eq!(actual, expected); + } + + Ok(()) + } + #[test] + fn test_get_meet_ordering() -> Result<()> { + let schema = create_test_schema()?; + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let eq_properties = EquivalenceProperties::new(schema); + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let tests_cases = vec![ + // Get meet ordering between [a ASC] and [a ASC, b ASC] + // result should be [a ASC] + ( + vec![(col_a, option_asc)], + vec![(col_a, option_asc), (col_b, option_asc)], + Some(vec![(col_a, option_asc)]), + ), + // Get meet ordering between [a ASC] and [a DESC] + // result should be None. + (vec![(col_a, option_asc)], vec![(col_a, option_desc)], None), + // Get meet ordering between [a ASC, b ASC] and [a ASC, b DESC] + // result should be [a ASC]. + ( + vec![(col_a, option_asc), (col_b, option_asc)], + vec![(col_a, option_asc), (col_b, option_desc)], + Some(vec![(col_a, option_asc)]), + ), + ]; + for (lhs, rhs, expected) in tests_cases { + let lhs = convert_to_sort_exprs(&lhs); + let rhs = convert_to_sort_exprs(&rhs); + let expected = expected.map(|expected| convert_to_sort_exprs(&expected)); + let finer = eq_properties.get_meet_ordering(&lhs, &rhs); + assert_eq!(finer, expected) + } + + Ok(()) + } + + #[test] + fn test_get_finer() -> Result<()> { + let schema = create_test_schema()?; + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let eq_properties = EquivalenceProperties::new(schema); + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + // First entry, and second entry are the physical sort requirement that are argument for get_finer_requirement. + // Third entry is the expected result. + let tests_cases = vec![ + // Get finer requirement between [a Some(ASC)] and [a None, b Some(ASC)] + // result should be [a Some(ASC), b Some(ASC)] + ( + vec![(col_a, Some(option_asc))], + vec![(col_a, None), (col_b, Some(option_asc))], + Some(vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))]), + ), + // Get finer requirement between [a Some(ASC), b Some(ASC), c Some(ASC)] and [a Some(ASC), b Some(ASC)] + // result should be [a Some(ASC), b Some(ASC), c Some(ASC)] + ( + vec![ + (col_a, Some(option_asc)), + (col_b, Some(option_asc)), + (col_c, Some(option_asc)), + ], + vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))], + Some(vec![ + (col_a, Some(option_asc)), + (col_b, Some(option_asc)), + (col_c, Some(option_asc)), + ]), + ), + // Get finer requirement between [a Some(ASC), b Some(ASC)] and [a Some(ASC), b Some(DESC)] + // result should be None + ( + vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))], + vec![(col_a, Some(option_asc)), (col_b, Some(option_desc))], + None, + ), + ]; + for (lhs, rhs, expected) in tests_cases { + let lhs = convert_to_sort_reqs(&lhs); + let rhs = convert_to_sort_reqs(&rhs); + let expected = expected.map(|expected| convert_to_sort_reqs(&expected)); + let finer = eq_properties.get_finer_requirement(&lhs, &rhs); + assert_eq!(finer, expected) + } + + Ok(()) + } + + #[test] + fn test_normalize_sort_reqs() -> Result<()> { + // Schema satisfies following properties + // a=c + // and following orderings are valid + // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] + let (test_schema, eq_properties) = create_test_params()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + // First element in the tuple stores vector of requirement, second element is the expected return value for ordering_satisfy function + let requirements = vec![ + ( + vec![(col_a, Some(option_asc))], + vec![(col_a, Some(option_asc))], + ), + ( + vec![(col_a, Some(option_desc))], + vec![(col_a, Some(option_desc))], + ), + (vec![(col_a, None)], vec![(col_a, None)]), + // Test whether equivalence works as expected + ( + vec![(col_c, Some(option_asc))], + vec![(col_a, Some(option_asc))], + ), + (vec![(col_c, None)], vec![(col_a, None)]), + // Test whether ordering equivalence works as expected + ( + vec![(col_d, Some(option_asc)), (col_b, Some(option_asc))], + vec![(col_d, Some(option_asc)), (col_b, Some(option_asc))], + ), + ( + vec![(col_d, None), (col_b, None)], + vec![(col_d, None), (col_b, None)], + ), + ( + vec![(col_e, Some(option_desc)), (col_f, Some(option_asc))], + vec![(col_e, Some(option_desc)), (col_f, Some(option_asc))], + ), + // We should be able to normalize in compatible requirements also (not exactly equal) + ( + vec![(col_e, Some(option_desc)), (col_f, None)], + vec![(col_e, Some(option_desc)), (col_f, None)], + ), + ( + vec![(col_e, None), (col_f, None)], + vec![(col_e, None), (col_f, None)], + ), + ]; + + for (reqs, expected_normalized) in requirements.into_iter() { + let req = convert_to_sort_reqs(&reqs); + let expected_normalized = convert_to_sort_reqs(&expected_normalized); + + assert_eq!( + eq_properties.normalize_sort_requirements(&req), + expected_normalized + ); + } + + Ok(()) + } + + #[test] + fn test_schema_normalize_sort_requirement_with_equivalence() -> Result<()> { + let option1 = SortOptions { + descending: false, + nulls_first: false, + }; + // Assume that column a and c are aliases. + let (test_schema, eq_properties) = create_test_params()?; + let col_a = &col("a", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + + // Test cases for equivalence normalization + // First entry in the tuple is PhysicalSortRequirement, second entry in the tuple is + // expected PhysicalSortRequirement after normalization. + let test_cases = vec![ + (vec![(col_a, Some(option1))], vec![(col_a, Some(option1))]), + // In the normalized version column c should be replace with column a + (vec![(col_c, Some(option1))], vec![(col_a, Some(option1))]), + (vec![(col_c, None)], vec![(col_a, None)]), + (vec![(col_d, Some(option1))], vec![(col_d, Some(option1))]), + ]; + for (reqs, expected) in test_cases.into_iter() { + let reqs = convert_to_sort_reqs(&reqs); + let expected = convert_to_sort_reqs(&expected); + + let normalized = eq_properties.normalize_sort_requirements(&reqs); + assert!( + expected.eq(&normalized), + "error in test: reqs: {reqs:?}, expected: {expected:?}, normalized: {normalized:?}" + ); + } + + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 63fa98011fdd3..c17081398cb8f 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -23,8 +23,8 @@ use std::{any::Any, sync::Arc}; use crate::array_expressions::{ array_append, array_concat, array_has_all, array_prepend, }; +use crate::expressions::datum::{apply, apply_cmp}; use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; -use crate::intervals::{apply_operator, Interval}; use crate::physical_expr::down_cast_any_ref; use crate::sort_properties::SortProperties; use crate::PhysicalExpr; @@ -38,12 +38,13 @@ use arrow::compute::kernels::comparison::regexp_is_match_utf8_scalar; use arrow::compute::kernels::concat_elements::concat_elements_utf8; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; + use datafusion_common::cast::as_boolean_array; use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::interval_arithmetic::{apply_operator, Interval}; use datafusion_expr::type_coercion::binary::get_result_type; use datafusion_expr::{ColumnarValue, Operator}; -use crate::expressions::datum::{apply, apply_cmp}; use kernels::{ bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar, bitwise_shift_left_dyn, bitwise_shift_left_dyn_scalar, bitwise_shift_right_dyn, @@ -304,8 +305,8 @@ impl PhysicalExpr for BinaryExpr { // if both arrays or both literals - extract arrays and continue execution let (left, right) = ( - lhs.into_array(batch.num_rows()), - rhs.into_array(batch.num_rows()), + lhs.into_array(batch.num_rows())?, + rhs.into_array(batch.num_rows())?, ); self.evaluate_with_resolved_args(left, &left_data_type, right, &right_data_type) .map(ColumnarValue::Array) @@ -338,32 +339,102 @@ impl PhysicalExpr for BinaryExpr { &self, interval: &Interval, children: &[&Interval], - ) -> Result>> { + ) -> Result>> { // Get children intervals. let left_interval = children[0]; let right_interval = children[1]; - let (left, right) = if self.op.is_logic_operator() { - // TODO: Currently, this implementation only supports the AND operator - // and does not require any further propagation. In the future, - // upon adding support for additional logical operators, this - // method will require modification to support propagating the - // changes accordingly. - return Ok(vec![]); - } else if self.op.is_comparison_operator() { - if interval == &Interval::CERTAINLY_FALSE { - // TODO: We will handle strictly false clauses by negating - // the comparison operator (e.g. GT to LE, LT to GE) - // once open/closed intervals are supported. - return Ok(vec![]); + if self.op.eq(&Operator::And) { + if interval.eq(&Interval::CERTAINLY_TRUE) { + // A certainly true logical conjunction can only derive from possibly + // true operands. Otherwise, we prove infeasability. + Ok((!left_interval.eq(&Interval::CERTAINLY_FALSE) + && !right_interval.eq(&Interval::CERTAINLY_FALSE)) + .then(|| vec![Interval::CERTAINLY_TRUE, Interval::CERTAINLY_TRUE])) + } else if interval.eq(&Interval::CERTAINLY_FALSE) { + // If the logical conjunction is certainly false, one of the + // operands must be false. However, it's not always possible to + // determine which operand is false, leading to different scenarios. + + // If one operand is certainly true and the other one is uncertain, + // then the latter must be certainly false. + if left_interval.eq(&Interval::CERTAINLY_TRUE) + && right_interval.eq(&Interval::UNCERTAIN) + { + Ok(Some(vec![ + Interval::CERTAINLY_TRUE, + Interval::CERTAINLY_FALSE, + ])) + } else if right_interval.eq(&Interval::CERTAINLY_TRUE) + && left_interval.eq(&Interval::UNCERTAIN) + { + Ok(Some(vec![ + Interval::CERTAINLY_FALSE, + Interval::CERTAINLY_TRUE, + ])) + } + // If both children are uncertain, or if one is certainly false, + // we cannot conclusively refine their intervals. In this case, + // propagation does not result in any interval changes. + else { + Ok(Some(vec![])) + } + } else { + // An uncertain logical conjunction result can not shrink the + // end-points of its children. + Ok(Some(vec![])) + } + } else if self.op.eq(&Operator::Or) { + if interval.eq(&Interval::CERTAINLY_FALSE) { + // A certainly false logical conjunction can only derive from certainly + // false operands. Otherwise, we prove infeasability. + Ok((!left_interval.eq(&Interval::CERTAINLY_TRUE) + && !right_interval.eq(&Interval::CERTAINLY_TRUE)) + .then(|| vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_FALSE])) + } else if interval.eq(&Interval::CERTAINLY_TRUE) { + // If the logical disjunction is certainly true, one of the + // operands must be true. However, it's not always possible to + // determine which operand is true, leading to different scenarios. + + // If one operand is certainly false and the other one is uncertain, + // then the latter must be certainly true. + if left_interval.eq(&Interval::CERTAINLY_FALSE) + && right_interval.eq(&Interval::UNCERTAIN) + { + Ok(Some(vec![ + Interval::CERTAINLY_FALSE, + Interval::CERTAINLY_TRUE, + ])) + } else if right_interval.eq(&Interval::CERTAINLY_FALSE) + && left_interval.eq(&Interval::UNCERTAIN) + { + Ok(Some(vec![ + Interval::CERTAINLY_TRUE, + Interval::CERTAINLY_FALSE, + ])) + } + // If both children are uncertain, or if one is certainly true, + // we cannot conclusively refine their intervals. In this case, + // propagation does not result in any interval changes. + else { + Ok(Some(vec![])) + } + } else { + // An uncertain logical disjunction result can not shrink the + // end-points of its children. + Ok(Some(vec![])) } - // Propagate the comparison operator. - propagate_comparison(&self.op, left_interval, right_interval)? + } else if self.op.is_comparison_operator() { + Ok( + propagate_comparison(&self.op, interval, left_interval, right_interval)? + .map(|(left, right)| vec![left, right]), + ) } else { - // Propagate the arithmetic operator. - propagate_arithmetic(&self.op, interval, left_interval, right_interval)? - }; - Ok(vec![left, right]) + Ok( + propagate_arithmetic(&self.op, interval, left_interval, right_interval)? + .map(|(left, right)| vec![left, right]), + ) + } } fn dyn_hash(&self, state: &mut dyn Hasher) { @@ -380,7 +451,7 @@ impl PhysicalExpr for BinaryExpr { Operator::Minus => left_child.sub(right_child), Operator::Gt | Operator::GtEq => left_child.gt_or_gteq(right_child), Operator::Lt | Operator::LtEq => right_child.gt_or_gteq(left_child), - Operator::And => left_child.and(right_child), + Operator::And | Operator::Or => left_child.and_or(right_child), _ => SortProperties::Unordered, } } @@ -558,8 +629,7 @@ mod tests { use arrow::datatypes::{ ArrowNumericType, Decimal128Type, Field, Int32Type, SchemaRef, }; - use arrow_schema::ArrowError; - use datafusion_common::Result; + use datafusion_common::{plan_datafusion_err, Result}; use datafusion_expr::type_coercion::binary::get_input_types; /// Performs a binary operation, applying any type coercion necessary @@ -597,7 +667,10 @@ mod tests { let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?; - let result = lt.evaluate(&batch)?.into_array(batch.num_rows()); + let result = lt + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(result.len(), 5); let expected = [false, false, true, true, true]; @@ -641,7 +714,10 @@ mod tests { assert_eq!("a@0 < b@1 OR a@0 = b@1", format!("{expr}")); - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(result.len(), 5); let expected = [true, true, false, true, false]; @@ -685,7 +761,7 @@ mod tests { assert_eq!(expression.data_type(&schema)?, $C_TYPE); // compute - let result = expression.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expression.evaluate(&batch)?.into_array(batch.num_rows()).expect("Failed to convert to array"); // verify that the array's data_type is correct assert_eq!(*result.data_type(), $C_TYPE); @@ -2138,7 +2214,10 @@ mod tests { let arithmetic_op = binary_op(col("a", &schema)?, op, col("b", &schema)?, &schema)?; let batch = RecordBatch::try_new(schema, data)?; - let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); + let result = arithmetic_op + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(result.as_ref(), &expected); Ok(()) @@ -2154,7 +2233,10 @@ mod tests { let lit = Arc::new(Literal::new(literal)); let arithmetic_op = binary_op(col("a", &schema)?, op, lit, &schema)?; let batch = RecordBatch::try_new(schema, data)?; - let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); + let result = arithmetic_op + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(&result, &expected); Ok(()) @@ -2170,7 +2252,10 @@ mod tests { let op = binary_op(col("a", schema)?, op, col("b", schema)?, schema)?; let data: Vec = vec![left.clone(), right.clone()]; let batch = RecordBatch::try_new(schema.clone(), data)?; - let result = op.evaluate(&batch)?.into_array(batch.num_rows()); + let result = op + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(result.as_ref(), &expected); Ok(()) @@ -2187,7 +2272,10 @@ mod tests { let scalar = lit(scalar.clone()); let op = binary_op(scalar, op, col("a", schema)?, schema)?; let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?; - let result = op.evaluate(&batch)?.into_array(batch.num_rows()); + let result = op + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(result.as_ref(), expected); Ok(()) @@ -2204,7 +2292,10 @@ mod tests { let scalar = lit(scalar.clone()); let op = binary_op(col("a", schema)?, op, scalar, schema)?; let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?; - let result = op.evaluate(&batch)?.into_array(batch.num_rows()); + let result = op + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(result.as_ref(), expected); Ok(()) @@ -2776,7 +2867,8 @@ mod tests { let result = expr .evaluate(&batch) .expect("evaluation") - .into_array(batch.num_rows()); + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let expected: Int32Array = input .into_iter() @@ -3255,7 +3347,10 @@ mod tests { let arithmetic_op = binary_op(col("a", schema)?, op, col("b", schema)?, schema)?; let data: Vec = vec![left.clone(), right.clone()]; let batch = RecordBatch::try_new(schema.clone(), data)?; - let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); + let result = arithmetic_op + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(result.as_ref(), expected.as_ref()); Ok(()) @@ -3512,10 +3607,9 @@ mod tests { ) .unwrap_err(); - assert!( - matches!(err, DataFusionError::ArrowError(ArrowError::DivideByZero)), - "{err}" - ); + let _expected = plan_datafusion_err!("Divide by zero"); + + assert!(matches!(err, ref _expected), "{err}"); // decimal let schema = Arc::new(Schema::new(vec![ @@ -3537,10 +3631,7 @@ mod tests { ) .unwrap_err(); - assert!( - matches!(err, DataFusionError::ArrowError(ArrowError::DivideByZero)), - "{err}" - ); + assert!(matches!(err, ref _expected), "{err}"); Ok(()) } diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index a2395c4a0ca2c..52fb85657f4e4 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -126,7 +126,7 @@ impl CaseExpr { let return_type = self.data_type(&batch.schema())?; let expr = self.expr.as_ref().unwrap(); let base_value = expr.evaluate(batch)?; - let base_value = base_value.into_array(batch.num_rows()); + let base_value = base_value.into_array(batch.num_rows())?; let base_nulls = is_null(base_value.as_ref())?; // start with nulls as default output @@ -137,7 +137,7 @@ impl CaseExpr { let when_value = self.when_then_expr[i] .0 .evaluate_selection(batch, &remainder)?; - let when_value = when_value.into_array(batch.num_rows()); + let when_value = when_value.into_array(batch.num_rows())?; // build boolean array representing which rows match the "when" value let when_match = eq(&when_value, &base_value)?; // Treat nulls as false @@ -145,6 +145,8 @@ impl CaseExpr { 0 => Cow::Borrowed(&when_match), _ => Cow::Owned(prep_null_mask_filter(&when_match)), }; + // Make sure we only consider rows that have not been matched yet + let when_match = and(&when_match, &remainder)?; let then_value = self.when_then_expr[i] .1 @@ -153,7 +155,7 @@ impl CaseExpr { ColumnarValue::Scalar(value) if value.is_null() => { new_null_array(&return_type, batch.num_rows()) } - _ => then_value.into_array(batch.num_rows()), + _ => then_value.into_array(batch.num_rows())?, }; current_value = @@ -170,7 +172,7 @@ impl CaseExpr { remainder = or(&base_nulls, &remainder)?; let else_ = expr .evaluate_selection(batch, &remainder)? - .into_array(batch.num_rows()); + .into_array(batch.num_rows())?; current_value = zip(&remainder, else_.as_ref(), current_value.as_ref())?; } @@ -194,7 +196,7 @@ impl CaseExpr { let when_value = self.when_then_expr[i] .0 .evaluate_selection(batch, &remainder)?; - let when_value = when_value.into_array(batch.num_rows()); + let when_value = when_value.into_array(batch.num_rows())?; let when_value = as_boolean_array(&when_value).map_err(|e| { DataFusionError::Context( "WHEN expression did not return a BooleanArray".to_string(), @@ -206,6 +208,8 @@ impl CaseExpr { 0 => Cow::Borrowed(when_value), _ => Cow::Owned(prep_null_mask_filter(when_value)), }; + // Make sure we only consider rows that have not been matched yet + let when_value = and(&when_value, &remainder)?; let then_value = self.when_then_expr[i] .1 @@ -214,7 +218,7 @@ impl CaseExpr { ColumnarValue::Scalar(value) if value.is_null() => { new_null_array(&return_type, batch.num_rows()) } - _ => then_value.into_array(batch.num_rows()), + _ => then_value.into_array(batch.num_rows())?, }; current_value = @@ -231,7 +235,7 @@ impl CaseExpr { .unwrap_or_else(|_| e.clone()); let else_ = expr .evaluate_selection(batch, &remainder)? - .into_array(batch.num_rows()); + .into_array(batch.num_rows())?; current_value = zip(&remainder, else_.as_ref(), current_value.as_ref())?; } @@ -425,7 +429,10 @@ mod tests { None, schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_int32_array(&result)?; let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]); @@ -453,7 +460,10 @@ mod tests { Some(else_value), schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_int32_array(&result)?; let expected = @@ -485,7 +495,10 @@ mod tests { Some(else_value), schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_float64_array(&result).expect("failed to downcast to Float64Array"); @@ -523,7 +536,10 @@ mod tests { None, schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_int32_array(&result)?; let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]); @@ -551,7 +567,10 @@ mod tests { Some(else_value), schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_int32_array(&result)?; let expected = @@ -583,7 +602,10 @@ mod tests { Some(x), schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_float64_array(&result).expect("failed to downcast to Float64Array"); @@ -629,7 +651,10 @@ mod tests { Some(else_value), schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_int32_array(&result)?; let expected = @@ -661,7 +686,10 @@ mod tests { Some(else_value), schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_float64_array(&result).expect("failed to downcast to Float64Array"); @@ -693,7 +721,10 @@ mod tests { None, schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_float64_array(&result).expect("failed to downcast to Float64Array"); @@ -721,7 +752,10 @@ mod tests { None, schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_float64_array(&result).expect("failed to downcast to Float64Array"); diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 9390089063a0e..0c4ed3c125498 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -20,18 +20,16 @@ use std::fmt; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::intervals::Interval; use crate::physical_expr::down_cast_any_ref; use crate::sort_properties::SortProperties; use crate::PhysicalExpr; -use arrow::compute; -use arrow::compute::{kernels, CastOptions}; +use arrow::compute::{can_cast_types, kernels, CastOptions}; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use compute::can_cast_types; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::ColumnarValue; const DEFAULT_CAST_OPTIONS: CastOptions<'static> = CastOptions { @@ -73,6 +71,11 @@ impl CastExpr { pub fn cast_type(&self) -> &DataType { &self.cast_type } + + /// The cast options + pub fn cast_options(&self) -> &CastOptions<'static> { + &self.cast_options + } } impl fmt::Display for CastExpr { @@ -124,21 +127,20 @@ impl PhysicalExpr for CastExpr { &self, interval: &Interval, children: &[&Interval], - ) -> Result>> { + ) -> Result>> { let child_interval = children[0]; // Get child's datatype: - let cast_type = child_interval.get_datatype()?; - Ok(vec![Some( - interval.cast_to(&cast_type, &self.cast_options)?, - )]) + let cast_type = child_interval.data_type(); + Ok(Some( + vec![interval.cast_to(&cast_type, &self.cast_options)?], + )) } fn dyn_hash(&self, state: &mut dyn Hasher) { let mut s = state; self.expr.hash(&mut s); self.cast_type.hash(&mut s); - // Add `self.cast_options` when hash is available - // https://github.com/apache/arrow-rs/pull/4395 + self.cast_options.hash(&mut s); } /// A [`CastExpr`] preserves the ordering of its child. @@ -154,8 +156,7 @@ impl PartialEq for CastExpr { .map(|x| { self.expr.eq(&x.expr) && self.cast_type == x.cast_type - // TODO: Use https://github.com/apache/arrow-rs/issues/2966 when available - && self.cast_options.safe == x.cast_options.safe + && self.cast_options == x.cast_options }) .unwrap_or(false) } @@ -173,7 +174,20 @@ pub fn cast_column( kernels::cast::cast_with_options(array, cast_type, &cast_options)?, )), ColumnarValue::Scalar(scalar) => { - let scalar_array = scalar.to_array(); + let scalar_array = if cast_type + == &DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, None) + { + if let ScalarValue::Float64(Some(float_ts)) = scalar { + ScalarValue::Int64( + Some((float_ts * 1_000_000_000_f64).trunc() as i64), + ) + .to_array()? + } else { + scalar.to_array()? + } + } else { + scalar.to_array()? + }; let cast_array = kernels::cast::cast_with_options( &scalar_array, cast_type, @@ -198,7 +212,10 @@ pub fn cast_with_options( let expr_type = expr.data_type(input_schema)?; if expr_type == cast_type { Ok(expr.clone()) - } else if can_cast_types(&expr_type, &cast_type) { + } else if can_cast_types(&expr_type, &cast_type) + || (expr_type == DataType::Float64 + && cast_type == DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, None)) + { Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options))) } else { not_impl_err!("Unsupported CAST from {expr_type:?} to {cast_type:?}") @@ -221,6 +238,7 @@ pub fn cast( mod tests { use super::*; use crate::expressions::col; + use arrow::{ array::{ Array, Decimal128Array, Float32Array, Float64Array, Int16Array, Int32Array, @@ -229,6 +247,7 @@ mod tests { }, datatypes::*, }; + use datafusion_common::Result; // runs an end-to-end test of physical type cast @@ -258,7 +277,10 @@ mod tests { assert_eq!(expression.data_type(&schema)?, $TYPE); // compute - let result = expression.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expression + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); // verify that the array's data_type is correct assert_eq!(*result.data_type(), $TYPE); @@ -307,7 +329,10 @@ mod tests { assert_eq!(expression.data_type(&schema)?, $TYPE); // compute - let result = expression.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expression + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); // verify that the array's data_type is correct assert_eq!(*result.data_type(), $TYPE); @@ -669,7 +694,11 @@ mod tests { // Ensure a useful error happens at plan time if invalid casts are used let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let result = cast(col("a", &schema).unwrap(), &schema, DataType::LargeBinary); + let result = cast( + col("a", &schema).unwrap(), + &schema, + DataType::Interval(IntervalUnit::MonthDayNano), + ); result.expect_err("expected Invalid CAST"); } diff --git a/datafusion/physical-expr/src/expressions/datum.rs b/datafusion/physical-expr/src/expressions/datum.rs index f57cbbd4ffa3a..2bb79922cfecc 100644 --- a/datafusion/physical-expr/src/expressions/datum.rs +++ b/datafusion/physical-expr/src/expressions/datum.rs @@ -34,14 +34,14 @@ pub(crate) fn apply( (ColumnarValue::Array(left), ColumnarValue::Array(right)) => { Ok(ColumnarValue::Array(f(&left.as_ref(), &right.as_ref())?)) } - (ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => { - Ok(ColumnarValue::Array(f(&left.to_scalar(), &right.as_ref())?)) - } - (ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => { - Ok(ColumnarValue::Array(f(&left.as_ref(), &right.to_scalar())?)) - } + (ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => Ok( + ColumnarValue::Array(f(&left.to_scalar()?, &right.as_ref())?), + ), + (ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => Ok( + ColumnarValue::Array(f(&left.as_ref(), &right.to_scalar()?)?), + ), (ColumnarValue::Scalar(left), ColumnarValue::Scalar(right)) => { - let array = f(&left.to_scalar(), &right.to_scalar())?; + let array = f(&left.to_scalar()?, &right.to_scalar()?)?; let scalar = ScalarValue::try_from_array(array.as_ref(), 0)?; Ok(ColumnarValue::Scalar(scalar)) } diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs b/datafusion/physical-expr/src/expressions/get_indexed_field.rs index df79e28358203..43fd5a812a16c 100644 --- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs +++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs @@ -110,7 +110,7 @@ impl GetIndexedFieldExpr { Self::new( arg, GetFieldAccessExpr::NamedStructField { - name: ScalarValue::Utf8(Some(name.into())), + name: ScalarValue::from(name.into()), }, ) } @@ -183,7 +183,7 @@ impl PhysicalExpr for GetIndexedFieldExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { - let array = self.arg.evaluate(batch)?.into_array(batch.num_rows()); + let array = self.arg.evaluate(batch)?.into_array(batch.num_rows())?; match &self.field { GetFieldAccessExpr::NamedStructField{name} => match (array.data_type(), name) { (DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => { @@ -210,7 +210,7 @@ impl PhysicalExpr for GetIndexedFieldExpr { with utf8 indexes. Tried {dt:?} with {name:?} index"), }, GetFieldAccessExpr::ListIndex{key} => { - let key = key.evaluate(batch)?.into_array(batch.num_rows()); + let key = key.evaluate(batch)?.into_array(batch.num_rows())?; match (array.data_type(), key.data_type()) { (DataType::List(_), DataType::Int64) => Ok(ColumnarValue::Array(array_element(&[ array, key @@ -224,8 +224,8 @@ impl PhysicalExpr for GetIndexedFieldExpr { } }, GetFieldAccessExpr::ListRange{start, stop} => { - let start = start.evaluate(batch)?.into_array(batch.num_rows()); - let stop = stop.evaluate(batch)?.into_array(batch.num_rows()); + let start = start.evaluate(batch)?.into_array(batch.num_rows())?; + let stop = stop.evaluate(batch)?.into_array(batch.num_rows())?; match (array.data_type(), start.data_type(), stop.data_type()) { (DataType::List(_), DataType::Int64, DataType::Int64) => Ok(ColumnarValue::Array(array_slice(&[ array, start, stop @@ -326,7 +326,10 @@ mod tests { // only one row should be processed let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_array)])?; let expr = Arc::new(GetIndexedFieldExpr::new_field(expr, "a")); - let result = expr.evaluate(&batch)?.into_array(1); + let result = expr + .evaluate(&batch)? + .into_array(1) + .expect("Failed to convert to array"); let result = as_boolean_array(&result).expect("failed to downcast to BooleanArray"); assert_eq!(boolean, result.clone()); @@ -383,7 +386,10 @@ mod tests { vec![Arc::new(list_col), Arc::new(key_col)], )?; let expr = Arc::new(GetIndexedFieldExpr::new_index(expr, key)); - let result = expr.evaluate(&batch)?.into_array(1); + let result = expr + .evaluate(&batch)? + .into_array(1) + .expect("Failed to convert to array"); let result = as_string_array(&result).expect("failed to downcast to ListArray"); let expected = StringArray::from(expected_list); assert_eq!(expected, result.clone()); @@ -419,7 +425,10 @@ mod tests { vec![Arc::new(list_col), Arc::new(start_col), Arc::new(stop_col)], )?; let expr = Arc::new(GetIndexedFieldExpr::new_range(expr, start, stop)); - let result = expr.evaluate(&batch)?.into_array(1); + let result = expr + .evaluate(&batch)? + .into_array(1) + .expect("Failed to convert to array"); let result = as_list_array(&result).expect("failed to downcast to ListArray"); let (expected, _, _) = build_list_arguments(expected_list, vec![None], vec![None]); @@ -440,8 +449,11 @@ mod tests { vec![Arc::new(list_builder.finish()), key_array], )?; let expr = Arc::new(GetIndexedFieldExpr::new_index(expr, key)); - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); - assert!(result.is_null(0)); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + assert!(result.is_empty()); Ok(()) } @@ -461,7 +473,10 @@ mod tests { vec![Arc::new(list_builder.finish()), Arc::new(key_array)], )?; let expr = Arc::new(GetIndexedFieldExpr::new_index(expr, key)); - let result = expr.evaluate(&batch)?.into_array(1); + let result = expr + .evaluate(&batch)? + .into_array(1) + .expect("Failed to convert to array"); assert!(result.is_null(0)); Ok(()) } diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 8d55fb70bd9e3..1a1634081c381 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -349,17 +349,18 @@ impl PhysicalExpr for InListExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { + let num_rows = batch.num_rows(); let value = self.expr.evaluate(batch)?; let r = match &self.static_filter { - Some(f) => f.contains(value.into_array(1).as_ref(), self.negated)?, + Some(f) => f.contains(value.into_array(num_rows)?.as_ref(), self.negated)?, None => { - let value = value.into_array(batch.num_rows()); + let value = value.into_array(num_rows)?; let found = self.list.iter().map(|expr| expr.evaluate(batch)).try_fold( - BooleanArray::new(BooleanBuffer::new_unset(batch.num_rows()), None), + BooleanArray::new(BooleanBuffer::new_unset(num_rows), None), |result, expr| -> Result { Ok(or_kleene( &result, - &eq(&value, &expr?.into_array(batch.num_rows()))?, + &eq(&value, &expr?.into_array(num_rows)?)?, )?) }, )?; @@ -501,7 +502,10 @@ mod tests { ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr, $SCHEMA:expr) => {{ let (cast_expr, cast_list_exprs) = in_list_cast($COL, $LIST, $SCHEMA)?; let expr = in_list(cast_expr, cast_list_exprs, $NEGATED, $SCHEMA).unwrap(); - let result = expr.evaluate(&$BATCH)?.into_array($BATCH.num_rows()); + let result = expr + .evaluate(&$BATCH)? + .into_array($BATCH.num_rows()) + .expect("Failed to convert to array"); let result = as_boolean_array(&result).expect("failed to downcast to BooleanArray"); let expected = &BooleanArray::from($EXPECTED); @@ -1264,4 +1268,52 @@ mod tests { Ok(()) } + + #[test] + fn in_list_no_cols() -> Result<()> { + // test logic when the in_list expression doesn't have any columns + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let a = Int32Array::from(vec![Some(1), Some(2), None]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + let list = vec![lit(ScalarValue::from(1i32)), lit(ScalarValue::from(6i32))]; + + // 1 IN (1, 6) + let expr = lit(ScalarValue::Int32(Some(1))); + in_list!( + batch, + list.clone(), + &false, + // should have three outputs, as the input batch has three rows + vec![Some(true), Some(true), Some(true)], + expr, + &schema + ); + + // 2 IN (1, 6) + let expr = lit(ScalarValue::Int32(Some(2))); + in_list!( + batch, + list.clone(), + &false, + // should have three outputs, as the input batch has three rows + vec![Some(false), Some(false), Some(false)], + expr, + &schema + ); + + // NULL IN (1, 6) + let expr = lit(ScalarValue::Int32(None)); + in_list!( + batch, + list.clone(), + &false, + // should have three outputs, as the input batch has three rows + vec![None, None, None], + expr, + &schema + ); + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index da717a517fb37..2e6a2bec9cab5 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -132,7 +132,10 @@ mod tests { let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; // expression: "a is not null" - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_boolean_array(&result).expect("failed to downcast to BooleanArray"); diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index ee7897edd4de6..3ad4058dd6493 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -134,7 +134,10 @@ mod tests { let expr = is_null(col("a", &schema)?).unwrap(); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_boolean_array(&result).expect("failed to downcast to BooleanArray"); diff --git a/datafusion/physical-expr/src/expressions/like.rs b/datafusion/physical-expr/src/expressions/like.rs index e833eabbfff26..37452e278484a 100644 --- a/datafusion/physical-expr/src/expressions/like.rs +++ b/datafusion/physical-expr/src/expressions/like.rs @@ -201,7 +201,10 @@ mod test { )?; // compute - let result = expression.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expression + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_boolean_array(&result).expect("failed to downcast to BooleanArray"); let expected = &BooleanArray::from($VEC); diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index 91cb23d5864e6..cd3b51f09105a 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -131,7 +131,10 @@ mod tests { let literal_expr = lit(42i32); assert_eq!("42", format!("{literal_expr}")); - let literal_array = literal_expr.evaluate(&batch)?.into_array(batch.num_rows()); + let literal_array = literal_expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let literal_array = as_int32_array(&literal_array)?; // note that the contents of the literal array are unrelated to the batch contents except for the length of the array diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index c44b3cf01d36c..b6d0ad5b91043 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -63,6 +63,7 @@ pub use crate::aggregate::min_max::{MaxAccumulator, MinAccumulator}; pub use crate::aggregate::regr::{Regr, RegrType}; pub use crate::aggregate::stats::StatsType; pub use crate::aggregate::stddev::{Stddev, StddevPop}; +pub use crate::aggregate::string_agg::StringAgg; pub use crate::aggregate::sum::Sum; pub use crate::aggregate::sum_distinct::DistinctSum; pub use crate::aggregate::variance::{Variance, VariancePop}; @@ -247,8 +248,10 @@ pub(crate) mod tests { let expr = agg.expressions(); let values = expr .iter() - .map(|e| e.evaluate(batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .map(|e| { + e.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) .collect::>>()?; accum.update_batch(&values)?; accum.evaluate() @@ -262,8 +265,10 @@ pub(crate) mod tests { let expr = agg.expressions(); let values = expr .iter() - .map(|e| e.evaluate(batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .map(|e| { + e.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) .collect::>>()?; let indices = vec![0; batch.num_rows()]; accum.update_batch(&values, &indices, None, 1)?; diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index 86b000e76a321..b64b4a0c86def 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -17,25 +17,26 @@ //! Negation (-) expression -use crate::intervals::Interval; +use std::any::Any; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + use crate::physical_expr::down_cast_any_ref; use crate::sort_properties::SortProperties; use crate::PhysicalExpr; + use arrow::{ compute::kernels::numeric::neg_wrapping, datatypes::{DataType, Schema}, record_batch::RecordBatch, }; use datafusion_common::{internal_err, DataFusionError, Result}; +use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::{ - type_coercion::{is_interval, is_null, is_signed_numeric}, + type_coercion::{is_interval, is_null, is_signed_numeric, is_timestamp}, ColumnarValue, }; -use std::any::Any; -use std::hash::{Hash, Hasher}; -use std::sync::Arc; - /// Negative expression #[derive(Debug, Hash)] pub struct NegativeExpr { @@ -108,10 +109,10 @@ impl PhysicalExpr for NegativeExpr { /// It replaces the upper and lower bounds after multiplying them with -1. /// Ex: `(a, b]` => `[-b, -a)` fn evaluate_bounds(&self, children: &[&Interval]) -> Result { - Ok(Interval::new( - children[0].upper.negate()?, - children[0].lower.negate()?, - )) + Interval::try_new( + children[0].upper().arithmetic_negate()?, + children[0].lower().arithmetic_negate()?, + ) } /// Returns a new [`Interval`] of a NegativeExpr that has the existing `interval` given that @@ -120,12 +121,16 @@ impl PhysicalExpr for NegativeExpr { &self, interval: &Interval, children: &[&Interval], - ) -> Result>> { + ) -> Result>> { let child_interval = children[0]; - let negated_interval = - Interval::new(interval.upper.negate()?, interval.lower.negate()?); + let negated_interval = Interval::try_new( + interval.upper().arithmetic_negate()?, + interval.lower().arithmetic_negate()?, + )?; - Ok(vec![child_interval.intersect(negated_interval)?]) + Ok(child_interval + .intersect(negated_interval)? + .map(|result| vec![result])) } /// The ordering of a [`NegativeExpr`] is simply the reverse of its child. @@ -155,7 +160,10 @@ pub fn negative( let data_type = arg.data_type(input_schema)?; if is_null(&data_type) { Ok(arg) - } else if !is_signed_numeric(&data_type) && !is_interval(&data_type) { + } else if !is_signed_numeric(&data_type) + && !is_interval(&data_type) + && !is_timestamp(&data_type) + { internal_err!( "Can't create negative physical expr for (- '{arg:?}'), the type of child expr is {data_type}, not signed numeric" ) @@ -167,14 +175,14 @@ pub fn negative( #[cfg(test)] mod tests { use super::*; - use crate::{ - expressions::{col, Column}, - intervals::Interval, - }; + use crate::expressions::{col, Column}; + use arrow::array::*; use arrow::datatypes::*; use arrow_schema::DataType::{Float32, Float64, Int16, Int32, Int64, Int8}; - use datafusion_common::{cast::as_primitive_array, Result}; + use datafusion_common::cast::as_primitive_array; + use datafusion_common::Result; + use paste::paste; macro_rules! test_array_negative_op { @@ -195,7 +203,7 @@ mod tests { let expected = &paste!{[<$DATA_TY Array>]::from(arr_expected)}; let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(input)])?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr.evaluate(&batch)?.into_array(batch.num_rows()).expect("Failed to convert to array"); let result = as_primitive_array(&result).expect(format!("failed to downcast to {:?}Array", $DATA_TY).as_str()); assert_eq!(result, expected); @@ -218,8 +226,8 @@ mod tests { let negative_expr = NegativeExpr { arg: Arc::new(Column::new("a", 0)), }; - let child_interval = Interval::make(Some(-2), Some(1), (true, false)); - let negative_expr_interval = Interval::make(Some(-1), Some(2), (false, true)); + let child_interval = Interval::make(Some(-2), Some(1))?; + let negative_expr_interval = Interval::make(Some(-1), Some(2))?; assert_eq!( negative_expr.evaluate_bounds(&[&child_interval])?, negative_expr_interval @@ -232,10 +240,9 @@ mod tests { let negative_expr = NegativeExpr { arg: Arc::new(Column::new("a", 0)), }; - let original_child_interval = Interval::make(Some(-2), Some(3), (false, false)); - let negative_expr_interval = Interval::make(Some(0), Some(4), (true, false)); - let after_propagation = - vec![Some(Interval::make(Some(-2), Some(0), (false, true)))]; + let original_child_interval = Interval::make(Some(-2), Some(3))?; + let negative_expr_interval = Interval::make(Some(0), Some(4))?; + let after_propagation = Some(vec![Interval::make(Some(-2), Some(0))?]); assert_eq!( negative_expr.propagate_constraints( &negative_expr_interval, diff --git a/datafusion/physical-expr/src/expressions/not.rs b/datafusion/physical-expr/src/expressions/not.rs index c154fad100371..4ceccc6932fe4 100644 --- a/datafusion/physical-expr/src/expressions/not.rs +++ b/datafusion/physical-expr/src/expressions/not.rs @@ -150,7 +150,10 @@ mod tests { let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(input)])?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_boolean_array(&result).expect("failed to downcast to BooleanArray"); assert_eq!(result, expected); diff --git a/datafusion/physical-expr/src/expressions/nullif.rs b/datafusion/physical-expr/src/expressions/nullif.rs index 7bbe9d73d4358..dcd883f92965b 100644 --- a/datafusion/physical-expr/src/expressions/nullif.rs +++ b/datafusion/physical-expr/src/expressions/nullif.rs @@ -37,7 +37,7 @@ pub fn nullif_func(args: &[ColumnarValue]) -> Result { match (lhs, rhs) { (ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => { - let rhs = rhs.to_scalar(); + let rhs = rhs.to_scalar()?; let array = nullif(lhs, &eq(&lhs, &rhs)?)?; Ok(ColumnarValue::Array(array)) @@ -47,7 +47,7 @@ pub fn nullif_func(args: &[ColumnarValue]) -> Result { Ok(ColumnarValue::Array(array)) } (ColumnarValue::Scalar(lhs), ColumnarValue::Array(rhs)) => { - let lhs = lhs.to_array_of_size(rhs.len()); + let lhs = lhs.to_array_of_size(rhs.len())?; let array = nullif(&lhs, &eq(&lhs, &rhs)?)?; Ok(ColumnarValue::Array(array)) } @@ -89,7 +89,7 @@ mod tests { let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); let result = nullif_func(&[a, lit_array])?; - let result = result.into_array(0); + let result = result.into_array(0).expect("Failed to convert to array"); let expected = Arc::new(Int32Array::from(vec![ Some(1), @@ -115,7 +115,7 @@ mod tests { let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(1i32))); let result = nullif_func(&[a, lit_array])?; - let result = result.into_array(0); + let result = result.into_array(0).expect("Failed to convert to array"); let expected = Arc::new(Int32Array::from(vec![ None, @@ -140,7 +140,7 @@ mod tests { let lit_array = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))); let result = nullif_func(&[a, lit_array])?; - let result = result.into_array(0); + let result = result.into_array(0).expect("Failed to convert to array"); let expected = Arc::new(BooleanArray::from(vec![Some(true), None, None])) as ArrayRef; @@ -154,10 +154,10 @@ mod tests { let a = StringArray::from(vec![Some("foo"), Some("bar"), None, Some("baz")]); let a = ColumnarValue::Array(Arc::new(a)); - let lit_array = ColumnarValue::Scalar(ScalarValue::Utf8(Some("bar".to_string()))); + let lit_array = ColumnarValue::Scalar(ScalarValue::from("bar")); let result = nullif_func(&[a, lit_array])?; - let result = result.into_array(0); + let result = result.into_array(0).expect("Failed to convert to array"); let expected = Arc::new(StringArray::from(vec![ Some("foo"), @@ -178,7 +178,7 @@ mod tests { let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); let result = nullif_func(&[lit_array, a])?; - let result = result.into_array(0); + let result = result.into_array(0).expect("Failed to convert to array"); let expected = Arc::new(Int32Array::from(vec![ Some(2), @@ -198,7 +198,7 @@ mod tests { let b_eq = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); let result_eq = nullif_func(&[a_eq, b_eq])?; - let result_eq = result_eq.into_array(1); + let result_eq = result_eq.into_array(1).expect("Failed to convert to array"); let expected_eq = Arc::new(Int32Array::from(vec![None])) as ArrayRef; @@ -208,7 +208,9 @@ mod tests { let b_neq = ColumnarValue::Scalar(ScalarValue::Int32(Some(1i32))); let result_neq = nullif_func(&[a_neq, b_neq])?; - let result_neq = result_neq.into_array(1); + let result_neq = result_neq + .into_array(1) + .expect("Failed to convert to array"); let expected_neq = Arc::new(Int32Array::from(vec![Some(2i32)])) as ArrayRef; assert_eq!(expected_neq.as_ref(), result_neq.as_ref()); diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index cba026c565134..0f7909097a106 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -89,7 +89,7 @@ impl PhysicalExpr for TryCastExpr { Ok(ColumnarValue::Array(cast)) } ColumnarValue::Scalar(scalar) => { - let array = scalar.to_array(); + let array = scalar.to_array()?; let cast_array = cast_with_options(&array, &self.cast_type, &options)?; let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?; Ok(ColumnarValue::Scalar(cast_scalar)) @@ -187,7 +187,10 @@ mod tests { assert_eq!(expression.data_type(&schema)?, $TYPE); // compute - let result = expression.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expression + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); // verify that the array's data_type is correct assert_eq!(*result.data_type(), $TYPE); @@ -235,7 +238,10 @@ mod tests { assert_eq!(expression.data_type(&schema)?, $TYPE); // compute - let result = expression.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expression + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); // verify that the array's data_type is correct assert_eq!(*result.data_type(), $TYPE); @@ -549,7 +555,11 @@ mod tests { // Ensure a useful error happens at plan time if invalid casts are used let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let result = try_cast(col("a", &schema).unwrap(), &schema, DataType::LargeBinary); + let result = try_cast( + col("a", &schema).unwrap(), + &schema, + DataType::Interval(IntervalUnit::MonthDayNano), + ); result.expect_err("expected Invalid TRY_CAST"); } diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index b66bac41014da..53de858439190 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -34,20 +34,19 @@ use crate::execution_props::ExecutionProps; use crate::sort_properties::SortProperties; use crate::{ array_expressions, conditional_expressions, datetime_expressions, - expressions::{cast_column, nullif_func}, - math_expressions, string_expressions, struct_expressions, PhysicalExpr, - ScalarFunctionExpr, + expressions::nullif_func, math_expressions, string_expressions, struct_expressions, + PhysicalExpr, ScalarFunctionExpr, }; use arrow::{ array::ArrayRef, compute::kernels::length::{bit_length, length}, - datatypes::TimeUnit, datatypes::{DataType, Int32Type, Int64Type, Schema}, }; use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; pub use datafusion_expr::FuncMonotonicity; use datafusion_expr::{ - BuiltinScalarFunction, ColumnarValue, ScalarFunctionImplementation, + type_coercion::functions::data_types, BuiltinScalarFunction, ColumnarValue, + ScalarFunctionImplementation, }; use std::ops::Neg; use std::sync::Arc; @@ -65,145 +64,13 @@ pub fn create_physical_expr( .map(|e| e.data_type(input_schema)) .collect::>>()?; + // verify that input data types is consistent with function's `TypeSignature` + data_types(&input_expr_types, &fun.signature())?; + let data_type = fun.return_type(&input_expr_types)?; - let fun_expr: ScalarFunctionImplementation = match fun { - // These functions need args and input schema to pick an implementation - // Unlike the string functions, which actually figure out the function to use with each array, - // here we return either a cast fn or string timestamp translation based on the expression data type - // so we don't have to pay a per-array/batch cost. - BuiltinScalarFunction::ToTimestamp => { - Arc::new(match input_phy_exprs[0].data_type(input_schema) { - Ok(DataType::Int64) => |col_values: &[ColumnarValue]| { - cast_column( - &col_values[0], - &DataType::Timestamp(TimeUnit::Second, None), - None, - ) - }, - Ok(DataType::Timestamp(_, None)) => |col_values: &[ColumnarValue]| { - cast_column( - &col_values[0], - &DataType::Timestamp(TimeUnit::Nanosecond, None), - None, - ) - }, - Ok(DataType::Utf8) => datetime_expressions::to_timestamp, - other => { - return internal_err!( - "Unsupported data type {other:?} for function to_timestamp" - ); - } - }) - } - BuiltinScalarFunction::ToTimestampMillis => { - Arc::new(match input_phy_exprs[0].data_type(input_schema) { - Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => { - |col_values: &[ColumnarValue]| { - cast_column( - &col_values[0], - &DataType::Timestamp(TimeUnit::Millisecond, None), - None, - ) - } - } - Ok(DataType::Utf8) => datetime_expressions::to_timestamp_millis, - other => { - return internal_err!( - "Unsupported data type {other:?} for function to_timestamp_millis" - ); - } - }) - } - BuiltinScalarFunction::ToTimestampMicros => { - Arc::new(match input_phy_exprs[0].data_type(input_schema) { - Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => { - |col_values: &[ColumnarValue]| { - cast_column( - &col_values[0], - &DataType::Timestamp(TimeUnit::Microsecond, None), - None, - ) - } - } - Ok(DataType::Utf8) => datetime_expressions::to_timestamp_micros, - other => { - return internal_err!( - "Unsupported data type {other:?} for function to_timestamp_micros" - ); - } - }) - } - BuiltinScalarFunction::ToTimestampNanos => { - Arc::new(match input_phy_exprs[0].data_type(input_schema) { - Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => { - |col_values: &[ColumnarValue]| { - cast_column( - &col_values[0], - &DataType::Timestamp(TimeUnit::Nanosecond, None), - None, - ) - } - } - Ok(DataType::Utf8) => datetime_expressions::to_timestamp_nanos, - other => { - return internal_err!( - "Unsupported data type {other:?} for function to_timestamp_nanos" - ); - } - }) - } - BuiltinScalarFunction::ToTimestampSeconds => Arc::new({ - match input_phy_exprs[0].data_type(input_schema) { - Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => { - |col_values: &[ColumnarValue]| { - cast_column( - &col_values[0], - &DataType::Timestamp(TimeUnit::Second, None), - None, - ) - } - } - Ok(DataType::Utf8) => datetime_expressions::to_timestamp_seconds, - other => { - return internal_err!( - "Unsupported data type {other:?} for function to_timestamp_seconds" - ); - } - } - }), - BuiltinScalarFunction::FromUnixtime => Arc::new({ - match input_phy_exprs[0].data_type(input_schema) { - Ok(DataType::Int64) => |col_values: &[ColumnarValue]| { - cast_column( - &col_values[0], - &DataType::Timestamp(TimeUnit::Second, None), - None, - ) - }, - other => { - return internal_err!( - "Unsupported data type {other:?} for function from_unixtime" - ); - } - } - }), - BuiltinScalarFunction::ArrowTypeof => { - let input_data_type = input_phy_exprs[0].data_type(input_schema)?; - Arc::new(move |_| { - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(format!( - "{input_data_type}" - ))))) - }) - } - BuiltinScalarFunction::Abs => { - let input_data_type = input_phy_exprs[0].data_type(input_schema)?; - let abs_fun = math_expressions::create_abs_function(&input_data_type)?; - Arc::new(move |args| make_scalar_function(abs_fun)(args)) - } - // These don't need args and input schema - _ => create_physical_fun(fun, execution_props)?, - }; + let fun_expr: ScalarFunctionImplementation = + create_physical_fun(fun, execution_props)?; let monotonicity = fun.monotonicity(); @@ -211,7 +78,7 @@ pub fn create_physical_expr( &format!("{fun}"), fun_expr, input_phy_exprs.to_vec(), - &data_type, + data_type, monotonicity, ))) } @@ -372,7 +239,7 @@ where }; arg.clone().into_array(expansion_len) }) - .collect::>(); + .collect::>>()?; let result = (inner)(&args); @@ -393,6 +260,9 @@ pub fn create_physical_fun( ) -> Result { Ok(match fun { // math functions + BuiltinScalarFunction::Abs => { + Arc::new(|args| make_scalar_function(math_expressions::abs_invoke)(args)) + } BuiltinScalarFunction::Acos => Arc::new(math_expressions::acos), BuiltinScalarFunction::Asin => Arc::new(math_expressions::asin), BuiltinScalarFunction::Atan => Arc::new(math_expressions::atan), @@ -459,6 +329,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayAppend => { Arc::new(|args| make_scalar_function(array_expressions::array_append)(args)) } + BuiltinScalarFunction::ArraySort => { + Arc::new(|args| make_scalar_function(array_expressions::array_sort)(args)) + } BuiltinScalarFunction::ArrayConcat => { Arc::new(|args| make_scalar_function(array_expressions::array_concat)(args)) } @@ -477,9 +350,15 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayDims => { Arc::new(|args| make_scalar_function(array_expressions::array_dims)(args)) } + BuiltinScalarFunction::ArrayDistinct => { + Arc::new(|args| make_scalar_function(array_expressions::array_distinct)(args)) + } BuiltinScalarFunction::ArrayElement => { Arc::new(|args| make_scalar_function(array_expressions::array_element)(args)) } + BuiltinScalarFunction::ArrayExcept => { + Arc::new(|args| make_scalar_function(array_expressions::array_except)(args)) + } BuiltinScalarFunction::ArrayLength => { Arc::new(|args| make_scalar_function(array_expressions::array_length)(args)) } @@ -489,6 +368,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayNdims => { Arc::new(|args| make_scalar_function(array_expressions::array_ndims)(args)) } + BuiltinScalarFunction::ArrayPopFront => Arc::new(|args| { + make_scalar_function(array_expressions::array_pop_front)(args) + }), BuiltinScalarFunction::ArrayPopBack => { Arc::new(|args| make_scalar_function(array_expressions::array_pop_back)(args)) } @@ -528,13 +410,21 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayToString => Arc::new(|args| { make_scalar_function(array_expressions::array_to_string)(args) }), + BuiltinScalarFunction::ArrayIntersect => Arc::new(|args| { + make_scalar_function(array_expressions::array_intersect)(args) + }), + BuiltinScalarFunction::Range => { + Arc::new(|args| make_scalar_function(array_expressions::gen_range)(args)) + } BuiltinScalarFunction::Cardinality => { Arc::new(|args| make_scalar_function(array_expressions::cardinality)(args)) } BuiltinScalarFunction::MakeArray => { Arc::new(|args| make_scalar_function(array_expressions::make_array)(args)) } - + BuiltinScalarFunction::ArrayUnion => { + Arc::new(|args| make_scalar_function(array_expressions::array_union)(args)) + } // struct functions BuiltinScalarFunction::Struct => Arc::new(struct_expressions::struct_expr), @@ -621,6 +511,24 @@ pub fn create_physical_fun( execution_props.query_execution_start_time, )) } + BuiltinScalarFunction::ToTimestamp => { + Arc::new(datetime_expressions::to_timestamp_invoke) + } + BuiltinScalarFunction::ToTimestampMillis => { + Arc::new(datetime_expressions::to_timestamp_millis_invoke) + } + BuiltinScalarFunction::ToTimestampMicros => { + Arc::new(datetime_expressions::to_timestamp_micros_invoke) + } + BuiltinScalarFunction::ToTimestampNanos => { + Arc::new(datetime_expressions::to_timestamp_nanos_invoke) + } + BuiltinScalarFunction::ToTimestampSeconds => { + Arc::new(datetime_expressions::to_timestamp_seconds_invoke) + } + BuiltinScalarFunction::FromUnixtime => { + Arc::new(datetime_expressions::from_unixtime_invoke) + } BuiltinScalarFunction::InitCap => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::initcap::)(args) @@ -923,11 +831,87 @@ pub fn create_physical_fun( }), BuiltinScalarFunction::Upper => Arc::new(string_expressions::upper), BuiltinScalarFunction::Uuid => Arc::new(string_expressions::uuid), - _ => { - return internal_err!( - "create_physical_fun: Unsupported scalar function {fun:?}" - ); + BuiltinScalarFunction::ArrowTypeof => Arc::new(move |args| { + if args.len() != 1 { + return internal_err!( + "arrow_typeof function requires 1 arguments, got {}", + args.len() + ); + } + + let input_data_type = args[0].data_type(); + Ok(ColumnarValue::Scalar(ScalarValue::from(format!( + "{input_data_type}" + )))) + }), + BuiltinScalarFunction::OverLay => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::overlay::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::overlay::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function overlay", + ))), + }), + BuiltinScalarFunction::Levenshtein => { + Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::levenshtein::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::levenshtein::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function levenshtein", + ))), + }) } + BuiltinScalarFunction::SubstrIndex => { + Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + substr_index, + i32, + "substr_index" + ); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + substr_index, + i64, + "substr_index" + ); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function substr_index", + ))), + }) + } + BuiltinScalarFunction::FindInSet => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + find_in_set, + Int32Type, + "find_in_set" + ); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + find_in_set, + Int64Type, + "find_in_set" + ); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function find_in_set", + ))), + }), }) } @@ -1038,7 +1022,7 @@ mod tests { match expected { Ok(expected) => { let result = expr.evaluate(&batch)?; - let result = result.into_array(batch.num_rows()); + let result = result.into_array(batch.num_rows()).expect("Failed to convert to array"); let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); // value is correct @@ -2952,13 +2936,8 @@ mod tests { "Builtin scalar function {fun} does not support empty arguments" ); } - Err(DataFusionError::Plan(err)) => { - if !err - .contains("No function matches the given name and argument types") - { - return plan_err!( - "Builtin scalar function {fun} didn't got the right error message with empty arguments"); - } + Err(DataFusionError::Plan(_)) => { + // Continue the loop } Err(..) => { return internal_err!( @@ -3012,7 +2991,10 @@ mod tests { // evaluate works let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); // downcast works let result = as_list_array(&result)?; @@ -3051,7 +3033,10 @@ mod tests { // evaluate works let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); // downcast works let result = as_list_array(&result)?; @@ -3123,8 +3108,11 @@ mod tests { let adapter_func = make_scalar_function(dummy_function); let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = - ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let array_arg = ColumnarValue::Array( + ScalarValue::Int64(Some(1)) + .to_array_of_size(5) + .expect("Failed to convert to array of size"), + ); let result = unpack_uint64_array(adapter_func(&[array_arg, scalar_arg]))?; assert_eq!(result, vec![5, 5]); @@ -3136,8 +3124,11 @@ mod tests { let adapter_func = make_scalar_function_with_hints(dummy_function, vec![]); let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = - ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let array_arg = ColumnarValue::Array( + ScalarValue::Int64(Some(1)) + .to_array_of_size(5) + .expect("Failed to convert to array of size"), + ); let result = unpack_uint64_array(adapter_func(&[array_arg, scalar_arg]))?; assert_eq!(result, vec![5, 5]); @@ -3152,8 +3143,11 @@ mod tests { ); let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = - ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let array_arg = ColumnarValue::Array( + ScalarValue::Int64(Some(1)) + .to_array_of_size(5) + .expect("Failed to convert to array of size"), + ); let result = unpack_uint64_array(adapter_func(&[array_arg, scalar_arg]))?; assert_eq!(result, vec![5, 1]); @@ -3162,8 +3156,11 @@ mod tests { #[test] fn test_make_scalar_function_with_hints_on_arrays() -> Result<()> { - let array_arg = - ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let array_arg = ColumnarValue::Array( + ScalarValue::Int64(Some(1)) + .to_array_of_size(5) + .expect("Failed to convert to array of size"), + ); let adapter_func = make_scalar_function_with_hints( dummy_function, vec![Hint::Pad, Hint::AcceptsSingular], @@ -3183,8 +3180,11 @@ mod tests { ); let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = - ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let array_arg = ColumnarValue::Array( + ScalarValue::Int64(Some(1)) + .to_array_of_size(5) + .expect("Failed to convert to array of size"), + ); let result = unpack_uint64_array(adapter_func(&[ array_arg, scalar_arg.clone(), @@ -3203,8 +3203,11 @@ mod tests { ); let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = - ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let array_arg = ColumnarValue::Array( + ScalarValue::Int64(Some(1)) + .to_array_of_size(5) + .expect("Failed to convert to array of size"), + ); let result = unpack_uint64_array(adapter_func(&[ array_arg.clone(), scalar_arg.clone(), @@ -3231,8 +3234,11 @@ mod tests { ); let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = - ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let array_arg = ColumnarValue::Array( + ScalarValue::Int64(Some(1)) + .to_array_of_size(5) + .expect("Failed to convert to array of size"), + ); let result = unpack_uint64_array(adapter_func(&[array_arg, scalar_arg]))?; assert_eq!(result, vec![5, 1]); diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index e7515341c52cf..5064ad8d5c487 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -24,15 +24,13 @@ use std::sync::Arc; use super::utils::{ convert_duration_type_to_interval, convert_interval_type_to_duration, get_inverse_op, }; -use super::IntervalBound; use crate::expressions::Literal; -use crate::intervals::interval_aritmetic::{apply_operator, Interval}; use crate::utils::{build_dag, ExprTreeNode}; use crate::PhysicalExpr; -use arrow_schema::DataType; -use datafusion_common::{DataFusionError, Result, ScalarValue}; -use datafusion_expr::type_coercion::binary::get_result_type; +use arrow_schema::{DataType, Schema}; +use datafusion_common::{internal_err, DataFusionError, Result}; +use datafusion_expr::interval_arithmetic::{apply_operator, satisfy_greater, Interval}; use datafusion_expr::Operator; use petgraph::graph::NodeIndex; @@ -148,7 +146,7 @@ pub enum PropagationResult { } /// This is a node in the DAEG; it encapsulates a reference to the actual -/// [PhysicalExpr] as well as an interval containing expression bounds. +/// [`PhysicalExpr`] as well as an interval containing expression bounds. #[derive(Clone, Debug)] pub struct ExprIntervalGraphNode { expr: Arc, @@ -163,11 +161,9 @@ impl Display for ExprIntervalGraphNode { impl ExprIntervalGraphNode { /// Constructs a new DAEG node with an [-∞, ∞] range. - pub fn new(expr: Arc) -> Self { - ExprIntervalGraphNode { - expr, - interval: Interval::default(), - } + pub fn new_unbounded(expr: Arc, dt: &DataType) -> Result { + Interval::make_unbounded(dt) + .map(|interval| ExprIntervalGraphNode { expr, interval }) } /// Constructs a new DAEG node with the given range. @@ -180,26 +176,24 @@ impl ExprIntervalGraphNode { &self.interval } - /// This function creates a DAEG node from Datafusion's [ExprTreeNode] + /// This function creates a DAEG node from Datafusion's [`ExprTreeNode`] /// object. Literals are created with definite, singleton intervals while /// any other expression starts with an indefinite interval ([-∞, ∞]). - pub fn make_node(node: &ExprTreeNode) -> ExprIntervalGraphNode { + pub fn make_node(node: &ExprTreeNode, schema: &Schema) -> Result { let expr = node.expression().clone(); if let Some(literal) = expr.as_any().downcast_ref::() { let value = literal.value(); - let interval = Interval::new( - IntervalBound::new_closed(value.clone()), - IntervalBound::new_closed(value.clone()), - ); - ExprIntervalGraphNode::new_with_interval(expr, interval) + Interval::try_new(value.clone(), value.clone()) + .map(|interval| Self::new_with_interval(expr, interval)) } else { - ExprIntervalGraphNode::new(expr) + expr.data_type(schema) + .and_then(|dt| Self::new_unbounded(expr, &dt)) } } } impl PartialEq for ExprIntervalGraphNode { - fn eq(&self, other: &ExprIntervalGraphNode) -> bool { + fn eq(&self, other: &Self) -> bool { self.expr.eq(&other.expr) } } @@ -216,16 +210,23 @@ impl PartialEq for ExprIntervalGraphNode { /// - For minus operation, specifically, we would first do /// - [xL, xU] <- ([yL, yU] + [pL, pU]) ∩ [xL, xU], and then /// - [yL, yU] <- ([xL, xU] - [pL, pU]) ∩ [yL, yU]. +/// - For multiplication operation, specifically, we would first do +/// - [xL, xU] <- ([pL, pU] / [yL, yU]) ∩ [xL, xU], and then +/// - [yL, yU] <- ([pL, pU] / [xL, xU]) ∩ [yL, yU]. +/// - For division operation, specifically, we would first do +/// - [xL, xU] <- ([yL, yU] * [pL, pU]) ∩ [xL, xU], and then +/// - [yL, yU] <- ([xL, xU] / [pL, pU]) ∩ [yL, yU]. pub fn propagate_arithmetic( op: &Operator, parent: &Interval, left_child: &Interval, right_child: &Interval, -) -> Result<(Option, Option)> { - let inverse_op = get_inverse_op(*op); - match (left_child.get_datatype()?, right_child.get_datatype()?) { - // If we have a child whose type is a time interval (i.e. DataType::Interval), we need special handling - // since timestamp differencing results in a Duration type. +) -> Result> { + let inverse_op = get_inverse_op(*op)?; + match (left_child.data_type(), right_child.data_type()) { + // If we have a child whose type is a time interval (i.e. DataType::Interval), + // we need special handling since timestamp differencing results in a + // Duration type. (DataType::Timestamp(..), DataType::Interval(_)) => { propagate_time_interval_at_right( left_child, @@ -250,87 +251,109 @@ pub fn propagate_arithmetic( .intersect(left_child)? { // Left is feasible: - Some(value) => { + Some(value) => Ok( // Propagate to the right using the new left. - let right = - propagate_right(&value, parent, right_child, op, &inverse_op)?; - - // Return intervals for both children: - Ok((Some(value), right)) - } + propagate_right(&value, parent, right_child, op, &inverse_op)? + .map(|right| (value, right)), + ), // If the left child is infeasible, short-circuit. - None => Ok((None, None)), + None => Ok(None), } } } } -/// This function provides a target parent interval for comparison operators. -/// If we have expression > 0, expression must have the range (0, ∞). -/// If we have expression >= 0, expression must have the range [0, ∞). -/// If we have expression < 0, expression must have the range (-∞, 0). -/// If we have expression <= 0, expression must have the range (-∞, 0]. -fn comparison_operator_target( - left_datatype: &DataType, - op: &Operator, - right_datatype: &DataType, -) -> Result { - let datatype = get_result_type(left_datatype, &Operator::Minus, right_datatype)?; - let unbounded = IntervalBound::make_unbounded(&datatype)?; - let zero = ScalarValue::new_zero(&datatype)?; - Ok(match *op { - Operator::GtEq => Interval::new(IntervalBound::new_closed(zero), unbounded), - Operator::Gt => Interval::new(IntervalBound::new_open(zero), unbounded), - Operator::LtEq => Interval::new(unbounded, IntervalBound::new_closed(zero)), - Operator::Lt => Interval::new(unbounded, IntervalBound::new_open(zero)), - Operator::Eq => Interval::new( - IntervalBound::new_closed(zero.clone()), - IntervalBound::new_closed(zero), - ), - _ => unreachable!(), - }) -} - -/// This function propagates constraints arising from comparison operators. -/// The main idea is that we can analyze an inequality like x > y through the -/// equivalent inequality x - y > 0. Assuming that x and y has ranges [xL, xU] -/// and [yL, yU], we simply apply constraint propagation across [xL, xU], -/// [yL, yH] and [0, ∞]. Specifically, we would first do -/// - [xL, xU] <- ([yL, yU] + [0, ∞]) ∩ [xL, xU], and then -/// - [yL, yU] <- ([xL, xU] - [0, ∞]) ∩ [yL, yU]. +/// This function refines intervals `left_child` and `right_child` by applying +/// comparison propagation through `parent` via operation. The main idea is +/// that we can shrink ranges of variables x and y using parent interval p. +/// Two intervals can be ordered in 6 ways for a Gt `>` operator: +/// ```text +/// (1): Infeasible, short-circuit +/// left: | ================ | +/// right: | ======================== | +/// +/// (2): Update both interval +/// left: | ====================== | +/// right: | ====================== | +/// | +/// V +/// left: | ======= | +/// right: | ======= | +/// +/// (3): Update left interval +/// left: | ============================== | +/// right: | ========== | +/// | +/// V +/// left: | ===================== | +/// right: | ========== | +/// +/// (4): Update right interval +/// left: | ========== | +/// right: | =========================== | +/// | +/// V +/// left: | ========== | +/// right | ================== | +/// +/// (5): No change +/// left: | ============================ | +/// right: | =================== | +/// +/// (6): No change +/// left: | ==================== | +/// right: | =============== | +/// +/// -inf --------------------------------------------------------------- +inf +/// ``` pub fn propagate_comparison( op: &Operator, + parent: &Interval, left_child: &Interval, right_child: &Interval, -) -> Result<(Option, Option)> { - let left_type = left_child.get_datatype()?; - let right_type = right_child.get_datatype()?; - let parent = comparison_operator_target(&left_type, op, &right_type)?; - match (&left_type, &right_type) { - // We can not compare a Duration type with a time interval type - // without a reference timestamp unless the latter has a zero month field. - (DataType::Interval(_), DataType::Duration(_)) => { - propagate_comparison_to_time_interval_at_left( - left_child, - &parent, - right_child, - ) +) -> Result> { + if parent == &Interval::CERTAINLY_TRUE { + match op { + Operator::Eq => left_child.intersect(right_child).map(|result| { + result.map(|intersection| (intersection.clone(), intersection)) + }), + Operator::Gt => satisfy_greater(left_child, right_child, true), + Operator::GtEq => satisfy_greater(left_child, right_child, false), + Operator::Lt => satisfy_greater(right_child, left_child, true) + .map(|t| t.map(reverse_tuple)), + Operator::LtEq => satisfy_greater(right_child, left_child, false) + .map(|t| t.map(reverse_tuple)), + _ => internal_err!( + "The operator must be a comparison operator to propagate intervals" + ), } - (DataType::Duration(_), DataType::Interval(_)) => { - propagate_comparison_to_time_interval_at_left( - left_child, - &parent, - right_child, - ) + } else if parent == &Interval::CERTAINLY_FALSE { + match op { + Operator::Eq => { + // TODO: Propagation is not possible until we support interval sets. + Ok(None) + } + Operator::Gt => satisfy_greater(right_child, left_child, false), + Operator::GtEq => satisfy_greater(right_child, left_child, true), + Operator::Lt => satisfy_greater(left_child, right_child, false) + .map(|t| t.map(reverse_tuple)), + Operator::LtEq => satisfy_greater(left_child, right_child, true) + .map(|t| t.map(reverse_tuple)), + _ => internal_err!( + "The operator must be a comparison operator to propagate intervals" + ), } - _ => propagate_arithmetic(&Operator::Minus, &parent, left_child, right_child), + } else { + // Uncertainty cannot change any end-point of the intervals. + Ok(None) } } impl ExprIntervalGraph { - pub fn try_new(expr: Arc) -> Result { + pub fn try_new(expr: Arc, schema: &Schema) -> Result { // Build the full graph: - let (root, graph) = build_dag(expr, &ExprIntervalGraphNode::make_node)?; + let (root, graph) = + build_dag(expr, &|node| ExprIntervalGraphNode::make_node(node, schema))?; Ok(Self { graph, root }) } @@ -383,7 +406,7 @@ impl ExprIntervalGraph { // // ``` - /// This function associates stable node indices with [PhysicalExpr]s so + /// This function associates stable node indices with [`PhysicalExpr`]s so /// that we can match `Arc` and NodeIndex objects during /// membership tests. pub fn gather_node_indices( @@ -437,6 +460,33 @@ impl ExprIntervalGraph { nodes } + /// Updates intervals for all expressions in the DAEG by successive + /// bottom-up and top-down traversals. + pub fn update_ranges( + &mut self, + leaf_bounds: &mut [(usize, Interval)], + given_range: Interval, + ) -> Result { + self.assign_intervals(leaf_bounds); + let bounds = self.evaluate_bounds()?; + // There are three possible cases to consider: + // (1) given_range ⊇ bounds => Nothing to propagate + // (2) ∅ ⊂ (given_range ∩ bounds) ⊂ bounds => Can propagate + // (3) Disjoint sets => Infeasible + if given_range.contains(bounds)? == Interval::CERTAINLY_TRUE { + // First case: + Ok(PropagationResult::CannotPropagate) + } else if bounds.contains(&given_range)? != Interval::CERTAINLY_FALSE { + // Second case: + let result = self.propagate_constraints(given_range); + self.update_intervals(leaf_bounds); + result + } else { + // Third case: + Ok(PropagationResult::Infeasible) + } + } + /// This function assigns given ranges to expressions in the DAEG. /// The argument `assignments` associates indices of sought expressions /// with their corresponding new ranges. @@ -466,34 +516,43 @@ impl ExprIntervalGraph { /// # Examples /// /// ``` - /// use std::sync::Arc; - /// use datafusion_common::ScalarValue; - /// use datafusion_expr::Operator; - /// use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; - /// use datafusion_physical_expr::intervals::{Interval, IntervalBound, ExprIntervalGraph}; - /// use datafusion_physical_expr::PhysicalExpr; - /// let expr = Arc::new(BinaryExpr::new( - /// Arc::new(Column::new("gnz", 0)), - /// Operator::Plus, - /// Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), - /// )); - /// let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); - /// // Do it once, while constructing. - /// let node_indices = graph + /// use arrow::datatypes::DataType; + /// use arrow::datatypes::Field; + /// use arrow::datatypes::Schema; + /// use datafusion_common::ScalarValue; + /// use datafusion_expr::interval_arithmetic::Interval; + /// use datafusion_expr::Operator; + /// use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; + /// use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; + /// use datafusion_physical_expr::PhysicalExpr; + /// use std::sync::Arc; + /// + /// let expr = Arc::new(BinaryExpr::new( + /// Arc::new(Column::new("gnz", 0)), + /// Operator::Plus, + /// Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + /// )); + /// + /// let schema = Schema::new(vec![Field::new("gnz".to_string(), DataType::Int32, true)]); + /// + /// let mut graph = ExprIntervalGraph::try_new(expr, &schema).unwrap(); + /// // Do it once, while constructing. + /// let node_indices = graph /// .gather_node_indices(&[Arc::new(Column::new("gnz", 0))]); - /// let left_index = node_indices.get(0).unwrap().1; - /// // Provide intervals for leaf variables (here, there is only one). - /// let intervals = vec![( + /// let left_index = node_indices.get(0).unwrap().1; + /// + /// // Provide intervals for leaf variables (here, there is only one). + /// let intervals = vec![( /// left_index, - /// Interval::make(Some(10), Some(20), (true, true)), - /// )]; - /// // Evaluate bounds for the composite expression: - /// graph.assign_intervals(&intervals); - /// assert_eq!( - /// graph.evaluate_bounds().unwrap(), - /// &Interval::make(Some(20), Some(30), (true, true)), - /// ) + /// Interval::make(Some(10), Some(20)).unwrap(), + /// )]; /// + /// // Evaluate bounds for the composite expression: + /// graph.assign_intervals(&intervals); + /// assert_eq!( + /// graph.evaluate_bounds().unwrap(), + /// &Interval::make(Some(20), Some(30)).unwrap(), + /// ) /// ``` pub fn evaluate_bounds(&mut self) -> Result<&Interval> { let mut dfs = DfsPostOrder::new(&self.graph, self.root); @@ -505,7 +564,7 @@ impl ExprIntervalGraph { // If the current expression is a leaf, its interval should already // be set externally, just continue with the evaluation procedure: if !children_intervals.is_empty() { - // Reverse to align with [PhysicalExpr]'s children: + // Reverse to align with `PhysicalExpr`'s children: children_intervals.reverse(); self.graph[node].interval = self.graph[node].expr.evaluate_bounds(&children_intervals)?; @@ -516,8 +575,19 @@ impl ExprIntervalGraph { /// Updates/shrinks bounds for leaf expressions using interval arithmetic /// via a top-down traversal. - fn propagate_constraints(&mut self) -> Result { + fn propagate_constraints( + &mut self, + given_range: Interval, + ) -> Result { let mut bfs = Bfs::new(&self.graph, self.root); + + // Adjust the root node with the given range: + if let Some(interval) = self.graph[self.root].interval.intersect(given_range)? { + self.graph[self.root].interval = interval; + } else { + return Ok(PropagationResult::Infeasible); + } + while let Some(node) = bfs.next(&self.graph) { let neighbors = self.graph.neighbors_directed(node, Outgoing); let mut children = neighbors.collect::>(); @@ -526,7 +596,7 @@ impl ExprIntervalGraph { if children.is_empty() { continue; } - // Reverse to align with [PhysicalExpr]'s children: + // Reverse to align with `PhysicalExpr`'s children: children.reverse(); let children_intervals = children .iter() @@ -536,164 +606,132 @@ impl ExprIntervalGraph { let propagated_intervals = self.graph[node] .expr .propagate_constraints(node_interval, &children_intervals)?; - for (child, interval) in children.into_iter().zip(propagated_intervals) { - if let Some(interval) = interval { + if let Some(propagated_intervals) = propagated_intervals { + for (child, interval) in children.into_iter().zip(propagated_intervals) { self.graph[child].interval = interval; - } else { - // The constraint is infeasible, report: - return Ok(PropagationResult::Infeasible); } + } else { + // The constraint is infeasible, report: + return Ok(PropagationResult::Infeasible); } } Ok(PropagationResult::Success) } - /// Updates intervals for all expressions in the DAEG by successive - /// bottom-up and top-down traversals. - pub fn update_ranges( - &mut self, - leaf_bounds: &mut [(usize, Interval)], - ) -> Result { - self.assign_intervals(leaf_bounds); - let bounds = self.evaluate_bounds()?; - if bounds == &Interval::CERTAINLY_FALSE { - Ok(PropagationResult::Infeasible) - } else if bounds == &Interval::UNCERTAIN { - let result = self.propagate_constraints(); - self.update_intervals(leaf_bounds); - result - } else { - Ok(PropagationResult::CannotPropagate) - } - } - /// Returns the interval associated with the node at the given `index`. pub fn get_interval(&self, index: usize) -> Interval { self.graph[NodeIndex::new(index)].interval.clone() } } -/// During the propagation of [`Interval`] values on an [`ExprIntervalGraph`], if there exists a `timestamp - timestamp` -/// operation, the result would be of type `Duration`. However, we may encounter a situation where a time interval -/// is involved in an arithmetic operation with a `Duration` type. This function offers special handling for such cases, -/// where the time interval resides on the left side of the operation. +/// This is a subfunction of the `propagate_arithmetic` function that propagates to the right child. +fn propagate_right( + left: &Interval, + parent: &Interval, + right: &Interval, + op: &Operator, + inverse_op: &Operator, +) -> Result> { + match op { + Operator::Minus => apply_operator(op, left, parent), + Operator::Plus => apply_operator(inverse_op, parent, left), + Operator::Divide => apply_operator(op, left, parent), + Operator::Multiply => apply_operator(inverse_op, parent, left), + _ => internal_err!("Interval arithmetic does not support the operator {}", op), + }? + .intersect(right) +} + +/// During the propagation of [`Interval`] values on an [`ExprIntervalGraph`], +/// if there exists a `timestamp - timestamp` operation, the result would be +/// of type `Duration`. However, we may encounter a situation where a time interval +/// is involved in an arithmetic operation with a `Duration` type. This function +/// offers special handling for such cases, where the time interval resides on +/// the left side of the operation. fn propagate_time_interval_at_left( left_child: &Interval, right_child: &Interval, parent: &Interval, op: &Operator, inverse_op: &Operator, -) -> Result<(Option, Option)> { +) -> Result> { // We check if the child's time interval(s) has a non-zero month or day field(s). // If so, we return it as is without propagating. Otherwise, we first convert - // the time intervals to the Duration type, then propagate, and then convert the bounds to time intervals again. - if let Some(duration) = convert_interval_type_to_duration(left_child) { + // the time intervals to the `Duration` type, then propagate, and then convert + // the bounds to time intervals again. + let result = if let Some(duration) = convert_interval_type_to_duration(left_child) { match apply_operator(inverse_op, parent, right_child)?.intersect(duration)? { Some(value) => { + let left = convert_duration_type_to_interval(&value); let right = propagate_right(&value, parent, right_child, op, inverse_op)?; - let new_interval = convert_duration_type_to_interval(&value); - Ok((new_interval, right)) + match (left, right) { + (Some(left), Some(right)) => Some((left, right)), + _ => None, + } } - None => Ok((None, None)), + None => None, } } else { - let right = propagate_right(left_child, parent, right_child, op, inverse_op)?; - Ok((Some(left_child.clone()), right)) - } + propagate_right(left_child, parent, right_child, op, inverse_op)? + .map(|right| (left_child.clone(), right)) + }; + Ok(result) } -/// During the propagation of [`Interval`] values on an [`ExprIntervalGraph`], if there exists a `timestamp - timestamp` -/// operation, the result would be of type `Duration`. However, we may encounter a situation where a time interval -/// is involved in an arithmetic operation with a `Duration` type. This function offers special handling for such cases, -/// where the time interval resides on the right side of the operation. +/// During the propagation of [`Interval`] values on an [`ExprIntervalGraph`], +/// if there exists a `timestamp - timestamp` operation, the result would be +/// of type `Duration`. However, we may encounter a situation where a time interval +/// is involved in an arithmetic operation with a `Duration` type. This function +/// offers special handling for such cases, where the time interval resides on +/// the right side of the operation. fn propagate_time_interval_at_right( left_child: &Interval, right_child: &Interval, parent: &Interval, op: &Operator, inverse_op: &Operator, -) -> Result<(Option, Option)> { +) -> Result> { // We check if the child's time interval(s) has a non-zero month or day field(s). // If so, we return it as is without propagating. Otherwise, we first convert - // the time intervals to the Duration type, then propagate, and then convert the bounds to time intervals again. - if let Some(duration) = convert_interval_type_to_duration(right_child) { + // the time intervals to the `Duration` type, then propagate, and then convert + // the bounds to time intervals again. + let result = if let Some(duration) = convert_interval_type_to_duration(right_child) { match apply_operator(inverse_op, parent, &duration)?.intersect(left_child)? { Some(value) => { - let right = - propagate_right(left_child, parent, &duration, op, inverse_op)?; - let right = - right.and_then(|right| convert_duration_type_to_interval(&right)); - Ok((Some(value), right)) + propagate_right(left_child, parent, &duration, op, inverse_op)? + .and_then(|right| convert_duration_type_to_interval(&right)) + .map(|right| (value, right)) } - None => Ok((None, None)), + None => None, } } else { - match apply_operator(inverse_op, parent, right_child)?.intersect(left_child)? { - Some(value) => Ok((Some(value), Some(right_child.clone()))), - None => Ok((None, None)), - } - } + apply_operator(inverse_op, parent, right_child)? + .intersect(left_child)? + .map(|value| (value, right_child.clone())) + }; + Ok(result) } -/// This is a subfunction of the `propagate_arithmetic` function that propagates to the right child. -fn propagate_right( - left: &Interval, - parent: &Interval, - right: &Interval, - op: &Operator, - inverse_op: &Operator, -) -> Result> { - match op { - Operator::Minus => apply_operator(op, left, parent), - Operator::Plus => apply_operator(inverse_op, parent, left), - _ => unreachable!(), - }? - .intersect(right) -} - -/// Converts the `time interval` (as the left child) to duration, then performs the propagation rule for comparison operators. -pub fn propagate_comparison_to_time_interval_at_left( - left_child: &Interval, - parent: &Interval, - right_child: &Interval, -) -> Result<(Option, Option)> { - if let Some(converted) = convert_interval_type_to_duration(left_child) { - propagate_arithmetic(&Operator::Minus, parent, &converted, right_child) - } else { - Err(DataFusionError::Internal( - "Interval type has a non-zero month field, cannot compare with a Duration type".to_string(), - )) - } -} - -/// Converts the `time interval` (as the right child) to duration, then performs the propagation rule for comparison operators. -pub fn propagate_comparison_to_time_interval_at_right( - left_child: &Interval, - parent: &Interval, - right_child: &Interval, -) -> Result<(Option, Option)> { - if let Some(converted) = convert_interval_type_to_duration(right_child) { - propagate_arithmetic(&Operator::Minus, parent, left_child, &converted) - } else { - Err(DataFusionError::Internal( - "Interval type has a non-zero month field, cannot compare with a Duration type".to_string(), - )) - } +fn reverse_tuple((first, second): (T, U)) -> (U, T) { + (second, first) } #[cfg(test)] mod tests { use super::*; - use itertools::Itertools; - use crate::expressions::{BinaryExpr, Column}; use crate::intervals::test_utils::gen_conjunctive_numerical_expr; + use arrow::datatypes::TimeUnit; + use arrow_schema::{DataType, Field}; use datafusion_common::ScalarValue; + + use itertools::Itertools; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use rstest::*; + #[allow(clippy::too_many_arguments)] fn experiment( expr: Arc, exprs_with_interval: (Arc, Arc), @@ -702,6 +740,7 @@ mod tests { left_expected: Interval, right_expected: Interval, result: PropagationResult, + schema: &Schema, ) -> Result<()> { let col_stats = vec![ (exprs_with_interval.0.clone(), left_interval), @@ -711,7 +750,7 @@ mod tests { (exprs_with_interval.0.clone(), left_expected), (exprs_with_interval.1.clone(), right_expected), ]; - let mut graph = ExprIntervalGraph::try_new(expr)?; + let mut graph = ExprIntervalGraph::try_new(expr, schema)?; let expr_indexes = graph .gather_node_indices(&col_stats.iter().map(|(e, _)| e.clone()).collect_vec()); @@ -726,14 +765,37 @@ mod tests { .map(|((_, interval), (_, index))| (*index, interval.clone())) .collect_vec(); - let exp_result = graph.update_ranges(&mut col_stat_nodes[..])?; + let exp_result = + graph.update_ranges(&mut col_stat_nodes[..], Interval::CERTAINLY_TRUE)?; assert_eq!(exp_result, result); col_stat_nodes.iter().zip(expected_nodes.iter()).for_each( |((_, calculated_interval_node), (_, expected))| { // NOTE: These randomized tests only check for conservative containment, // not openness/closedness of endpoints. - assert!(calculated_interval_node.lower.value <= expected.lower.value); - assert!(calculated_interval_node.upper.value >= expected.upper.value); + + // Calculated bounds are relaxed by 1 to cover all strict and + // and non-strict comparison cases since we have only closed bounds. + let one = ScalarValue::new_one(&expected.data_type()).unwrap(); + assert!( + calculated_interval_node.lower() + <= &expected.lower().add(&one).unwrap(), + "{}", + format!( + "Calculated {} must be less than or equal {}", + calculated_interval_node.lower(), + expected.lower() + ) + ); + assert!( + calculated_interval_node.upper() + >= &expected.upper().sub(&one).unwrap(), + "{}", + format!( + "Calculated {} must be greater than or equal {}", + calculated_interval_node.upper(), + expected.upper() + ) + ); }, ); Ok(()) @@ -773,12 +835,24 @@ mod tests { experiment( expr, - (left_col, right_col), - Interval::make(left_given.0, left_given.1, (true, true)), - Interval::make(right_given.0, right_given.1, (true, true)), - Interval::make(left_expected.0, left_expected.1, (true, true)), - Interval::make(right_expected.0, right_expected.1, (true, true)), + (left_col.clone(), right_col.clone()), + Interval::make(left_given.0, left_given.1).unwrap(), + Interval::make(right_given.0, right_given.1).unwrap(), + Interval::make(left_expected.0, left_expected.1).unwrap(), + Interval::make(right_expected.0, right_expected.1).unwrap(), PropagationResult::Success, + &Schema::new(vec![ + Field::new( + left_col.as_any().downcast_ref::().unwrap().name(), + DataType::$SCALAR, + true, + ), + Field::new( + right_col.as_any().downcast_ref::().unwrap().name(), + DataType::$SCALAR, + true, + ), + ]), ) } }; @@ -802,12 +876,24 @@ mod tests { let expr = Arc::new(BinaryExpr::new(left_and_1, Operator::Gt, right_col.clone())); experiment( expr, - (left_col, right_col), - Interval::make(Some(10), Some(20), (true, true)), - Interval::make(Some(100), None, (true, true)), - Interval::make(Some(10), Some(20), (true, true)), - Interval::make(Some(100), None, (true, true)), + (left_col.clone(), right_col.clone()), + Interval::make(Some(10_i32), Some(20_i32))?, + Interval::make(Some(100), None)?, + Interval::make(Some(10), Some(20))?, + Interval::make(Some(100), None)?, PropagationResult::Infeasible, + &Schema::new(vec![ + Field::new( + left_col.as_any().downcast_ref::().unwrap().name(), + DataType::Int32, + true, + ), + Field::new( + right_col.as_any().downcast_ref::().unwrap().name(), + DataType::Int32, + true, + ), + ]), ) } @@ -1112,7 +1198,14 @@ mod tests { Arc::new(Column::new("b", 1)), )); let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr)); - let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); + let mut graph = ExprIntervalGraph::try_new( + expr, + &Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ]), + ) + .unwrap(); // Define a test leaf node. let leaf_node = Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), @@ -1151,7 +1244,16 @@ mod tests { Arc::new(Column::new("z", 1)), )); let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr)); - let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); + let mut graph = ExprIntervalGraph::try_new( + expr, + &Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("y", DataType::Int32, true), + Field::new("z", DataType::Int32, true), + ]), + ) + .unwrap(); // Define a test leaf node. let leaf_node = Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), @@ -1190,7 +1292,15 @@ mod tests { Arc::new(Column::new("z", 1)), )); let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr)); - let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); + let mut graph = ExprIntervalGraph::try_new( + expr, + &Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("z", DataType::Int32, true), + ]), + ) + .unwrap(); // Define a test leaf node. let leaf_node = Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), @@ -1213,9 +1323,9 @@ mod tests { fn test_gather_node_indices_cannot_provide() -> Result<()> { // Expression: a@0 + 1 + b@1 > y@0 - z@1 -> provide a@0 + b@1 // TODO: We expect nodes a@0 and b@1 to be pruned, and intervals to be provided from the a@0 + b@1 node. - // However, we do not have an exact node for a@0 + b@1 due to the binary tree structure of the expressions. - // Pruning and interval providing for BinaryExpr expressions are more challenging without exact matches. - // Currently, we only support exact matches for BinaryExprs, but we plan to extend support beyond exact matches in the future. + // However, we do not have an exact node for a@0 + b@1 due to the binary tree structure of the expressions. + // Pruning and interval providing for BinaryExpr expressions are more challenging without exact matches. + // Currently, we only support exact matches for BinaryExprs, but we plan to extend support beyond exact matches in the future. let left_expr = Arc::new(BinaryExpr::new( Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), @@ -1232,7 +1342,16 @@ mod tests { Arc::new(Column::new("z", 1)), )); let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr)); - let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); + let mut graph = ExprIntervalGraph::try_new( + expr, + &Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("y", DataType::Int32, true), + Field::new("z", DataType::Int32, true), + ]), + ) + .unwrap(); // Define a test leaf node. let leaf_node = Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), @@ -1257,80 +1376,51 @@ mod tests { Operator::Plus, Arc::new(Literal::new(ScalarValue::new_interval_mdn(0, 1, 321))), ); - let parent = Interval::new( - IntervalBound::new( - // 15.10.2020 - 10:11:12.000_000_321 AM - ScalarValue::TimestampNanosecond(Some(1_602_756_672_000_000_321), None), - false, - ), - IntervalBound::new( - // 16.10.2020 - 10:11:12.000_000_321 AM - ScalarValue::TimestampNanosecond(Some(1_602_843_072_000_000_321), None), - false, - ), - ); - let left_child = Interval::new( - IntervalBound::new( - // 10.10.2020 - 10:11:12 AM - ScalarValue::TimestampNanosecond(Some(1_602_324_672_000_000_000), None), - false, - ), - IntervalBound::new( - // 20.10.2020 - 10:11:12 AM - ScalarValue::TimestampNanosecond(Some(1_603_188_672_000_000_000), None), - false, - ), - ); - let right_child = Interval::new( - IntervalBound::new( - // 1 day 321 ns - ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)), - false, - ), - IntervalBound::new( - // 1 day 321 ns - ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)), - false, - ), - ); + let parent = Interval::try_new( + // 15.10.2020 - 10:11:12.000_000_321 AM + ScalarValue::TimestampNanosecond(Some(1_602_756_672_000_000_321), None), + // 16.10.2020 - 10:11:12.000_000_321 AM + ScalarValue::TimestampNanosecond(Some(1_602_843_072_000_000_321), None), + )?; + let left_child = Interval::try_new( + // 10.10.2020 - 10:11:12 AM + ScalarValue::TimestampNanosecond(Some(1_602_324_672_000_000_000), None), + // 20.10.2020 - 10:11:12 AM + ScalarValue::TimestampNanosecond(Some(1_603_188_672_000_000_000), None), + )?; + let right_child = Interval::try_new( + // 1 day 321 ns + ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)), + // 1 day 321 ns + ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)), + )?; let children = vec![&left_child, &right_child]; - let result = expression.propagate_constraints(&parent, &children)?; + let result = expression + .propagate_constraints(&parent, &children)? + .unwrap(); assert_eq!( - Some(Interval::new( - // 14.10.2020 - 10:11:12 AM - IntervalBound::new( + vec![ + Interval::try_new( + // 14.10.2020 - 10:11:12 AM ScalarValue::TimestampNanosecond( Some(1_602_670_272_000_000_000), None ), - false, - ), - // 15.10.2020 - 10:11:12 AM - IntervalBound::new( + // 15.10.2020 - 10:11:12 AM ScalarValue::TimestampNanosecond( Some(1_602_756_672_000_000_000), None ), - false, - ), - )), - result[0] - ); - assert_eq!( - Some(Interval::new( - // 1 day 321 ns in Duration type - IntervalBound::new( + )?, + Interval::try_new( + // 1 day 321 ns in Duration type ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)), - false, - ), - // 1 day 321 ns in Duration type - IntervalBound::new( + // 1 day 321 ns in Duration type ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)), - false, - ), - )), - result[1] + )? + ], + result ); Ok(()) @@ -1343,206 +1433,216 @@ mod tests { Operator::Plus, Arc::new(Column::new("ts_column", 0)), ); - let parent = Interval::new( - IntervalBound::new( - // 15.10.2020 - 10:11:12 AM - ScalarValue::TimestampMillisecond(Some(1_602_756_672_000), None), - false, - ), - IntervalBound::new( - // 16.10.2020 - 10:11:12 AM - ScalarValue::TimestampMillisecond(Some(1_602_843_072_000), None), - false, - ), - ); - let right_child = Interval::new( - IntervalBound::new( - // 10.10.2020 - 10:11:12 AM - ScalarValue::TimestampMillisecond(Some(1_602_324_672_000), None), - false, - ), - IntervalBound::new( - // 20.10.2020 - 10:11:12 AM - ScalarValue::TimestampMillisecond(Some(1_603_188_672_000), None), - false, - ), - ); - let left_child = Interval::new( - IntervalBound::new( - // 2 days - ScalarValue::IntervalDayTime(Some(172_800_000)), - false, - ), - IntervalBound::new( - // 10 days - ScalarValue::IntervalDayTime(Some(864_000_000)), - false, - ), - ); + let parent = Interval::try_new( + // 15.10.2020 - 10:11:12 AM + ScalarValue::TimestampMillisecond(Some(1_602_756_672_000), None), + // 16.10.2020 - 10:11:12 AM + ScalarValue::TimestampMillisecond(Some(1_602_843_072_000), None), + )?; + let right_child = Interval::try_new( + // 10.10.2020 - 10:11:12 AM + ScalarValue::TimestampMillisecond(Some(1_602_324_672_000), None), + // 20.10.2020 - 10:11:12 AM + ScalarValue::TimestampMillisecond(Some(1_603_188_672_000), None), + )?; + let left_child = Interval::try_new( + // 2 days + ScalarValue::IntervalDayTime(Some(172_800_000)), + // 10 days + ScalarValue::IntervalDayTime(Some(864_000_000)), + )?; let children = vec![&left_child, &right_child]; - let result = expression.propagate_constraints(&parent, &children)?; + let result = expression + .propagate_constraints(&parent, &children)? + .unwrap(); assert_eq!( - Some(Interval::new( - // 10.10.2020 - 10:11:12 AM - IntervalBound::new( - ScalarValue::TimestampMillisecond(Some(1_602_324_672_000), None), - false, - ), - // 14.10.2020 - 10:11:12 AM - IntervalBound::new( - ScalarValue::TimestampMillisecond(Some(1_602_670_272_000), None), - false, - ) - )), - result[1] - ); - assert_eq!( - Some(Interval::new( - IntervalBound::new( + vec![ + Interval::try_new( // 2 days ScalarValue::IntervalDayTime(Some(172_800_000)), - false, - ), - IntervalBound::new( // 6 days ScalarValue::IntervalDayTime(Some(518_400_000)), - false, - ), - )), - result[0] + )?, + Interval::try_new( + // 10.10.2020 - 10:11:12 AM + ScalarValue::TimestampMillisecond(Some(1_602_324_672_000), None), + // 14.10.2020 - 10:11:12 AM + ScalarValue::TimestampMillisecond(Some(1_602_670_272_000), None), + )? + ], + result ); Ok(()) } #[test] - fn test_propagate_comparison() { + fn test_propagate_comparison() -> Result<()> { // In the examples below: // `left` is unbounded: [?, ?], // `right` is known to be [1000,1000] - // so `left` < `right` results in no new knowledge of `right` but knowing that `left` is now < 1000:` [?, 1000) - let left = Interval::new( - IntervalBound::make_unbounded(DataType::Int64).unwrap(), - IntervalBound::make_unbounded(DataType::Int64).unwrap(), - ); - let right = Interval::new( - IntervalBound::new(ScalarValue::Int64(Some(1000)), false), - IntervalBound::new(ScalarValue::Int64(Some(1000)), false), - ); + // so `left` < `right` results in no new knowledge of `right` but knowing that `left` is now < 1000:` [?, 999] + let left = Interval::make_unbounded(&DataType::Int64)?; + let right = Interval::make(Some(1000_i64), Some(1000_i64))?; assert_eq!( - ( - Some(Interval::new( - IntervalBound::make_unbounded(DataType::Int64).unwrap(), - IntervalBound::new(ScalarValue::Int64(Some(1000)), true) - )), - Some(Interval::new( - IntervalBound::new(ScalarValue::Int64(Some(1000)), false), - IntervalBound::new(ScalarValue::Int64(Some(1000)), false) - )), - ), - propagate_comparison(&Operator::Lt, &left, &right).unwrap() + (Some(( + Interval::make(None, Some(999_i64))?, + Interval::make(Some(1000_i64), Some(1000_i64))?, + ))), + propagate_comparison( + &Operator::Lt, + &Interval::CERTAINLY_TRUE, + &left, + &right + )? ); - let left = Interval::new( - IntervalBound::make_unbounded(DataType::Timestamp( - TimeUnit::Nanosecond, - None, - )) - .unwrap(), - IntervalBound::make_unbounded(DataType::Timestamp( - TimeUnit::Nanosecond, - None, - )) - .unwrap(), - ); - let right = Interval::new( - IntervalBound::new(ScalarValue::TimestampNanosecond(Some(1000), None), false), - IntervalBound::new(ScalarValue::TimestampNanosecond(Some(1000), None), false), - ); + let left = + Interval::make_unbounded(&DataType::Timestamp(TimeUnit::Nanosecond, None))?; + let right = Interval::try_new( + ScalarValue::TimestampNanosecond(Some(1000), None), + ScalarValue::TimestampNanosecond(Some(1000), None), + )?; assert_eq!( - ( - Some(Interval::new( - IntervalBound::make_unbounded(DataType::Timestamp( + (Some(( + Interval::try_new( + ScalarValue::try_from(&DataType::Timestamp( TimeUnit::Nanosecond, None )) .unwrap(), - IntervalBound::new( - ScalarValue::TimestampNanosecond(Some(1000), None), - true - ) - )), - Some(Interval::new( - IntervalBound::new( - ScalarValue::TimestampNanosecond(Some(1000), None), - false - ), - IntervalBound::new( - ScalarValue::TimestampNanosecond(Some(1000), None), - false - ) - )), - ), - propagate_comparison(&Operator::Lt, &left, &right).unwrap() + ScalarValue::TimestampNanosecond(Some(999), None), + )?, + Interval::try_new( + ScalarValue::TimestampNanosecond(Some(1000), None), + ScalarValue::TimestampNanosecond(Some(1000), None), + )? + ))), + propagate_comparison( + &Operator::Lt, + &Interval::CERTAINLY_TRUE, + &left, + &right + )? ); - let left = Interval::new( - IntervalBound::make_unbounded(DataType::Timestamp( - TimeUnit::Nanosecond, - Some("+05:00".into()), - )) - .unwrap(), - IntervalBound::make_unbounded(DataType::Timestamp( - TimeUnit::Nanosecond, - Some("+05:00".into()), - )) - .unwrap(), - ); - let right = Interval::new( - IntervalBound::new( - ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())), - false, - ), - IntervalBound::new( - ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())), - false, - ), - ); + let left = Interval::make_unbounded(&DataType::Timestamp( + TimeUnit::Nanosecond, + Some("+05:00".into()), + ))?; + let right = Interval::try_new( + ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())), + ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())), + )?; assert_eq!( - ( - Some(Interval::new( - IntervalBound::make_unbounded(DataType::Timestamp( + (Some(( + Interval::try_new( + ScalarValue::try_from(&DataType::Timestamp( TimeUnit::Nanosecond, Some("+05:00".into()), )) .unwrap(), - IntervalBound::new( - ScalarValue::TimestampNanosecond( - Some(1000), - Some("+05:00".into()) - ), - true - ) - )), - Some(Interval::new( - IntervalBound::new( - ScalarValue::TimestampNanosecond( - Some(1000), - Some("+05:00".into()) - ), - false - ), - IntervalBound::new( - ScalarValue::TimestampNanosecond( - Some(1000), - Some("+05:00".into()) - ), - false - ) - )), - ), - propagate_comparison(&Operator::Lt, &left, &right).unwrap() + ScalarValue::TimestampNanosecond(Some(999), Some("+05:00".into())), + )?, + Interval::try_new( + ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())), + ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())), + )? + ))), + propagate_comparison( + &Operator::Lt, + &Interval::CERTAINLY_TRUE, + &left, + &right + )? ); + + Ok(()) + } + + #[test] + fn test_propagate_or() -> Result<()> { + let expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Or, + Arc::new(Column::new("b", 1)), + )); + let parent = Interval::CERTAINLY_FALSE; + let children_set = vec![ + vec![&Interval::CERTAINLY_FALSE, &Interval::UNCERTAIN], + vec![&Interval::UNCERTAIN, &Interval::CERTAINLY_FALSE], + vec![&Interval::CERTAINLY_FALSE, &Interval::CERTAINLY_FALSE], + vec![&Interval::UNCERTAIN, &Interval::UNCERTAIN], + ]; + for children in children_set { + assert_eq!( + expr.propagate_constraints(&parent, &children)?.unwrap(), + vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_FALSE], + ); + } + + let parent = Interval::CERTAINLY_FALSE; + let children_set = vec![ + vec![&Interval::CERTAINLY_TRUE, &Interval::UNCERTAIN], + vec![&Interval::UNCERTAIN, &Interval::CERTAINLY_TRUE], + ]; + for children in children_set { + assert_eq!(expr.propagate_constraints(&parent, &children)?, None,); + } + + let parent = Interval::CERTAINLY_TRUE; + let children = vec![&Interval::CERTAINLY_FALSE, &Interval::UNCERTAIN]; + assert_eq!( + expr.propagate_constraints(&parent, &children)?.unwrap(), + vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_TRUE] + ); + + let parent = Interval::CERTAINLY_TRUE; + let children = vec![&Interval::UNCERTAIN, &Interval::UNCERTAIN]; + assert_eq!( + expr.propagate_constraints(&parent, &children)?.unwrap(), + // Empty means unchanged intervals. + vec![] + ); + + Ok(()) + } + + #[test] + fn test_propagate_certainly_false_and() -> Result<()> { + let expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::And, + Arc::new(Column::new("b", 1)), + )); + let parent = Interval::CERTAINLY_FALSE; + let children_and_results_set = vec![ + ( + vec![&Interval::CERTAINLY_TRUE, &Interval::UNCERTAIN], + vec![Interval::CERTAINLY_TRUE, Interval::CERTAINLY_FALSE], + ), + ( + vec![&Interval::UNCERTAIN, &Interval::CERTAINLY_TRUE], + vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_TRUE], + ), + ( + vec![&Interval::UNCERTAIN, &Interval::UNCERTAIN], + // Empty means unchanged intervals. + vec![], + ), + ( + vec![&Interval::CERTAINLY_FALSE, &Interval::UNCERTAIN], + vec![], + ), + ]; + for (children, result) in children_and_results_set { + assert_eq!( + expr.propagate_constraints(&parent, &children)?.unwrap(), + result + ); + } + + Ok(()) } } diff --git a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs deleted file mode 100644 index 1ea9b2d9aee60..0000000000000 --- a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs +++ /dev/null @@ -1,1886 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -//http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Interval arithmetic library - -use std::borrow::Borrow; -use std::fmt::{self, Display, Formatter}; -use std::ops::{AddAssign, SubAssign}; - -use crate::aggregate::min_max::{max, min}; -use crate::intervals::rounding::{alter_fp_rounding_mode, next_down, next_up}; - -use arrow::compute::{cast_with_options, CastOptions}; -use arrow::datatypes::DataType; -use arrow_array::ArrowNativeTypeOp; -use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; -use datafusion_expr::type_coercion::binary::get_result_type; -use datafusion_expr::Operator; - -/// This type represents a single endpoint of an [`Interval`]. An -/// endpoint can be open (does not include the endpoint) or closed -/// (includes the endpoint). -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct IntervalBound { - pub value: ScalarValue, - /// If true, interval does not include `value` - pub open: bool, -} - -impl IntervalBound { - /// Creates a new `IntervalBound` object using the given value. - pub const fn new(value: ScalarValue, open: bool) -> IntervalBound { - IntervalBound { value, open } - } - - /// Creates a new "open" interval (does not include the `value` - /// bound) - pub const fn new_open(value: ScalarValue) -> IntervalBound { - IntervalBound::new(value, true) - } - - /// Creates a new "closed" interval (includes the `value` - /// bound) - pub const fn new_closed(value: ScalarValue) -> IntervalBound { - IntervalBound::new(value, false) - } - - /// This convenience function creates an unbounded interval endpoint. - pub fn make_unbounded>(data_type: T) -> Result { - ScalarValue::try_from(data_type.borrow()).map(|v| IntervalBound::new(v, true)) - } - - /// This convenience function returns the data type associated with this - /// `IntervalBound`. - pub fn get_datatype(&self) -> DataType { - self.value.data_type() - } - - /// This convenience function checks whether the `IntervalBound` represents - /// an unbounded interval endpoint. - pub fn is_unbounded(&self) -> bool { - self.value.is_null() - } - - /// This function casts the `IntervalBound` to the given data type. - pub(crate) fn cast_to( - &self, - data_type: &DataType, - cast_options: &CastOptions, - ) -> Result { - cast_scalar_value(&self.value, data_type, cast_options) - .map(|value| IntervalBound::new(value, self.open)) - } - - /// Returns a new bound with a negated value, if any, and the same open/closed. - /// For example negating `[5` would return `[-5`, or `-1)` would return `1)`. - pub fn negate(&self) -> Result { - self.value.arithmetic_negate().map(|value| IntervalBound { - value, - open: self.open, - }) - } - - /// This function adds the given `IntervalBound` to this `IntervalBound`. - /// The result is unbounded if either is; otherwise, their values are - /// added. The result is closed if both original bounds are closed, or open - /// otherwise. - pub fn add>( - &self, - other: T, - ) -> Result { - let rhs = other.borrow(); - if self.is_unbounded() || rhs.is_unbounded() { - return IntervalBound::make_unbounded(get_result_type( - &self.get_datatype(), - &Operator::Plus, - &rhs.get_datatype(), - )?); - } - match self.get_datatype() { - DataType::Float64 | DataType::Float32 => { - alter_fp_rounding_mode::(&self.value, &rhs.value, |lhs, rhs| { - lhs.add(rhs) - }) - } - _ => self.value.add(&rhs.value), - } - .map(|v| IntervalBound::new(v, self.open || rhs.open)) - } - - /// This function subtracts the given `IntervalBound` from `self`. - /// The result is unbounded if either is; otherwise, their values are - /// subtracted. The result is closed if both original bounds are closed, - /// or open otherwise. - pub fn sub>( - &self, - other: T, - ) -> Result { - let rhs = other.borrow(); - if self.is_unbounded() || rhs.is_unbounded() { - return IntervalBound::make_unbounded(get_result_type( - &self.get_datatype(), - &Operator::Minus, - &rhs.get_datatype(), - )?); - } - match self.get_datatype() { - DataType::Float64 | DataType::Float32 => { - alter_fp_rounding_mode::(&self.value, &rhs.value, |lhs, rhs| { - lhs.sub(rhs) - }) - } - _ => self.value.sub(&rhs.value), - } - .map(|v| IntervalBound::new(v, self.open || rhs.open)) - } - - /// This function chooses one of the given `IntervalBound`s according to - /// the given function `decide`. The result is unbounded if both are. If - /// only one of the arguments is unbounded, the other one is chosen by - /// default. If neither is unbounded, the function `decide` is used. - pub fn choose( - first: &IntervalBound, - second: &IntervalBound, - decide: fn(&ScalarValue, &ScalarValue) -> Result, - ) -> Result { - Ok(if first.is_unbounded() { - second.clone() - } else if second.is_unbounded() { - first.clone() - } else if first.value != second.value { - let chosen = decide(&first.value, &second.value)?; - if chosen.eq(&first.value) { - first.clone() - } else { - second.clone() - } - } else { - IntervalBound::new(second.value.clone(), first.open || second.open) - }) - } -} - -impl Display for IntervalBound { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "IntervalBound [{}]", self.value) - } -} - -/// This type represents an interval, which is used to calculate reliable -/// bounds for expressions: -/// -/// * An *open* interval does not include the endpoint and is written using a -/// `(` or `)`. -/// -/// * A *closed* interval does include the endpoint and is written using `[` or -/// `]`. -/// -/// * If the interval's `lower` and/or `upper` bounds are not known, they are -/// called *unbounded* endpoint and represented using a `NULL` and written using -/// `∞`. -/// -/// # Examples -/// -/// A `Int64` `Interval` of `[10, 20)` represents the values `10, 11, ... 18, -/// 19` (includes 10, but does not include 20). -/// -/// A `Int64` `Interval` of `[10, ∞)` represents a value known to be either -/// `10` or higher. -/// -/// An `Interval` of `(-∞, ∞)` represents that the range is entirely unknown. -/// -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Interval { - pub lower: IntervalBound, - pub upper: IntervalBound, -} - -impl Default for Interval { - fn default() -> Self { - Interval::new( - IntervalBound::new(ScalarValue::Null, true), - IntervalBound::new(ScalarValue::Null, true), - ) - } -} - -impl Display for Interval { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!( - f, - "{}{}, {}{}", - if self.lower.open { "(" } else { "[" }, - self.lower.value, - self.upper.value, - if self.upper.open { ")" } else { "]" } - ) - } -} - -impl Interval { - /// Creates a new interval object using the given bounds. - /// - /// # Boolean intervals need special handling - /// - /// For boolean intervals, having an open false lower bound is equivalent to - /// having a true closed lower bound. Similarly, open true upper bound is - /// equivalent to having a false closed upper bound. Also for boolean - /// intervals, having an unbounded left endpoint is equivalent to having a - /// false closed lower bound, while having an unbounded right endpoint is - /// equivalent to having a true closed upper bound. Therefore; input - /// parameters to construct an Interval can have different types, but they - /// all result in `[false, false]`, `[false, true]` or `[true, true]`. - pub fn new(lower: IntervalBound, upper: IntervalBound) -> Interval { - // Boolean intervals need a special handling. - if let ScalarValue::Boolean(_) = lower.value { - let standardized_lower = match lower.value { - ScalarValue::Boolean(None) if lower.open => { - ScalarValue::Boolean(Some(false)) - } - ScalarValue::Boolean(Some(false)) if lower.open => { - ScalarValue::Boolean(Some(true)) - } - // The rest may include some invalid interval cases. The validation of - // interval construction parameters will be implemented later. - // For now, let's return them unchanged. - _ => lower.value, - }; - let standardized_upper = match upper.value { - ScalarValue::Boolean(None) if upper.open => { - ScalarValue::Boolean(Some(true)) - } - ScalarValue::Boolean(Some(true)) if upper.open => { - ScalarValue::Boolean(Some(false)) - } - _ => upper.value, - }; - Interval { - lower: IntervalBound::new(standardized_lower, false), - upper: IntervalBound::new(standardized_upper, false), - } - } else { - Interval { lower, upper } - } - } - - pub fn make(lower: Option, upper: Option, open: (bool, bool)) -> Interval - where - ScalarValue: From>, - { - Interval::new( - IntervalBound::new(ScalarValue::from(lower), open.0), - IntervalBound::new(ScalarValue::from(upper), open.1), - ) - } - - /// Casts this interval to `data_type` using `cast_options`. - pub(crate) fn cast_to( - &self, - data_type: &DataType, - cast_options: &CastOptions, - ) -> Result { - let lower = self.lower.cast_to(data_type, cast_options)?; - let upper = self.upper.cast_to(data_type, cast_options)?; - Ok(Interval::new(lower, upper)) - } - - /// This function returns the data type of this interval. If both endpoints - /// do not have the same data type, returns an error. - pub fn get_datatype(&self) -> Result { - let lower_type = self.lower.get_datatype(); - let upper_type = self.upper.get_datatype(); - if lower_type == upper_type { - Ok(lower_type) - } else { - internal_err!( - "Interval bounds have different types: {lower_type} != {upper_type}" - ) - } - } - - /// Decide if this interval is certainly greater than, possibly greater than, - /// or can't be greater than `other` by returning [true, true], - /// [false, true] or [false, false] respectively. - pub(crate) fn gt>(&self, other: T) -> Interval { - let rhs = other.borrow(); - let flags = if !self.upper.is_unbounded() - && !rhs.lower.is_unbounded() - && self.upper.value <= rhs.lower.value - { - // Values in this interval are certainly less than or equal to those - // in the given interval. - (false, false) - } else if !self.lower.is_unbounded() - && !rhs.upper.is_unbounded() - && self.lower.value >= rhs.upper.value - && (self.lower.value > rhs.upper.value || self.lower.open || rhs.upper.open) - { - // Values in this interval are certainly greater than those in the - // given interval. - (true, true) - } else { - // All outcomes are possible. - (false, true) - }; - - Interval::make(Some(flags.0), Some(flags.1), (false, false)) - } - - /// Decide if this interval is certainly greater than or equal to, possibly greater than - /// or equal to, or can't be greater than or equal to `other` by returning [true, true], - /// [false, true] or [false, false] respectively. - pub(crate) fn gt_eq>(&self, other: T) -> Interval { - let rhs = other.borrow(); - let flags = if !self.lower.is_unbounded() - && !rhs.upper.is_unbounded() - && self.lower.value >= rhs.upper.value - { - // Values in this interval are certainly greater than or equal to those - // in the given interval. - (true, true) - } else if !self.upper.is_unbounded() - && !rhs.lower.is_unbounded() - && self.upper.value <= rhs.lower.value - && (self.upper.value < rhs.lower.value || self.upper.open || rhs.lower.open) - { - // Values in this interval are certainly less than those in the - // given interval. - (false, false) - } else { - // All outcomes are possible. - (false, true) - }; - - Interval::make(Some(flags.0), Some(flags.1), (false, false)) - } - - /// Decide if this interval is certainly less than, possibly less than, - /// or can't be less than `other` by returning [true, true], - /// [false, true] or [false, false] respectively. - pub(crate) fn lt>(&self, other: T) -> Interval { - other.borrow().gt(self) - } - - /// Decide if this interval is certainly less than or equal to, possibly - /// less than or equal to, or can't be less than or equal to `other` by returning - /// [true, true], [false, true] or [false, false] respectively. - pub(crate) fn lt_eq>(&self, other: T) -> Interval { - other.borrow().gt_eq(self) - } - - /// Decide if this interval is certainly equal to, possibly equal to, - /// or can't be equal to `other` by returning [true, true], - /// [false, true] or [false, false] respectively. - pub(crate) fn equal>(&self, other: T) -> Interval { - let rhs = other.borrow(); - let flags = if !self.lower.is_unbounded() - && (self.lower.value == self.upper.value) - && (rhs.lower.value == rhs.upper.value) - && (self.lower.value == rhs.lower.value) - { - (true, true) - } else if self.gt(rhs) == Interval::CERTAINLY_TRUE - || self.lt(rhs) == Interval::CERTAINLY_TRUE - { - (false, false) - } else { - (false, true) - }; - - Interval::make(Some(flags.0), Some(flags.1), (false, false)) - } - - /// Compute the logical conjunction of this (boolean) interval with the given boolean interval. - pub(crate) fn and>(&self, other: T) -> Result { - let rhs = other.borrow(); - match ( - &self.lower.value, - &self.upper.value, - &rhs.lower.value, - &rhs.upper.value, - ) { - ( - ScalarValue::Boolean(Some(self_lower)), - ScalarValue::Boolean(Some(self_upper)), - ScalarValue::Boolean(Some(other_lower)), - ScalarValue::Boolean(Some(other_upper)), - ) => { - let lower = *self_lower && *other_lower; - let upper = *self_upper && *other_upper; - - Ok(Interval { - lower: IntervalBound::new(ScalarValue::Boolean(Some(lower)), false), - upper: IntervalBound::new(ScalarValue::Boolean(Some(upper)), false), - }) - } - _ => internal_err!("Incompatible types for logical conjunction"), - } - } - - /// Compute the logical negation of this (boolean) interval. - pub(crate) fn not(&self) -> Result { - if !matches!(self.get_datatype()?, DataType::Boolean) { - return internal_err!( - "Cannot apply logical negation to non-boolean interval" - ); - } - if self == &Interval::CERTAINLY_TRUE { - Ok(Interval::CERTAINLY_FALSE) - } else if self == &Interval::CERTAINLY_FALSE { - Ok(Interval::CERTAINLY_TRUE) - } else { - Ok(Interval::UNCERTAIN) - } - } - - /// Compute the intersection of the interval with the given interval. - /// If the intersection is empty, return None. - pub(crate) fn intersect>( - &self, - other: T, - ) -> Result> { - let rhs = other.borrow(); - // If it is evident that the result is an empty interval, - // do not make any calculation and directly return None. - if (!self.lower.is_unbounded() - && !rhs.upper.is_unbounded() - && self.lower.value > rhs.upper.value) - || (!self.upper.is_unbounded() - && !rhs.lower.is_unbounded() - && self.upper.value < rhs.lower.value) - { - // This None value signals an empty interval. - return Ok(None); - } - - let lower = IntervalBound::choose(&self.lower, &rhs.lower, max)?; - let upper = IntervalBound::choose(&self.upper, &rhs.upper, min)?; - - let non_empty = lower.is_unbounded() - || upper.is_unbounded() - || lower.value != upper.value - || (!lower.open && !upper.open); - Ok(non_empty.then_some(Interval::new(lower, upper))) - } - - /// Decide if this interval is certainly contains, possibly contains, - /// or can't can't `other` by returning [true, true], - /// [false, true] or [false, false] respectively. - pub fn contains>(&self, other: T) -> Result { - match self.intersect(other.borrow())? { - Some(intersection) => { - // Need to compare with same bounds close-ness. - if intersection.close_bounds() == other.borrow().clone().close_bounds() { - Ok(Interval::CERTAINLY_TRUE) - } else { - Ok(Interval::UNCERTAIN) - } - } - None => Ok(Interval::CERTAINLY_FALSE), - } - } - - /// Add the given interval (`other`) to this interval. Say we have - /// intervals [a1, b1] and [a2, b2], then their sum is [a1 + a2, b1 + b2]. - /// Note that this represents all possible values the sum can take if - /// one can choose single values arbitrarily from each of the operands. - pub fn add>(&self, other: T) -> Result { - let rhs = other.borrow(); - Ok(Interval::new( - self.lower.add::(&rhs.lower)?, - self.upper.add::(&rhs.upper)?, - )) - } - - /// Subtract the given interval (`other`) from this interval. Say we have - /// intervals [a1, b1] and [a2, b2], then their sum is [a1 - b2, b1 - a2]. - /// Note that this represents all possible values the difference can take - /// if one can choose single values arbitrarily from each of the operands. - pub fn sub>(&self, other: T) -> Result { - let rhs = other.borrow(); - Ok(Interval::new( - self.lower.sub::(&rhs.upper)?, - self.upper.sub::(&rhs.lower)?, - )) - } - - pub const CERTAINLY_FALSE: Interval = Interval { - lower: IntervalBound::new_closed(ScalarValue::Boolean(Some(false))), - upper: IntervalBound::new_closed(ScalarValue::Boolean(Some(false))), - }; - - pub const UNCERTAIN: Interval = Interval { - lower: IntervalBound::new_closed(ScalarValue::Boolean(Some(false))), - upper: IntervalBound::new_closed(ScalarValue::Boolean(Some(true))), - }; - - pub const CERTAINLY_TRUE: Interval = Interval { - lower: IntervalBound::new_closed(ScalarValue::Boolean(Some(true))), - upper: IntervalBound::new_closed(ScalarValue::Boolean(Some(true))), - }; - - /// Returns the cardinality of this interval, which is the number of all - /// distinct points inside it. This function returns `None` if: - /// - The interval is unbounded from either side, or - /// - Cardinality calculations for the datatype in question is not - /// implemented yet, or - /// - An overflow occurs during the calculation. - /// - /// This function returns an error if the given interval is malformed. - pub fn cardinality(&self) -> Result> { - let data_type = self.get_datatype()?; - if data_type.is_integer() { - Ok(self.upper.value.distance(&self.lower.value).map(|diff| { - calculate_cardinality_based_on_bounds( - self.lower.open, - self.upper.open, - diff as u64, - ) - })) - } - // Ordering floating-point numbers according to their binary representations - // coincide with their natural ordering. Therefore, we can consider their - // binary representations as "indices" and subtract them. For details, see: - // https://stackoverflow.com/questions/8875064/how-many-distinct-floating-point-numbers-in-a-specific-range - else if data_type.is_floating() { - match (&self.lower.value, &self.upper.value) { - ( - ScalarValue::Float32(Some(lower)), - ScalarValue::Float32(Some(upper)), - ) => { - // Negative numbers are sorted in the reverse order. To always have a positive difference after the subtraction, - // we perform following transformation: - let lower_bits = lower.to_bits() as i32; - let upper_bits = upper.to_bits() as i32; - let transformed_lower = - lower_bits ^ ((lower_bits >> 31) & 0x7fffffff); - let transformed_upper = - upper_bits ^ ((upper_bits >> 31) & 0x7fffffff); - let Ok(count) = transformed_upper.sub_checked(transformed_lower) - else { - return Ok(None); - }; - Ok(Some(calculate_cardinality_based_on_bounds( - self.lower.open, - self.upper.open, - count as u64, - ))) - } - ( - ScalarValue::Float64(Some(lower)), - ScalarValue::Float64(Some(upper)), - ) => { - let lower_bits = lower.to_bits() as i64; - let upper_bits = upper.to_bits() as i64; - let transformed_lower = - lower_bits ^ ((lower_bits >> 63) & 0x7fffffffffffffff); - let transformed_upper = - upper_bits ^ ((upper_bits >> 63) & 0x7fffffffffffffff); - let Ok(count) = transformed_upper.sub_checked(transformed_lower) - else { - return Ok(None); - }; - Ok(Some(calculate_cardinality_based_on_bounds( - self.lower.open, - self.upper.open, - count as u64, - ))) - } - _ => Ok(None), - } - } else { - // Cardinality calculations are not implemented for this data type yet: - Ok(None) - } - } - - /// This function "closes" this interval; i.e. it modifies the endpoints so - /// that we end up with the narrowest possible closed interval containing - /// the original interval. - pub fn close_bounds(mut self) -> Interval { - if self.lower.open { - // Get next value - self.lower.value = next_value::(self.lower.value); - self.lower.open = false; - } - - if self.upper.open { - // Get previous value - self.upper.value = next_value::(self.upper.value); - self.upper.open = false; - } - - self - } -} - -trait OneTrait: Sized + std::ops::Add + std::ops::Sub { - fn one() -> Self; -} - -macro_rules! impl_OneTrait{ - ($($m:ty),*) => {$( impl OneTrait for $m { fn one() -> Self { 1 as $m } })*} -} -impl_OneTrait! {u8, u16, u32, u64, i8, i16, i32, i64} - -/// This function either increments or decrements its argument, depending on the `INC` value. -/// If `true`, it increments; otherwise it decrements the argument. -fn increment_decrement( - mut val: T, -) -> T { - if INC { - val.add_assign(T::one()); - } else { - val.sub_assign(T::one()); - } - val -} - -macro_rules! check_infinite_bounds { - ($value:expr, $val:expr, $type:ident, $inc:expr) => { - if ($val == $type::MAX && $inc) || ($val == $type::MIN && !$inc) { - return $value; - } - }; -} - -/// This function returns the next/previous value depending on the `ADD` value. -/// If `true`, it returns the next value; otherwise it returns the previous value. -fn next_value(value: ScalarValue) -> ScalarValue { - use ScalarValue::*; - match value { - Float32(Some(val)) => { - let new_float = if INC { next_up(val) } else { next_down(val) }; - Float32(Some(new_float)) - } - Float64(Some(val)) => { - let new_float = if INC { next_up(val) } else { next_down(val) }; - Float64(Some(new_float)) - } - Int8(Some(val)) => { - check_infinite_bounds!(value, val, i8, INC); - Int8(Some(increment_decrement::(val))) - } - Int16(Some(val)) => { - check_infinite_bounds!(value, val, i16, INC); - Int16(Some(increment_decrement::(val))) - } - Int32(Some(val)) => { - check_infinite_bounds!(value, val, i32, INC); - Int32(Some(increment_decrement::(val))) - } - Int64(Some(val)) => { - check_infinite_bounds!(value, val, i64, INC); - Int64(Some(increment_decrement::(val))) - } - UInt8(Some(val)) => { - check_infinite_bounds!(value, val, u8, INC); - UInt8(Some(increment_decrement::(val))) - } - UInt16(Some(val)) => { - check_infinite_bounds!(value, val, u16, INC); - UInt16(Some(increment_decrement::(val))) - } - UInt32(Some(val)) => { - check_infinite_bounds!(value, val, u32, INC); - UInt32(Some(increment_decrement::(val))) - } - UInt64(Some(val)) => { - check_infinite_bounds!(value, val, u64, INC); - UInt64(Some(increment_decrement::(val))) - } - _ => value, // Unsupported datatypes - } -} - -/// This function computes the selectivity of an operation by computing the -/// cardinality ratio of the given input/output intervals. If this can not be -/// calculated for some reason, it returns `1.0` meaning fullly selective (no -/// filtering). -pub fn cardinality_ratio( - initial_interval: &Interval, - final_interval: &Interval, -) -> Result { - Ok( - match ( - final_interval.cardinality()?, - initial_interval.cardinality()?, - ) { - (Some(final_interval), Some(initial_interval)) => { - final_interval as f64 / initial_interval as f64 - } - _ => 1.0, - }, - ) -} - -pub fn apply_operator(op: &Operator, lhs: &Interval, rhs: &Interval) -> Result { - match *op { - Operator::Eq => Ok(lhs.equal(rhs)), - Operator::NotEq => Ok(lhs.equal(rhs).not()?), - Operator::Gt => Ok(lhs.gt(rhs)), - Operator::GtEq => Ok(lhs.gt_eq(rhs)), - Operator::Lt => Ok(lhs.lt(rhs)), - Operator::LtEq => Ok(lhs.lt_eq(rhs)), - Operator::And => lhs.and(rhs), - Operator::Plus => lhs.add(rhs), - Operator::Minus => lhs.sub(rhs), - _ => Ok(Interval::default()), - } -} - -/// Cast scalar value to the given data type using an arrow kernel. -fn cast_scalar_value( - value: &ScalarValue, - data_type: &DataType, - cast_options: &CastOptions, -) -> Result { - let cast_array = cast_with_options(&value.to_array(), data_type, cast_options)?; - ScalarValue::try_from_array(&cast_array, 0) -} - -/// This function calculates the final cardinality result by inspecting the endpoints of the interval. -fn calculate_cardinality_based_on_bounds( - lower_open: bool, - upper_open: bool, - diff: u64, -) -> u64 { - match (lower_open, upper_open) { - (false, false) => diff + 1, - (true, true) => diff - 1, - _ => diff, - } -} - -/// An [Interval] that also tracks null status using a boolean interval. -/// -/// This represents values that may be in a particular range or be null. -/// -/// # Examples -/// -/// ``` -/// use arrow::datatypes::DataType; -/// use datafusion_physical_expr::intervals::{Interval, NullableInterval}; -/// use datafusion_common::ScalarValue; -/// -/// // [1, 2) U {NULL} -/// NullableInterval::MaybeNull { -/// values: Interval::make(Some(1), Some(2), (false, true)), -/// }; -/// -/// // (0, ∞) -/// NullableInterval::NotNull { -/// values: Interval::make(Some(0), None, (true, true)), -/// }; -/// -/// // {NULL} -/// NullableInterval::Null { datatype: DataType::Int32 }; -/// -/// // {4} -/// NullableInterval::from(ScalarValue::Int32(Some(4))); -/// ``` -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum NullableInterval { - /// The value is always null in this interval - /// - /// This is typed so it can be used in physical expressions, which don't do - /// type coercion. - Null { datatype: DataType }, - /// The value may or may not be null in this interval. If it is non null its value is within - /// the specified values interval - MaybeNull { values: Interval }, - /// The value is definitely not null in this interval and is within values - NotNull { values: Interval }, -} - -impl Default for NullableInterval { - fn default() -> Self { - NullableInterval::MaybeNull { - values: Interval::default(), - } - } -} - -impl Display for NullableInterval { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match self { - Self::Null { .. } => write!(f, "NullableInterval: {{NULL}}"), - Self::MaybeNull { values } => { - write!(f, "NullableInterval: {} U {{NULL}}", values) - } - Self::NotNull { values } => write!(f, "NullableInterval: {}", values), - } - } -} - -impl From for NullableInterval { - /// Create an interval that represents a single value. - fn from(value: ScalarValue) -> Self { - if value.is_null() { - Self::Null { - datatype: value.data_type(), - } - } else { - Self::NotNull { - values: Interval::new( - IntervalBound::new(value.clone(), false), - IntervalBound::new(value, false), - ), - } - } - } -} - -impl NullableInterval { - /// Get the values interval, or None if this interval is definitely null. - pub fn values(&self) -> Option<&Interval> { - match self { - Self::Null { .. } => None, - Self::MaybeNull { values } | Self::NotNull { values } => Some(values), - } - } - - /// Get the data type - pub fn get_datatype(&self) -> Result { - match self { - Self::Null { datatype } => Ok(datatype.clone()), - Self::MaybeNull { values } | Self::NotNull { values } => { - values.get_datatype() - } - } - } - - /// Return true if the value is definitely true (and not null). - pub fn is_certainly_true(&self) -> bool { - match self { - Self::Null { .. } | Self::MaybeNull { .. } => false, - Self::NotNull { values } => values == &Interval::CERTAINLY_TRUE, - } - } - - /// Return true if the value is definitely false (and not null). - pub fn is_certainly_false(&self) -> bool { - match self { - Self::Null { .. } => false, - Self::MaybeNull { .. } => false, - Self::NotNull { values } => values == &Interval::CERTAINLY_FALSE, - } - } - - /// Perform logical negation on a boolean nullable interval. - fn not(&self) -> Result { - match self { - Self::Null { datatype } => Ok(Self::Null { - datatype: datatype.clone(), - }), - Self::MaybeNull { values } => Ok(Self::MaybeNull { - values: values.not()?, - }), - Self::NotNull { values } => Ok(Self::NotNull { - values: values.not()?, - }), - } - } - - /// Apply the given operator to this interval and the given interval. - /// - /// # Examples - /// - /// ``` - /// use datafusion_common::ScalarValue; - /// use datafusion_expr::Operator; - /// use datafusion_physical_expr::intervals::{Interval, NullableInterval}; - /// - /// // 4 > 3 -> true - /// let lhs = NullableInterval::from(ScalarValue::Int32(Some(4))); - /// let rhs = NullableInterval::from(ScalarValue::Int32(Some(3))); - /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); - /// assert_eq!(result, NullableInterval::from(ScalarValue::Boolean(Some(true)))); - /// - /// // [1, 3) > NULL -> NULL - /// let lhs = NullableInterval::NotNull { - /// values: Interval::make(Some(1), Some(3), (false, true)), - /// }; - /// let rhs = NullableInterval::from(ScalarValue::Int32(None)); - /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); - /// assert_eq!(result.single_value(), Some(ScalarValue::Boolean(None))); - /// - /// // [1, 3] > [2, 4] -> [false, true] - /// let lhs = NullableInterval::NotNull { - /// values: Interval::make(Some(1), Some(3), (false, false)), - /// }; - /// let rhs = NullableInterval::NotNull { - /// values: Interval::make(Some(2), Some(4), (false, false)), - /// }; - /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); - /// // Both inputs are valid (non-null), so result must be non-null - /// assert_eq!(result, NullableInterval::NotNull { - /// // Uncertain whether inequality is true or false - /// values: Interval::UNCERTAIN, - /// }); - /// - /// ``` - pub fn apply_operator(&self, op: &Operator, rhs: &Self) -> Result { - match op { - Operator::IsDistinctFrom => { - let values = match (self, rhs) { - // NULL is distinct from NULL -> False - (Self::Null { .. }, Self::Null { .. }) => Interval::CERTAINLY_FALSE, - // x is distinct from y -> x != y, - // if at least one of them is never null. - (Self::NotNull { .. }, _) | (_, Self::NotNull { .. }) => { - let lhs_values = self.values(); - let rhs_values = rhs.values(); - match (lhs_values, rhs_values) { - (Some(lhs_values), Some(rhs_values)) => { - lhs_values.equal(rhs_values).not()? - } - (Some(_), None) | (None, Some(_)) => Interval::CERTAINLY_TRUE, - (None, None) => unreachable!("Null case handled above"), - } - } - _ => Interval::UNCERTAIN, - }; - // IsDistinctFrom never returns null. - Ok(Self::NotNull { values }) - } - Operator::IsNotDistinctFrom => self - .apply_operator(&Operator::IsDistinctFrom, rhs) - .map(|i| i.not())?, - _ => { - if let (Some(left_values), Some(right_values)) = - (self.values(), rhs.values()) - { - let values = apply_operator(op, left_values, right_values)?; - match (self, rhs) { - (Self::NotNull { .. }, Self::NotNull { .. }) => { - Ok(Self::NotNull { values }) - } - _ => Ok(Self::MaybeNull { values }), - } - } else if op.is_comparison_operator() { - Ok(Self::Null { - datatype: DataType::Boolean, - }) - } else { - Ok(Self::Null { - datatype: self.get_datatype()?, - }) - } - } - } - } - - /// Determine if this interval contains the given interval. Returns a boolean - /// interval that is [true, true] if this interval is a superset of the - /// given interval, [false, false] if this interval is disjoint from the - /// given interval, and [false, true] otherwise. - pub fn contains>(&self, other: T) -> Result { - let rhs = other.borrow(); - if let (Some(left_values), Some(right_values)) = (self.values(), rhs.values()) { - let values = left_values.contains(right_values)?; - match (self, rhs) { - (Self::NotNull { .. }, Self::NotNull { .. }) => { - Ok(Self::NotNull { values }) - } - _ => Ok(Self::MaybeNull { values }), - } - } else { - Ok(Self::Null { - datatype: DataType::Boolean, - }) - } - } - - /// If the interval has collapsed to a single value, return that value. - /// - /// Otherwise returns None. - /// - /// # Examples - /// - /// ``` - /// use datafusion_common::ScalarValue; - /// use datafusion_physical_expr::intervals::{Interval, NullableInterval}; - /// - /// let interval = NullableInterval::from(ScalarValue::Int32(Some(4))); - /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(Some(4)))); - /// - /// let interval = NullableInterval::from(ScalarValue::Int32(None)); - /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(None))); - /// - /// let interval = NullableInterval::MaybeNull { - /// values: Interval::make(Some(1), Some(4), (false, true)), - /// }; - /// assert_eq!(interval.single_value(), None); - /// ``` - pub fn single_value(&self) -> Option { - match self { - Self::Null { datatype } => { - Some(ScalarValue::try_from(datatype).unwrap_or(ScalarValue::Null)) - } - Self::MaybeNull { values } | Self::NotNull { values } - if values.lower.value == values.upper.value - && !values.lower.is_unbounded() => - { - Some(values.lower.value.clone()) - } - _ => None, - } - } -} - -#[cfg(test)] -mod tests { - use super::next_value; - use crate::intervals::{Interval, IntervalBound}; - use arrow_schema::DataType; - use datafusion_common::{Result, ScalarValue}; - - fn open_open(lower: Option, upper: Option) -> Interval - where - ScalarValue: From>, - { - Interval::make(lower, upper, (true, true)) - } - - fn open_closed(lower: Option, upper: Option) -> Interval - where - ScalarValue: From>, - { - Interval::make(lower, upper, (true, false)) - } - - fn closed_open(lower: Option, upper: Option) -> Interval - where - ScalarValue: From>, - { - Interval::make(lower, upper, (false, true)) - } - - fn closed_closed(lower: Option, upper: Option) -> Interval - where - ScalarValue: From>, - { - Interval::make(lower, upper, (false, false)) - } - - #[test] - fn intersect_test() -> Result<()> { - let possible_cases = vec![ - (Some(1000_i64), None, None, None, Some(1000_i64), None), - (None, Some(1000_i64), None, None, None, Some(1000_i64)), - (None, None, Some(1000_i64), None, Some(1000_i64), None), - (None, None, None, Some(1000_i64), None, Some(1000_i64)), - ( - Some(1000_i64), - None, - Some(1000_i64), - None, - Some(1000_i64), - None, - ), - ( - None, - Some(1000_i64), - Some(999_i64), - Some(1002_i64), - Some(999_i64), - Some(1000_i64), - ), - (None, None, None, None, None, None), - ]; - - for case in possible_cases { - assert_eq!( - open_open(case.0, case.1).intersect(open_open(case.2, case.3))?, - Some(open_open(case.4, case.5)) - ) - } - - let empty_cases = vec![ - (None, Some(1000_i64), Some(1001_i64), None), - (Some(1001_i64), None, None, Some(1000_i64)), - (None, Some(1000_i64), Some(1001_i64), Some(1002_i64)), - (Some(1001_i64), Some(1002_i64), None, Some(1000_i64)), - ]; - - for case in empty_cases { - assert_eq!( - open_open(case.0, case.1).intersect(open_open(case.2, case.3))?, - None - ) - } - - Ok(()) - } - - #[test] - fn gt_test() { - let cases = vec![ - (Some(1000_i64), None, None, None, false, true), - (None, Some(1000_i64), None, None, false, true), - (None, None, Some(1000_i64), None, false, true), - (None, None, None, Some(1000_i64), false, true), - (None, Some(1000_i64), Some(1000_i64), None, false, false), - (None, Some(1000_i64), Some(1001_i64), None, false, false), - (Some(1000_i64), None, Some(1000_i64), None, false, true), - ( - None, - Some(1000_i64), - Some(1001_i64), - Some(1002_i64), - false, - false, - ), - ( - None, - Some(1000_i64), - Some(999_i64), - Some(1002_i64), - false, - true, - ), - ( - Some(1002_i64), - None, - Some(999_i64), - Some(1002_i64), - true, - true, - ), - ( - Some(1003_i64), - None, - Some(999_i64), - Some(1002_i64), - true, - true, - ), - (None, None, None, None, false, true), - ]; - - for case in cases { - assert_eq!( - open_open(case.0, case.1).gt(open_open(case.2, case.3)), - closed_closed(Some(case.4), Some(case.5)) - ); - } - } - - #[test] - fn lt_test() { - let cases = vec![ - (Some(1000_i64), None, None, None, false, true), - (None, Some(1000_i64), None, None, false, true), - (None, None, Some(1000_i64), None, false, true), - (None, None, None, Some(1000_i64), false, true), - (None, Some(1000_i64), Some(1000_i64), None, true, true), - (None, Some(1000_i64), Some(1001_i64), None, true, true), - (Some(1000_i64), None, Some(1000_i64), None, false, true), - ( - None, - Some(1000_i64), - Some(1001_i64), - Some(1002_i64), - true, - true, - ), - ( - None, - Some(1000_i64), - Some(999_i64), - Some(1002_i64), - false, - true, - ), - (None, None, None, None, false, true), - ]; - - for case in cases { - assert_eq!( - open_open(case.0, case.1).lt(open_open(case.2, case.3)), - closed_closed(Some(case.4), Some(case.5)) - ); - } - } - - #[test] - fn and_test() -> Result<()> { - let cases = vec![ - (false, true, false, false, false, false), - (false, false, false, true, false, false), - (false, true, false, true, false, true), - (false, true, true, true, false, true), - (false, false, false, false, false, false), - (true, true, true, true, true, true), - ]; - - for case in cases { - assert_eq!( - open_open(Some(case.0), Some(case.1)) - .and(open_open(Some(case.2), Some(case.3)))?, - open_open(Some(case.4), Some(case.5)) - ); - } - Ok(()) - } - - #[test] - fn add_test() -> Result<()> { - let cases = vec![ - (Some(1000_i64), None, None, None, None, None), - (None, Some(1000_i64), None, None, None, None), - (None, None, Some(1000_i64), None, None, None), - (None, None, None, Some(1000_i64), None, None), - ( - Some(1000_i64), - None, - Some(1000_i64), - None, - Some(2000_i64), - None, - ), - ( - None, - Some(1000_i64), - Some(999_i64), - Some(1002_i64), - None, - Some(2002_i64), - ), - (None, Some(1000_i64), Some(1000_i64), None, None, None), - ( - Some(2001_i64), - Some(1_i64), - Some(1005_i64), - Some(-999_i64), - Some(3006_i64), - Some(-998_i64), - ), - (None, None, None, None, None, None), - ]; - - for case in cases { - assert_eq!( - open_open(case.0, case.1).add(open_open(case.2, case.3))?, - open_open(case.4, case.5) - ); - } - Ok(()) - } - - #[test] - fn sub_test() -> Result<()> { - let cases = vec![ - (Some(1000_i64), None, None, None, None, None), - (None, Some(1000_i64), None, None, None, None), - (None, None, Some(1000_i64), None, None, None), - (None, None, None, Some(1000_i64), None, None), - (Some(1000_i64), None, Some(1000_i64), None, None, None), - ( - None, - Some(1000_i64), - Some(999_i64), - Some(1002_i64), - None, - Some(1_i64), - ), - ( - None, - Some(1000_i64), - Some(1000_i64), - None, - None, - Some(0_i64), - ), - ( - Some(2001_i64), - Some(1000_i64), - Some(1005), - Some(999_i64), - Some(1002_i64), - Some(-5_i64), - ), - (None, None, None, None, None, None), - ]; - - for case in cases { - assert_eq!( - open_open(case.0, case.1).sub(open_open(case.2, case.3))?, - open_open(case.4, case.5) - ); - } - Ok(()) - } - - #[test] - fn sub_test_various_bounds() -> Result<()> { - let cases = vec![ - ( - closed_closed(Some(100_i64), Some(200_i64)), - closed_open(Some(200_i64), None), - open_closed(None, Some(0_i64)), - ), - ( - closed_open(Some(100_i64), Some(200_i64)), - open_closed(Some(300_i64), Some(150_i64)), - closed_open(Some(-50_i64), Some(-100_i64)), - ), - ( - closed_open(Some(100_i64), Some(200_i64)), - open_open(Some(200_i64), None), - open_open(None, Some(0_i64)), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - closed_closed(Some(11_i64), Some(11_i64)), - closed_closed(Some(-10_i64), Some(-10_i64)), - ), - ]; - for case in cases { - assert_eq!(case.0.sub(case.1)?, case.2) - } - Ok(()) - } - - #[test] - fn add_test_various_bounds() -> Result<()> { - let cases = vec![ - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_closed(None, Some(200_i64)), - open_closed(None, Some(400_i64)), - ), - ( - closed_open(Some(100_i64), Some(200_i64)), - closed_open(Some(-300_i64), Some(150_i64)), - closed_open(Some(-200_i64), Some(350_i64)), - ), - ( - closed_open(Some(100_i64), Some(200_i64)), - open_open(Some(200_i64), None), - open_open(Some(300_i64), None), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - closed_closed(Some(11_i64), Some(11_i64)), - closed_closed(Some(12_i64), Some(12_i64)), - ), - ]; - for case in cases { - assert_eq!(case.0.add(case.1)?, case.2) - } - Ok(()) - } - - #[test] - fn lt_test_various_bounds() -> Result<()> { - let cases = vec![ - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_closed(None, Some(100_i64)), - closed_closed(Some(false), Some(false)), - ), - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_open(None, Some(100_i64)), - closed_closed(Some(false), Some(false)), - ), - ( - open_open(Some(100_i64), Some(200_i64)), - closed_closed(Some(0_i64), Some(100_i64)), - closed_closed(Some(false), Some(false)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_closed(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(false)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(false)), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - open_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(true), Some(true)), - ), - ]; - for case in cases { - assert_eq!(case.0.lt(case.1), case.2) - } - Ok(()) - } - - #[test] - fn gt_test_various_bounds() -> Result<()> { - let cases = vec![ - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_closed(None, Some(100_i64)), - closed_closed(Some(false), Some(true)), - ), - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_open(None, Some(100_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - open_open(Some(100_i64), Some(200_i64)), - closed_closed(Some(0_i64), Some(100_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_closed(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(true)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - open_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(false)), - ), - ]; - for case in cases { - assert_eq!(case.0.gt(case.1), case.2) - } - Ok(()) - } - - #[test] - fn lt_eq_test_various_bounds() -> Result<()> { - let cases = vec![ - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_closed(None, Some(100_i64)), - closed_closed(Some(false), Some(true)), - ), - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_open(None, Some(100_i64)), - closed_closed(Some(false), Some(false)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_closed(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(true)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(false)), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - closed_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - open_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(true), Some(true)), - ), - ]; - for case in cases { - assert_eq!(case.0.lt_eq(case.1), case.2) - } - Ok(()) - } - - #[test] - fn gt_eq_test_various_bounds() -> Result<()> { - let cases = vec![ - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_closed(None, Some(100_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_open(None, Some(100_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_closed(Some(1_i64), Some(2_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - closed_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(true)), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - open_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(false)), - ), - ]; - for case in cases { - assert_eq!(case.0.gt_eq(case.1), case.2) - } - Ok(()) - } - - #[test] - fn intersect_test_various_bounds() -> Result<()> { - let cases = vec![ - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_closed(None, Some(100_i64)), - Some(closed_closed(Some(100_i64), Some(100_i64))), - ), - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_open(None, Some(100_i64)), - None, - ), - ( - open_open(Some(100_i64), Some(200_i64)), - closed_closed(Some(0_i64), Some(100_i64)), - None, - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_closed(Some(1_i64), Some(2_i64)), - Some(closed_closed(Some(2_i64), Some(2_i64))), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_open(Some(1_i64), Some(2_i64)), - None, - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - open_open(Some(1_i64), Some(2_i64)), - None, - ), - ( - closed_closed(Some(1_i64), Some(3_i64)), - open_open(Some(1_i64), Some(2_i64)), - Some(open_open(Some(1_i64), Some(2_i64))), - ), - ]; - for case in cases { - assert_eq!(case.0.intersect(case.1)?, case.2) - } - Ok(()) - } - - // This function tests if valid constructions produce standardized objects - // ([false, false], [false, true], [true, true]) for boolean intervals. - #[test] - fn non_standard_interval_constructs() { - use ScalarValue::Boolean; - let cases = vec![ - ( - IntervalBound::new(Boolean(None), true), - IntervalBound::new(Boolean(Some(true)), false), - closed_closed(Some(false), Some(true)), - ), - ( - IntervalBound::new(Boolean(None), true), - IntervalBound::new(Boolean(Some(true)), true), - closed_closed(Some(false), Some(false)), - ), - ( - IntervalBound::new(Boolean(Some(false)), false), - IntervalBound::new(Boolean(None), true), - closed_closed(Some(false), Some(true)), - ), - ( - IntervalBound::new(Boolean(Some(true)), false), - IntervalBound::new(Boolean(None), true), - closed_closed(Some(true), Some(true)), - ), - ( - IntervalBound::new(Boolean(None), true), - IntervalBound::new(Boolean(None), true), - closed_closed(Some(false), Some(true)), - ), - ( - IntervalBound::new(Boolean(Some(false)), true), - IntervalBound::new(Boolean(None), true), - closed_closed(Some(true), Some(true)), - ), - ]; - - for case in cases { - assert_eq!(Interval::new(case.0, case.1), case.2) - } - } - - macro_rules! capture_mode_change { - ($TYPE:ty) => { - paste::item! { - capture_mode_change_helper!([], - [], - $TYPE); - } - }; - } - - macro_rules! capture_mode_change_helper { - ($TEST_FN_NAME:ident, $CREATE_FN_NAME:ident, $TYPE:ty) => { - fn $CREATE_FN_NAME(lower: $TYPE, upper: $TYPE) -> Interval { - Interval::make(Some(lower as $TYPE), Some(upper as $TYPE), (true, true)) - } - - fn $TEST_FN_NAME(input: ($TYPE, $TYPE), expect_low: bool, expect_high: bool) { - assert!(expect_low || expect_high); - let interval1 = $CREATE_FN_NAME(input.0, input.0); - let interval2 = $CREATE_FN_NAME(input.1, input.1); - let result = interval1.add(&interval2).unwrap(); - let without_fe = $CREATE_FN_NAME(input.0 + input.1, input.0 + input.1); - assert!( - (!expect_low || result.lower.value < without_fe.lower.value) - && (!expect_high || result.upper.value > without_fe.upper.value) - ); - } - }; - } - - capture_mode_change!(f32); - capture_mode_change!(f64); - - #[cfg(all( - any(target_arch = "x86_64", target_arch = "aarch64"), - not(target_os = "windows") - ))] - #[test] - fn test_add_intervals_lower_affected_f32() { - // Lower is affected - let lower = f32::from_bits(1073741887); //1000000000000000000000000111111 - let upper = f32::from_bits(1098907651); //1000001100000000000000000000011 - capture_mode_change_f32((lower, upper), true, false); - - // Upper is affected - let lower = f32::from_bits(1072693248); //111111111100000000000000000000 - let upper = f32::from_bits(715827883); //101010101010101010101010101011 - capture_mode_change_f32((lower, upper), false, true); - - // Lower is affected - let lower = 1.0; // 0x3FF0000000000000 - let upper = 0.3; // 0x3FD3333333333333 - capture_mode_change_f64((lower, upper), true, false); - - // Upper is affected - let lower = 1.4999999999999998; // 0x3FF7FFFFFFFFFFFF - let upper = 0.000_000_000_000_000_022_044_604_925_031_31; // 0x3C796A6B413BB21F - capture_mode_change_f64((lower, upper), false, true); - } - - #[cfg(any( - not(any(target_arch = "x86_64", target_arch = "aarch64")), - target_os = "windows" - ))] - #[test] - fn test_next_impl_add_intervals_f64() { - let lower = 1.5; - let upper = 1.5; - capture_mode_change_f64((lower, upper), true, true); - - let lower = 1.5; - let upper = 1.5; - capture_mode_change_f32((lower, upper), true, true); - } - - #[test] - fn test_cardinality_of_intervals() -> Result<()> { - // In IEEE 754 standard for floating-point arithmetic, if we keep the sign and exponent fields same, - // we can represent 4503599627370496 different numbers by changing the mantissa - // (4503599627370496 = 2^52, since there are 52 bits in mantissa, and 2^23 = 8388608 for f32). - let distinct_f64 = 4503599627370496; - let distinct_f32 = 8388608; - let intervals = [ - Interval::new( - IntervalBound::new(ScalarValue::from(0.25), false), - IntervalBound::new(ScalarValue::from(0.50), true), - ), - Interval::new( - IntervalBound::new(ScalarValue::from(0.5), false), - IntervalBound::new(ScalarValue::from(1.0), true), - ), - Interval::new( - IntervalBound::new(ScalarValue::from(1.0), false), - IntervalBound::new(ScalarValue::from(2.0), true), - ), - Interval::new( - IntervalBound::new(ScalarValue::from(32.0), false), - IntervalBound::new(ScalarValue::from(64.0), true), - ), - Interval::new( - IntervalBound::new(ScalarValue::from(-0.50), false), - IntervalBound::new(ScalarValue::from(-0.25), true), - ), - Interval::new( - IntervalBound::new(ScalarValue::from(-32.0), false), - IntervalBound::new(ScalarValue::from(-16.0), true), - ), - ]; - for interval in intervals { - assert_eq!(interval.cardinality()?.unwrap(), distinct_f64); - } - - let intervals = [ - Interval::new( - IntervalBound::new(ScalarValue::from(0.25_f32), false), - IntervalBound::new(ScalarValue::from(0.50_f32), true), - ), - Interval::new( - IntervalBound::new(ScalarValue::from(-1_f32), false), - IntervalBound::new(ScalarValue::from(-0.5_f32), true), - ), - ]; - for interval in intervals { - assert_eq!(interval.cardinality()?.unwrap(), distinct_f32); - } - - // The regular logarithmic distribution of floating-point numbers are - // only applicable outside of the `(-phi, phi)` interval where `phi` - // denotes the largest positive subnormal floating-point number. Since - // the following intervals include such subnormal points, we cannot use - // a simple powers-of-two type formula for our expectations. Therefore, - // we manually supply the actual expected cardinality. - let interval = Interval::new( - IntervalBound::new(ScalarValue::from(-0.0625), false), - IntervalBound::new(ScalarValue::from(0.0625), true), - ); - assert_eq!(interval.cardinality()?.unwrap(), 9178336040581070849); - - let interval = Interval::new( - IntervalBound::new(ScalarValue::from(-0.0625_f32), false), - IntervalBound::new(ScalarValue::from(0.0625_f32), true), - ); - assert_eq!(interval.cardinality()?.unwrap(), 2063597569); - - Ok(()) - } - - #[test] - fn test_next_value() -> Result<()> { - // integer increment / decrement - let zeros = vec![ - ScalarValue::new_zero(&DataType::UInt8)?, - ScalarValue::new_zero(&DataType::UInt16)?, - ScalarValue::new_zero(&DataType::UInt32)?, - ScalarValue::new_zero(&DataType::UInt64)?, - ScalarValue::new_zero(&DataType::Int8)?, - ScalarValue::new_zero(&DataType::Int8)?, - ScalarValue::new_zero(&DataType::Int8)?, - ScalarValue::new_zero(&DataType::Int8)?, - ]; - - let ones = vec![ - ScalarValue::new_one(&DataType::UInt8)?, - ScalarValue::new_one(&DataType::UInt16)?, - ScalarValue::new_one(&DataType::UInt32)?, - ScalarValue::new_one(&DataType::UInt64)?, - ScalarValue::new_one(&DataType::Int8)?, - ScalarValue::new_one(&DataType::Int8)?, - ScalarValue::new_one(&DataType::Int8)?, - ScalarValue::new_one(&DataType::Int8)?, - ]; - - zeros.into_iter().zip(ones).for_each(|(z, o)| { - assert_eq!(next_value::(z.clone()), o); - assert_eq!(next_value::(o), z); - }); - - // floating value increment / decrement - let values = vec![ - ScalarValue::new_zero(&DataType::Float32)?, - ScalarValue::new_zero(&DataType::Float64)?, - ]; - - let eps = vec![ - ScalarValue::Float32(Some(1e-6)), - ScalarValue::Float64(Some(1e-6)), - ]; - - values.into_iter().zip(eps).for_each(|(v, e)| { - assert!(next_value::(v.clone()).sub(v.clone()).unwrap().lt(&e)); - assert!(v.clone().sub(next_value::(v)).unwrap().lt(&e)); - }); - - // Min / Max values do not change for integer values - let min = vec![ - ScalarValue::UInt64(Some(u64::MIN)), - ScalarValue::Int8(Some(i8::MIN)), - ]; - let max = vec![ - ScalarValue::UInt64(Some(u64::MAX)), - ScalarValue::Int8(Some(i8::MAX)), - ]; - - min.into_iter().zip(max).for_each(|(min, max)| { - assert_eq!(next_value::(max.clone()), max); - assert_eq!(next_value::(min.clone()), min); - }); - - // Min / Max values results in infinity for floating point values - assert_eq!( - next_value::(ScalarValue::Float32(Some(f32::MAX))), - ScalarValue::Float32(Some(f32::INFINITY)) - ); - assert_eq!( - next_value::(ScalarValue::Float64(Some(f64::MIN))), - ScalarValue::Float64(Some(f64::NEG_INFINITY)) - ); - - Ok(()) - } - - #[test] - fn test_interval_display() { - let interval = Interval::new( - IntervalBound::new(ScalarValue::from(0.25_f32), true), - IntervalBound::new(ScalarValue::from(0.50_f32), false), - ); - assert_eq!(format!("{}", interval), "(0.25, 0.5]"); - - let interval = Interval::new( - IntervalBound::new(ScalarValue::from(0.25_f32), false), - IntervalBound::new(ScalarValue::from(0.50_f32), true), - ); - assert_eq!(format!("{}", interval), "[0.25, 0.5)"); - - let interval = Interval::new( - IntervalBound::new(ScalarValue::from(0.25_f32), true), - IntervalBound::new(ScalarValue::from(0.50_f32), true), - ); - assert_eq!(format!("{}", interval), "(0.25, 0.5)"); - - let interval = Interval::new( - IntervalBound::new(ScalarValue::from(0.25_f32), false), - IntervalBound::new(ScalarValue::from(0.50_f32), false), - ); - assert_eq!(format!("{}", interval), "[0.25, 0.5]"); - } -} diff --git a/datafusion/physical-expr/src/intervals/mod.rs b/datafusion/physical-expr/src/intervals/mod.rs index b89d1c59dc64e..9752ca27b5a38 100644 --- a/datafusion/physical-expr/src/intervals/mod.rs +++ b/datafusion/physical-expr/src/intervals/mod.rs @@ -18,10 +18,5 @@ //! Interval arithmetic and constraint propagation library pub mod cp_solver; -pub mod interval_aritmetic; -pub mod rounding; pub mod test_utils; pub mod utils; - -pub use cp_solver::ExprIntervalGraph; -pub use interval_aritmetic::*; diff --git a/datafusion/physical-expr/src/intervals/utils.rs b/datafusion/physical-expr/src/intervals/utils.rs index 7a4ccff950e6f..03d13632104dd 100644 --- a/datafusion/physical-expr/src/intervals/utils.rs +++ b/datafusion/physical-expr/src/intervals/utils.rs @@ -19,14 +19,16 @@ use std::sync::Arc; -use super::{Interval, IntervalBound}; use crate::{ expressions::{BinaryExpr, CastExpr, Column, Literal, NegativeExpr}, PhysicalExpr, }; use arrow_schema::{DataType, SchemaRef}; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::Operator; const MDN_DAY_MASK: i128 = 0xFFFF_FFFF_0000_0000_0000_0000; @@ -66,11 +68,13 @@ pub fn check_support(expr: &Arc, schema: &SchemaRef) -> bool { } // This function returns the inverse operator of the given operator. -pub fn get_inverse_op(op: Operator) -> Operator { +pub fn get_inverse_op(op: Operator) -> Result { match op { - Operator::Plus => Operator::Minus, - Operator::Minus => Operator::Plus, - _ => unreachable!(), + Operator::Plus => Ok(Operator::Minus), + Operator::Minus => Ok(Operator::Plus), + Operator::Multiply => Ok(Operator::Divide), + Operator::Divide => Ok(Operator::Multiply), + _ => internal_err!("Interval arithmetic does not support the operator {}", op), } } @@ -86,6 +90,8 @@ pub fn is_operator_supported(op: &Operator) -> bool { | &Operator::Lt | &Operator::LtEq | &Operator::Eq + | &Operator::Multiply + | &Operator::Divide ) } @@ -109,36 +115,26 @@ pub fn is_datatype_supported(data_type: &DataType) -> bool { /// Converts an [`Interval`] of time intervals to one of `Duration`s, if applicable. Otherwise, returns [`None`]. pub fn convert_interval_type_to_duration(interval: &Interval) -> Option { if let (Some(lower), Some(upper)) = ( - convert_interval_bound_to_duration(&interval.lower), - convert_interval_bound_to_duration(&interval.upper), + convert_interval_bound_to_duration(interval.lower()), + convert_interval_bound_to_duration(interval.upper()), ) { - Some(Interval::new(lower, upper)) + Interval::try_new(lower, upper).ok() } else { None } } -/// Converts an [`IntervalBound`] containing a time interval to one containing a `Duration`, if applicable. Otherwise, returns [`None`]. +/// Converts an [`ScalarValue`] containing a time interval to one containing a `Duration`, if applicable. Otherwise, returns [`None`]. fn convert_interval_bound_to_duration( - interval_bound: &IntervalBound, -) -> Option { - match interval_bound.value { - ScalarValue::IntervalMonthDayNano(Some(mdn)) => { - interval_mdn_to_duration_ns(&mdn).ok().map(|duration| { - IntervalBound::new( - ScalarValue::DurationNanosecond(Some(duration)), - interval_bound.open, - ) - }) - } - ScalarValue::IntervalDayTime(Some(dt)) => { - interval_dt_to_duration_ms(&dt).ok().map(|duration| { - IntervalBound::new( - ScalarValue::DurationMillisecond(Some(duration)), - interval_bound.open, - ) - }) - } + interval_bound: &ScalarValue, +) -> Option { + match interval_bound { + ScalarValue::IntervalMonthDayNano(Some(mdn)) => interval_mdn_to_duration_ns(mdn) + .ok() + .map(|duration| ScalarValue::DurationNanosecond(Some(duration))), + ScalarValue::IntervalDayTime(Some(dt)) => interval_dt_to_duration_ms(dt) + .ok() + .map(|duration| ScalarValue::DurationMillisecond(Some(duration))), _ => None, } } @@ -146,28 +142,32 @@ fn convert_interval_bound_to_duration( /// Converts an [`Interval`] of `Duration`s to one of time intervals, if applicable. Otherwise, returns [`None`]. pub fn convert_duration_type_to_interval(interval: &Interval) -> Option { if let (Some(lower), Some(upper)) = ( - convert_duration_bound_to_interval(&interval.lower), - convert_duration_bound_to_interval(&interval.upper), + convert_duration_bound_to_interval(interval.lower()), + convert_duration_bound_to_interval(interval.upper()), ) { - Some(Interval::new(lower, upper)) + Interval::try_new(lower, upper).ok() } else { None } } -/// Converts an [`IntervalBound`] containing a `Duration` to one containing a time interval, if applicable. Otherwise, returns [`None`]. +/// Converts a [`ScalarValue`] containing a `Duration` to one containing a time interval, if applicable. Otherwise, returns [`None`]. fn convert_duration_bound_to_interval( - interval_bound: &IntervalBound, -) -> Option { - match interval_bound.value { - ScalarValue::DurationNanosecond(Some(duration)) => Some(IntervalBound::new( - ScalarValue::new_interval_mdn(0, 0, duration), - interval_bound.open, - )), - ScalarValue::DurationMillisecond(Some(duration)) => Some(IntervalBound::new( - ScalarValue::new_interval_dt(0, duration as i32), - interval_bound.open, - )), + interval_bound: &ScalarValue, +) -> Option { + match interval_bound { + ScalarValue::DurationNanosecond(Some(duration)) => { + Some(ScalarValue::new_interval_mdn(0, 0, *duration)) + } + ScalarValue::DurationMicrosecond(Some(duration)) => { + Some(ScalarValue::new_interval_mdn(0, 0, *duration * 1000)) + } + ScalarValue::DurationMillisecond(Some(duration)) => { + Some(ScalarValue::new_interval_dt(0, *duration as i32)) + } + ScalarValue::DurationSecond(Some(duration)) => { + Some(ScalarValue::new_interval_dt(0, *duration as i32 * 1000)) + } _ => None, } } @@ -180,14 +180,13 @@ fn interval_mdn_to_duration_ns(mdn: &i128) -> Result { let nanoseconds = mdn & MDN_NS_MASK; if months == 0 && days == 0 { - nanoseconds.try_into().map_err(|_| { - DataFusionError::Internal("Resulting duration exceeds i64::MAX".to_string()) - }) + nanoseconds + .try_into() + .map_err(|_| internal_datafusion_err!("Resulting duration exceeds i64::MAX")) } else { - Err(DataFusionError::Internal( + internal_err!( "The interval cannot have a non-zero month or day value for duration convertibility" - .to_string(), - )) + ) } } @@ -200,9 +199,8 @@ fn interval_dt_to_duration_ms(dt: &i64) -> Result { if days == 0 { Ok(milliseconds) } else { - Err(DataFusionError::Internal( + internal_err!( "The interval cannot have a non-zero day value for duration convertibility" - .to_string(), - )) + ) } } diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index 96f611e2b7b49..af66862aecc5a 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -743,6 +743,18 @@ pub(super) fn create_abs_function( } } +/// abs() SQL function implementation +pub fn abs_invoke(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return internal_err!("abs function requires 1 argument, got {}", args.len()); + } + + let input_data_type = args[0].data_type(); + let abs_fun = create_abs_function(input_data_type)?; + + abs_fun(args) +} + #[cfg(test)] mod tests { @@ -757,7 +769,8 @@ mod tests { let args = vec![ColumnarValue::Array(Arc::new(NullArray::new(1)))]; let array = random(&args) .expect("failed to initialize function random") - .into_array(1); + .into_array(1) + .expect("Failed to convert to array"); let floats = as_float64_array(&array).expect("failed to initialize function random"); diff --git a/datafusion/physical-expr/src/partitioning.rs b/datafusion/physical-expr/src/partitioning.rs index cbacb7a8a906b..301f12e9aa2ea 100644 --- a/datafusion/physical-expr/src/partitioning.rs +++ b/datafusion/physical-expr/src/partitioning.rs @@ -26,7 +26,7 @@ use crate::{physical_exprs_equal, EquivalenceProperties, PhysicalExpr}; /// /// When `executed`, `ExecutionPlan`s produce one or more independent stream of /// data batches in parallel, referred to as partitions. The streams are Rust -/// `aync` [`Stream`]s (a special kind of future). The number of output +/// `async` [`Stream`]s (a special kind of future). The number of output /// partitions varies based on the input and the operation performed. /// /// For example, an `ExecutionPlan` that has output partitioning of 3 will @@ -64,7 +64,7 @@ use crate::{physical_exprs_equal, EquivalenceProperties, PhysicalExpr}; /// ``` /// /// It is common (but not required) that an `ExecutionPlan` has the same number -/// of input partitions as output partitons. However, some plans have different +/// of input partitions as output partitions. However, some plans have different /// numbers such as the `RepartitionExec` that redistributes batches from some /// number of inputs to some number of outputs /// diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index 79cbe6828b64b..a8d1e3638a177 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -20,7 +20,6 @@ use std::fmt::{Debug, Display}; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::intervals::Interval; use crate::sort_properties::SortProperties; use crate::utils::scatter; @@ -30,6 +29,7 @@ use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::utils::DataPtr; use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; +use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::ColumnarValue; use itertools::izip; @@ -95,36 +95,34 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq { /// Updates bounds for child expressions, given a known interval for this /// expression. /// - /// This is used to propagate constraints down through an - /// expression tree. + /// This is used to propagate constraints down through an expression tree. /// /// # Arguments /// /// * `interval` is the currently known interval for this expression. - /// * `children` are the current intervals for the children of this expression + /// * `children` are the current intervals for the children of this expression. /// /// # Returns /// - /// A Vec of new intervals for the children, in order. + /// A `Vec` of new intervals for the children, in order. /// - /// If constraint propagation reveals an infeasibility, returns [None] for - /// the child causing infeasibility. - /// - /// If none of the child intervals change as a result of propagation, may - /// return an empty vector instead of cloning `children`. + /// If constraint propagation reveals an infeasibility for any child, returns + /// [`None`]. If none of the children intervals change as a result of propagation, + /// may return an empty vector instead of cloning `children`. This is the default + /// (and conservative) return value. /// /// # Example /// - /// If the expression is `a + b`, the current `interval` is `[4, 5] and the - /// inputs are given [`a: [0, 2], `b: [-∞, 4]]`, then propagation would - /// would return `[a: [0, 2], b: [2, 4]]` as `b` must be at least 2 to - /// make the output at least `4`. + /// If the expression is `a + b`, the current `interval` is `[4, 5]` and the + /// inputs `a` and `b` are respectively given as `[0, 2]` and `[-∞, 4]`, then + /// propagation would would return `[0, 2]` and `[2, 4]` as `b` must be at + /// least `2` to make the output at least `4`. fn propagate_constraints( &self, _interval: &Interval, _children: &[&Interval], - ) -> Result>> { - not_impl_err!("Not implemented for {self}") + ) -> Result>> { + Ok(Some(vec![])) } /// Update the hash `state` with this expression requirements from @@ -228,14 +226,6 @@ pub fn physical_exprs_contains( .any(|physical_expr| physical_expr.eq(expr)) } -/// Checks whether the given slices have any common entries. -pub fn have_common_entries( - lhs: &[Arc], - rhs: &[Arc], -) -> bool { - lhs.iter().any(|expr| physical_exprs_contains(rhs, expr)) -} - /// Checks whether the given physical expression slices are equal. pub fn physical_exprs_equal( lhs: &[Arc], @@ -293,8 +283,8 @@ mod tests { use crate::expressions::{Column, Literal}; use crate::physical_expr::{ - deduplicate_physical_exprs, have_common_entries, physical_exprs_bag_equal, - physical_exprs_contains, physical_exprs_equal, PhysicalExpr, + deduplicate_physical_exprs, physical_exprs_bag_equal, physical_exprs_contains, + physical_exprs_equal, PhysicalExpr, }; use datafusion_common::ScalarValue; @@ -334,29 +324,6 @@ mod tests { assert!(!physical_exprs_contains(&physical_exprs, &lit1)); } - #[test] - fn test_have_common_entries() { - let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) - as Arc; - let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) - as Arc; - let lit2 = - Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; - let lit1 = - Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; - let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; - - let vec1 = vec![lit_true.clone(), lit_false.clone()]; - let vec2 = vec![lit_true.clone(), col_b_expr.clone()]; - let vec3 = vec![lit2.clone(), lit1.clone()]; - - // lit_true is common - assert!(have_common_entries(&vec1, &vec2)); - // there is no common entry - assert!(!have_common_entries(&vec1, &vec3)); - assert!(!have_common_entries(&vec2, &vec3)); - } - #[test] fn test_physical_exprs_equal() { let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 64c1d0be04558..9c212cb81f6b3 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -29,10 +29,10 @@ use datafusion_common::{ exec_err, internal_err, not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::expr::{Alias, Cast, InList, ScalarFunction, ScalarUDF}; +use datafusion_expr::expr::{Alias, Cast, InList, ScalarFunction}; use datafusion_expr::{ binary_expr, Between, BinaryExpr, Expr, GetFieldAccess, GetIndexedField, Like, - Operator, TryCast, + Operator, ScalarFunctionDefinition, TryCast, }; use std::sync::Arc; @@ -348,35 +348,37 @@ pub fn create_physical_expr( ))) } - Expr::ScalarFunction(ScalarFunction { fun, args }) => { - let physical_args = args + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + let mut physical_args = args .iter() .map(|e| { create_physical_expr(e, input_dfschema, input_schema, execution_props) }) .collect::>>()?; - functions::create_physical_expr( - fun, - &physical_args, - input_schema, - execution_props, - ) - } - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - let mut physical_args = vec![]; - for e in args { - physical_args.push(create_physical_expr( - e, - input_dfschema, - input_schema, - execution_props, - )?); - } - // udfs with zero params expect null array as input - if args.is_empty() { - physical_args.push(Arc::new(Literal::new(ScalarValue::Null))); + match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + functions::create_physical_expr( + fun, + &physical_args, + input_schema, + execution_props, + ) + } + ScalarFunctionDefinition::UDF(fun) => { + // udfs with zero params expect null array as input + if args.is_empty() { + physical_args.push(Arc::new(Literal::new(ScalarValue::Null))); + } + udf::create_physical_expr( + fun.clone().as_ref(), + &physical_args, + input_schema, + ) + } + ScalarFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } } - udf::create_physical_expr(fun.clone().as_ref(), &physical_args, input_schema) } Expr::Between(Between { expr, @@ -472,7 +474,7 @@ mod tests { ]))], )?; let result = p.evaluate(&batch)?; - let result = result.into_array(4); + let result = result.into_array(4).expect("Failed to convert to array"); assert_eq!( &result, diff --git a/datafusion/physical-expr/src/regex_expressions.rs b/datafusion/physical-expr/src/regex_expressions.rs index 41cd01949595a..b778fd86c24b1 100644 --- a/datafusion/physical-expr/src/regex_expressions.rs +++ b/datafusion/physical-expr/src/regex_expressions.rs @@ -25,8 +25,9 @@ use arrow::array::{ new_null_array, Array, ArrayDataBuilder, ArrayRef, BufferBuilder, GenericStringArray, OffsetSizeTrait, }; -use arrow::compute; -use datafusion_common::plan_err; +use arrow_array::builder::{GenericStringBuilder, ListBuilder}; +use arrow_schema::ArrowError; +use datafusion_common::{arrow_datafusion_err, plan_err}; use datafusion_common::{ cast::as_generic_string_array, internal_err, DataFusionError, Result, }; @@ -58,7 +59,7 @@ pub fn regexp_match(args: &[ArrayRef]) -> Result { 2 => { let values = as_generic_string_array::(&args[0])?; let regex = as_generic_string_array::(&args[1])?; - compute::regexp_match(values, regex, None).map_err(DataFusionError::ArrowError) + _regexp_match(values, regex, None).map_err(|e| arrow_datafusion_err!(e)) } 3 => { let values = as_generic_string_array::(&args[0])?; @@ -69,7 +70,7 @@ pub fn regexp_match(args: &[ArrayRef]) -> Result { Some(f) if f.iter().any(|s| s == Some("g")) => { plan_err!("regexp_match() does not support the \"global\" option") }, - _ => compute::regexp_match(values, regex, flags).map_err(DataFusionError::ArrowError), + _ => _regexp_match(values, regex, flags).map_err(|e| arrow_datafusion_err!(e)), } } other => internal_err!( @@ -78,6 +79,83 @@ pub fn regexp_match(args: &[ArrayRef]) -> Result { } } +/// TODO: Remove this once it is included in arrow-rs new release. +/// +fn _regexp_match( + array: &GenericStringArray, + regex_array: &GenericStringArray, + flags_array: Option<&GenericStringArray>, +) -> std::result::Result { + let mut patterns: std::collections::HashMap = + std::collections::HashMap::new(); + let builder: GenericStringBuilder = + GenericStringBuilder::with_capacity(0, 0); + let mut list_builder = ListBuilder::new(builder); + + let complete_pattern = match flags_array { + Some(flags) => Box::new(regex_array.iter().zip(flags.iter()).map( + |(pattern, flags)| { + pattern.map(|pattern| match flags { + Some(value) => format!("(?{value}){pattern}"), + None => pattern.to_string(), + }) + }, + )) as Box>>, + None => Box::new( + regex_array + .iter() + .map(|pattern| pattern.map(|pattern| pattern.to_string())), + ), + }; + + array + .iter() + .zip(complete_pattern) + .map(|(value, pattern)| { + match (value, pattern) { + // Required for Postgres compatibility: + // SELECT regexp_match('foobarbequebaz', ''); = {""} + (Some(_), Some(pattern)) if pattern == *"" => { + list_builder.values().append_value(""); + list_builder.append(true); + } + (Some(value), Some(pattern)) => { + let existing_pattern = patterns.get(&pattern); + let re = match existing_pattern { + Some(re) => re, + None => { + let re = Regex::new(pattern.as_str()).map_err(|e| { + ArrowError::ComputeError(format!( + "Regular expression did not compile: {e:?}" + )) + })?; + patterns.insert(pattern.clone(), re); + patterns.get(&pattern).unwrap() + } + }; + match re.captures(value) { + Some(caps) => { + let mut iter = caps.iter(); + if caps.len() > 1 { + iter.next(); + } + for m in iter.flatten() { + list_builder.values().append_value(m.as_str()); + } + + list_builder.append(true); + } + None => list_builder.append(false), + } + } + _ => list_builder.append(false), + } + Ok(()) + }) + .collect::, ArrowError>>()?; + Ok(Arc::new(list_builder.finish())) +} + /// replace POSIX capture groups (like \1) with Rust Regex group (like ${1}) /// used by regexp_replace fn regex_replace_posix_groups(replacement: &str) -> String { @@ -116,12 +194,12 @@ pub fn regexp_replace(args: &[ArrayRef]) -> Result // if patterns hashmap already has regexp then use else else create and return let re = match patterns.get(pattern) { - Some(re) => Ok(re.clone()), + Some(re) => Ok(re), None => { match Regex::new(pattern) { Ok(re) => { - patterns.insert(pattern.to_string(), re.clone()); - Ok(re) + patterns.insert(pattern.to_string(), re); + Ok(patterns.get(pattern).unwrap()) }, Err(err) => Err(DataFusionError::External(Box::new(err))), } @@ -162,12 +240,12 @@ pub fn regexp_replace(args: &[ArrayRef]) -> Result // if patterns hashmap already has regexp then use else else create and return let re = match patterns.get(&pattern) { - Some(re) => Ok(re.clone()), + Some(re) => Ok(re), None => { match Regex::new(pattern.as_str()) { Ok(re) => { - patterns.insert(pattern, re.clone()); - Ok(re) + patterns.insert(pattern.clone(), re); + Ok(patterns.get(&pattern).unwrap()) }, Err(err) => Err(DataFusionError::External(Box::new(err))), } diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 768aa04dd9c1b..0a9d69720e19a 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -77,14 +77,14 @@ impl ScalarFunctionExpr { name: &str, fun: ScalarFunctionImplementation, args: Vec>, - return_type: &DataType, + return_type: DataType, monotonicity: Option, ) -> Self { Self { fun, name: name.to_owned(), args, - return_type: return_type.clone(), + return_type, monotonicity, } } @@ -108,6 +108,11 @@ impl ScalarFunctionExpr { pub fn return_type(&self) -> &DataType { &self.return_type } + + /// Monotonicity information of the function + pub fn monotonicity(&self) -> &Option { + &self.monotonicity + } } impl fmt::Display for ScalarFunctionExpr { @@ -168,7 +173,7 @@ impl PhysicalExpr for ScalarFunctionExpr { &self.name, self.fun.clone(), children, - self.return_type(), + self.return_type().clone(), self.monotonicity.clone(), ))) } diff --git a/datafusion/physical-expr/src/sort_expr.rs b/datafusion/physical-expr/src/sort_expr.rs index 664a6b65b7f7b..914d76f9261a1 100644 --- a/datafusion/physical-expr/src/sort_expr.rs +++ b/datafusion/physical-expr/src/sort_expr.rs @@ -26,7 +26,7 @@ use crate::PhysicalExpr; use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::record_batch::RecordBatch; use arrow_schema::Schema; -use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_common::Result; use datafusion_expr::ColumnarValue; /// Represents Sort operation for a column in a RecordBatch @@ -65,11 +65,7 @@ impl PhysicalSortExpr { let value_to_sort = self.expr.evaluate(batch)?; let array_to_sort = match value_to_sort { ColumnarValue::Array(array) => array, - ColumnarValue::Scalar(scalar) => { - return exec_err!( - "Sort operation is not applicable to scalar value {scalar}" - ); - } + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(batch.num_rows())?, }; Ok(SortColumn { values: array_to_sort, diff --git a/datafusion/physical-expr/src/sort_properties.rs b/datafusion/physical-expr/src/sort_properties.rs index a3b201f84e9db..0205f85dced40 100644 --- a/datafusion/physical-expr/src/sort_properties.rs +++ b/datafusion/physical-expr/src/sort_properties.rs @@ -15,15 +15,14 @@ // specific language governing permissions and limitations // under the License. +use std::borrow::Cow; use std::{ops::Neg, sync::Arc}; -use crate::PhysicalExpr; - use arrow_schema::SortOptions; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; -use datafusion_common::Result; -use itertools::Itertools; +use crate::PhysicalExpr; +use datafusion_common::tree_node::TreeNode; +use datafusion_common::Result; /// To propagate [`SortOptions`] across the [`PhysicalExpr`], it is insufficient /// to simply use `Option`: There must be a differentiation between @@ -36,11 +35,12 @@ use itertools::Itertools; /// sorted data; however the ((a_ordered + 999) + c_ordered) expression can. Therefore, /// we need two different variants for literals and unordered columns as literals are /// often more ordering-friendly under most mathematical operations. -#[derive(PartialEq, Debug, Clone, Copy)] +#[derive(PartialEq, Debug, Clone, Copy, Default)] pub enum SortProperties { /// Use the ordinary [`SortOptions`] struct to represent ordered data: Ordered(SortOptions), // This alternative represents unordered data: + #[default] Unordered, // Singleton is used for single-valued literal numbers: Singleton, @@ -99,7 +99,7 @@ impl SortProperties { } } - pub fn and(&self, rhs: &Self) -> Self { + pub fn and_or(&self, rhs: &Self) -> Self { match (self, rhs) { (Self::Ordered(lhs), Self::Ordered(rhs)) if lhs.descending == rhs.descending => @@ -148,75 +148,47 @@ impl Neg for SortProperties { /// It encapsulates the orderings (`state`) associated with the expression (`expr`), and /// orderings of the children expressions (`children_states`). The [`ExprOrdering`] of a parent /// expression is determined based on the [`ExprOrdering`] states of its children expressions. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ExprOrdering { pub expr: Arc, pub state: SortProperties, - pub children_states: Vec, + pub children: Vec, } impl ExprOrdering { /// Creates a new [`ExprOrdering`] with [`SortProperties::Unordered`] states /// for `expr` and its children. pub fn new(expr: Arc) -> Self { - let size = expr.children().len(); + let children = expr.children(); Self { expr, - state: SortProperties::Unordered, - children_states: vec![SortProperties::Unordered; size], + state: Default::default(), + children: children.into_iter().map(Self::new).collect(), } } - /// Updates this [`ExprOrdering`]'s children states with the given states. - pub fn with_new_children(mut self, children_states: Vec) -> Self { - self.children_states = children_states; - self - } - - /// Creates new [`ExprOrdering`] objects for each child of the expression. - pub fn children_expr_orderings(&self) -> Vec { - self.expr - .children() - .into_iter() - .map(ExprOrdering::new) - .collect() + /// Get a reference to each child state. + pub fn children_state(&self) -> Vec { + self.children.iter().map(|c| c.state).collect() } } impl TreeNode for ExprOrdering { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in self.children_expr_orderings() { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children.iter().map(Cow::Borrowed).collect() } - fn map_children(self, transform: F) -> Result + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - if self.children_states.is_empty() { - Ok(self) - } else { - let child_expr_orderings = self.children_expr_orderings(); - // After mapping over the children, the function `F` applies to the - // current object and updates its state. - Ok(self.with_new_children( - child_expr_orderings - .into_iter() - // Update children states after this transformation: - .map(transform) - // Extract the state (i.e. sort properties) information: - .map_ok(|c| c.state) - .collect::>>()?, - )) + if !self.children.is_empty() { + self.children = self + .children + .into_iter() + .map(transform) + .collect::>()?; } + Ok(self) } } diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs index e6a3d5c331a54..7d9fecf614075 100644 --- a/datafusion/physical-expr/src/string_expressions.rs +++ b/datafusion/physical-expr/src/string_expressions.rs @@ -23,11 +23,12 @@ use arrow::{ array::{ - Array, ArrayRef, BooleanArray, GenericStringArray, Int32Array, OffsetSizeTrait, - StringArray, + Array, ArrayRef, BooleanArray, GenericStringArray, Int32Array, Int64Array, + OffsetSizeTrait, StringArray, }, datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType}, }; +use datafusion_common::utils::datafusion_strsim; use datafusion_common::{ cast::{ as_generic_string_array, as_int64_array, as_primitive_array, as_string_array, @@ -36,8 +37,11 @@ use datafusion_common::{ }; use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_expr::ColumnarValue; -use std::iter; use std::sync::Arc; +use std::{ + fmt::{Display, Formatter}, + iter, +}; use uuid::Uuid; /// applies a unary expression to `args[0]` that is expected to be downcastable to @@ -132,53 +136,6 @@ pub fn ascii(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } -/// Removes the longest string containing only characters in characters (a space by default) from the start and end of string. -/// btrim('xyxtrimyyx', 'xyz') = 'trim' -pub fn btrim(args: &[ArrayRef]) -> Result { - match args.len() { - 1 => { - let string_array = as_generic_string_array::(&args[0])?; - - let result = string_array - .iter() - .map(|string| { - string.map(|string: &str| { - string.trim_start_matches(' ').trim_end_matches(' ') - }) - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - 2 => { - let string_array = as_generic_string_array::(&args[0])?; - let characters_array = as_generic_string_array::(&args[1])?; - - let result = string_array - .iter() - .zip(characters_array.iter()) - .map(|(string, characters)| match (string, characters) { - (None, _) => None, - (_, None) => None, - (Some(string), Some(characters)) => { - let chars: Vec = characters.chars().collect(); - Some( - string - .trim_start_matches(&chars[..]) - .trim_end_matches(&chars[..]), - ) - } - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - other => internal_err!( - "btrim was called with {other} arguments. It requires at least 1 and at most 2." - ), - } -} - /// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character. /// chr(65) = 'A' pub fn chr(args: &[ArrayRef]) -> Result { @@ -345,44 +302,95 @@ pub fn lower(args: &[ColumnarValue]) -> Result { handle(args, |string| string.to_ascii_lowercase(), "lower") } -/// Removes the longest string containing only characters in characters (a space by default) from the start of string. -/// ltrim('zzzytest', 'xyz') = 'test' -pub fn ltrim(args: &[ArrayRef]) -> Result { +enum TrimType { + Left, + Right, + Both, +} + +impl Display for TrimType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + TrimType::Left => write!(f, "ltrim"), + TrimType::Right => write!(f, "rtrim"), + TrimType::Both => write!(f, "btrim"), + } + } +} + +fn general_trim( + args: &[ArrayRef], + trim_type: TrimType, +) -> Result { + let func = match trim_type { + TrimType::Left => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_start_matches::<&[char]>(input, pattern.as_ref()) + }, + TrimType::Right => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_end_matches::<&[char]>(input, pattern.as_ref()) + }, + TrimType::Both => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_end_matches::<&[char]>( + str::trim_start_matches::<&[char]>(input, pattern.as_ref()), + pattern.as_ref(), + ) + }, + }; + + let string_array = as_generic_string_array::(&args[0])?; + match args.len() { 1 => { - let string_array = as_generic_string_array::(&args[0])?; - let result = string_array .iter() - .map(|string| string.map(|string: &str| string.trim_start_matches(' '))) + .map(|string| string.map(|string: &str| func(string, " "))) .collect::>(); Ok(Arc::new(result) as ArrayRef) } 2 => { - let string_array = as_generic_string_array::(&args[0])?; let characters_array = as_generic_string_array::(&args[1])?; let result = string_array .iter() .zip(characters_array.iter()) .map(|(string, characters)| match (string, characters) { - (Some(string), Some(characters)) => { - let chars: Vec = characters.chars().collect(); - Some(string.trim_start_matches(&chars[..])) - } + (Some(string), Some(characters)) => Some(func(string, characters)), _ => None, }) .collect::>(); Ok(Arc::new(result) as ArrayRef) } - other => internal_err!( - "ltrim was called with {other} arguments. It requires at least 1 and at most 2." - ), + other => { + internal_err!( + "{trim_type} was called with {other} arguments. It requires at least 1 and at most 2." + ) + } } } +/// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed. +/// btrim('xyxtrimyyx', 'xyz') = 'trim' +pub fn btrim(args: &[ArrayRef]) -> Result { + general_trim::(args, TrimType::Both) +} + +/// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed. +/// ltrim('zzzytest', 'xyz') = 'test' +pub fn ltrim(args: &[ArrayRef]) -> Result { + general_trim::(args, TrimType::Left) +} + +/// Returns the longest string with trailing characters removed. If the characters are not specified, whitespace is removed. +/// rtrim('testxxzx', 'xyz') = 'test' +pub fn rtrim(args: &[ArrayRef]) -> Result { + general_trim::(args, TrimType::Right) +} + /// Repeats string the specified number of times. /// repeat('Pg', 4) = 'PgPgPgPg' pub fn repeat(args: &[ArrayRef]) -> Result { @@ -421,44 +429,6 @@ pub fn replace(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } -/// Removes the longest string containing only characters in characters (a space by default) from the end of string. -/// rtrim('testxxzx', 'xyz') = 'test' -pub fn rtrim(args: &[ArrayRef]) -> Result { - match args.len() { - 1 => { - let string_array = as_generic_string_array::(&args[0])?; - - let result = string_array - .iter() - .map(|string| string.map(|string: &str| string.trim_end_matches(' '))) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - 2 => { - let string_array = as_generic_string_array::(&args[0])?; - let characters_array = as_generic_string_array::(&args[1])?; - - let result = string_array - .iter() - .zip(characters_array.iter()) - .map(|(string, characters)| match (string, characters) { - (Some(string), Some(characters)) => { - let chars: Vec = characters.chars().collect(); - Some(string.trim_end_matches(&chars[..])) - } - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - other => internal_err!( - "rtrim was called with {other} arguments. It requires at least 1 and at most 2." - ), - } -} - /// Splits string at occurrences of delimiter and returns the n'th field (counting from one). /// split_part('abc~@~def~@~ghi', '~@~', 2) = 'def' pub fn split_part(args: &[ArrayRef]) -> Result { @@ -553,11 +523,149 @@ pub fn uuid(args: &[ColumnarValue]) -> Result { Ok(ColumnarValue::Array(Arc::new(array))) } +/// OVERLAY(string1 PLACING string2 FROM integer FOR integer2) +/// Replaces a substring of string1 with string2 starting at the integer bit +/// pgsql overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas +/// overlay('Txxxxas' placing 'hom' from 2) -> Thomxas, without for option, str2's len is instead +pub fn overlay(args: &[ArrayRef]) -> Result { + match args.len() { + 3 => { + let string_array = as_generic_string_array::(&args[0])?; + let characters_array = as_generic_string_array::(&args[1])?; + let pos_num = as_int64_array(&args[2])?; + + let result = string_array + .iter() + .zip(characters_array.iter()) + .zip(pos_num.iter()) + .map(|((string, characters), start_pos)| { + match (string, characters, start_pos) { + (Some(string), Some(characters), Some(start_pos)) => { + let string_len = string.chars().count(); + let characters_len = characters.chars().count(); + let replace_len = characters_len as i64; + let mut res = + String::with_capacity(string_len.max(characters_len)); + + //as sql replace index start from 1 while string index start from 0 + if start_pos > 1 && start_pos - 1 < string_len as i64 { + let start = (start_pos - 1) as usize; + res.push_str(&string[..start]); + } + res.push_str(characters); + // if start + replace_len - 1 >= string_length, just to string end + if start_pos + replace_len - 1 < string_len as i64 { + let end = (start_pos + replace_len - 1) as usize; + res.push_str(&string[end..]); + } + Ok(Some(res)) + } + _ => Ok(None), + } + }) + .collect::>>()?; + Ok(Arc::new(result) as ArrayRef) + } + 4 => { + let string_array = as_generic_string_array::(&args[0])?; + let characters_array = as_generic_string_array::(&args[1])?; + let pos_num = as_int64_array(&args[2])?; + let len_num = as_int64_array(&args[3])?; + + let result = string_array + .iter() + .zip(characters_array.iter()) + .zip(pos_num.iter()) + .zip(len_num.iter()) + .map(|(((string, characters), start_pos), len)| { + match (string, characters, start_pos, len) { + (Some(string), Some(characters), Some(start_pos), Some(len)) => { + let string_len = string.chars().count(); + let characters_len = characters.chars().count(); + let replace_len = len.min(string_len as i64); + let mut res = + String::with_capacity(string_len.max(characters_len)); + + //as sql replace index start from 1 while string index start from 0 + if start_pos > 1 && start_pos - 1 < string_len as i64 { + let start = (start_pos - 1) as usize; + res.push_str(&string[..start]); + } + res.push_str(characters); + // if start + replace_len - 1 >= string_length, just to string end + if start_pos + replace_len - 1 < string_len as i64 { + let end = (start_pos + replace_len - 1) as usize; + res.push_str(&string[end..]); + } + Ok(Some(res)) + } + _ => Ok(None), + } + }) + .collect::>>()?; + Ok(Arc::new(result) as ArrayRef) + } + other => { + internal_err!( + "overlay was called with {other} arguments. It requires 3 or 4." + ) + } + } +} + +///Returns the Levenshtein distance between the two given strings. +/// LEVENSHTEIN('kitten', 'sitting') = 3 +pub fn levenshtein(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return Err(DataFusionError::Internal(format!( + "levenshtein function requires two arguments, got {}", + args.len() + ))); + } + let str1_array = as_generic_string_array::(&args[0])?; + let str2_array = as_generic_string_array::(&args[1])?; + match args[0].data_type() { + DataType::Utf8 => { + let result = str1_array + .iter() + .zip(str2_array.iter()) + .map(|(string1, string2)| match (string1, string2) { + (Some(string1), Some(string2)) => { + Some(datafusion_strsim::levenshtein(string1, string2) as i32) + } + _ => None, + }) + .collect::(); + Ok(Arc::new(result) as ArrayRef) + } + DataType::LargeUtf8 => { + let result = str1_array + .iter() + .zip(str2_array.iter()) + .map(|(string1, string2)| match (string1, string2) { + (Some(string1), Some(string2)) => { + Some(datafusion_strsim::levenshtein(string1, string2) as i64) + } + _ => None, + }) + .collect::(); + Ok(Arc::new(result) as ArrayRef) + } + other => { + internal_err!( + "levenshtein was called with {other} datatype arguments. It requires Utf8 or LargeUtf8." + ) + } + } +} + #[cfg(test)] mod tests { use crate::string_expressions; use arrow::{array::Int32Array, datatypes::Int32Type}; + use arrow_array::Int64Array; + use datafusion_common::cast::as_int32_array; use super::*; @@ -599,4 +707,36 @@ mod tests { Ok(()) } + + #[test] + fn to_overlay() -> Result<()> { + let string = + Arc::new(StringArray::from(vec!["123", "abcdefg", "xyz", "Txxxxas"])); + let replace_string = + Arc::new(StringArray::from(vec!["abc", "qwertyasdfg", "ijk", "hom"])); + let start = Arc::new(Int64Array::from(vec![4, 1, 1, 2])); // start + let end = Arc::new(Int64Array::from(vec![5, 7, 2, 4])); // replace len + + let res = overlay::(&[string, replace_string, start, end]).unwrap(); + let result = as_generic_string_array::(&res).unwrap(); + let expected = StringArray::from(vec!["abc", "qwertyasdfg", "ijkz", "Thomas"]); + assert_eq!(&expected, result); + + Ok(()) + } + + #[test] + fn to_levenshtein() -> Result<()> { + let string1_array = + Arc::new(StringArray::from(vec!["123", "abc", "xyz", "kitten"])); + let string2_array = + Arc::new(StringArray::from(vec!["321", "def", "zyx", "sitting"])); + let res = levenshtein::(&[string1_array, string2_array]).unwrap(); + let result = + as_int32_array(&res).expect("failed to initialized function levenshtein"); + let expected = Int32Array::from(vec![2, 3, 2, 3]); + assert_eq!(&expected, result); + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/struct_expressions.rs b/datafusion/physical-expr/src/struct_expressions.rs index baa29d668e902..b0ccb2a3ccb68 100644 --- a/datafusion/physical-expr/src/struct_expressions.rs +++ b/datafusion/physical-expr/src/struct_expressions.rs @@ -18,8 +18,8 @@ //! Struct expressions use arrow::array::*; -use arrow::datatypes::{DataType, Field}; -use datafusion_common::{exec_err, not_impl_err, DataFusionError, Result}; +use arrow::datatypes::Field; +use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr::ColumnarValue; use std::sync::Arc; @@ -34,31 +34,14 @@ fn array_struct(args: &[ArrayRef]) -> Result { .enumerate() .map(|(i, arg)| { let field_name = format!("c{i}"); - match arg.data_type() { - DataType::Utf8 - | DataType::LargeUtf8 - | DataType::Boolean - | DataType::Float32 - | DataType::Float64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 => Ok(( - Arc::new(Field::new( - field_name.as_str(), - arg.data_type().clone(), - true, - )), - arg.clone(), + Ok(( + Arc::new(Field::new( + field_name.as_str(), + arg.data_type().clone(), + true, )), - data_type => { - not_impl_err!("Struct is not implemented for type '{data_type:?}'.") - } - } + arg.clone(), + )) }) .collect::>>()?; @@ -67,13 +50,15 @@ fn array_struct(args: &[ArrayRef]) -> Result { /// put values in a struct array. pub fn struct_expr(values: &[ColumnarValue]) -> Result { - let arrays: Vec = values + let arrays = values .iter() - .map(|x| match x { - ColumnarValue::Array(array) => array.clone(), - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), + .map(|x| { + Ok(match x { + ColumnarValue::Array(array) => array.clone(), + ColumnarValue::Scalar(scalar) => scalar.to_array()?.clone(), + }) }) - .collect(); + .collect::>>()?; Ok(ColumnarValue::Array(array_struct(arrays.as_slice())?)) } @@ -93,7 +78,8 @@ mod tests { ]; let struc = struct_expr(&args) .expect("failed to initialize function struct") - .into_array(1); + .into_array(1) + .expect("Failed to convert to array"); let result = as_struct_array(&struc).expect("failed to initialize function struct"); assert_eq!( diff --git a/datafusion/physical-expr/src/udf.rs b/datafusion/physical-expr/src/udf.rs index af1e77cbf566d..0ec1cf3f256b0 100644 --- a/datafusion/physical-expr/src/udf.rs +++ b/datafusion/physical-expr/src/udf.rs @@ -35,10 +35,10 @@ pub fn create_physical_expr( .collect::>>()?; Ok(Arc::new(ScalarFunctionExpr::new( - &fun.name, - fun.fun.clone(), + fun.name(), + fun.fun().clone(), input_phy_exprs.to_vec(), - (fun.return_type)(&input_exprs_types)?.as_ref(), + fun.return_type(&input_exprs_types)?, None, ))) } diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs index e28700a25ce47..240efe4223c33 100644 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ b/datafusion/physical-expr/src/unicode_expressions.rs @@ -455,3 +455,107 @@ pub fn translate(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } + +/// Returns the substring from str before count occurrences of the delimiter delim. If count is positive, everything to the left of the final delimiter (counting from the left) is returned. If count is negative, everything to the right of the final delimiter (counting from the right) is returned. +/// SUBSTRING_INDEX('www.apache.org', '.', 1) = www +/// SUBSTRING_INDEX('www.apache.org', '.', 2) = www.apache +/// SUBSTRING_INDEX('www.apache.org', '.', -2) = apache.org +/// SUBSTRING_INDEX('www.apache.org', '.', -1) = org +pub fn substr_index(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return internal_err!( + "substr_index was called with {} arguments. It requires 3.", + args.len() + ); + } + + let string_array = as_generic_string_array::(&args[0])?; + let delimiter_array = as_generic_string_array::(&args[1])?; + let count_array = as_int64_array(&args[2])?; + + let result = string_array + .iter() + .zip(delimiter_array.iter()) + .zip(count_array.iter()) + .map(|((string, delimiter), n)| match (string, delimiter, n) { + (Some(string), Some(delimiter), Some(n)) => { + let mut res = String::new(); + match n { + 0 => { + "".to_string(); + } + _other => { + if n > 0 { + let idx = string + .split(delimiter) + .take(n as usize) + .fold(0, |len, x| len + x.len() + delimiter.len()) + - delimiter.len(); + res.push_str(if idx >= string.len() { + string + } else { + &string[..idx] + }); + } else { + let idx = (string.split(delimiter).take((-n) as usize).fold( + string.len() as isize, + |len, x| { + len - x.len() as isize - delimiter.len() as isize + }, + ) + delimiter.len() as isize) + as usize; + res.push_str(if idx >= string.len() { + string + } else { + &string[idx..] + }); + } + } + } + Some(res) + } + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +///Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings +///A string list is a string composed of substrings separated by , characters. +pub fn find_in_set(args: &[ArrayRef]) -> Result +where + T::Native: OffsetSizeTrait, +{ + if args.len() != 2 { + return internal_err!( + "find_in_set was called with {} arguments. It requires 2.", + args.len() + ); + } + + let str_array: &GenericStringArray = + as_generic_string_array::(&args[0])?; + let str_list_array: &GenericStringArray = + as_generic_string_array::(&args[1])?; + + let result = str_array + .iter() + .zip(str_list_array.iter()) + .map(|(string, str_list)| match (string, str_list) { + (Some(string), Some(str_list)) => { + let mut res = 0; + let str_set: Vec<&str> = str_list.split(',').collect(); + for (idx, str) in str_set.iter().enumerate() { + if str == &string { + res = idx + 1; + break; + } + } + T::Native::from_usize(res) + } + _ => None, + }) + .collect::>(); + Ok(Arc::new(result) as ArrayRef) +} diff --git a/datafusion/physical-expr/src/utils/guarantee.rs b/datafusion/physical-expr/src/utils/guarantee.rs new file mode 100644 index 0000000000000..0aee2af67fdd3 --- /dev/null +++ b/datafusion/physical-expr/src/utils/guarantee.rs @@ -0,0 +1,856 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`LiteralGuarantee`] predicate analysis to determine if a column is a +//! constant. + +use crate::utils::split_disjunction; +use crate::{split_conjunction, PhysicalExpr}; +use datafusion_common::{Column, ScalarValue}; +use datafusion_expr::Operator; +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; + +/// Represents a guarantee that must be true for a boolean expression to +/// evaluate to `true`. +/// +/// The guarantee takes the form of a column and a set of literal (constant) +/// [`ScalarValue`]s. For the expression to evaluate to `true`, the column *must +/// satisfy* the guarantee(s). +/// +/// To satisfy the guarantee, depending on [`Guarantee`], the values in the +/// column must either: +/// +/// 1. be ONLY one of that set +/// 2. NOT be ANY of that set +/// +/// # Uses `LiteralGuarantee`s +/// +/// `LiteralGuarantee`s can be used to simplify filter expressions and skip data +/// files (e.g. row groups in parquet files) by proving expressions can not +/// possibly evaluate to `true`. For example, if we have a guarantee that `a` +/// must be in (`1`) for a filter to evaluate to `true`, then we can skip any +/// partition where we know that `a` never has the value of `1`. +/// +/// **Important**: If a `LiteralGuarantee` is not satisfied, the relevant +/// expression is *guaranteed* to evaluate to `false` or `null`. **However**, +/// the opposite does not hold. Even if all `LiteralGuarantee`s are satisfied, +/// that does **not** guarantee that the predicate will actually evaluate to +/// `true`: it may still evaluate to `true`, `false` or `null`. +/// +/// # Creating `LiteralGuarantee`s +/// +/// Use [`LiteralGuarantee::analyze`] to extract literal guarantees from a +/// filter predicate. +/// +/// # Details +/// A guarantee can be one of two forms: +/// +/// 1. The column must be one the values for the predicate to be `true`. If the +/// column takes on any other value, the predicate can not evaluate to `true`. +/// For example, +/// `(a = 1)`, `(a = 1 OR a = 2) or `a IN (1, 2, 3)` +/// +/// 2. The column must NOT be one of the values for the predicate to be `true`. +/// If the column can ONLY take one of these values, the predicate can not +/// evaluate to `true`. For example, +/// `(a != 1)`, `(a != 1 AND a != 2)` or `a NOT IN (1, 2, 3)` +#[derive(Debug, Clone, PartialEq)] +pub struct LiteralGuarantee { + pub column: Column, + pub guarantee: Guarantee, + pub literals: HashSet, +} + +/// What is guaranteed about the values for a [`LiteralGuarantee`]? +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Guarantee { + /// Guarantee that the expression is `true` if `column` is one of the values. If + /// `column` is not one of the values, the expression can not be `true`. + In, + /// Guarantee that the expression is `true` if `column` is not ANY of the + /// values. If `column` only takes one of these values, the expression can + /// not be `true`. + NotIn, +} + +impl LiteralGuarantee { + /// Create a new instance of the guarantee if the provided operator is + /// supported. Returns None otherwise. See [`LiteralGuarantee::analyze`] to + /// create these structures from an predicate (boolean expression). + fn try_new<'a>( + column_name: impl Into, + guarantee: Guarantee, + literals: impl IntoIterator, + ) -> Option { + let literals: HashSet<_> = literals.into_iter().cloned().collect(); + + Some(Self { + column: Column::from_name(column_name), + guarantee, + literals, + }) + } + + /// Return a list of [`LiteralGuarantee`]s that must be satisfied for `expr` + /// to evaluate to `true`. + /// + /// If more than one `LiteralGuarantee` is returned, they must **all** hold + /// for the expression to possibly be `true`. If any is not satisfied, the + /// expression is guaranteed to be `null` or `false`. + /// + /// # Notes: + /// 1. `expr` must be a boolean expression or inlist expression. + /// 2. `expr` is not simplified prior to analysis. + pub fn analyze(expr: &Arc) -> Vec { + // split conjunction: AND AND ... + split_conjunction(expr) + .into_iter() + // for an `AND` conjunction to be true, all terms individually must be true + .fold(GuaranteeBuilder::new(), |builder, expr| { + if let Some(cel) = ColOpLit::try_new(expr) { + return builder.aggregate_conjunct(cel); + } else if let Some(inlist) = expr + .as_any() + .downcast_ref::() + { + // Only support single-column inlist currently, multi-column inlist is not supported + let col = inlist + .expr() + .as_any() + .downcast_ref::(); + let Some(col) = col else { + return builder; + }; + + let literals = inlist + .list() + .iter() + .map(|e| e.as_any().downcast_ref::()) + .collect::>>(); + let Some(literals) = literals else { + return builder; + }; + + let guarantee = if inlist.negated() { + Guarantee::NotIn + } else { + Guarantee::In + }; + + builder.aggregate_multi_conjunct( + col, + guarantee, + literals.iter().map(|e| e.value()), + ) + } else { + // split disjunction: OR OR ... + let disjunctions = split_disjunction(expr); + + // We are trying to add a guarantee that a column must be + // in/not in a particular set of values for the expression + // to evaluate to true. + // + // A disjunction is true, if at least one of the terms is be + // true. + // + // Thus, we can infer a guarantee if all terms are of the + // form `(col literal) OR (col literal) OR ...`. + // + // For example, we can infer that `a = 1 OR a = 2 OR a = 3` + // is guaranteed to be true ONLY if a is in (`1`, `2` or `3`). + // + // However, for something like `a = 1 OR a = 2 OR a < 0` we + // **can't** guarantee that the predicate is only true if a + // is in (`1`, `2`), as it could also be true if `a` were less + // than zero. + let terms = disjunctions + .iter() + .filter_map(|expr| ColOpLit::try_new(expr)) + .collect::>(); + + if terms.is_empty() { + return builder; + } + + // if not all terms are of the form (col literal), + // can't infer any guarantees + if terms.len() != disjunctions.len() { + return builder; + } + + // if all terms are 'col literal' with the same column + // and operation we can infer any guarantees + // + // For those like (a != foo AND (a != bar OR a != baz)). + // We can't combine the (a != bar OR a != baz) part, but + // it also doesn't invalidate our knowledge that a != + // foo is required for the expression to be true. + // So we can only create a multi value guarantee for `=` + // (or a single value). (e.g. ignore `a != foo OR a != bar`) + let first_term = &terms[0]; + if terms.iter().all(|term| { + term.col.name() == first_term.col.name() + && term.guarantee == Guarantee::In + }) { + builder.aggregate_multi_conjunct( + first_term.col, + Guarantee::In, + terms.iter().map(|term| term.lit.value()), + ) + } else { + // can't infer anything + builder + } + } + }) + .build() + } +} + +/// Combines conjuncts (aka terms `AND`ed together) into [`LiteralGuarantee`]s, +/// preserving insert order +#[derive(Debug, Default)] +struct GuaranteeBuilder<'a> { + /// List of guarantees that have been created so far + /// if we have determined a subsequent conjunct invalidates a guarantee + /// e.g. `a = foo AND a = bar` then the relevant guarantee will be None + guarantees: Vec>, + + /// Key is the (column name, guarantee type) + /// Value is the index into `guarantees` + map: HashMap<(&'a crate::expressions::Column, Guarantee), usize>, +} + +impl<'a> GuaranteeBuilder<'a> { + fn new() -> Self { + Default::default() + } + + /// Aggregate a new single `AND col literal` term to this builder + /// combining with existing guarantees if possible. + /// + /// # Examples + /// * `AND (a = 1)`: `a` is guaranteed to be 1 + /// * `AND (a != 1)`: a is guaranteed to not be 1 + fn aggregate_conjunct(self, col_op_lit: ColOpLit<'a>) -> Self { + self.aggregate_multi_conjunct( + col_op_lit.col, + col_op_lit.guarantee, + [col_op_lit.lit.value()], + ) + } + + /// Aggregates a new single column, multi literal term to ths builder + /// combining with previously known guarantees if possible. + /// + /// # Examples + /// For the following examples, we can guarantee the expression is `true` if: + /// * `AND (a = 1 OR a = 2 OR a = 3)`: a is in (1, 2, or 3) + /// * `AND (a IN (1,2,3))`: a is in (1, 2, or 3) + /// * `AND (a != 1 OR a != 2 OR a != 3)`: a is not in (1, 2, or 3) + /// * `AND (a NOT IN (1,2,3))`: a is not in (1, 2, or 3) + fn aggregate_multi_conjunct( + mut self, + col: &'a crate::expressions::Column, + guarantee: Guarantee, + new_values: impl IntoIterator, + ) -> Self { + let key = (col, guarantee); + if let Some(index) = self.map.get(&key) { + // already have a guarantee for this column + let entry = &mut self.guarantees[*index]; + + let Some(existing) = entry else { + // determined the previous guarantee for this column has been + // invalidated, nothing to do + return self; + }; + + // Combine conjuncts if we have `a != foo AND a != bar`. `a = foo + // AND a = bar` doesn't make logical sense so we don't optimize this + // case + match existing.guarantee { + // knew that the column could not be a set of values + // + // For example, if we previously had `a != 5` and now we see + // another `AND a != 6` we know that a must not be either 5 or 6 + // for the expression to be true + Guarantee::NotIn => { + let new_values: HashSet<_> = new_values.into_iter().collect(); + existing.literals.extend(new_values.into_iter().cloned()); + } + Guarantee::In => { + let intersection = new_values + .into_iter() + .filter(|new_value| existing.literals.contains(*new_value)) + .collect::>(); + // for an In guarantee, if the intersection is not empty, we can extend the guarantee + // e.g. `a IN (1,2,3) AND a IN (2,3,4)` is `a IN (2,3)` + // otherwise, we invalidate the guarantee + // e.g. `a IN (1,2,3) AND a IN (4,5,6)` is `a IN ()`, which is invalid + if !intersection.is_empty() { + existing.literals = intersection.into_iter().cloned().collect(); + } else { + // at least one was not, so invalidate the guarantee + *entry = None; + } + } + } + } else { + // This is a new guarantee + let new_values: HashSet<_> = new_values.into_iter().collect(); + + if let Some(guarantee) = + LiteralGuarantee::try_new(col.name(), guarantee, new_values) + { + // add it to the list of guarantees + self.guarantees.push(Some(guarantee)); + self.map.insert(key, self.guarantees.len() - 1); + } + } + + self + } + + /// Return all guarantees that have been created so far + fn build(self) -> Vec { + // filter out any guarantees that have been invalidated + self.guarantees.into_iter().flatten().collect() + } +} + +/// Represents a single `col [not]in literal` expression +struct ColOpLit<'a> { + col: &'a crate::expressions::Column, + guarantee: Guarantee, + lit: &'a crate::expressions::Literal, +} + +impl<'a> ColOpLit<'a> { + /// Returns Some(ColEqLit) if the expression is either: + /// 1. `col literal` + /// 2. `literal col` + /// 3. operator is `=` or `!=` + /// Returns None otherwise + fn try_new(expr: &'a Arc) -> Option { + let binary_expr = expr + .as_any() + .downcast_ref::()?; + + let (left, op, right) = ( + binary_expr.left().as_any(), + binary_expr.op(), + binary_expr.right().as_any(), + ); + let guarantee = match op { + Operator::Eq => Guarantee::In, + Operator::NotEq => Guarantee::NotIn, + _ => return None, + }; + // col literal + if let (Some(col), Some(lit)) = ( + left.downcast_ref::(), + right.downcast_ref::(), + ) { + Some(Self { + col, + guarantee, + lit, + }) + } + // literal col + else if let (Some(lit), Some(col)) = ( + left.downcast_ref::(), + right.downcast_ref::(), + ) { + Some(Self { + col, + guarantee, + lit, + }) + } else { + None + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::create_physical_expr; + use crate::execution_props::ExecutionProps; + use arrow_schema::{DataType, Field, Schema, SchemaRef}; + use datafusion_common::ToDFSchema; + use datafusion_expr::expr_fn::*; + use datafusion_expr::{lit, Expr}; + use std::sync::OnceLock; + + #[test] + fn test_literal() { + // a single literal offers no guarantee + test_analyze(lit(true), vec![]) + } + + #[test] + fn test_single() { + // a = "foo" + test_analyze(col("a").eq(lit("foo")), vec![in_guarantee("a", ["foo"])]); + // "foo" = a + test_analyze(lit("foo").eq(col("a")), vec![in_guarantee("a", ["foo"])]); + // a != "foo" + test_analyze( + col("a").not_eq(lit("foo")), + vec![not_in_guarantee("a", ["foo"])], + ); + // "foo" != a + test_analyze( + lit("foo").not_eq(col("a")), + vec![not_in_guarantee("a", ["foo"])], + ); + } + + #[test] + fn test_conjunction_single_column() { + // b = 1 AND b = 2. This is impossible. Ideally this expression could be simplified to false + test_analyze(col("b").eq(lit(1)).and(col("b").eq(lit(2))), vec![]); + // b = 1 AND b != 2 . In theory, this could be simplified to `b = 1`. + test_analyze( + col("b").eq(lit(1)).and(col("b").not_eq(lit(2))), + vec![ + // can only be true of b is 1 and b is not 2 (even though it is redundant) + in_guarantee("b", [1]), + not_in_guarantee("b", [2]), + ], + ); + // b != 1 AND b = 2. In theory, this could be simplified to `b = 2`. + test_analyze( + col("b").not_eq(lit(1)).and(col("b").eq(lit(2))), + vec![ + // can only be true of b is not 1 and b is is 2 (even though it is redundant) + not_in_guarantee("b", [1]), + in_guarantee("b", [2]), + ], + ); + // b != 1 AND b != 2 + test_analyze( + col("b").not_eq(lit(1)).and(col("b").not_eq(lit(2))), + vec![not_in_guarantee("b", [1, 2])], + ); + // b != 1 AND b != 2 and b != 3 + test_analyze( + col("b") + .not_eq(lit(1)) + .and(col("b").not_eq(lit(2))) + .and(col("b").not_eq(lit(3))), + vec![not_in_guarantee("b", [1, 2, 3])], + ); + // b != 1 AND b = 2 and b != 3. Can only be true if b is 2 and b is not in (1, 3) + test_analyze( + col("b") + .not_eq(lit(1)) + .and(col("b").eq(lit(2))) + .and(col("b").not_eq(lit(3))), + vec![not_in_guarantee("b", [1, 3]), in_guarantee("b", [2])], + ); + // b != 1 AND b != 2 and b = 3 (in theory could determine b = 3) + test_analyze( + col("b") + .not_eq(lit(1)) + .and(col("b").not_eq(lit(2))) + .and(col("b").eq(lit(3))), + vec![not_in_guarantee("b", [1, 2]), in_guarantee("b", [3])], + ); + // b != 1 AND b != 2 and b > 3 (to be true, b can't be either 1 or 2 + test_analyze( + col("b") + .not_eq(lit(1)) + .and(col("b").not_eq(lit(2))) + .and(col("b").gt(lit(3))), + vec![not_in_guarantee("b", [1, 2])], + ); + } + + #[test] + fn test_conjunction_multi_column() { + // a = "foo" AND b = 1 + test_analyze( + col("a").eq(lit("foo")).and(col("b").eq(lit(1))), + vec![ + // should find both column guarantees + in_guarantee("a", ["foo"]), + in_guarantee("b", [1]), + ], + ); + // a != "foo" AND b != 1 + test_analyze( + col("a").not_eq(lit("foo")).and(col("b").not_eq(lit(1))), + // should find both column guarantees + vec![not_in_guarantee("a", ["foo"]), not_in_guarantee("b", [1])], + ); + // a = "foo" AND a = "bar" + test_analyze( + col("a").eq(lit("foo")).and(col("a").eq(lit("bar"))), + // this predicate is impossible ( can't be both foo and bar), + vec![], + ); + // a = "foo" AND b != "bar" + test_analyze( + col("a").eq(lit("foo")).and(col("a").not_eq(lit("bar"))), + vec![in_guarantee("a", ["foo"]), not_in_guarantee("a", ["bar"])], + ); + // a != "foo" AND a != "bar" + test_analyze( + col("a").not_eq(lit("foo")).and(col("a").not_eq(lit("bar"))), + // know it isn't "foo" or "bar" + vec![not_in_guarantee("a", ["foo", "bar"])], + ); + // a != "foo" AND a != "bar" and a != "baz" + test_analyze( + col("a") + .not_eq(lit("foo")) + .and(col("a").not_eq(lit("bar"))) + .and(col("a").not_eq(lit("baz"))), + // know it isn't "foo" or "bar" or "baz" + vec![not_in_guarantee("a", ["foo", "bar", "baz"])], + ); + // a = "foo" AND a = "foo" + let expr = col("a").eq(lit("foo")); + test_analyze(expr.clone().and(expr), vec![in_guarantee("a", ["foo"])]); + // b > 5 AND b = 10 (should get an b = 10 guarantee) + test_analyze( + col("b").gt(lit(5)).and(col("b").eq(lit(10))), + vec![in_guarantee("b", [10])], + ); + // b > 10 AND b = 10 (this is impossible) + test_analyze( + col("b").gt(lit(10)).and(col("b").eq(lit(10))), + vec![ + // if b isn't 10, it can not be true (though the expression actually can never be true) + in_guarantee("b", [10]), + ], + ); + // a != "foo" and (a != "bar" OR a != "baz") + test_analyze( + col("a") + .not_eq(lit("foo")) + .and(col("a").not_eq(lit("bar")).or(col("a").not_eq(lit("baz")))), + // a is not foo (we can't represent other knowledge about a) + vec![not_in_guarantee("a", ["foo"])], + ); + } + + #[test] + fn test_conjunction_and_disjunction_single_column() { + // b != 1 AND (b > 2) + test_analyze( + col("b").not_eq(lit(1)).and(col("b").gt(lit(2))), + vec![ + // for the expression to be true, b can not be one + not_in_guarantee("b", [1]), + ], + ); + + // b = 1 AND (b = 2 OR b = 3). Could be simplified to false. + test_analyze( + col("b") + .eq(lit(1)) + .and(col("b").eq(lit(2)).or(col("b").eq(lit(3)))), + vec![ + // in theory, b must be 1 and one of 2,3 for this expression to be true + // which is a logical contradiction + ], + ); + } + + #[test] + fn test_disjunction_single_column() { + // b = 1 OR b = 2 + test_analyze( + col("b").eq(lit(1)).or(col("b").eq(lit(2))), + vec![in_guarantee("b", [1, 2])], + ); + // b != 1 OR b = 2 + test_analyze(col("b").not_eq(lit(1)).or(col("b").eq(lit(2))), vec![]); + // b = 1 OR b != 2 + test_analyze(col("b").eq(lit(1)).or(col("b").not_eq(lit(2))), vec![]); + // b != 1 OR b != 2 + test_analyze(col("b").not_eq(lit(1)).or(col("b").not_eq(lit(2))), vec![]); + // b != 1 OR b != 2 OR b = 3 -- in theory could guarantee that b = 3 + test_analyze( + col("b") + .not_eq(lit(1)) + .or(col("b").not_eq(lit(2))) + .or(lit("b").eq(lit(3))), + vec![], + ); + // b = 1 OR b = 2 OR b = 3 + test_analyze( + col("b") + .eq(lit(1)) + .or(col("b").eq(lit(2))) + .or(col("b").eq(lit(3))), + vec![in_guarantee("b", [1, 2, 3])], + ); + // b = 1 OR b = 2 OR b > 3 -- can't guarantee that the expression is only true if a is in (1, 2) + test_analyze( + col("b") + .eq(lit(1)) + .or(col("b").eq(lit(2))) + .or(lit("b").eq(lit(3))), + vec![], + ); + } + + #[test] + fn test_disjunction_multi_column() { + // a = "foo" OR b = 1 + test_analyze( + col("a").eq(lit("foo")).or(col("b").eq(lit(1))), + // no can't have a single column guarantee (if a = "foo" then b != 1) etc + vec![], + ); + // a != "foo" OR b != 1 + test_analyze( + col("a").not_eq(lit("foo")).or(col("b").not_eq(lit(1))), + // No single column guarantee + vec![], + ); + // a = "foo" OR a = "bar" + test_analyze( + col("a").eq(lit("foo")).or(col("a").eq(lit("bar"))), + vec![in_guarantee("a", ["foo", "bar"])], + ); + // a = "foo" OR a = "foo" + test_analyze( + col("a").eq(lit("foo")).or(col("a").eq(lit("foo"))), + vec![in_guarantee("a", ["foo"])], + ); + // a != "foo" OR a != "bar" + test_analyze( + col("a").not_eq(lit("foo")).or(col("a").not_eq(lit("bar"))), + // can't represent knowledge about a in this case + vec![], + ); + // a = "foo" OR a = "bar" OR a = "baz" + test_analyze( + col("a") + .eq(lit("foo")) + .or(col("a").eq(lit("bar"))) + .or(col("a").eq(lit("baz"))), + vec![in_guarantee("a", ["foo", "bar", "baz"])], + ); + // (a = "foo" OR a = "bar") AND (a = "baz)" + test_analyze( + (col("a").eq(lit("foo")).or(col("a").eq(lit("bar")))) + .and(col("a").eq(lit("baz"))), + // this could potentially be represented as 2 constraints with a more + // sophisticated analysis + vec![], + ); + // (a = "foo" OR a = "bar") AND (b = 1) + test_analyze( + (col("a").eq(lit("foo")).or(col("a").eq(lit("bar")))) + .and(col("b").eq(lit(1))), + vec![in_guarantee("a", ["foo", "bar"]), in_guarantee("b", [1])], + ); + // (a = "foo" OR a = "bar") OR (b = 1) + test_analyze( + col("a") + .eq(lit("foo")) + .or(col("a").eq(lit("bar"))) + .or(col("b").eq(lit(1))), + // can't represent knowledge about a or b in this case + vec![], + ); + } + + #[test] + fn test_single_inlist() { + // b IN (1, 2, 3) + test_analyze( + col("b").in_list(vec![lit(1), lit(2), lit(3)], false), + vec![in_guarantee("b", [1, 2, 3])], + ); + // b NOT IN (1, 2, 3) + test_analyze( + col("b").in_list(vec![lit(1), lit(2), lit(3)], true), + vec![not_in_guarantee("b", [1, 2, 3])], + ); + } + + #[test] + fn test_inlist_conjunction() { + // b IN (1, 2, 3) AND b IN (2, 3, 4) + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .and(col("b").in_list(vec![lit(2), lit(3), lit(4)], false)), + vec![in_guarantee("b", [2, 3])], + ); + // b NOT IN (1, 2, 3) AND b IN (2, 3, 4) + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], true) + .and(col("b").in_list(vec![lit(2), lit(3), lit(4)], false)), + vec![ + not_in_guarantee("b", [1, 2, 3]), + in_guarantee("b", [2, 3, 4]), + ], + ); + // b NOT IN (1, 2, 3) AND b NOT IN (2, 3, 4) + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], true) + .and(col("b").in_list(vec![lit(2), lit(3), lit(4)], true)), + vec![not_in_guarantee("b", [1, 2, 3, 4])], + ); + // b IN (1, 2, 3) AND b = 4 + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .and(col("b").eq(lit(4))), + vec![], + ); + // b IN (1, 2, 3) AND b = 2 + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .and(col("b").eq(lit(2))), + vec![in_guarantee("b", [2])], + ); + // b IN (1, 2, 3) AND b != 2 + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .and(col("b").not_eq(lit(2))), + vec![in_guarantee("b", [1, 2, 3]), not_in_guarantee("b", [2])], + ); + // b NOT IN (1, 2, 3) AND b != 4 + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], true) + .and(col("b").not_eq(lit(4))), + vec![not_in_guarantee("b", [1, 2, 3, 4])], + ); + // b NOT IN (1, 2, 3) AND b != 2 + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], true) + .and(col("b").not_eq(lit(2))), + vec![not_in_guarantee("b", [1, 2, 3])], + ); + } + + #[test] + fn test_inlist_with_disjunction() { + // b IN (1, 2, 3) AND (b = 3 OR b = 4) + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .and(col("b").eq(lit(3)).or(col("b").eq(lit(4)))), + vec![in_guarantee("b", [3])], + ); + // b IN (1, 2, 3) AND (b = 4 OR b = 5) + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .and(col("b").eq(lit(4)).or(col("b").eq(lit(5)))), + vec![], + ); + // b NOT IN (1, 2, 3) AND (b = 3 OR b = 4) + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], true) + .and(col("b").eq(lit(3)).or(col("b").eq(lit(4)))), + vec![not_in_guarantee("b", [1, 2, 3]), in_guarantee("b", [3, 4])], + ); + // b IN (1, 2, 3) OR b = 2 + // TODO this should be in_guarantee("b", [1, 2, 3]) but currently we don't support to anylize this kind of disjunction. Only `ColOpLit OR ColOpLit` is supported. + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .or(col("b").eq(lit(2))), + vec![], + ); + // b IN (1, 2, 3) OR b != 3 + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .or(col("b").not_eq(lit(3))), + vec![], + ); + } + + /// Tests that analyzing expr results in the expected guarantees + fn test_analyze(expr: Expr, expected: Vec) { + println!("Begin analyze of {expr}"); + let schema = schema(); + let physical_expr = logical2physical(&expr, &schema); + + let actual = LiteralGuarantee::analyze(&physical_expr); + assert_eq!( + expected, actual, + "expr: {expr}\ + \n\nexpected: {expected:#?}\ + \n\nactual: {actual:#?}\ + \n\nexpr: {expr:#?}\ + \n\nphysical_expr: {physical_expr:#?}" + ); + } + + /// Guarantee that the expression is true if the column is one of the specified values + fn in_guarantee<'a, I, S>(column: &str, literals: I) -> LiteralGuarantee + where + I: IntoIterator, + S: Into + 'a, + { + let literals: Vec<_> = literals.into_iter().map(|s| s.into()).collect(); + LiteralGuarantee::try_new(column, Guarantee::In, literals.iter()).unwrap() + } + + /// Guarantee that the expression is true if the column is NOT any of the specified values + fn not_in_guarantee<'a, I, S>(column: &str, literals: I) -> LiteralGuarantee + where + I: IntoIterator, + S: Into + 'a, + { + let literals: Vec<_> = literals.into_iter().map(|s| s.into()).collect(); + LiteralGuarantee::try_new(column, Guarantee::NotIn, literals.iter()).unwrap() + } + + /// Convert a logical expression to a physical expression (without any simplification, etc) + fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { + let df_schema = schema.clone().to_dfschema().unwrap(); + let execution_props = ExecutionProps::new(); + create_physical_expr(expr, &df_schema, schema, &execution_props).unwrap() + } + + // Schema for testing + fn schema() -> SchemaRef { + SCHEMA + .get_or_init(|| { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Int32, false), + ])) + }) + .clone() + } + + static SCHEMA: OnceLock = OnceLock::new(); +} diff --git a/datafusion/physical-expr/src/utils.rs b/datafusion/physical-expr/src/utils/mod.rs similarity index 93% rename from datafusion/physical-expr/src/utils.rs rename to datafusion/physical-expr/src/utils/mod.rs index 2f4ee89463a85..64a62dc7820d8 100644 --- a/datafusion/physical-expr/src/utils.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -15,7 +15,10 @@ // specific language governing permissions and limitations // under the License. -use std::borrow::Borrow; +mod guarantee; +pub use guarantee::{Guarantee, LiteralGuarantee}; + +use std::borrow::{Borrow, Cow}; use std::collections::{HashMap, HashSet}; use std::sync::Arc; @@ -41,25 +44,29 @@ use petgraph::stable_graph::StableGraph; pub fn split_conjunction( predicate: &Arc, ) -> Vec<&Arc> { - split_conjunction_impl(predicate, vec![]) + split_impl(Operator::And, predicate, vec![]) } -fn split_conjunction_impl<'a>( +/// Assume the predicate is in the form of DNF, split the predicate to a Vec of PhysicalExprs. +/// +/// For example, split "a1 = a2 OR b1 <= b2 OR c1 != c2" into ["a1 = a2", "b1 <= b2", "c1 != c2"] +pub fn split_disjunction( + predicate: &Arc, +) -> Vec<&Arc> { + split_impl(Operator::Or, predicate, vec![]) +} + +fn split_impl<'a>( + operator: Operator, predicate: &'a Arc, mut exprs: Vec<&'a Arc>, ) -> Vec<&'a Arc> { match predicate.as_any().downcast_ref::() { - Some(binary) => match binary.op() { - Operator::And => { - let exprs = split_conjunction_impl(binary.left(), exprs); - split_conjunction_impl(binary.right(), exprs) - } - _ => { - exprs.push(predicate); - exprs - } - }, - None => { + Some(binary) if binary.op() == &operator => { + let exprs = split_impl(operator, binary.left(), exprs); + split_impl(operator, binary.right(), exprs) + } + Some(_) | None => { exprs.push(predicate); exprs } @@ -129,10 +136,11 @@ pub struct ExprTreeNode { impl ExprTreeNode { pub fn new(expr: Arc) -> Self { + let children = expr.children(); ExprTreeNode { expr, data: None, - child_nodes: vec![], + child_nodes: children.into_iter().map(Self::new).collect_vec(), } } @@ -140,29 +148,14 @@ impl ExprTreeNode { &self.expr } - pub fn children(&self) -> Vec> { - self.expr - .children() - .into_iter() - .map(ExprTreeNode::new) - .collect() + pub fn children(&self) -> &[ExprTreeNode] { + &self.child_nodes } } impl TreeNode for ExprTreeNode { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in self.children() { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children().iter().map(Cow::Borrowed).collect() } fn map_children(mut self, transform: F) -> Result @@ -170,7 +163,7 @@ impl TreeNode for ExprTreeNode { F: FnMut(Self) -> Result, { self.child_nodes = self - .children() + .child_nodes .into_iter() .map(transform) .collect::>>()?; @@ -183,7 +176,7 @@ impl TreeNode for ExprTreeNode { /// identical expressions in one node. Caller specifies the node type in the /// DAEG via the `constructor` argument, which constructs nodes in the DAEG /// from the [ExprTreeNode] ancillary object. -struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode) -> T> { +struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode) -> Result> { // The resulting DAEG (expression DAG). graph: StableGraph, // A vector of visited expression nodes and their corresponding node indices. @@ -192,7 +185,7 @@ struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode) -> T> { constructor: &'a F, } -impl<'a, T, F: Fn(&ExprTreeNode) -> T> TreeNodeRewriter +impl<'a, T, F: Fn(&ExprTreeNode) -> Result> TreeNodeRewriter for PhysicalExprDAEGBuilder<'a, T, F> { type N = ExprTreeNode; @@ -213,7 +206,7 @@ impl<'a, T, F: Fn(&ExprTreeNode) -> T> TreeNodeRewriter // add edges to its child nodes. Add the visited expression to the vector // of visited expressions and return the newly created node index. None => { - let node_idx = self.graph.add_node((self.constructor)(&node)); + let node_idx = self.graph.add_node((self.constructor)(&node)?); for expr_node in node.child_nodes.iter() { self.graph.add_edge(node_idx, expr_node.data.unwrap(), 0); } @@ -234,7 +227,7 @@ pub fn build_dag( constructor: &F, ) -> Result<(NodeIndex, StableGraph)> where - F: Fn(&ExprTreeNode) -> T, + F: Fn(&ExprTreeNode) -> Result, { // Create a new expression tree node from the input expression. let init = ExprTreeNode::new(expr); @@ -394,7 +387,7 @@ mod tests { } } - fn make_dummy_node(node: &ExprTreeNode) -> PhysicalExprDummyNode { + fn make_dummy_node(node: &ExprTreeNode) -> Result { let expr = node.expression().clone(); let dummy_property = if expr.as_any().is::() { "Binary" @@ -406,12 +399,12 @@ mod tests { "Other" } .to_owned(); - PhysicalExprDummyNode { + Ok(PhysicalExprDummyNode { expr, property: DummyProperty { expr_type: dummy_property, }, - } + }) } #[test] diff --git a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs index 66ffa990b78bc..7aa4f6536a6e4 100644 --- a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs +++ b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs @@ -60,8 +60,10 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { fn evaluate_args(&self, batch: &RecordBatch) -> Result> { self.expressions() .iter() - .map(|e| e.evaluate(batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .map(|e| { + e.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) .collect() } diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index f55f1600b9cae..7ee736ce9caab 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -23,7 +23,7 @@ use crate::PhysicalExpr; use arrow::array::ArrayRef; use arrow::compute::cast; use arrow::datatypes::{DataType, Field}; -use datafusion_common::ScalarValue; +use datafusion_common::{arrow_datafusion_err, ScalarValue}; use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_expr::PartitionEvaluator; use std::any::Any; @@ -139,9 +139,10 @@ fn create_empty_array( let array = value .as_ref() .map(|scalar| scalar.to_array_of_size(size)) + .transpose()? .unwrap_or_else(|| new_null_array(data_type, size)); if array.data_type() != data_type { - cast(&array, data_type).map_err(DataFusionError::ArrowError) + cast(&array, data_type).map_err(|e| arrow_datafusion_err!(e)) } else { Ok(array) } @@ -171,10 +172,10 @@ fn shift_with_default_value( // Concatenate both arrays, add nulls after if shift > 0 else before if offset > 0 { concat(&[default_values.as_ref(), slice.as_ref()]) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) } else { concat(&[slice.as_ref(), default_values.as_ref()]) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) } } } diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index 262a50969b820..b3c89122ebad2 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -15,21 +15,24 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expressions for `first_value`, `last_value`, and `nth_value` -//! that can evaluated at runtime during query execution +//! Defines physical expressions for `FIRST_VALUE`, `LAST_VALUE`, and `NTH_VALUE` +//! functions that can be evaluated at run time during query execution. + +use std::any::Any; +use std::cmp::Ordering; +use std::ops::Range; +use std::sync::Arc; use crate::window::window_expr::{NthValueKind, NthValueState}; use crate::window::BuiltInWindowFunctionExpr; use crate::PhysicalExpr; + use arrow::array::{Array, ArrayRef}; use arrow::datatypes::{DataType, Field}; use datafusion_common::{exec_err, ScalarValue}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::window_state::WindowAggState; use datafusion_expr::PartitionEvaluator; -use std::any::Any; -use std::ops::Range; -use std::sync::Arc; /// nth_value expression #[derive(Debug)] @@ -77,17 +80,17 @@ impl NthValue { n: u32, ) -> Result { match n { - 0 => exec_err!("nth_value expect n to be > 0"), + 0 => exec_err!("NTH_VALUE expects n to be non-zero"), _ => Ok(Self { name: name.into(), expr, data_type, - kind: NthValueKind::Nth(n), + kind: NthValueKind::Nth(n as i64), }), } } - /// Get nth_value kind + /// Get the NTH_VALUE kind pub fn get_kind(&self) -> NthValueKind { self.kind } @@ -125,7 +128,7 @@ impl BuiltInWindowFunctionExpr for NthValue { let reversed_kind = match self.kind { NthValueKind::First => NthValueKind::Last, NthValueKind::Last => NthValueKind::First, - NthValueKind::Nth(_) => return None, + NthValueKind::Nth(idx) => NthValueKind::Nth(-idx), }; Some(Arc::new(Self { name: self.name.clone(), @@ -143,16 +146,17 @@ pub(crate) struct NthValueEvaluator { } impl PartitionEvaluator for NthValueEvaluator { - /// When the window frame has a fixed beginning (e.g UNBOUNDED - /// PRECEDING), for some functions such as FIRST_VALUE, LAST_VALUE and - /// NTH_VALUE we can memoize result. Once result is calculated it - /// will always stay same. Hence, we do not need to keep past data - /// as we process the entire dataset. This feature enables us to - /// prune rows from table. The default implementation does nothing + /// When the window frame has a fixed beginning (e.g UNBOUNDED PRECEDING), + /// for some functions such as FIRST_VALUE, LAST_VALUE and NTH_VALUE, we + /// can memoize the result. Once result is calculated, it will always stay + /// same. Hence, we do not need to keep past data as we process the entire + /// dataset. fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> { let out = &state.out_col; let size = out.len(); - let (is_prunable, is_last) = match self.state.kind { + let mut buffer_size = 1; + // Decide if we arrived at a final result yet: + let (is_prunable, is_reverse_direction) = match self.state.kind { NthValueKind::First => { let n_range = state.window_frame_range.end - state.window_frame_range.start; @@ -162,16 +166,30 @@ impl PartitionEvaluator for NthValueEvaluator { NthValueKind::Nth(n) => { let n_range = state.window_frame_range.end - state.window_frame_range.start; - (n_range >= (n as usize) && size >= (n as usize), false) + match n.cmp(&0) { + Ordering::Greater => { + (n_range >= (n as usize) && size > (n as usize), false) + } + Ordering::Less => { + let reverse_index = (-n) as usize; + buffer_size = reverse_index; + // Negative index represents reverse direction. + (n_range >= reverse_index, true) + } + Ordering::Equal => { + // The case n = 0 is not valid for the NTH_VALUE function. + unreachable!(); + } + } } }; if is_prunable { - if self.state.finalized_result.is_none() && !is_last { + if self.state.finalized_result.is_none() && !is_reverse_direction { let result = ScalarValue::try_from_array(out, size - 1)?; self.state.finalized_result = Some(result); } state.window_frame_range.start = - state.window_frame_range.end.saturating_sub(1); + state.window_frame_range.end.saturating_sub(buffer_size); } Ok(()) } @@ -195,12 +213,33 @@ impl PartitionEvaluator for NthValueEvaluator { NthValueKind::First => ScalarValue::try_from_array(arr, range.start), NthValueKind::Last => ScalarValue::try_from_array(arr, range.end - 1), NthValueKind::Nth(n) => { - // We are certain that n > 0. - let index = (n as usize) - 1; - if index >= n_range { - ScalarValue::try_from(arr.data_type()) - } else { - ScalarValue::try_from_array(arr, range.start + index) + match n.cmp(&0) { + Ordering::Greater => { + // SQL indices are not 0-based. + let index = (n as usize) - 1; + if index >= n_range { + // Outside the range, return NULL: + ScalarValue::try_from(arr.data_type()) + } else { + ScalarValue::try_from_array(arr, range.start + index) + } + } + Ordering::Less => { + let reverse_index = (-n) as usize; + if n_range >= reverse_index { + ScalarValue::try_from_array( + arr, + range.start + n_range - reverse_index, + ) + } else { + // Outside the range, return NULL: + ScalarValue::try_from(arr.data_type()) + } + } + Ordering::Equal => { + // The case n = 0 is not valid for the NTH_VALUE function. + unreachable!(); + } } } } diff --git a/datafusion/physical-expr/src/window/ntile.rs b/datafusion/physical-expr/src/window/ntile.rs index 49aac0877ab33..f5442e1b0fee4 100644 --- a/datafusion/physical-expr/src/window/ntile.rs +++ b/datafusion/physical-expr/src/window/ntile.rs @@ -96,8 +96,9 @@ impl PartitionEvaluator for NtileEvaluator { ) -> Result { let num_rows = num_rows as u64; let mut vec: Vec = Vec::new(); + let n = u64::min(self.n, num_rows); for i in 0..num_rows { - let res = i * self.n / num_rows; + let res = i * n / num_rows; vec.push(res + 1) } Ok(Arc::new(UInt64Array::from(vec))) diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/physical-expr/src/window/rank.rs index 9bc36728f46ef..86af5b322133c 100644 --- a/datafusion/physical-expr/src/window/rank.rs +++ b/datafusion/physical-expr/src/window/rank.rs @@ -141,9 +141,16 @@ impl PartitionEvaluator for RankEvaluator { // There is no argument, values are order by column values (where rank is calculated) let range_columns = values; let last_rank_data = get_row_at_idx(range_columns, row_idx)?; - let empty = self.state.last_rank_data.is_empty(); - if empty || self.state.last_rank_data != last_rank_data { - self.state.last_rank_data = last_rank_data; + let new_rank_encountered = + if let Some(state_last_rank_data) = &self.state.last_rank_data { + // if rank data changes, new rank is encountered + state_last_rank_data != &last_rank_data + } else { + // First rank seen + true + }; + if new_rank_encountered { + self.state.last_rank_data = Some(last_rank_data); self.state.last_rank_boundary += self.state.current_group_count; self.state.current_group_count = 1; self.state.n_rank += 1; diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 9b0a02d329c43..548fae75bd977 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -15,7 +15,13 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; +use std::fmt::Debug; +use std::ops::Range; +use std::sync::Arc; + use crate::{PhysicalExpr, PhysicalSortExpr}; + use arrow::array::{new_empty_array, Array, ArrayRef}; use arrow::compute::kernels::sort::SortColumn; use arrow::compute::SortOptions; @@ -25,13 +31,9 @@ use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::window_state::{ PartitionBatchState, WindowAggState, WindowFrameContext, }; -use datafusion_expr::PartitionEvaluator; -use datafusion_expr::{Accumulator, WindowFrame}; +use datafusion_expr::{Accumulator, PartitionEvaluator, WindowFrame}; + use indexmap::IndexMap; -use std::any::Any; -use std::fmt::Debug; -use std::ops::Range; -use std::sync::Arc; /// Common trait for [window function] implementations /// @@ -82,8 +84,10 @@ pub trait WindowExpr: Send + Sync + Debug { fn evaluate_args(&self, batch: &RecordBatch) -> Result> { self.expressions() .iter() - .map(|e| e.evaluate(batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .map(|e| { + e.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) .collect() } @@ -270,7 +274,7 @@ pub enum WindowFn { #[derive(Debug, Clone, Default)] pub struct RankState { /// The last values for rank as these values change, we increase n_rank - pub last_rank_data: Vec, + pub last_rank_data: Option>, /// The index where last_rank_boundary is started pub last_rank_boundary: usize, /// Keep the number of entries in current rank @@ -290,7 +294,7 @@ pub struct NumRowsState { pub enum NthValueKind { First, Last, - Nth(u32), + Nth(i64), } #[derive(Debug, Clone)] diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index b8d8b6d2d61b8..c5b689496e90a 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -48,7 +48,7 @@ futures = { workspace = true } half = { version = "2.1", default-features = false } hashbrown = { version = "0.14", features = ["raw"] } indexmap = { workspace = true } -itertools = { version = "0.11", features = ["use_std"] } +itertools = { version = "0.12", features = ["use_std"] } log = { workspace = true } once_cell = "1.18.0" parking_lot = { workspace = true } diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index 10ff9edb8912f..e7c7a42cf9029 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -17,22 +17,18 @@ use crate::aggregates::group_values::GroupValues; use ahash::RandomState; -use arrow::compute::cast; use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, Rows, SortField}; -use arrow_array::{Array, ArrayRef}; -use arrow_schema::{DataType, SchemaRef}; +use arrow_array::ArrayRef; +use arrow_schema::SchemaRef; use datafusion_common::hash_utils::create_hashes; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::Result; use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; use datafusion_physical_expr::EmitTo; use hashbrown::raw::RawTable; /// A [`GroupValues`] making use of [`Rows`] pub struct GroupValuesRows { - /// The output schema - schema: SchemaRef, - /// Converter for the group values row_converter: RowConverter, @@ -79,7 +75,6 @@ impl GroupValuesRows { let map = RawTable::with_capacity(0); Ok(Self { - schema, row_converter, map, map_size: 0, @@ -170,7 +165,7 @@ impl GroupValues for GroupValuesRows { .take() .expect("Can not emit from empty rows"); - let mut output = match emit_to { + let output = match emit_to { EmitTo::All => { let output = self.row_converter.convert_rows(&group_values)?; group_values.clear(); @@ -203,20 +198,6 @@ impl GroupValues for GroupValuesRows { } }; - // TODO: Materialize dictionaries in group keys (#7647) - for (field, array) in self.schema.fields.iter().zip(&mut output) { - let expected = field.data_type(); - if let DataType::Dictionary(_, v) = expected { - let actual = array.data_type(); - if v.as_ref() != actual { - return Err(DataFusionError::Internal(format!( - "Converted group rows expected dictionary of {v} got {actual}" - ))); - } - *array = cast(array.as_ref(), expected)?; - } - } - self.group_values = Some(group_values); Ok(output) } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 9cbf12aeeb88f..a38044de02e38 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -27,30 +27,29 @@ use crate::aggregates::{ }; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use crate::windows::{ - get_ordered_partition_by_indices, get_window_mode, PartitionSearchMode, -}; +use crate::windows::get_ordered_partition_by_indices; use crate::{ - DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode, Partitioning, SendableRecordBatchStream, Statistics, }; use arrow::array::ArrayRef; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use arrow_schema::DataType; use datafusion_common::stats::Precision; use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; use datafusion_execution::TaskContext; use datafusion_expr::Accumulator; use datafusion_physical_expr::{ aggregate::is_order_sensitive, - equivalence::collapse_lex_req, - expressions::{Column, Max, Min, UnKnownColumn}, + equivalence::{collapse_lex_req, ProjectionMapping}, + expressions::{Column, FirstValue, LastValue, Max, Min, UnKnownColumn}, physical_exprs_contains, reverse_order_bys, AggregateExpr, EquivalenceProperties, LexOrdering, LexRequirement, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, }; -use itertools::{izip, Itertools}; +use itertools::Itertools; mod group_values; mod no_grouping; @@ -60,7 +59,6 @@ mod topk; mod topk_stream; pub use datafusion_expr::AggregateFunction; -use datafusion_physical_expr::equivalence::ProjectionMapping; pub use datafusion_physical_expr::expressions::create_aggregate_expr; /// Hash aggregate modes @@ -103,34 +101,6 @@ impl AggregateMode { } } -/// Group By expression modes -/// -/// `PartiallyOrdered` and `FullyOrdered` are used to reason about -/// when certain group by keys will never again be seen (and thus can -/// be emitted by the grouping operator). -/// -/// Specifically, each distinct combination of the relevant columns -/// are contiguous in the input, and once a new combination is seen -/// previous combinations are guaranteed never to appear again -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum GroupByOrderMode { - /// The input is known to be ordered by a preset (prefix but - /// possibly reordered) of the expressions in the `GROUP BY` clause. - /// - /// For example, if the input is ordered by `a, b, c` and we group - /// by `b, a, d`, `PartiallyOrdered` means a subset of group `b, - /// a, d` defines a preset for the existing ordering, in this case - /// `a, b`. - PartiallyOrdered, - /// The input is known to be ordered by *all* the expressions in the - /// `GROUP BY` clause. - /// - /// For example, if the input is ordered by `a, b, c, d` and we group by b, a, - /// `Ordered` means that all of the of group by expressions appear - /// as a preset for the existing ordering, in this case `a, b`. - FullyOrdered, -} - /// Represents `GROUP BY` clause in the plan (including the more general GROUPING SET) /// In the case of a simple `GROUP BY a, b` clause, this will contain the expression [a, b] /// and a single group [false, false]. @@ -280,12 +250,13 @@ pub struct AggregateExec { aggr_expr: Vec>, /// FILTER (WHERE clause) expression for each aggregate expression filter_expr: Vec>>, - /// (ORDER BY clause) expression for each aggregate expression - order_by_expr: Vec>, /// Set if the output of this aggregation is truncated by a upstream sort/limit clause limit: Option, /// Input plan, could be a partial aggregate or the input to the aggregate pub input: Arc, + /// Original aggregation schema, could be different from `schema` before dictionary group + /// keys get materialized + original_schema: SchemaRef, /// Schema after the aggregate is applied schema: SchemaRef, /// Input schema before any aggregation is applied. For partial aggregate this will be the @@ -300,176 +271,23 @@ pub struct AggregateExec { /// Execution metrics metrics: ExecutionPlanMetricsSet, required_input_ordering: Option, - partition_search_mode: PartitionSearchMode, + /// Describes how the input is ordered relative to the group by columns + input_order_mode: InputOrderMode, + /// Describe how the output is ordered output_ordering: Option, } -/// This function returns the ordering requirement of the first non-reversible -/// order-sensitive aggregate function such as ARRAY_AGG. This requirement serves -/// as the initial requirement while calculating the finest requirement among all -/// aggregate functions. If this function returns `None`, it means there is no -/// hard ordering requirement for the aggregate functions (in terms of direction). -/// Then, we can generate two alternative requirements with opposite directions. -fn get_init_req( - aggr_expr: &[Arc], - order_by_expr: &[Option], -) -> Option { - for (aggr_expr, fn_reqs) in aggr_expr.iter().zip(order_by_expr.iter()) { - // If the aggregation function is a non-reversible order-sensitive function - // and there is a hard requirement, choose first such requirement: - if is_order_sensitive(aggr_expr) - && aggr_expr.reverse_expr().is_none() - && fn_reqs.is_some() - { - return fn_reqs.clone(); - } - } - None -} - -/// This function gets the finest ordering requirement among all the aggregation -/// functions. If requirements are conflicting, (i.e. we can not compute the -/// aggregations in a single [`AggregateExec`]), the function returns an error. -fn get_finest_requirement( - aggr_expr: &mut [Arc], - order_by_expr: &mut [Option], - eq_properties: &EquivalenceProperties, -) -> Result> { - // First, we check if all the requirements are satisfied by the existing - // ordering. If so, we return `None` to indicate this. - let mut all_satisfied = true; - for (aggr_expr, fn_req) in aggr_expr.iter_mut().zip(order_by_expr.iter_mut()) { - if eq_properties.ordering_satisfy(fn_req.as_deref().unwrap_or(&[])) { - continue; - } - if let Some(reverse) = aggr_expr.reverse_expr() { - let reverse_req = fn_req.as_ref().map(|item| reverse_order_bys(item)); - if eq_properties.ordering_satisfy(reverse_req.as_deref().unwrap_or(&[])) { - // We need to update `aggr_expr` with its reverse since only its - // reverse requirement is compatible with the existing requirements: - *aggr_expr = reverse; - *fn_req = reverse_req; - continue; - } - } - // Requirement is not satisfied: - all_satisfied = false; - } - if all_satisfied { - // All of the requirements are already satisfied. - return Ok(None); - } - let mut finest_req = get_init_req(aggr_expr, order_by_expr); - for (aggr_expr, fn_req) in aggr_expr.iter_mut().zip(order_by_expr.iter_mut()) { - let Some(fn_req) = fn_req else { - continue; - }; - - if let Some(finest_req) = &mut finest_req { - if let Some(finer) = eq_properties.get_finer_ordering(finest_req, fn_req) { - *finest_req = finer; - continue; - } - // If an aggregate function is reversible, analyze whether its reverse - // direction is compatible with existing requirements: - if let Some(reverse) = aggr_expr.reverse_expr() { - let fn_req_reverse = reverse_order_bys(fn_req); - if let Some(finer) = - eq_properties.get_finer_ordering(finest_req, &fn_req_reverse) - { - // We need to update `aggr_expr` with its reverse, since only its - // reverse requirement is compatible with existing requirements: - *aggr_expr = reverse; - *finest_req = finer; - *fn_req = fn_req_reverse; - continue; - } - } - // If neither of the requirements satisfy the other, this means - // requirements are conflicting. Currently, we do not support - // conflicting requirements. - return not_impl_err!( - "Conflicting ordering requirements in aggregate functions is not supported" - ); - } else { - finest_req = Some(fn_req.clone()); - } - } - Ok(finest_req) -} - -/// Calculates search_mode for the aggregation -fn get_aggregate_search_mode( - group_by: &PhysicalGroupBy, - input: &Arc, - aggr_expr: &mut [Arc], - order_by_expr: &mut [Option], - ordering_req: &mut Vec, -) -> Result { - let groupby_exprs = group_by - .expr - .iter() - .map(|(item, _)| item.clone()) - .collect::>(); - let mut partition_search_mode = PartitionSearchMode::Linear; - if !group_by.is_single() || groupby_exprs.is_empty() { - return Ok(partition_search_mode); - } - - if let Some((should_reverse, mode)) = - get_window_mode(&groupby_exprs, ordering_req, input)? - { - let all_reversible = aggr_expr - .iter() - .all(|expr| !is_order_sensitive(expr) || expr.reverse_expr().is_some()); - if should_reverse && all_reversible { - izip!(aggr_expr.iter_mut(), order_by_expr.iter_mut()).for_each( - |(aggr, order_by)| { - if let Some(reverse) = aggr.reverse_expr() { - *aggr = reverse; - } else { - unreachable!(); - } - *order_by = order_by.as_ref().map(|ob| reverse_order_bys(ob)); - }, - ); - *ordering_req = reverse_order_bys(ordering_req); - } - partition_search_mode = mode; - } - Ok(partition_search_mode) -} - -/// Check whether group by expression contains all of the expression inside `requirement` -// As an example Group By (c,b,a) contains all of the expressions in the `requirement`: (a ASC, b DESC) -fn group_by_contains_all_requirements( - group_by: &PhysicalGroupBy, - requirement: &LexOrdering, -) -> bool { - let physical_exprs = group_by.input_exprs(); - // When we have multiple groups (grouping set) - // since group by may be calculated on the subset of the group_by.expr() - // it is not guaranteed to have all of the requirements among group by expressions. - // Hence do the analysis: whether group by contains all requirements in the single group case. - group_by.is_single() - && requirement - .iter() - .all(|req| physical_exprs_contains(&physical_exprs, &req.expr)) -} - impl AggregateExec { /// Create a new hash aggregate execution plan pub fn try_new( mode: AggregateMode, group_by: PhysicalGroupBy, - mut aggr_expr: Vec>, + aggr_expr: Vec>, filter_expr: Vec>>, - // Ordering requirement of each aggregate expression - mut order_by_expr: Vec>, input: Arc, input_schema: SchemaRef, ) -> Result { - let schema = create_schema( + let original_schema = create_schema( &input.schema(), &group_by.expr, &aggr_expr, @@ -477,44 +295,43 @@ impl AggregateExec { mode, )?; - let schema = Arc::new(schema); - // Reset ordering requirement to `None` if aggregator is not order-sensitive - order_by_expr = aggr_expr - .iter() - .zip(order_by_expr) - .map(|(aggr_expr, fn_reqs)| { - // If - // - aggregation function is order-sensitive and - // - aggregation is performing a "first stage" calculation, and - // - at least one of the aggregate function requirement is not inside group by expression - // keep the ordering requirement as is; otherwise ignore the ordering requirement. - // In non-first stage modes, we accumulate data (using `merge_batch`) - // from different partitions (i.e. merge partial results). During - // this merge, we consider the ordering of each partial result. - // Hence, we do not need to use the ordering requirement in such - // modes as long as partial results are generated with the - // correct ordering. - fn_reqs.filter(|req| { - is_order_sensitive(aggr_expr) - && mode.is_first_stage() - && !group_by_contains_all_requirements(&group_by, req) - }) - }) - .collect::>(); - let requirement = get_finest_requirement( - &mut aggr_expr, - &mut order_by_expr, - &input.equivalence_properties(), - )?; - let mut ordering_req = requirement.unwrap_or(vec![]); - let partition_search_mode = get_aggregate_search_mode( - &group_by, - &input, - &mut aggr_expr, - &mut order_by_expr, - &mut ordering_req, - )?; + let schema = Arc::new(materialize_dict_group_keys( + &original_schema, + group_by.expr.len(), + )); + let original_schema = Arc::new(original_schema); + AggregateExec::try_new_with_schema( + mode, + group_by, + aggr_expr, + filter_expr, + input, + input_schema, + schema, + original_schema, + ) + } + /// Create a new hash aggregate execution plan with the given schema. + /// This constructor isn't part of the public API, it is used internally + /// by Datafusion to enforce schema consistency during when re-creating + /// `AggregateExec`s inside optimization rules. Schema field names of an + /// `AggregateExec` depends on the names of aggregate expressions. Since + /// a rule may re-write aggregate expressions (e.g. reverse them) during + /// initialization, field names may change inadvertently if one re-creates + /// the schema in such cases. + #[allow(clippy::too_many_arguments)] + fn try_new_with_schema( + mode: AggregateMode, + group_by: PhysicalGroupBy, + mut aggr_expr: Vec>, + filter_expr: Vec>>, + input: Arc, + input_schema: SchemaRef, + schema: SchemaRef, + original_schema: SchemaRef, + ) -> Result { + let input_eq_properties = input.equivalence_properties(); // Get GROUP BY expressions: let groupby_exprs = group_by.input_exprs(); // If existing ordering satisfies a prefix of the GROUP BY expressions, @@ -522,17 +339,32 @@ impl AggregateExec { // work more efficiently. let indices = get_ordered_partition_by_indices(&groupby_exprs, &input); let mut new_requirement = indices - .into_iter() - .map(|idx| PhysicalSortRequirement { + .iter() + .map(|&idx| PhysicalSortRequirement { expr: groupby_exprs[idx].clone(), options: None, }) .collect::>(); - // Postfix ordering requirement of the aggregation to the requirement. - let req = PhysicalSortRequirement::from_sort_exprs(&ordering_req); + + let req = get_aggregate_exprs_requirement( + &new_requirement, + &mut aggr_expr, + &group_by, + &input_eq_properties, + &mode, + )?; new_requirement.extend(req); new_requirement = collapse_lex_req(new_requirement); + let input_order_mode = + if indices.len() == groupby_exprs.len() && !indices.is_empty() { + InputOrderMode::Sorted + } else if !indices.is_empty() { + InputOrderMode::PartiallySorted(indices) + } else { + InputOrderMode::Linear + }; + // construct a map from the input expression to the output expression of the Aggregation group by let projection_mapping = ProjectionMapping::try_new(&group_by.expr, &input.schema())?; @@ -540,9 +372,8 @@ impl AggregateExec { let required_input_ordering = (!new_requirement.is_empty()).then_some(new_requirement); - let aggregate_eqs = input - .equivalence_properties() - .project(&projection_mapping, schema.clone()); + let aggregate_eqs = + input_eq_properties.project(&projection_mapping, schema.clone()); let output_ordering = aggregate_eqs.oeq_class().output_ordering(); Ok(AggregateExec { @@ -550,15 +381,15 @@ impl AggregateExec { group_by, aggr_expr, filter_expr, - order_by_expr, input, + original_schema, schema, input_schema, projection_mapping, metrics: ExecutionPlanMetricsSet::new(), required_input_ordering, limit: None, - partition_search_mode, + input_order_mode, output_ordering, }) } @@ -593,11 +424,6 @@ impl AggregateExec { &self.filter_expr } - /// ORDER BY clause expression for each aggregate expression - pub fn order_by_expr(&self) -> &[Option] { - &self.order_by_expr - } - /// Input plan pub fn input(&self) -> &Arc { &self.input @@ -608,6 +434,11 @@ impl AggregateExec { self.input_schema.clone() } + /// number of rows soft limit of the AggregateExec + pub fn limit(&self) -> Option { + self.limit + } + fn execute_typed( &self, partition: usize, @@ -622,9 +453,11 @@ impl AggregateExec { // grouping by an expression that has a sort/limit upstream if let Some(limit) = self.limit { - return Ok(StreamType::GroupedPriorityQueue( - GroupedTopKAggregateStream::new(self, context, partition, limit)?, - )); + if !self.is_unordered_unfiltered_group_by_distinct() { + return Ok(StreamType::GroupedPriorityQueue( + GroupedTopKAggregateStream::new(self, context, partition, limit)?, + )); + } } // grouping by something else and we need to just materialize all results @@ -648,6 +481,39 @@ impl AggregateExec { pub fn group_by(&self) -> &PhysicalGroupBy { &self.group_by } + + /// true, if this Aggregate has a group-by with no required or explicit ordering, + /// no filtering and no aggregate expressions + /// This method qualifies the use of the LimitedDistinctAggregation rewrite rule + /// on an AggregateExec. + pub fn is_unordered_unfiltered_group_by_distinct(&self) -> bool { + // ensure there is a group by + if self.group_by().is_empty() { + return false; + } + // ensure there are no aggregate expressions + if !self.aggr_expr().is_empty() { + return false; + } + // ensure there are no filters on aggregate expressions; the above check + // may preclude this case + if self.filter_expr().iter().any(|e| e.is_some()) { + return false; + } + // ensure there are no order by expressions + if self.aggr_expr().iter().any(|e| e.order_bys().is_some()) { + return false; + } + // ensure there is no output ordering; can this rule be relaxed? + if self.output_ordering().is_some() { + return false; + } + // ensure no ordering is required on the input + if self.required_input_ordering()[0].is_some() { + return false; + } + true + } } impl DisplayAs for AggregateExec { @@ -718,8 +584,8 @@ impl DisplayAs for AggregateExec { write!(f, ", lim=[{limit}]")?; } - if self.partition_search_mode != PartitionSearchMode::Linear { - write!(f, ", ordering_mode={:?}", self.partition_search_mode)?; + if self.input_order_mode != InputOrderMode::Linear { + write!(f, ", ordering_mode={:?}", self.input_order_mode)?; } } } @@ -770,7 +636,7 @@ impl ExecutionPlan for AggregateExec { /// infinite, returns an error to indicate this. fn unbounded_output(&self, children: &[bool]) -> Result { if children[0] { - if self.partition_search_mode == PartitionSearchMode::Linear { + if self.input_order_mode == InputOrderMode::Linear { // Cannot run without breaking pipeline. plan_err!( "Aggregate Error: `GROUP BY` clauses with columns without ordering and GROUPING SETS are not supported for unbounded inputs." @@ -819,14 +685,15 @@ impl ExecutionPlan for AggregateExec { self: Arc, children: Vec>, ) -> Result> { - let mut me = AggregateExec::try_new( + let mut me = AggregateExec::try_new_with_schema( self.mode, self.group_by.clone(), self.aggr_expr.clone(), self.filter_expr.clone(), - self.order_by_expr.clone(), children[0].clone(), self.input_schema.clone(), + self.schema.clone(), + self.original_schema.clone(), )?; me.limit = self.limit; Ok(Arc::new(me)) @@ -933,11 +800,199 @@ fn create_schema( Ok(Schema::new(fields)) } +/// returns schema with dictionary group keys materialized as their value types +/// The actual convertion happens in `RowConverter` and we don't do unnecessary +/// conversion back into dictionaries +fn materialize_dict_group_keys(schema: &Schema, group_count: usize) -> Schema { + let fields = schema + .fields + .iter() + .enumerate() + .map(|(i, field)| match field.data_type() { + DataType::Dictionary(_, value_data_type) if i < group_count => { + Field::new(field.name(), *value_data_type.clone(), field.is_nullable()) + } + _ => Field::clone(field), + }) + .collect::>(); + Schema::new(fields) +} + fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef { let group_fields = schema.fields()[0..group_count].to_vec(); Arc::new(Schema::new(group_fields)) } +/// Determines the lexical ordering requirement for an aggregate expression. +/// +/// # Parameters +/// +/// - `aggr_expr`: A reference to an `Arc` representing the +/// aggregate expression. +/// - `group_by`: A reference to a `PhysicalGroupBy` instance representing the +/// physical GROUP BY expression. +/// - `agg_mode`: A reference to an `AggregateMode` instance representing the +/// mode of aggregation. +/// +/// # Returns +/// +/// A `LexOrdering` instance indicating the lexical ordering requirement for +/// the aggregate expression. +fn get_aggregate_expr_req( + aggr_expr: &Arc, + group_by: &PhysicalGroupBy, + agg_mode: &AggregateMode, +) -> LexOrdering { + // If the aggregation function is not order sensitive, or the aggregation + // is performing a "second stage" calculation, or all aggregate function + // requirements are inside the GROUP BY expression, then ignore the ordering + // requirement. + if !is_order_sensitive(aggr_expr) || !agg_mode.is_first_stage() { + return vec![]; + } + + let mut req = aggr_expr.order_bys().unwrap_or_default().to_vec(); + + // In non-first stage modes, we accumulate data (using `merge_batch`) from + // different partitions (i.e. merge partial results). During this merge, we + // consider the ordering of each partial result. Hence, we do not need to + // use the ordering requirement in such modes as long as partial results are + // generated with the correct ordering. + if group_by.is_single() { + // Remove all orderings that occur in the group by. These requirements + // will definitely be satisfied -- Each group by expression will have + // distinct values per group, hence all requirements are satisfied. + let physical_exprs = group_by.input_exprs(); + req.retain(|sort_expr| { + !physical_exprs_contains(&physical_exprs, &sort_expr.expr) + }); + } + req +} + +/// Computes the finer ordering for between given existing ordering requirement +/// of aggregate expression. +/// +/// # Parameters +/// +/// * `existing_req` - The existing lexical ordering that needs refinement. +/// * `aggr_expr` - A reference to an aggregate expression trait object. +/// * `group_by` - Information about the physical grouping (e.g group by expression). +/// * `eq_properties` - Equivalence properties relevant to the computation. +/// * `agg_mode` - The mode of aggregation (e.g., Partial, Final, etc.). +/// +/// # Returns +/// +/// An `Option` representing the computed finer lexical ordering, +/// or `None` if there is no finer ordering; e.g. the existing requirement and +/// the aggregator requirement is incompatible. +fn finer_ordering( + existing_req: &LexOrdering, + aggr_expr: &Arc, + group_by: &PhysicalGroupBy, + eq_properties: &EquivalenceProperties, + agg_mode: &AggregateMode, +) -> Option { + let aggr_req = get_aggregate_expr_req(aggr_expr, group_by, agg_mode); + eq_properties.get_finer_ordering(existing_req, &aggr_req) +} + +/// Concatenates the given slices. +fn concat_slices(lhs: &[T], rhs: &[T]) -> Vec { + [lhs, rhs].concat() +} + +/// Get the common requirement that satisfies all the aggregate expressions. +/// +/// # Parameters +/// +/// - `aggr_exprs`: A slice of `Arc` containing all the +/// aggregate expressions. +/// - `group_by`: A reference to a `PhysicalGroupBy` instance representing the +/// physical GROUP BY expression. +/// - `eq_properties`: A reference to an `EquivalenceProperties` instance +/// representing equivalence properties for ordering. +/// - `agg_mode`: A reference to an `AggregateMode` instance representing the +/// mode of aggregation. +/// +/// # Returns +/// +/// A `LexRequirement` instance, which is the requirement that satisfies all the +/// aggregate requirements. Returns an error in case of conflicting requirements. +fn get_aggregate_exprs_requirement( + prefix_requirement: &[PhysicalSortRequirement], + aggr_exprs: &mut [Arc], + group_by: &PhysicalGroupBy, + eq_properties: &EquivalenceProperties, + agg_mode: &AggregateMode, +) -> Result { + let mut requirement = vec![]; + for aggr_expr in aggr_exprs.iter_mut() { + let aggr_req = aggr_expr.order_bys().unwrap_or(&[]); + let reverse_aggr_req = reverse_order_bys(aggr_req); + let aggr_req = PhysicalSortRequirement::from_sort_exprs(aggr_req); + let reverse_aggr_req = + PhysicalSortRequirement::from_sort_exprs(&reverse_aggr_req); + if let Some(first_value) = aggr_expr.as_any().downcast_ref::() { + let mut first_value = first_value.clone(); + if eq_properties.ordering_satisfy_requirement(&concat_slices( + prefix_requirement, + &aggr_req, + )) { + first_value = first_value.with_requirement_satisfied(true); + *aggr_expr = Arc::new(first_value) as _; + } else if eq_properties.ordering_satisfy_requirement(&concat_slices( + prefix_requirement, + &reverse_aggr_req, + )) { + // Converting to LAST_VALUE enables more efficient execution + // given the existing ordering: + let mut last_value = first_value.convert_to_last(); + last_value = last_value.with_requirement_satisfied(true); + *aggr_expr = Arc::new(last_value) as _; + } else { + // Requirement is not satisfied with existing ordering. + first_value = first_value.with_requirement_satisfied(false); + *aggr_expr = Arc::new(first_value) as _; + } + } else if let Some(last_value) = aggr_expr.as_any().downcast_ref::() { + let mut last_value = last_value.clone(); + if eq_properties.ordering_satisfy_requirement(&concat_slices( + prefix_requirement, + &aggr_req, + )) { + last_value = last_value.with_requirement_satisfied(true); + *aggr_expr = Arc::new(last_value) as _; + } else if eq_properties.ordering_satisfy_requirement(&concat_slices( + prefix_requirement, + &reverse_aggr_req, + )) { + // Converting to FIRST_VALUE enables more efficient execution + // given the existing ordering: + let mut first_value = last_value.convert_to_first(); + first_value = first_value.with_requirement_satisfied(true); + *aggr_expr = Arc::new(first_value) as _; + } else { + // Requirement is not satisfied with existing ordering. + last_value = last_value.with_requirement_satisfied(false); + *aggr_expr = Arc::new(last_value) as _; + } + } else if let Some(finer_ordering) = + finer_ordering(&requirement, aggr_expr, group_by, eq_properties, agg_mode) + { + requirement = finer_ordering; + } else { + // If neither of the requirements satisfy the other, this means + // requirements are conflicting. Currently, we do not support + // conflicting requirements. + return not_impl_err!( + "Conflicting ordering requirements in aggregate functions is not supported" + ); + } + } + Ok(PhysicalSortRequirement::from_sort_exprs(&requirement)) +} + /// returns physical expressions for arguments to evaluate against a batch /// The expressions are different depending on `mode`: /// * Partial: AggregateExpr::expressions @@ -953,33 +1008,27 @@ fn aggregate_expressions( | AggregateMode::SinglePartitioned => Ok(aggr_expr .iter() .map(|agg| { - let mut result = agg.expressions().clone(); - // In partial mode, append ordering requirements to expressions' results. - // Ordering requirements are used by subsequent executors to satisfy the required - // ordering for `AggregateMode::FinalPartitioned`/`AggregateMode::Final` modes. - if matches!(mode, AggregateMode::Partial) { - if let Some(ordering_req) = agg.order_bys() { - let ordering_exprs = ordering_req - .iter() - .map(|item| item.expr.clone()) - .collect::>(); - result.extend(ordering_exprs); - } + let mut result = agg.expressions(); + // Append ordering requirements to expressions' results. This + // way order sensitive aggregators can satisfy requirement + // themselves. + if let Some(ordering_req) = agg.order_bys() { + result.extend(ordering_req.iter().map(|item| item.expr.clone())); } result }) .collect()), - // in this mode, we build the merge expressions of the aggregation + // In this mode, we build the merge expressions of the aggregation. AggregateMode::Final | AggregateMode::FinalPartitioned => { let mut col_idx_base = col_idx_base; - Ok(aggr_expr + aggr_expr .iter() .map(|agg| { let exprs = merge_expressions(col_idx_base, agg)?; col_idx_base += exprs.len(); Ok(exprs) }) - .collect::>>()?) + .collect() } } } @@ -992,14 +1041,13 @@ fn merge_expressions( index_base: usize, expr: &Arc, ) -> Result>> { - Ok(expr - .state_fields()? - .iter() - .enumerate() - .map(|(idx, f)| { - Arc::new(Column::new(f.name(), index_base + idx)) as Arc - }) - .collect::>()) + expr.state_fields().map(|fields| { + fields + .iter() + .enumerate() + .map(|(idx, f)| Arc::new(Column::new(f.name(), index_base + idx)) as _) + .collect() + }) } pub(crate) type AccumulatorItem = Box; @@ -1010,7 +1058,7 @@ fn create_accumulators( aggr_expr .iter() .map(|expr| expr.create_accumulator()) - .collect::>>() + .collect() } /// returns a vector of ArrayRefs, where each entry corresponds to either the @@ -1021,27 +1069,28 @@ fn finalize_aggregation( ) -> Result> { match mode { AggregateMode::Partial => { - // build the vector of states - let a = accumulators + // Build the vector of states + accumulators .iter() - .map(|accumulator| accumulator.state()) - .map(|value| { - value.map(|e| { - e.iter().map(|v| v.to_array()).collect::>() + .map(|accumulator| { + accumulator.state().and_then(|e| { + e.iter() + .map(|v| v.to_array()) + .collect::>>() }) }) - .collect::>>()?; - Ok(a.iter().flatten().cloned().collect::>()) + .flatten_ok() + .collect() } AggregateMode::Final | AggregateMode::FinalPartitioned | AggregateMode::Single | AggregateMode::SinglePartitioned => { - // merge the state to the final value + // Merge the state to the final value accumulators .iter() - .map(|accumulator| accumulator.evaluate().map(|v| v.to_array())) - .collect::>>() + .map(|accumulator| accumulator.evaluate().and_then(|v| v.to_array())) + .collect() } } } @@ -1052,9 +1101,11 @@ fn evaluate( batch: &RecordBatch, ) -> Result> { expr.iter() - .map(|expr| expr.evaluate(batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) - .collect::>>() + .map(|expr| { + expr.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) + .collect() } /// Evaluates expressions against a record batch. @@ -1062,9 +1113,7 @@ pub(crate) fn evaluate_many( expr: &[Vec>], batch: &RecordBatch, ) -> Result>> { - expr.iter() - .map(|expr| evaluate(expr, batch)) - .collect::>>() + expr.iter().map(|expr| evaluate(expr, batch)).collect() } fn evaluate_optional( @@ -1074,11 +1123,13 @@ fn evaluate_optional( expr.iter() .map(|expr| { expr.as_ref() - .map(|expr| expr.evaluate(batch)) + .map(|expr| { + expr.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) .transpose() - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) }) - .collect::>>() + .collect() } /// Evaluate a group by expression against a `RecordBatch` @@ -1100,7 +1151,7 @@ pub(crate) fn evaluate_group_by( .iter() .map(|(expr, _)| { let value = expr.evaluate(batch)?; - Ok(value.into_array(batch.num_rows())) + value.into_array(batch.num_rows()) }) .collect::>>()?; @@ -1109,7 +1160,7 @@ pub(crate) fn evaluate_group_by( .iter() .map(|(expr, _)| { let value = expr.evaluate(batch)?; - Ok(value.into_array(batch.num_rows())) + value.into_array(batch.num_rows()) }) .collect::>>()?; @@ -1139,9 +1190,7 @@ mod tests { use std::task::{Context, Poll}; use super::*; - use crate::aggregates::{ - get_finest_requirement, AggregateExec, AggregateMode, PhysicalGroupBy, - }; + use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::coalesce_batches::CoalesceBatchesExec; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::common; @@ -1163,15 +1212,16 @@ mod tests { Result, ScalarValue, }; use datafusion_execution::config::SessionConfig; + use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_physical_expr::expressions::{ - lit, ApproxDistinct, Count, FirstValue, LastValue, Median, + lit, ApproxDistinct, Count, FirstValue, LastValue, Median, OrderSensitiveArrayAgg, }; use datafusion_physical_expr::{ - AggregateExpr, EquivalenceProperties, PhysicalExpr, PhysicalSortExpr, + reverse_order_bys, AggregateExpr, EquivalenceProperties, PhysicalExpr, + PhysicalSortExpr, }; - use datafusion_execution::memory_pool::FairSpillPool; use futures::{FutureExt, Stream}; // Generate a schema which consists of 5 columns (a, b, c, d, e) @@ -1323,7 +1373,6 @@ mod tests { grouping_set.clone(), aggregates.clone(), vec![None], - vec![None], input, input_schema.clone(), )?); @@ -1402,7 +1451,6 @@ mod tests { final_grouping_set, aggregates, vec![None], - vec![None], merge, input_schema, )?); @@ -1468,7 +1516,6 @@ mod tests { grouping_set.clone(), aggregates.clone(), vec![None], - vec![None], input, input_schema.clone(), )?); @@ -1516,7 +1563,6 @@ mod tests { final_grouping_set, aggregates, vec![None], - vec![None], merge, input_schema, )?); @@ -1783,7 +1829,6 @@ mod tests { groups, aggregates, vec![None; 3], - vec![None; 3], input.clone(), input_schema.clone(), )?); @@ -1839,7 +1884,6 @@ mod tests { groups.clone(), aggregates.clone(), vec![None], - vec![None], blocking_exec, schema, )?); @@ -1878,7 +1922,6 @@ mod tests { groups, aggregates.clone(), vec![None], - vec![None], blocking_exec, schema, )?); @@ -1980,7 +2023,6 @@ mod tests { groups.clone(), aggregates.clone(), vec![None], - vec![Some(ordering_req.clone())], memory_exec, schema.clone(), )?); @@ -1996,7 +2038,6 @@ mod tests { groups, aggregates.clone(), vec![None], - vec![Some(ordering_req)], coalesce, schema, )?) as Arc; @@ -2037,11 +2078,6 @@ mod tests { descending: false, nulls_first: false, }; - // This is the reverse requirement of options1 - let options2 = SortOptions { - descending: true, - nulls_first: true, - }; let col_a = &col("a", &test_schema)?; let col_b = &col("b", &test_schema)?; let col_c = &col("c", &test_schema)?; @@ -2050,7 +2086,7 @@ mod tests { eq_properties.add_equal_conditions(col_a, col_b); // Aggregate requirements are // [None], [a ASC], [a ASC, b ASC, c ASC], [a ASC, b ASC] respectively - let mut order_by_exprs = vec![ + let order_by_exprs = vec![ None, Some(vec![PhysicalSortExpr { expr: col_a.clone(), @@ -2080,14 +2116,8 @@ mod tests { options: options1, }, ]), - // Since aggregate expression is reversible (FirstValue), we should be able to resolve below - // contradictory requirement by reversing it. - Some(vec![PhysicalSortExpr { - expr: col_b.clone(), - options: options2, - }]), ]; - let common_requirement = Some(vec![ + let common_requirement = vec![ PhysicalSortExpr { expr: col_a.clone(), options: options1, @@ -2096,18 +2126,82 @@ mod tests { expr: col_c.clone(), options: options1, }, - ]); - let aggr_expr = Arc::new(FirstValue::new( - col_a.clone(), - "first1", - DataType::Int32, - vec![], - vec![], - )) as _; - let mut aggr_exprs = vec![aggr_expr; order_by_exprs.len()]; - let res = - get_finest_requirement(&mut aggr_exprs, &mut order_by_exprs, &eq_properties)?; + ]; + let mut aggr_exprs = order_by_exprs + .into_iter() + .map(|order_by_expr| { + Arc::new(OrderSensitiveArrayAgg::new( + col_a.clone(), + "array_agg", + DataType::Int32, + false, + vec![], + order_by_expr.unwrap_or_default(), + )) as _ + }) + .collect::>(); + let group_by = PhysicalGroupBy::new_single(vec![]); + let res = get_aggregate_exprs_requirement( + &[], + &mut aggr_exprs, + &group_by, + &eq_properties, + &AggregateMode::Partial, + )?; + let res = PhysicalSortRequirement::to_sort_exprs(res); assert_eq!(res, common_requirement); Ok(()) } + + #[test] + fn test_agg_exec_same_schema() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float32, true), + Field::new("b", DataType::Float32, true), + ])); + + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let sort_expr = vec![PhysicalSortExpr { + expr: col_b.clone(), + options: option_desc, + }]; + let sort_expr_reverse = reverse_order_bys(&sort_expr); + let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]); + + let aggregates: Vec> = vec![ + Arc::new(FirstValue::new( + col_b.clone(), + "FIRST_VALUE(b)".to_string(), + DataType::Float64, + sort_expr_reverse.clone(), + vec![DataType::Float64], + )), + Arc::new(LastValue::new( + col_b.clone(), + "LAST_VALUE(b)".to_string(), + DataType::Float64, + sort_expr.clone(), + vec![DataType::Float64], + )), + ]; + let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); + let aggregate_exec = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + groups, + aggregates.clone(), + vec![None, None], + blocking_exec.clone(), + schema, + )?); + let new_agg = aggregate_exec + .clone() + .with_new_children(vec![blocking_exec])?; + assert_eq!(new_agg.schema(), aggregate_exec.schema()); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/aggregates/no_grouping.rs b/datafusion/physical-plan/src/aggregates/no_grouping.rs index 32c0bbc78a5de..90eb488a2ead2 100644 --- a/datafusion/physical-plan/src/aggregates/no_grouping.rs +++ b/datafusion/physical-plan/src/aggregates/no_grouping.rs @@ -217,8 +217,10 @@ fn aggregate_batch( // 1.3 let values = &expr .iter() - .map(|e| e.evaluate(&batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .map(|e| { + e.evaluate(&batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) .collect::>>()?; // 1.4 diff --git a/datafusion/physical-plan/src/aggregates/order/mod.rs b/datafusion/physical-plan/src/aggregates/order/mod.rs index f72d2f06e459f..b258b97a9e84f 100644 --- a/datafusion/physical-plan/src/aggregates/order/mod.rs +++ b/datafusion/physical-plan/src/aggregates/order/mod.rs @@ -23,7 +23,7 @@ use datafusion_physical_expr::{EmitTo, PhysicalSortExpr}; mod full; mod partial; -use crate::windows::PartitionSearchMode; +use crate::InputOrderMode; pub(crate) use full::GroupOrderingFull; pub(crate) use partial::GroupOrderingPartial; @@ -42,18 +42,16 @@ impl GroupOrdering { /// Create a `GroupOrdering` for the the specified ordering pub fn try_new( input_schema: &Schema, - mode: &PartitionSearchMode, + mode: &InputOrderMode, ordering: &[PhysicalSortExpr], ) -> Result { match mode { - PartitionSearchMode::Linear => Ok(GroupOrdering::None), - PartitionSearchMode::PartiallySorted(order_indices) => { + InputOrderMode::Linear => Ok(GroupOrdering::None), + InputOrderMode::PartiallySorted(order_indices) => { GroupOrderingPartial::try_new(input_schema, order_indices, ordering) .map(GroupOrdering::Partial) } - PartitionSearchMode::Sorted => { - Ok(GroupOrdering::Full(GroupOrderingFull::new())) - } + InputOrderMode::Sorted => Ok(GroupOrdering::Full(GroupOrderingFull::new())), } } diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 7cee4a3e7cfc0..89614fd3020ce 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -267,6 +267,12 @@ pub(crate) struct GroupedHashAggregateStream { /// The spill state object spill_state: SpillState, + + /// Optional soft limit on the number of `group_values` in a batch + /// If the number of `group_values` in a single batch exceeds this value, + /// the `GroupedHashAggregateStream` operation immediately switches to + /// output mode and emits all groups. + group_values_soft_limit: Option, } impl GroupedHashAggregateStream { @@ -318,7 +324,9 @@ impl GroupedHashAggregateStream { .map(create_group_accumulator) .collect::>()?; - let group_schema = group_schema(&agg_schema, agg_group_by.expr.len()); + // we need to use original schema so RowConverter in group_values below + // will do the proper coversion of dictionaries into value types + let group_schema = group_schema(&agg.original_schema, agg_group_by.expr.len()); let spill_expr = group_schema .fields .into_iter() @@ -338,7 +346,7 @@ impl GroupedHashAggregateStream { .find_longest_permutation(&agg_group_by.output_exprs()); let group_ordering = GroupOrdering::try_new( &group_schema, - &agg.partition_search_mode, + &agg.input_order_mode, ordering.as_slice(), )?; @@ -374,6 +382,7 @@ impl GroupedHashAggregateStream { input_done: false, runtime: context.runtime_env(), spill_state, + group_values_soft_limit: agg.limit, }) } } @@ -419,7 +428,7 @@ impl Stream for GroupedHashAggregateStream { loop { match &self.exec_state { - ExecutionState::ReadingInput => { + ExecutionState::ReadingInput => 'reading_input: { match ready!(self.input.poll_next_unpin(cx)) { // new batch to aggregate Some(Ok(batch)) => { @@ -434,9 +443,21 @@ impl Stream for GroupedHashAggregateStream { // otherwise keep consuming input assert!(!self.input_done); + // If the number of group values equals or exceeds the soft limit, + // emit all groups and switch to producing output + if self.hit_soft_group_limit() { + timer.done(); + extract_ok!(self.set_input_done_and_produce_output()); + // make sure the exec_state just set is not overwritten below + break 'reading_input; + } + if let Some(to_emit) = self.group_ordering.emit_to() { let batch = extract_ok!(self.emit(to_emit, false)); self.exec_state = ExecutionState::ProducingOutput(batch); + timer.done(); + // make sure the exec_state just set is not overwritten below + break 'reading_input; } extract_ok!(self.emit_early_if_necessary()); @@ -449,18 +470,7 @@ impl Stream for GroupedHashAggregateStream { } None => { // inner is done, emit all rows and switch to producing output - self.input_done = true; - self.group_ordering.input_done(); - let timer = elapsed_compute.timer(); - self.exec_state = if self.spill_state.spills.is_empty() { - let batch = extract_ok!(self.emit(EmitTo::All, false)); - ExecutionState::ProducingOutput(batch) - } else { - // If spill files exist, stream-merge them. - extract_ok!(self.update_merged_stream()); - ExecutionState::ReadingInput - }; - timer.done(); + extract_ok!(self.set_input_done_and_produce_output()); } } } @@ -759,4 +769,31 @@ impl GroupedHashAggregateStream { self.group_ordering = GroupOrdering::Full(GroupOrderingFull::new()); Ok(()) } + + /// returns true if there is a soft groups limit and the number of distinct + /// groups we have seen is over that limit + fn hit_soft_group_limit(&self) -> bool { + let Some(group_values_soft_limit) = self.group_values_soft_limit else { + return false; + }; + group_values_soft_limit <= self.group_values.len() + } + + /// common function for signalling end of processing of the input stream + fn set_input_done_and_produce_output(&mut self) -> Result<()> { + self.input_done = true; + self.group_ordering.input_done(); + let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + let timer = elapsed_compute.timer(); + self.exec_state = if self.spill_state.spills.is_empty() { + let batch = self.emit(EmitTo::All, false)?; + ExecutionState::ProducingOutput(batch) + } else { + // If spill files exist, stream-merge them. + self.update_merged_stream()?; + ExecutionState::ReadingInput + }; + timer.done(); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/analyze.rs b/datafusion/physical-plan/src/analyze.rs index ded37983bb211..4f1578e220ddd 100644 --- a/datafusion/physical-plan/src/analyze.rs +++ b/datafusion/physical-plan/src/analyze.rs @@ -115,8 +115,12 @@ impl ExecutionPlan for AnalyzeExec { /// Specifies whether this plan generates an infinite stream of records. /// If the plan does not support pipelining, but its input(s) are /// infinite, returns an error to indicate this. - fn unbounded_output(&self, _children: &[bool]) -> Result { - internal_err!("Optimization not supported for ANALYZE") + fn unbounded_output(&self, children: &[bool]) -> Result { + if children[0] { + internal_err!("Streaming execution of AnalyzeExec is not possible") + } else { + Ok(false) + } } /// Get the output partitioning of this plan diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index 649f3a31aa7ef..e83dc2525b9fe 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -30,6 +30,7 @@ use crate::{ColumnStatistics, ExecutionPlan, Statistics}; use arrow::datatypes::Schema; use arrow::ipc::writer::{FileWriter, IpcWriteOptions}; use arrow::record_batch::RecordBatch; +use arrow_array::Array; use datafusion_common::stats::Precision; use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_execution::memory_pool::MemoryReservation; @@ -139,17 +140,22 @@ pub fn compute_record_batch_statistics( ) -> Statistics { let nb_rows = batches.iter().flatten().map(RecordBatch::num_rows).sum(); - let total_byte_size = batches - .iter() - .flatten() - .map(|b| b.get_array_memory_size()) - .sum(); - let projection = match projection { Some(p) => p, None => (0..schema.fields().len()).collect(), }; + let total_byte_size = batches + .iter() + .flatten() + .map(|b| { + projection + .iter() + .map(|index| b.column(*index).get_array_memory_size()) + .sum::() + }) + .sum(); + let mut column_statistics = vec![ColumnStatistics::new_unknown(); projection.len()]; for partition in batches.iter() { @@ -388,6 +394,7 @@ mod tests { datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; + use arrow_array::UInt64Array; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{col, Column}; @@ -685,20 +692,30 @@ mod tests { let schema = Arc::new(Schema::new(vec![ Field::new("f32", DataType::Float32, false), Field::new("f64", DataType::Float64, false), + Field::new("u64", DataType::UInt64, false), ])); let batch = RecordBatch::try_new( Arc::clone(&schema), vec![ Arc::new(Float32Array::from(vec![1., 2., 3.])), Arc::new(Float64Array::from(vec![9., 8., 7.])), + Arc::new(UInt64Array::from(vec![4, 5, 6])), ], )?; + + // just select f32,f64 + let select_projection = Some(vec![0, 1]); + let byte_size = batch + .project(&select_projection.clone().unwrap()) + .unwrap() + .get_array_memory_size(); + let actual = - compute_record_batch_statistics(&[vec![batch]], &schema, Some(vec![0, 1])); + compute_record_batch_statistics(&[vec![batch]], &schema, select_projection); - let mut expected = Statistics { + let expected = Statistics { num_rows: Precision::Exact(3), - total_byte_size: Precision::Exact(464), // this might change a bit if the way we compute the size changes + total_byte_size: Precision::Exact(byte_size), column_statistics: vec![ ColumnStatistics { distinct_count: Precision::Absent, @@ -715,9 +732,6 @@ mod tests { ], }; - // Prevent test flakiness due to undefined / changing implementation details - expected.total_byte_size = actual.total_byte_size.clone(); - assert_eq!(actual, expected); Ok(()) } diff --git a/datafusion/physical-plan/src/display.rs b/datafusion/physical-plan/src/display.rs index aa368251ebf32..19c2847b09dc8 100644 --- a/datafusion/physical-plan/src/display.rs +++ b/datafusion/physical-plan/src/display.rs @@ -132,7 +132,7 @@ impl<'a> DisplayableExecutionPlan<'a> { /// ```dot /// strict digraph dot_plan { // 0[label="ProjectionExec: expr=[id@0 + 2 as employee.id + Int32(2)]",tooltip=""] - // 1[label="EmptyExec: produce_one_row=false",tooltip=""] + // 1[label="EmptyExec",tooltip=""] // 0 -> 1 // } /// ``` @@ -260,8 +260,8 @@ impl<'a, 'b> ExecutionPlanVisitor for IndentVisitor<'a, 'b> { } } } - let stats = plan.statistics().map_err(|_e| fmt::Error)?; if self.show_statistics { + let stats = plan.statistics().map_err(|_e| fmt::Error)?; write!(self.f, ", statistics=[{}]", stats)?; } writeln!(self.f)?; @@ -341,8 +341,8 @@ impl ExecutionPlanVisitor for GraphvizVisitor<'_, '_> { } }; - let stats = plan.statistics().map_err(|_e| fmt::Error)?; let statistics = if self.show_statistics { + let stats = plan.statistics().map_err(|_e| fmt::Error)?; format!("statistics=[{}]", stats) } else { "".to_string() @@ -436,3 +436,126 @@ impl<'a> fmt::Display for OutputOrderingDisplay<'a> { write!(f, "]") } } + +#[cfg(test)] +mod tests { + use std::fmt::Write; + use std::sync::Arc; + + use datafusion_common::DataFusionError; + + use crate::{DisplayAs, ExecutionPlan}; + + use super::DisplayableExecutionPlan; + + #[derive(Debug, Clone, Copy)] + enum TestStatsExecPlan { + Panic, + Error, + Ok, + } + + impl DisplayAs for TestStatsExecPlan { + fn fmt_as( + &self, + _t: crate::DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!(f, "TestStatsExecPlan") + } + } + + impl ExecutionPlan for TestStatsExecPlan { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> arrow_schema::SchemaRef { + Arc::new(arrow_schema::Schema::empty()) + } + + fn output_partitioning(&self) -> datafusion_physical_expr::Partitioning { + datafusion_physical_expr::Partitioning::UnknownPartitioning(1) + } + + fn output_ordering( + &self, + ) -> Option<&[datafusion_physical_expr::PhysicalSortExpr]> { + None + } + + fn children(&self) -> Vec> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> datafusion_common::Result> { + unimplemented!() + } + + fn execute( + &self, + _: usize, + _: Arc, + ) -> datafusion_common::Result + { + todo!() + } + + fn statistics(&self) -> datafusion_common::Result { + match self { + Self::Panic => panic!("expected panic"), + Self::Error => { + Err(DataFusionError::Internal("expected error".to_string())) + } + Self::Ok => Ok(datafusion_common::Statistics::new_unknown( + self.schema().as_ref(), + )), + } + } + } + + fn test_stats_display(exec: TestStatsExecPlan, show_stats: bool) { + let display = + DisplayableExecutionPlan::new(&exec).set_show_statistics(show_stats); + + let mut buf = String::new(); + write!(&mut buf, "{}", display.one_line()).unwrap(); + let buf = buf.trim(); + assert_eq!(buf, "TestStatsExecPlan"); + } + + #[test] + fn test_display_when_stats_panic_with_no_show_stats() { + test_stats_display(TestStatsExecPlan::Panic, false); + } + + #[test] + fn test_display_when_stats_error_with_no_show_stats() { + test_stats_display(TestStatsExecPlan::Error, false); + } + + #[test] + fn test_display_when_stats_ok_with_no_show_stats() { + test_stats_display(TestStatsExecPlan::Ok, false); + } + + #[test] + #[should_panic(expected = "expected panic")] + fn test_display_when_stats_panic_with_show_stats() { + test_stats_display(TestStatsExecPlan::Panic, true); + } + + #[test] + #[should_panic(expected = "Error")] // fmt::Error + fn test_display_when_stats_error_with_show_stats() { + test_stats_display(TestStatsExecPlan::Error, true); + } + + #[test] + fn test_display_when_stats_ok_with_show_stats() { + test_stats_display(TestStatsExecPlan::Ok, false); + } +} diff --git a/datafusion/physical-plan/src/empty.rs b/datafusion/physical-plan/src/empty.rs index a3e1fb79edb59..41c8dbed14536 100644 --- a/datafusion/physical-plan/src/empty.rs +++ b/datafusion/physical-plan/src/empty.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! EmptyRelation execution plan +//! EmptyRelation with produce_one_row=false execution plan use std::any::Any; use std::sync::Arc; @@ -24,19 +24,16 @@ use super::expressions::PhysicalSortExpr; use super::{common, DisplayAs, SendableRecordBatchStream, Statistics}; use crate::{memory::MemoryStream, DisplayFormatType, ExecutionPlan, Partitioning}; -use arrow::array::{ArrayRef, NullArray}; -use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; +use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_execution::TaskContext; use log::trace; -/// Execution plan for empty relation (produces no rows) +/// Execution plan for empty relation with produce_one_row=false #[derive(Debug)] pub struct EmptyExec { - /// Specifies whether this exec produces a row or not - produce_one_row: bool, /// The schema for the produced row schema: SchemaRef, /// Number of partitions @@ -45,9 +42,8 @@ pub struct EmptyExec { impl EmptyExec { /// Create a new EmptyExec - pub fn new(produce_one_row: bool, schema: SchemaRef) -> Self { + pub fn new(schema: SchemaRef) -> Self { EmptyExec { - produce_one_row, schema, partitions: 1, } @@ -59,36 +55,8 @@ impl EmptyExec { self } - /// Specifies whether this exec produces a row or not - pub fn produce_one_row(&self) -> bool { - self.produce_one_row - } - fn data(&self) -> Result> { - let batch = if self.produce_one_row { - let n_field = self.schema.fields.len(); - // hack for https://github.com/apache/arrow-datafusion/pull/3242 - let n_field = if n_field == 0 { 1 } else { n_field }; - vec![RecordBatch::try_new( - Arc::new(Schema::new( - (0..n_field) - .map(|i| { - Field::new(format!("placeholder_{i}"), DataType::Null, true) - }) - .collect::(), - )), - (0..n_field) - .map(|_i| { - let ret: ArrayRef = Arc::new(NullArray::new(1)); - ret - }) - .collect(), - )?] - } else { - vec![] - }; - - Ok(batch) + Ok(vec![]) } } @@ -100,7 +68,7 @@ impl DisplayAs for EmptyExec { ) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "EmptyExec: produce_one_row={}", self.produce_one_row) + write!(f, "EmptyExec") } } } @@ -133,10 +101,7 @@ impl ExecutionPlan for EmptyExec { self: Arc, _: Vec>, ) -> Result> { - Ok(Arc::new(EmptyExec::new( - self.produce_one_row, - self.schema.clone(), - ))) + Ok(Arc::new(EmptyExec::new(self.schema.clone()))) } fn execute( @@ -184,7 +149,7 @@ mod tests { let task_ctx = Arc::new(TaskContext::default()); let schema = test::aggr_test_schema(); - let empty = EmptyExec::new(false, schema.clone()); + let empty = EmptyExec::new(schema.clone()); assert_eq!(empty.schema(), schema); // we should have no results @@ -198,16 +163,11 @@ mod tests { #[test] fn with_new_children() -> Result<()> { let schema = test::aggr_test_schema(); - let empty = Arc::new(EmptyExec::new(false, schema.clone())); - let empty_with_row = Arc::new(EmptyExec::new(true, schema)); + let empty = Arc::new(EmptyExec::new(schema.clone())); let empty2 = with_new_children_if_necessary(empty.clone(), vec![])?.into(); assert_eq!(empty.schema(), empty2.schema()); - let empty_with_row_2 = - with_new_children_if_necessary(empty_with_row.clone(), vec![])?.into(); - assert_eq!(empty_with_row.schema(), empty_with_row_2.schema()); - let too_many_kids = vec![empty2]; assert!( with_new_children_if_necessary(empty, too_many_kids).is_err(), @@ -220,44 +180,11 @@ mod tests { async fn invalid_execute() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); let schema = test::aggr_test_schema(); - let empty = EmptyExec::new(false, schema); + let empty = EmptyExec::new(schema); // ask for the wrong partition assert!(empty.execute(1, task_ctx.clone()).is_err()); assert!(empty.execute(20, task_ctx).is_err()); Ok(()) } - - #[tokio::test] - async fn produce_one_row() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); - let schema = test::aggr_test_schema(); - let empty = EmptyExec::new(true, schema); - - let iter = empty.execute(0, task_ctx)?; - let batches = common::collect(iter).await?; - - // should have one item - assert_eq!(batches.len(), 1); - - Ok(()) - } - - #[tokio::test] - async fn produce_one_row_multiple_partition() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); - let schema = test::aggr_test_schema(); - let partitions = 3; - let empty = EmptyExec::new(true, schema).with_partitions(partitions); - - for n in 0..partitions { - let iter = empty.execute(n, task_ctx.clone())?; - let batches = common::collect(iter).await?; - - // should have one item - assert_eq!(batches.len(), 1); - } - - Ok(()) - } } diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index ce66d614721c5..56a1b4e178219 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -27,7 +27,6 @@ use super::expressions::PhysicalSortExpr; use super::{ ColumnStatistics, DisplayAs, RecordBatchStream, SendableRecordBatchStream, Statistics, }; - use crate::{ metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, Column, DisplayFormatType, ExecutionPlan, Partitioning, @@ -62,6 +61,8 @@ pub struct FilterExec { input: Arc, /// Execution metrics metrics: ExecutionPlanMetricsSet, + /// Selectivity for statistics. 0 = no rows, 100 all rows + default_selectivity: u8, } impl FilterExec { @@ -75,6 +76,7 @@ impl FilterExec { predicate, input: input.clone(), metrics: ExecutionPlanMetricsSet::new(), + default_selectivity: 20, }), other => { plan_err!("Filter predicate must return boolean values, not {other:?}") @@ -82,6 +84,17 @@ impl FilterExec { } } + pub fn with_default_selectivity( + mut self, + default_selectivity: u8, + ) -> Result { + if default_selectivity > 100 { + return plan_err!("Default flter selectivity needs to be less than 100"); + } + self.default_selectivity = default_selectivity; + Ok(self) + } + /// The expression to filter on. This expression must evaluate to a boolean value. pub fn predicate(&self) -> &Arc { &self.predicate @@ -91,6 +104,11 @@ impl FilterExec { pub fn input(&self) -> &Arc { &self.input } + + /// The default selectivity + pub fn default_selectivity(&self) -> u8 { + self.default_selectivity + } } impl DisplayAs for FilterExec { @@ -167,6 +185,10 @@ impl ExecutionPlan for FilterExec { mut children: Vec>, ) -> Result> { FilterExec::try_new(self.predicate.clone(), children.swap_remove(0)) + .and_then(|e| { + let selectivity = e.default_selectivity(); + e.with_default_selectivity(selectivity) + }) .map(|e| Arc::new(e) as _) } @@ -194,11 +216,17 @@ impl ExecutionPlan for FilterExec { fn statistics(&self) -> Result { let predicate = self.predicate(); + let input_stats = self.input.statistics()?; let schema = self.schema(); if !check_support(predicate, &schema) { - return Ok(Statistics::new_unknown(&schema)); + let selectivity = self.default_selectivity as f64 / 100.0; + let mut stats = input_stats.into_inexact(); + stats.num_rows = stats.num_rows.with_estimated_selectivity(selectivity); + stats.total_byte_size = stats + .total_byte_size + .with_estimated_selectivity(selectivity); + return Ok(stats); } - let input_stats = self.input.statistics()?; let num_rows = input_stats.num_rows; let total_byte_size = input_stats.total_byte_size; @@ -206,18 +234,13 @@ impl ExecutionPlan for FilterExec { &self.input.schema(), &input_stats.column_statistics, )?; - let analysis_ctx = analyze(predicate, input_analysis_ctx)?; + + let analysis_ctx = analyze(predicate, input_analysis_ctx, &self.schema())?; // Estimate (inexact) selectivity of predicate let selectivity = analysis_ctx.selectivity.unwrap_or(1.0); - let num_rows = match num_rows.get_value() { - Some(nr) => Precision::Inexact((*nr as f64 * selectivity).ceil() as usize), - None => Precision::Absent, - }; - let total_byte_size = match total_byte_size.get_value() { - Some(tbs) => Precision::Inexact((*tbs as f64 * selectivity).ceil() as usize), - None => Precision::Absent, - }; + let num_rows = num_rows.with_estimated_selectivity(selectivity); + let total_byte_size = total_byte_size.with_estimated_selectivity(selectivity); let column_statistics = collect_new_statistics( &input_stats.column_statistics, @@ -251,18 +274,17 @@ fn collect_new_statistics( .. }, )| { - let closed_interval = interval.close_bounds(); + let (lower, upper) = interval.into_bounds(); + let (min_value, max_value) = if lower.eq(&upper) { + (Precision::Exact(lower), Precision::Exact(upper)) + } else { + (Precision::Inexact(lower), Precision::Inexact(upper)) + }; ColumnStatistics { - null_count: match input_column_stats[idx].null_count.get_value() { - Some(nc) => Precision::Inexact(*nc), - None => Precision::Absent, - }, - max_value: Precision::Inexact(closed_interval.upper.value), - min_value: Precision::Inexact(closed_interval.lower.value), - distinct_count: match distinct_count.get_value() { - Some(dc) => Precision::Inexact(*dc), - None => Precision::Absent, - }, + null_count: input_column_stats[idx].null_count.clone().to_inexact(), + max_value, + min_value, + distinct_count: distinct_count.to_inexact(), } }, ) @@ -288,7 +310,7 @@ pub(crate) fn batch_filter( ) -> Result { predicate .evaluate(batch) - .map(|v| v.into_array(batch.num_rows())) + .and_then(|v| v.into_array(batch.num_rows())) .and_then(|array| { Ok(as_boolean_array(&array)?) // apply filter array to record batch @@ -963,4 +985,76 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_statistics_with_constant_column() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics::new_unknown(&schema), + schema, + )); + // WHERE a = 10 + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, input)?); + let filter_statistics = filter.statistics()?; + // First column is "a", and it is a column with only one value after the filter. + assert!(filter_statistics.column_statistics[0].is_singleton()); + + Ok(()) + } + + #[tokio::test] + async fn test_validation_filter_selectivity() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics::new_unknown(&schema), + schema, + )); + // WHERE a = 10 + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )); + let filter = FilterExec::try_new(predicate, input)?; + assert!(filter.with_default_selectivity(120).is_err()); + Ok(()) + } + + #[tokio::test] + async fn test_custom_filter_selectivity() -> Result<()> { + // Need a decimal to trigger inexact selectivity + let schema = + Schema::new(vec![Field::new("a", DataType::Decimal128(2, 3), false)]); + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(1000), + total_byte_size: Precision::Inexact(4000), + column_statistics: vec![ColumnStatistics { + ..Default::default() + }], + }, + schema, + )); + // WHERE a = 10 + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Decimal128(Some(10), 10, 10))), + )); + let filter = FilterExec::try_new(predicate, input)?; + let statistics = filter.statistics()?; + assert_eq!(statistics.num_rows, Precision::Inexact(200)); + assert_eq!(statistics.total_byte_size, Precision::Inexact(800)); + let filter = filter.with_default_selectivity(40)?; + let statistics = filter.statistics()?; + assert_eq!(statistics.num_rows, Precision::Inexact(400)); + assert_eq!(statistics.total_byte_size, Precision::Inexact(1600)); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/insert.rs b/datafusion/physical-plan/src/insert.rs index 627d58e137816..81cdfd753fe69 100644 --- a/datafusion/physical-plan/src/insert.rs +++ b/datafusion/physical-plan/src/insert.rs @@ -151,11 +151,21 @@ impl FileSinkExec { } } + /// Input execution plan + pub fn input(&self) -> &Arc { + &self.input + } + /// Returns insert sink pub fn sink(&self) -> &dyn DataSink { self.sink.as_ref() } + /// Optional sort order for output data + pub fn sort_order(&self) -> &Option> { + &self.sort_order + } + /// Returns the metrics of the underlying [DataSink] pub fn metrics(&self) -> Option { self.sink.metrics() @@ -170,7 +180,7 @@ impl DisplayAs for FileSinkExec { ) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "InsertExec: sink=")?; + write!(f, "FileSinkExec: sink=")?; self.sink.fmt_as(t, f) } } @@ -209,24 +219,17 @@ impl ExecutionPlan for FileSinkExec { } fn required_input_ordering(&self) -> Vec>> { - // The input order is either exlicitly set (such as by a ListingTable), - // or require that the [FileSinkExec] gets the data in the order the - // input produced it (otherwise the optimizer may chose to reorder - // the input which could result in unintended / poor UX) - // - // More rationale: - // https://github.com/apache/arrow-datafusion/pull/6354#discussion_r1195284178 - match &self.sort_order { - Some(requirements) => vec![Some(requirements.clone())], - None => vec![self - .input - .output_ordering() - .map(PhysicalSortRequirement::from_sort_exprs)], - } + // The required input ordering is set externally (e.g. by a `ListingTable`). + // Otherwise, there is no specific requirement (i.e. `sort_expr` is `None`). + vec![self.sort_order.as_ref().cloned()] } fn maintains_input_order(&self) -> Vec { - vec![false] + // Maintains ordering in the sense that the written file will reflect + // the ordering of the input. For more context, see: + // + // https://github.com/apache/arrow-datafusion/pull/6354#discussion_r1195284178 + vec![true] } fn children(&self) -> Vec> { diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 102f0c42e90c9..938c9e4d343d6 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -344,7 +344,7 @@ fn build_batch( .iter() .map(|arr| { let scalar = ScalarValue::try_from_array(arr, left_index)?; - Ok(scalar.to_array_of_size(batch.num_rows())) + scalar.to_array_of_size(batch.num_rows()) }) .collect::>>()?; @@ -476,12 +476,8 @@ mod tests { }, ColumnStatistics { distinct_count: Precision::Exact(1), - max_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "x", - )))), - min_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "a", - )))), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), null_count: Precision::Exact(3), }, ], @@ -512,12 +508,8 @@ mod tests { }, ColumnStatistics { distinct_count: Precision::Exact(1), - max_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "x", - )))), - min_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "a", - )))), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), null_count: Precision::Exact(3 * right_row_count), }, ColumnStatistics { @@ -548,12 +540,8 @@ mod tests { }, ColumnStatistics { distinct_count: Precision::Exact(1), - max_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "x", - )))), - min_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "a", - )))), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), null_count: Precision::Exact(3), }, ], @@ -584,12 +572,8 @@ mod tests { }, ColumnStatistics { distinct_count: Precision::Exact(1), - max_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "x", - )))), - min_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "a", - )))), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), null_count: Precision::Absent, // we don't know the row count on the right }, ColumnStatistics { diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 3cc26e7edafae..7111cec2829bc 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -27,25 +27,23 @@ use std::{any::Any, usize, vec}; use crate::joins::utils::{ adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, calculate_join_output_ordering, get_final_indices_from_bit_map, - need_produce_result_in_final, + need_produce_result_in_final, JoinHashMap, JoinHashMapType, }; -use crate::DisplayAs; use crate::{ - coalesce_batches::concat_batches, coalesce_partitions::CoalescePartitionsExec, expressions::Column, expressions::PhysicalSortExpr, hash_utils::create_hashes, - joins::hash_join_utils::{JoinHashMap, JoinHashMapType}, joins::utils::{ adjust_right_output_partitioning, build_join_schema, check_join_is_valid, estimate_join_statistics, partitioned_join_output_partitioning, - BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinOn, + BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinOn, StatefulStreamResult, }, metrics::{ExecutionPlanMetricsSet, MetricsSet}, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, }; +use crate::{handle_state, DisplayAs}; use super::{ utils::{OnceAsync, OnceFut}, @@ -54,10 +52,10 @@ use super::{ use arrow::array::{ Array, ArrayRef, BooleanArray, BooleanBufferBuilder, PrimitiveArray, UInt32Array, - UInt32BufferBuilder, UInt64Array, UInt64BufferBuilder, + UInt64Array, }; use arrow::compute::kernels::cmp::{eq, not_distinct}; -use arrow::compute::{and, take, FilterBuilder}; +use arrow::compute::{and, concat_batches, take, FilterBuilder}; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow::util::bit_util; @@ -74,7 +72,47 @@ use datafusion_physical_expr::EquivalenceProperties; use ahash::RandomState; use futures::{ready, Stream, StreamExt, TryStreamExt}; -type JoinLeftData = (JoinHashMap, RecordBatch, MemoryReservation); +/// HashTable and input data for the left (build side) of a join +struct JoinLeftData { + /// The hash table with indices into `batch` + hash_map: JoinHashMap, + /// The input rows for the build side + batch: RecordBatch, + /// Memory reservation that tracks memory used by `hash_map` hash table + /// `batch`. Cleared on drop. + #[allow(dead_code)] + reservation: MemoryReservation, +} + +impl JoinLeftData { + /// Create a new `JoinLeftData` from its parts + fn new( + hash_map: JoinHashMap, + batch: RecordBatch, + reservation: MemoryReservation, + ) -> Self { + Self { + hash_map, + batch, + reservation, + } + } + + /// Returns the number of rows in the build side + fn num_rows(&self) -> usize { + self.batch.num_rows() + } + + /// return a reference to the hash map + fn hash_map(&self) -> &JoinHashMap { + &self.hash_map + } + + /// returns a reference to the build side batch + fn batch(&self) -> &RecordBatch { + &self.batch + } +} /// Join execution plan: Evaluates eqijoin predicates in parallel on multiple /// partitions using a hash table and an optional filter list to apply post @@ -118,8 +156,48 @@ type JoinLeftData = (JoinHashMap, RecordBatch, MemoryReservation); /// /// Execution proceeds in 2 stages: /// -/// 1. the **build phase** where a hash table is created from the tuples of the -/// build side. +/// 1. the **build phase** creates a hash table from the tuples of the build side, +/// and single concatenated batch containing data from all fetched record batches. +/// Resulting hash table stores hashed join-key fields for each row as a key, and +/// indices of corresponding rows in concatenated batch. +/// +/// Hash join uses LIFO data structure as a hash table, and in order to retain +/// original build-side input order while obtaining data during probe phase, hash +/// table is updated by iterating batch sequence in reverse order -- it allows to +/// keep rows with smaller indices "on the top" of hash table, and still maintain +/// correct indexing for concatenated build-side data batch. +/// +/// Example of build phase for 3 record batches: +/// +/// +/// ```text +/// +/// Original build-side data Inserting build-side values into hashmap Concatenated build-side batch +/// ┌───────────────────────────┐ +/// hasmap.insert(row-hash, row-idx + offset) │ idx │ +/// ┌───────┐ │ ┌───────┐ │ +/// │ Row 1 │ 1) update_hash for batch 3 with offset 0 │ │ Row 6 │ 0 │ +/// Batch 1 │ │ - hashmap.insert(Row 7, idx 1) │ Batch 3 │ │ │ +/// │ Row 2 │ - hashmap.insert(Row 6, idx 0) │ │ Row 7 │ 1 │ +/// └───────┘ │ └───────┘ │ +/// │ │ +/// ┌───────┐ │ ┌───────┐ │ +/// │ Row 3 │ 2) update_hash for batch 2 with offset 2 │ │ Row 3 │ 2 │ +/// │ │ - hashmap.insert(Row 5, idx 4) │ │ │ │ +/// Batch 2 │ Row 4 │ - hashmap.insert(Row 4, idx 3) │ Batch 2 │ Row 4 │ 3 │ +/// │ │ - hashmap.insert(Row 3, idx 2) │ │ │ │ +/// │ Row 5 │ │ │ Row 5 │ 4 │ +/// └───────┘ │ └───────┘ │ +/// │ │ +/// ┌───────┐ │ ┌───────┐ │ +/// │ Row 6 │ 3) update_hash for batch 1 with offset 5 │ │ Row 1 │ 5 │ +/// Batch 3 │ │ - hashmap.insert(Row 2, idx 5) │ Batch 1 │ │ │ +/// │ Row 7 │ - hashmap.insert(Row 1, idx 6) │ │ Row 2 │ 6 │ +/// └───────┘ │ └───────┘ │ +/// │ │ +/// └───────────────────────────┘ +/// +/// ``` /// /// 2. the **probe phase** where the tuples of the probe side are streamed /// through, checking for matches of the join keys in the hash table. @@ -582,18 +660,15 @@ impl ExecutionPlan for HashJoinExec { on_right, filter: self.filter.clone(), join_type: self.join_type, - left_fut, - visited_left_side: None, right: right_stream, column_indices: self.column_indices.clone(), random_state: self.random_state.clone(), join_metrics, null_equals_null: self.null_equals_null, - is_exhausted: false, reservation, + state: HashJoinStreamState::WaitBuildSide, + build_side: BuildSide::Initial(BuildSideInitialState { left_fut }), batch_size, - probe_batch: None, - output_state: HashJoinOutputState::default(), })) } @@ -615,6 +690,8 @@ impl ExecutionPlan for HashJoinExec { } } +/// Reads the left (build) side of the input, buffering it in memory, to build a +/// hash table (`LeftJoinData`) async fn collect_left_input( partition: Option, random_state: RandomState, @@ -681,7 +758,10 @@ async fn collect_left_input( let mut hashmap = JoinHashMap::with_capacity(num_rows); let mut hashes_buffer = Vec::new(); let mut offset = 0; - for batch in batches.iter() { + + // Updating hashmap starting from the last batch + let batches_iter = batches.iter().rev(); + for batch in batches_iter.clone() { hashes_buffer.clear(); hashes_buffer.resize(batch.num_rows(), 0); update_hash( @@ -692,18 +772,25 @@ async fn collect_left_input( &random_state, &mut hashes_buffer, 0, + true, )?; offset += batch.num_rows(); } // Merge all batches into a single batch, so we // can directly index into the arrays - let single_batch = concat_batches(&schema, &batches, num_rows)?; + let single_batch = concat_batches(&schema, batches_iter)?; + let data = JoinLeftData::new(hashmap, single_batch, reservation); - Ok((hashmap, single_batch, reservation)) + Ok(data) } -/// Updates `hash` with new entries from [RecordBatch] evaluated against the expressions `on`, -/// assuming that the [RecordBatch] corresponds to the `index`th +/// Updates `hash_map` with new entries from `batch` evaluated against the expressions `on` +/// using `offset` as a start value for `batch` row indices. +/// +/// `fifo_hashmap` sets the order of iteration over `batch` rows while updating hashmap, +/// which allows to keep either first (if set to true) or last (if set to false) row index +/// as a chain head for rows with equal hash values. +#[allow(clippy::too_many_arguments)] pub fn update_hash( on: &[Column], batch: &RecordBatch, @@ -712,6 +799,7 @@ pub fn update_hash( random_state: &RandomState, hashes_buffer: &mut Vec, deleted_offset: usize, + fifo_hashmap: bool, ) -> Result<()> where T: JoinHashMapType, @@ -719,7 +807,7 @@ where // evaluate the keys let keys_values = on .iter() - .map(|c| Ok(c.evaluate(batch)?.into_array(batch.num_rows()))) + .map(|c| c.evaluate(batch)?.into_array(batch.num_rows())) .collect::>>()?; // calculate the hash values @@ -728,119 +816,116 @@ where // For usual JoinHashmap, the implementation is void. hash_map.extend_zero(batch.num_rows()); - // insert hashes to key of the hashmap - let (mut_map, mut_list) = hash_map.get_mut(); - for (row, hash_value) in hash_values.iter().enumerate() { - let item = mut_map.get_mut(*hash_value, |(hash, _)| *hash_value == *hash); - if let Some((_, index)) = item { - // Already exists: add index to next array - let prev_index = *index; - // Store new value inside hashmap - *index = (row + offset + 1) as u64; - // Update chained Vec at row + offset with previous value - mut_list[row + offset - deleted_offset] = prev_index; - } else { - mut_map.insert( - *hash_value, - // store the value + 1 as 0 value reserved for end of list - (*hash_value, (row + offset + 1) as u64), - |(hash, _)| *hash, - ); - // chained list at (row + offset) is already initialized with 0 - // meaning end of list - } + // Updating JoinHashMap from hash values iterator + let hash_values_iter = hash_values + .iter() + .enumerate() + .map(|(i, val)| (i + offset, val)); + + if fifo_hashmap { + hash_map.update_from_iter(hash_values_iter.rev(), deleted_offset); + } else { + hash_map.update_from_iter(hash_values_iter, deleted_offset); } + Ok(()) } -// State for storing left/right side indices used for partial batch output -// & producing ranges for adjusting indices -#[derive(Debug, Default)] -pub(crate) struct HashJoinOutputState { - // total rows in current probe batch - probe_rows: usize, - // saved probe-build indices to resume matching from - last_matched_indices: Option<(usize, usize)>, - // current iteration has been updated - matched_indices_updated: bool, - // last probe side index, joined during current iteration - last_joined_probe_index: Option, - // last probe side index, joined during previous iteration - prev_joined_probe_index: Option, +/// Represents build-side of hash join. +enum BuildSide { + /// Indicates that build-side not collected yet + Initial(BuildSideInitialState), + /// Indicates that build-side data has been collected + Ready(BuildSideReadyState), } -impl HashJoinOutputState { - // set total probe rows to process - pub(crate) fn set_probe_rows(&mut self, probe_rows: usize) { - self.probe_rows = probe_rows; - } - // obtain last_matched_indices -- initial point to resume matching - fn start_mathching_iteration(&self) -> (usize, usize) { - self.last_matched_indices - .map_or_else(|| (0, 0), |pair| pair) - } - - // if current probe batch processing is in partial-output state - fn partial_output(&self) -> bool { - self.last_matched_indices.is_some() - } +/// Container for BuildSide::Initial related data +struct BuildSideInitialState { + /// Future for building hash table from build-side input + left_fut: OnceFut, +} - // if current probe batch processing completed -- all probe rows have been joined to build rows - pub(crate) fn is_completed(&self) -> bool { - self.last_matched_indices - .is_some_and(|(probe, build)| probe + 1 >= self.probe_rows && build == 0) - } +/// Container for BuildSide::Ready related data +struct BuildSideReadyState { + /// Collected build-side data + left_data: Arc, + /// Which build-side rows have been matched while creating output. + /// For some OUTER joins, we need to know which rows have not been matched + /// to produce the correct output. + visited_left_side: BooleanBufferBuilder, +} - // saving next probe-build indices to start next iteration of matching - fn update_matching_iteration(&mut self, probe_idx: usize, build_idx: usize) { - self.last_matched_indices = Some((probe_idx, build_idx)); - self.matched_indices_updated = true; +impl BuildSide { + /// Tries to extract BuildSideInitialState from BuildSide enum. + /// Returns an error if state is not Initial. + fn try_as_initial_mut(&mut self) -> Result<&mut BuildSideInitialState> { + match self { + BuildSide::Initial(state) => Ok(state), + _ => internal_err!("Expected build side in initial state"), + } } - // updating state after matching iteration has been performed - fn finalize_matching_iteration(&mut self, joined_right_side: &UInt32Array) { - // if there were no intermediate updates of matched inidices, during current iteration, - // setting indices like whole current batch has been scanned - if !self.matched_indices_updated { - self.last_matched_indices = Some((self.probe_rows, 0)); + /// Tries to extract BuildSideReadyState from BuildSide enum. + /// Returns an error if state is not Ready. + fn try_as_ready(&self) -> Result<&BuildSideReadyState> { + match self { + BuildSide::Ready(state) => Ok(state), + _ => internal_err!("Expected build side in ready state"), } - self.matched_indices_updated = false; + } - // advancing joined probe-side indices - self.prev_joined_probe_index = self.last_joined_probe_index; - if !joined_right_side.is_empty() { - self.last_joined_probe_index = - Some(joined_right_side.value(joined_right_side.len() - 1) as usize); + /// Tries to extract BuildSideReadyState from BuildSide enum. + /// Returns an error if state is not Ready. + fn try_as_ready_mut(&mut self) -> Result<&mut BuildSideReadyState> { + match self { + BuildSide::Ready(state) => Ok(state), + _ => internal_err!("Expected build side in ready state"), } } +} - pub(crate) fn reset_state(&mut self) { - self.probe_rows = 0; - self.last_matched_indices = None; - self.last_joined_probe_index = None; - self.matched_indices_updated = false; - } - - // The goals for different join types are: - // 1) Right & FullJoin -- to append all missing probe-side indices between - // previous (excluding) and current joined indices. - // 2) SemiJoin -- deduplicate probe indices in range between previous - // (excluding) and current joined indices. - // 3) AntiJoin -- return only missing indices in range between - // previous and current joined indices. - // Inclusion/exclusion of the indices themselves don't matter - // As a result -- partial adjustment range can be produced based only on - // joined (matched with filters applied) probe side indices, excluding starting one - // (left from previous iteration) - pub(crate) fn adjust_range(&self) -> Range { - let rg_start = self.prev_joined_probe_index.map_or(0, |v| v + 1); - let rg_end = if self.is_completed() { - self.probe_rows - } else { - self.last_joined_probe_index.map_or(0, |v| v + 1) - }; +/// Represents state of HashJoinStream +/// +/// Expected state transitions performed by HashJoinStream are: +/// +/// ```text +/// +/// WaitBuildSide +/// │ +/// ▼ +/// ┌─► FetchProbeBatch ───► ExhaustedProbeSide ───► Completed +/// │ │ +/// │ ▼ +/// └─ ProcessProbeBatch +/// +/// ``` +enum HashJoinStreamState { + /// Initial state for HashJoinStream indicating that build-side data not collected yet + WaitBuildSide, + /// Indicates that build-side has been collected, and stream is ready for fetching probe-side + FetchProbeBatch, + /// Indicates that non-empty batch has been fetched from probe-side, and is ready to be processed + ProcessProbeBatch(ProcessProbeBatchState), + /// Indicates that probe-side has been fully processed + ExhaustedProbeSide, + /// Indicates that HashJoinStream execution is completed + Completed, +} + +/// Container for HashJoinStreamState::ProcessProbeBatch related data +struct ProcessProbeBatchState { + /// Current probe-side batch + batch: RecordBatch, +} - rg_start..rg_end +impl HashJoinStreamState { + /// Tries to extract ProcessProbeBatchState from HashJoinStreamState enum. + /// Returns an error if state is not ProcessProbeBatchState. + fn try_as_process_probe_batch(&self) -> Result<&ProcessProbeBatchState> { + match self { + HashJoinStreamState::ProcessProbeBatch(state) => Ok(state), + _ => internal_err!("Expected hash join stream in ProcessProbeBatch state"), + } } } @@ -863,20 +948,10 @@ struct HashJoinStream { filter: Option, /// type of the join (left, right, semi, etc) join_type: JoinType, - /// future which builds hash table from left side - left_fut: OnceFut, - /// Which left (probe) side rows have been matches while creating output. - /// For some OUTER joins, we need to know which rows have not been matched - /// to produce the correct. - visited_left_side: Option, /// right (probe) input right: SendableRecordBatchStream, /// Random state used for hashing initialization random_state: RandomState, - /// The join output is complete. For outer joins, this is used to - /// distinguish when the input stream is exhausted and when any unmatched - /// rows are output. - is_exhausted: bool, /// Metrics join_metrics: BuildProbeJoinMetrics, /// Information of index and left / right placement of columns @@ -885,15 +960,12 @@ struct HashJoinStream { null_equals_null: bool, /// Memory reservation reservation: MemoryReservation, - /// Batch size + /// State of the stream + state: HashJoinStreamState, + /// Build side + build_side: BuildSide, + /// Max output batch size batch_size: usize, - /// Current probe batch - probe_batch: Option, - /// In case joining current probe batch with build side may produce more than `batch_size` records - /// (cross-join due to key duplication on build side) `HashJoinStream` saves its state - /// and emits result batch to upstream operator. - /// On next poll these indices are used to skip already matched rows and adjusted probe-side indices. - output_state: HashJoinOutputState, } impl RecordBatchStream for HashJoinStream { @@ -907,7 +979,7 @@ impl RecordBatchStream for HashJoinStream { /// # Example /// /// For `LEFT.b1 = RIGHT.b2`: -/// LEFT Table: +/// LEFT (build) Table: /// ```text /// a1 b1 c1 /// 1 1 10 @@ -919,7 +991,7 @@ impl RecordBatchStream for HashJoinStream { /// 13 10 130 /// ``` /// -/// RIGHT Table: +/// RIGHT (probe) Table: /// ```text /// a2 b2 c2 /// 2 2 20 @@ -960,27 +1032,25 @@ pub(crate) fn build_equal_condition_join_indices( filter: Option<&JoinFilter>, build_side: JoinSide, deleted_offset: Option, - output_limit: usize, - state: &mut HashJoinOutputState, + fifo_hashmap: bool, ) -> Result<(UInt64Array, UInt32Array)> { let keys_values = probe_on .iter() - .map(|c| Ok(c.evaluate(probe_batch)?.into_array(probe_batch.num_rows()))) + .map(|c| c.evaluate(probe_batch)?.into_array(probe_batch.num_rows())) .collect::>>()?; let build_join_values = build_on .iter() .map(|c| { - Ok(c.evaluate(build_input_buffer)? - .into_array(build_input_buffer.num_rows())) + c.evaluate(build_input_buffer)? + .into_array(build_input_buffer.num_rows()) }) .collect::>>()?; hashes_buffer.clear(); hashes_buffer.resize(probe_batch.num_rows(), 0); let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; - // Using a buffer builder to avoid slower normal builder - let mut build_indices = UInt64BufferBuilder::new(0); - let mut probe_indices = UInt32BufferBuilder::new(0); - // The chained list algorithm generates build indices for each probe row in a reversed sequence as such: + + // In case build-side input has not been inverted while JoinHashMap creation, the chained list algorithm + // will return build indices for each probe row in a reverse order as such: // Build Indices: [5, 4, 3] // Probe Indices: [1, 1, 1] // @@ -1009,68 +1079,17 @@ pub(crate) fn build_equal_condition_join_indices( // (5,1) // // With this approach, the lexicographic order on both the probe side and the build side is preserved. - let hash_map = build_hashmap.get_map(); - let next_chain = build_hashmap.get_list(); - - let mut output_tuples = 0_usize; - - // Get starting point in case resuming current probe-batch - let (initial_probe, initial_build) = state.start_mathching_iteration(); - - 'probe: for (row, hash_value) in hash_values.iter().enumerate().skip(initial_probe) { - let index = if state.partial_output() && row == initial_probe { - // using build index from state for the first row - // in case of partially skipped input - if initial_build == 0 { - continue; - } - Some(initial_build as u64) - } else if let Some((_, index)) = - hash_map.get(*hash_value, |(hash, _)| *hash_value == *hash) - { - // otherwise -- checking build hashmap for presence of current hash_value - Some(*index) - } else { - None - }; - - // For every item on the build and probe we check if it matches - // This possibly contains rows with hash collisions, - // So we have to check here whether rows are equal or not - if let Some(index) = index { - let mut i = index - 1; - - loop { - let build_row_value = if let Some(offset) = deleted_offset { - // This arguments means that we prune the next index way before here. - if i < offset as u64 { - // End of the list due to pruning - break; - } - i - offset as u64 - } else { - i - }; - build_indices.append(build_row_value); - probe_indices.append(row as u32); - output_tuples += 1; - - // Follow the chain to get the next index value - let next = next_chain[build_row_value as usize]; + let (mut probe_indices, mut build_indices) = if fifo_hashmap { + build_hashmap.get_matched_indices(hash_values.iter().enumerate(), deleted_offset) + } else { + let (mut matched_probe, mut matched_build) = build_hashmap + .get_matched_indices(hash_values.iter().enumerate().rev(), deleted_offset); - if output_tuples >= output_limit { - state.update_matching_iteration(row, next as usize); - break 'probe; - } + matched_probe.as_slice_mut().reverse(); + matched_build.as_slice_mut().reverse(); - if next == 0 { - // end of list - break; - } - i = next - 1; - } - } - } + (matched_probe, matched_build) + }; let left: UInt64Array = PrimitiveArray::new(build_indices.finish().into(), None); let right: UInt32Array = PrimitiveArray::new(probe_indices.finish().into(), None); @@ -1164,163 +1183,213 @@ impl HashJoinStream { &mut self, cx: &mut std::task::Context<'_>, ) -> Poll>> { + loop { + return match self.state { + HashJoinStreamState::WaitBuildSide => { + handle_state!(ready!(self.collect_build_side(cx))) + } + HashJoinStreamState::FetchProbeBatch => { + handle_state!(ready!(self.fetch_probe_batch(cx))) + } + HashJoinStreamState::ProcessProbeBatch(_) => { + handle_state!(self.process_probe_batch()) + } + HashJoinStreamState::ExhaustedProbeSide => { + handle_state!(self.process_unmatched_build_batch()) + } + HashJoinStreamState::Completed => Poll::Ready(None), + }; + } + } + + /// Collects build-side data by polling `OnceFut` future from initialized build-side + /// + /// Updates build-side to `Ready`, and state to `FetchProbeSide` + fn collect_build_side( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { let build_timer = self.join_metrics.build_time.timer(); // build hash table from left (build) side, if not yet done - let left_data = match ready!(self.left_fut.get(cx)) { - Ok(left_data) => left_data, - Err(e) => return Poll::Ready(Some(Err(e))), - }; + let left_data = ready!(self + .build_side + .try_as_initial_mut()? + .left_fut + .get_shared(cx))?; build_timer.done(); // Reserving memory for visited_left_side bitmap in case it hasn't been initialized yet // and join_type requires to store it - if self.visited_left_side.is_none() - && need_produce_result_in_final(self.join_type) - { + if need_produce_result_in_final(self.join_type) { // TODO: Replace `ceil` wrapper with stable `div_cell` after // https://github.com/rust-lang/rust/issues/88581 - let visited_bitmap_size = bit_util::ceil(left_data.1.num_rows(), 8); + let visited_bitmap_size = bit_util::ceil(left_data.num_rows(), 8); self.reservation.try_grow(visited_bitmap_size)?; self.join_metrics.build_mem_used.add(visited_bitmap_size); } - let visited_left_side = self.visited_left_side.get_or_insert_with(|| { - let num_rows = left_data.1.num_rows(); - if need_produce_result_in_final(self.join_type) { - // Some join types need to track which row has be matched or unmatched: - // `left semi` join: need to use the bitmap to produce the matched row in the left side - // `left` join: need to use the bitmap to produce the unmatched row in the left side with null - // `left anti` join: need to use the bitmap to produce the unmatched row in the left side - // `full` join: need to use the bitmap to produce the unmatched row in the left side with null - let mut buffer = BooleanBufferBuilder::new(num_rows); - buffer.append_n(num_rows, false); - buffer - } else { - BooleanBufferBuilder::new(0) - } + let visited_left_side = if need_produce_result_in_final(self.join_type) { + let num_rows = left_data.num_rows(); + // Some join types need to track which row has be matched or unmatched: + // `left semi` join: need to use the bitmap to produce the matched row in the left side + // `left` join: need to use the bitmap to produce the unmatched row in the left side with null + // `left anti` join: need to use the bitmap to produce the unmatched row in the left side + // `full` join: need to use the bitmap to produce the unmatched row in the left side with null + let mut buffer = BooleanBufferBuilder::new(num_rows); + buffer.append_n(num_rows, false); + buffer + } else { + BooleanBufferBuilder::new(0) + }; + + self.state = HashJoinStreamState::FetchProbeBatch; + self.build_side = BuildSide::Ready(BuildSideReadyState { + left_data, + visited_left_side, }); + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + /// Fetches next batch from probe-side + /// + /// If non-empty batch has been fetched, updates state to `ProcessProbeBatchState`, + /// otherwise updates state to `ExhaustedProbeSide` + fn fetch_probe_batch( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + match ready!(self.right.poll_next_unpin(cx)) { + None => { + self.state = HashJoinStreamState::ExhaustedProbeSide; + } + Some(Ok(batch)) => { + self.state = + HashJoinStreamState::ProcessProbeBatch(ProcessProbeBatchState { + batch, + }); + } + Some(Err(err)) => return Poll::Ready(Err(err)), + }; + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + /// Joins current probe batch with build-side data and produces batch with matched output + /// + /// Updates state to `FetchProbeBatch` + fn process_probe_batch( + &mut self, + ) -> Result>> { + let state = self.state.try_as_process_probe_batch()?; + let build_side = self.build_side.try_as_ready_mut()?; + + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(state.batch.num_rows()); + let timer = self.join_metrics.join_time.timer(); + let mut hashes_buffer = vec![]; + // get the matched two indices for the on condition + let left_right_indices = build_equal_condition_join_indices( + build_side.left_data.hash_map(), + build_side.left_data.batch(), + &state.batch, + &self.on_left, + &self.on_right, + &self.random_state, + self.null_equals_null, + &mut hashes_buffer, + self.filter.as_ref(), + JoinSide::Left, + None, + true, + ); - // Fetch next right (probe) input batch if required - if self.probe_batch.is_none() { - match ready!(self.right.poll_next_unpin(cx)) { - Some(Ok(batch)) => { - self.output_state.set_probe_rows(batch.num_rows()); - self.probe_batch = Some(batch); + let result = match left_right_indices { + Ok((left_side, right_side)) => { + // set the left bitmap + // and only left, full, left semi, left anti need the left bitmap + if need_produce_result_in_final(self.join_type) { + left_side.iter().flatten().for_each(|x| { + build_side.visited_left_side.set_bit(x as usize, true); + }); } - None => { - self.probe_batch = None; - } - Some(err) => return Poll::Ready(Some(err)), - } - } - let output_batch = match &self.probe_batch { - // one right batch in the join loop - Some(batch) => { - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - let timer = self.join_metrics.join_time.timer(); - - // get the matched two indices for the on condition - let left_right_indices = build_equal_condition_join_indices( - &left_data.0, - &left_data.1, - batch, - &self.on_left, - &self.on_right, - &self.random_state, - self.null_equals_null, - &mut hashes_buffer, - self.filter.as_ref(), - JoinSide::Left, - None, - self.batch_size, - &mut self.output_state, + // adjust the two side indices base on the join type + let (left_side, right_side) = adjust_indices_by_join_type( + left_side, + right_side, + state.batch.num_rows(), + self.join_type, ); - let result = match left_right_indices { - Ok((left_side, right_side)) => { - // set the left bitmap - // and only left, full, left semi, left anti need the left bitmap - if need_produce_result_in_final(self.join_type) { - left_side.iter().flatten().for_each(|x| { - visited_left_side.set_bit(x as usize, true); - }); - } - - // adjust the two side indices base on the join type - - let (left_side, right_side) = adjust_indices_by_join_type( - left_side, - right_side, - self.output_state.adjust_range(), - self.join_type, - ); - - let result = build_batch_from_indices( - &self.schema, - &left_data.1, - batch, - &left_side, - &right_side, - &self.column_indices, - JoinSide::Left, - ); - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - - if self.output_state.is_completed() { - self.probe_batch = None; - self.output_state.reset_state(); - } - - Some(result) - } - Err(err) => Some(exec_err!( - "Fail to build join indices in HashJoinExec, error:{err}" - )), - }; - - timer.done(); + let result = build_batch_from_indices( + &self.schema, + build_side.left_data.batch(), + &state.batch, + &left_side, + &right_side, + &self.column_indices, + JoinSide::Left, + ); + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(state.batch.num_rows()); result } - None => { - let timer = self.join_metrics.join_time.timer(); - if need_produce_result_in_final(self.join_type) && !self.is_exhausted { - // use the global left bitmap to produce the left indices and right indices - let (left_side, right_side) = - get_final_indices_from_bit_map(visited_left_side, self.join_type); - let empty_right_batch = RecordBatch::new_empty(self.right.schema()); - // use the left and right indices to produce the batch result - let result = build_batch_from_indices( - &self.schema, - &left_data.1, - &empty_right_batch, - &left_side, - &right_side, - &self.column_indices, - JoinSide::Left, - ); - - if let Ok(ref batch) = result { - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } - timer.done(); - self.is_exhausted = true; - Some(result) - } else { - // end of the join loop - None - } + Err(err) => { + exec_err!("Fail to build join indices in HashJoinExec, error:{err}") } }; + timer.done(); + + self.state = HashJoinStreamState::FetchProbeBatch; - Poll::Ready(output_batch) + Ok(StatefulStreamResult::Ready(Some(result?))) + } + + /// Processes unmatched build-side rows for certain join types and produces output batch + /// + /// Updates state to `Completed` + fn process_unmatched_build_batch( + &mut self, + ) -> Result>> { + let timer = self.join_metrics.join_time.timer(); + + if !need_produce_result_in_final(self.join_type) { + self.state = HashJoinStreamState::Completed; + + return Ok(StatefulStreamResult::Continue); + } + + let build_side = self.build_side.try_as_ready()?; + + // use the global left bitmap to produce the left indices and right indices + let (left_side, right_side) = + get_final_indices_from_bit_map(&build_side.visited_left_side, self.join_type); + let empty_right_batch = RecordBatch::new_empty(self.right.schema()); + // use the left and right indices to produce the batch result + let result = build_batch_from_indices( + &self.schema, + build_side.left_data.batch(), + &empty_right_batch, + &left_side, + &right_side, + &self.column_indices, + JoinSide::Left, + ); + + if let Ok(ref batch) = result { + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + } + timer.done(); + + self.state = HashJoinStreamState::Completed; + + Ok(StatefulStreamResult::Ready(Some(result?))) } } @@ -1351,6 +1420,9 @@ mod tests { use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, ScalarValue, }; + use datafusion_common::{ + assert_batches_eq, assert_batches_sorted_eq, assert_contains, ScalarValue, + }; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::Operator; @@ -1531,7 +1603,9 @@ mod tests { "| 3 | 5 | 9 | 20 | 5 | 80 |", "+----+----+----+----+----+----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -1614,7 +1688,49 @@ mod tests { "+----+----+----+----+----+----+", ]; - assert_batches_sorted_eq!(expected, &batches); + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); + + Ok(()) + } + + #[apply(batch_sizes)] + #[tokio::test] + async fn join_inner_one_randomly_ordered() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let left = build_table( + ("a1", &vec![0, 3, 2, 1]), + ("b1", &vec![4, 5, 5, 4]), + ("c1", &vec![6, 9, 8, 7]), + ); + let right = build_table( + ("a2", &vec![20, 30, 10]), + ("b2", &vec![5, 6, 4]), + ("c2", &vec![80, 90, 70]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + )]; + + let (columns, batches) = + join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); + + let expected = [ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 0 | 4 | 6 | 10 | 4 | 70 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "+----+----+----+----+----+----+", + ]; + + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -1667,7 +1783,8 @@ mod tests { "+----+----+----+----+----+----+", ]; - assert_batches_sorted_eq!(expected, &batches); + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -1728,7 +1845,58 @@ mod tests { "+----+----+----+----+----+----+", ]; - assert_batches_sorted_eq!(expected, &batches); + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_inner_one_two_parts_left_randomly_ordered() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let batch1 = build_table_i32( + ("a1", &vec![0, 3]), + ("b1", &vec![4, 5]), + ("c1", &vec![6, 9]), + ); + let batch2 = build_table_i32( + ("a1", &vec![2, 1]), + ("b1", &vec![5, 4]), + ("c1", &vec![8, 7]), + ); + let schema = batch1.schema(); + + let left = Arc::new( + MemoryExec::try_new(&[vec![batch1], vec![batch2]], schema, None).unwrap(), + ); + let right = build_table( + ("a2", &vec![20, 30, 10]), + ("b2", &vec![5, 6, 4]), + ("c2", &vec![80, 90, 70]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + )]; + + let (columns, batches) = + join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); + + let expected = [ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 0 | 4 | 6 | 10 | 4 | 70 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "+----+----+----+----+----+----+", + ]; + + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -1785,7 +1953,9 @@ mod tests { "| 1 | 4 | 7 | 10 | 4 | 70 |", "+----+----+----+----+----+----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); // second part let stream = join.execute(1, task_ctx.clone())?; @@ -1804,7 +1974,8 @@ mod tests { "+----+----+----+----+----+----+", ]; - assert_batches_sorted_eq!(expected, &batches); + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -2237,12 +2408,14 @@ mod tests { "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", - "| 10 | 10 | 100 |", - "| 12 | 10 | 40 |", "| 8 | 8 | 20 |", + "| 12 | 10 | 40 |", + "| 10 | 10 | 100 |", "+----+----+-----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // RightSemi join output is expected to preserve right input order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -2298,12 +2471,14 @@ mod tests { "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", - "| 10 | 10 | 100 |", - "| 12 | 10 | 40 |", "| 8 | 8 | 20 |", + "| 12 | 10 | 40 |", + "| 10 | 10 | 100 |", "+----+----+-----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // RightSemi join output is expected to preserve right input order + assert_batches_eq!(expected, &batches); // left_table right semi join right_table on left_table.b1 = right_table.b2 on left_table.a1!=9 let filter_expression = Arc::new(BinaryExpr::new( @@ -2324,11 +2499,13 @@ mod tests { "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", - "| 10 | 10 | 100 |", "| 12 | 10 | 40 |", + "| 10 | 10 | 100 |", "+----+----+-----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // RightSemi join output is expected to preserve right input order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -2484,12 +2661,14 @@ mod tests { "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", + "| 6 | 6 | 60 |", "| 2 | 2 | 80 |", "| 4 | 4 | 120 |", - "| 6 | 6 | 60 |", "+----+----+-----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // RightAnti join output is expected to preserve right input order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -2543,14 +2722,16 @@ mod tests { "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", - "| 10 | 10 | 100 |", "| 12 | 10 | 40 |", + "| 6 | 6 | 60 |", "| 2 | 2 | 80 |", + "| 10 | 10 | 100 |", "| 4 | 4 | 120 |", - "| 6 | 6 | 60 |", "+----+----+-----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // RightAnti join output is expected to preserve right input order + assert_batches_eq!(expected, &batches); // left_table right anti join right_table on left_table.b1 = right_table.b2 and right_table.b2!=8 let column_indices = vec![ColumnIndex { @@ -2579,13 +2760,15 @@ mod tests { "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", + "| 8 | 8 | 20 |", + "| 6 | 6 | 60 |", "| 2 | 2 | 80 |", "| 4 | 4 | 120 |", - "| 6 | 6 | 60 |", - "| 8 | 8 | 20 |", "+----+----+-----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // RightAnti join output is expected to preserve right input order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -2737,16 +2920,11 @@ mod tests { ("c", &vec![30, 40]), ); - let left_data = ( - JoinHashMap { - map: hashmap_left, - next, - }, - left, - ); + let join_hash_map = JoinHashMap::new(hashmap_left, next); + let (l, r) = build_equal_condition_join_indices( - &left_data.0, - &left_data.1, + &join_hash_map, + &left, &right, &[Column::new("a", 0)], &[Column::new("a", 0)], @@ -2756,8 +2934,7 @@ mod tests { None, JoinSide::Left, None, - 64, - &mut HashJoinOutputState::default(), + false, )?; let mut left_ids = UInt64Builder::with_capacity(0); diff --git a/datafusion/physical-plan/src/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs index 19f10d06e1ef8..6ddf19c511933 100644 --- a/datafusion/physical-plan/src/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -25,9 +25,9 @@ pub use sort_merge_join::SortMergeJoinExec; pub use symmetric_hash_join::SymmetricHashJoinExec; mod cross_join; mod hash_join; -mod hash_join_utils; mod nested_loop_join; mod sort_merge_join; +mod stream_join_utils; mod symmetric_hash_join; pub mod utils; diff --git a/datafusion/physical-plan/src/joins/hash_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs similarity index 61% rename from datafusion/physical-plan/src/joins/hash_join_utils.rs rename to datafusion/physical-plan/src/joins/stream_join_utils.rs index 3a2a85c727226..9a4c98927683d 100644 --- a/datafusion/physical-plan/src/joins/hash_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -15,137 +15,39 @@ // specific language governing permissions and limitations // under the License. -//! This file contains common subroutines for regular and symmetric hash join +//! This file contains common subroutines for symmetric hash join //! related functionality, used both in join calculations and optimization rules. use std::collections::{HashMap, VecDeque}; -use std::fmt::Debug; -use std::ops::IndexMut; use std::sync::Arc; -use std::{fmt, usize}; +use std::task::{Context, Poll}; +use std::usize; -use crate::joins::utils::JoinFilter; +use crate::joins::utils::{JoinFilter, JoinHashMapType, StatefulStreamResult}; +use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder}; +use crate::{handle_async_state, handle_state, metrics, ExecutionPlan}; use arrow::compute::concat_batches; -use arrow::datatypes::{ArrowNativeType, SchemaRef}; -use arrow_array::builder::BooleanBufferBuilder; use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, RecordBatch}; +use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder}; +use arrow_schema::{Schema, SchemaRef}; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{DataFusionError, JoinSide, Result, ScalarValue}; +use datafusion_common::{ + arrow_datafusion_err, plan_datafusion_err, DataFusionError, JoinSide, Result, + ScalarValue, +}; +use datafusion_execution::SendableRecordBatchStream; +use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::intervals::{Interval, IntervalBound}; +use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; +use async_trait::async_trait; +use futures::{ready, FutureExt, StreamExt}; use hashbrown::raw::RawTable; use hashbrown::HashSet; -// Maps a `u64` hash value based on the build side ["on" values] to a list of indices with this key's value. -// By allocating a `HashMap` with capacity for *at least* the number of rows for entries at the build side, -// we make sure that we don't have to re-hash the hashmap, which needs access to the key (the hash in this case) value. -// E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and 8 for hash value 1 -// As the key is a hash value, we need to check possible hash collisions in the probe stage -// During this stage it might be the case that a row is contained the same hashmap value, -// but the values don't match. Those are checked in the [equal_rows] macro -// The indices (values) are stored in a separate chained list stored in the `Vec`. -// The first value (+1) is stored in the hashmap, whereas the next value is stored in array at the position value. -// The chain can be followed until the value "0" has been reached, meaning the end of the list. -// Also see chapter 5.3 of [Balancing vectorized query execution with bandwidth-optimized storage](https://dare.uva.nl/search?identifier=5ccbb60a-38b8-4eeb-858a-e7735dd37487) -// See the example below: -// Insert (1,1) -// map: -// --------- -// | 1 | 2 | -// --------- -// next: -// --------------------- -// | 0 | 0 | 0 | 0 | 0 | -// --------------------- -// Insert (2,2) -// map: -// --------- -// | 1 | 2 | -// | 2 | 3 | -// --------- -// next: -// --------------------- -// | 0 | 0 | 0 | 0 | 0 | -// --------------------- -// Insert (1,3) -// map: -// --------- -// | 1 | 4 | -// | 2 | 3 | -// --------- -// next: -// --------------------- -// | 0 | 0 | 0 | 2 | 0 | <--- hash value 1 maps to 4,2 (which means indices values 3,1) -// --------------------- -// Insert (1,4) -// map: -// --------- -// | 1 | 5 | -// | 2 | 3 | -// --------- -// next: -// --------------------- -// | 0 | 0 | 0 | 2 | 4 | <--- hash value 1 maps to 5,4,2 (which means indices values 4,3,1) -// --------------------- -// TODO: speed up collision checks -// https://github.com/apache/arrow-datafusion/issues/50 -pub struct JoinHashMap { - // Stores hash value to last row index - pub map: RawTable<(u64, u64)>, - // Stores indices in chained list data structure - pub next: Vec, -} - -impl JoinHashMap { - pub(crate) fn with_capacity(capacity: usize) -> Self { - JoinHashMap { - map: RawTable::with_capacity(capacity), - next: vec![0; capacity], - } - } -} - -/// Trait defining methods that must be implemented by a hash map type to be used for joins. -pub trait JoinHashMapType { - /// The type of list used to store the hash values. - type NextType: IndexMut; - /// Extend with zero - fn extend_zero(&mut self, len: usize); - /// Returns mutable references to the hash map and the next. - fn get_mut(&mut self) -> (&mut RawTable<(u64, u64)>, &mut Self::NextType); - /// Returns a reference to the hash map. - fn get_map(&self) -> &RawTable<(u64, u64)>; - /// Returns a reference to the next. - fn get_list(&self) -> &Self::NextType; -} - -/// Implementation of `JoinHashMapType` for `JoinHashMap`. -impl JoinHashMapType for JoinHashMap { - type NextType = Vec; - - // Void implementation - fn extend_zero(&mut self, _: usize) {} - - /// Get mutable references to the hash map and the next. - fn get_mut(&mut self) -> (&mut RawTable<(u64, u64)>, &mut Self::NextType) { - (&mut self.map, &mut self.next) - } - - /// Get a reference to the hash map. - fn get_map(&self) -> &RawTable<(u64, u64)> { - &self.map - } - - /// Get a reference to the next. - fn get_list(&self) -> &Self::NextType { - &self.next - } -} - /// Implementation of `JoinHashMapType` for `PruningJoinHashMap`. impl JoinHashMapType for PruningJoinHashMap { type NextType = VecDeque; @@ -171,12 +73,6 @@ impl JoinHashMapType for PruningJoinHashMap { } } -impl fmt::Debug for JoinHashMap { - fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { - Ok(()) - } -} - /// The `PruningJoinHashMap` is similar to a regular `JoinHashMap`, but with /// the capability of pruning elements in an efficient manner. This structure /// is particularly useful for cases where it's necessary to remove elements @@ -188,15 +84,15 @@ impl fmt::Debug for JoinHashMap { /// Let's continue the example of `JoinHashMap` and then show how `PruningJoinHashMap` would /// handle the pruning scenario. /// -/// Insert the pair (1,4) into the `PruningJoinHashMap`: +/// Insert the pair (10,4) into the `PruningJoinHashMap`: /// map: -/// --------- -/// | 1 | 5 | -/// | 2 | 3 | -/// --------- +/// ---------- +/// | 10 | 5 | +/// | 20 | 3 | +/// ---------- /// list: /// --------------------- -/// | 0 | 0 | 0 | 2 | 4 | <--- hash value 1 maps to 5,4,2 (which means indices values 4,3,1) +/// | 0 | 0 | 0 | 2 | 4 | <--- hash value 10 maps to 5,4,2 (which means indices values 4,3,1) /// --------------------- /// /// Now, let's prune 3 rows from `PruningJoinHashMap`: @@ -206,7 +102,7 @@ impl fmt::Debug for JoinHashMap { /// --------- /// list: /// --------- -/// | 2 | 4 | <--- hash value 1 maps to 2 (5 - 3), 1 (4 - 3), NA (2 - 3) (which means indices values 1,0) +/// | 2 | 4 | <--- hash value 10 maps to 2 (5 - 3), 1 (4 - 3), NA (2 - 3) (which means indices values 1,0) /// --------- /// /// After pruning, the | 2 | 3 | entry is deleted from `PruningJoinHashMap` since @@ -281,7 +177,7 @@ impl PruningJoinHashMap { prune_length: usize, deleting_offset: u64, shrink_factor: usize, - ) -> Result<()> { + ) { // Remove elements from the list based on the pruning length. self.next.drain(0..prune_length); @@ -304,7 +200,6 @@ impl PruningJoinHashMap { // Shrink the map if necessary. self.shrink_if_necessary(shrink_factor); - Ok(()) } } @@ -333,7 +228,7 @@ pub fn map_origin_col_to_filter_col( side: &JoinSide, ) -> Result> { let filter_schema = filter.schema(); - let mut col_to_col_map: HashMap = HashMap::new(); + let mut col_to_col_map = HashMap::::new(); for (filter_schema_index, index) in filter.column_indices().iter().enumerate() { if index.side.eq(side) { // Get the main field from column index: @@ -425,7 +320,11 @@ pub fn build_filter_input_order( order: &PhysicalSortExpr, ) -> Result> { let opt_expr = convert_sort_expr_with_filter_schema(&side, filter, schema, order)?; - Ok(opt_expr.map(|filter_expr| SortedFilterExpr::new(order.clone(), filter_expr))) + opt_expr + .map(|filter_expr| { + SortedFilterExpr::try_new(order.clone(), filter_expr, filter.schema()) + }) + .transpose() } /// Convert a physical expression into a filter expression using the given @@ -468,16 +367,18 @@ pub struct SortedFilterExpr { impl SortedFilterExpr { /// Constructor - pub fn new( + pub fn try_new( origin_sorted_expr: PhysicalSortExpr, filter_expr: Arc, - ) -> Self { - Self { + filter_schema: &Schema, + ) -> Result { + let dt = &filter_expr.data_type(filter_schema)?; + Ok(Self { origin_sorted_expr, filter_expr, - interval: Interval::default(), + interval: Interval::make_unbounded(dt)?, node_index: 0, - } + }) } /// Get origin expr information pub fn origin_sorted_expr(&self) -> &PhysicalSortExpr { @@ -599,16 +500,16 @@ pub fn update_filter_expr_interval( .origin_sorted_expr() .expr .evaluate(batch)? - .into_array(1); + .into_array(1)?; // Convert the array to a ScalarValue: let value = ScalarValue::try_from_array(&array, 0)?; // Create a ScalarValue representing positive or negative infinity for the same data type: - let unbounded = IntervalBound::make_unbounded(value.data_type())?; + let inf = ScalarValue::try_from(value.data_type())?; // Update the interval with lower and upper bounds based on the sort option: let interval = if sorted_expr.origin_sorted_expr().options.descending { - Interval::new(unbounded, IntervalBound::new(value, false)) + Interval::try_new(inf, value)? } else { - Interval::new(IntervalBound::new(value, false), unbounded) + Interval::try_new(value, inf)? }; // Set the calculated interval for the sorted filter expression: sorted_expr.set_interval(interval); @@ -681,7 +582,7 @@ where // get the semi index (0..prune_length) .filter_map(|idx| (bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx))) - .collect::>() + .collect() } pub fn combine_two_batches( @@ -697,7 +598,7 @@ pub fn combine_two_batches( (Some(left_batch), Some(right_batch)) => { // If both batches are present, concatenate them: concat_batches(output_schema, &[left_batch, right_batch]) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) .map(Some) } (None, None) => { @@ -726,68 +627,516 @@ pub fn record_visited_indices( } } +/// Represents the various states of an eager join stream operation. +/// +/// This enum is used to track the current state of streaming during a join +/// operation. It provides indicators as to which side of the join needs to be +/// pulled next or if one (or both) sides have been exhausted. This allows +/// for efficient management of resources and optimal performance during the +/// join process. +#[derive(Clone, Debug)] +pub enum EagerJoinStreamState { + /// Indicates that the next step should pull from the right side of the join. + PullRight, + + /// Indicates that the next step should pull from the left side of the join. + PullLeft, + + /// State representing that the right side of the join has been fully processed. + RightExhausted, + + /// State representing that the left side of the join has been fully processed. + LeftExhausted, + + /// Represents a state where both sides of the join are exhausted. + /// + /// The `final_result` field indicates whether the join operation has + /// produced a final result or not. + BothExhausted { final_result: bool }, +} + +/// `EagerJoinStream` is an asynchronous trait designed for managing incremental +/// join operations between two streams, such as those used in `SymmetricHashJoinExec` +/// and `SortMergeJoinExec`. Unlike traditional join approaches that need to scan +/// one side of the join fully before proceeding, `EagerJoinStream` facilitates +/// more dynamic join operations by working with streams as they emit data. This +/// approach allows for more efficient processing, particularly in scenarios +/// where waiting for complete data materialization is not feasible or optimal. +/// The trait provides a framework for handling various states of such a join +/// process, ensuring that join logic is efficiently executed as data becomes +/// available from either stream. +/// +/// Implementors of this trait can perform eager joins of data from two different +/// asynchronous streams, typically referred to as left and right streams. The +/// trait provides a comprehensive set of methods to control and execute the join +/// process, leveraging the states defined in `EagerJoinStreamState`. Methods are +/// primarily focused on asynchronously fetching data batches from each stream, +/// processing them, and managing transitions between various states of the join. +/// +/// This trait's default implementations use a state machine approach to navigate +/// different stages of the join operation, handling data from both streams and +/// determining when the join completes. +/// +/// State Transitions: +/// - From `PullLeft` to `PullRight` or `LeftExhausted`: +/// - In `fetch_next_from_left_stream`, when fetching a batch from the left stream: +/// - On success (`Some(Ok(batch))`), state transitions to `PullRight` for +/// processing the batch. +/// - On error (`Some(Err(e))`), the error is returned, and the state remains +/// unchanged. +/// - On no data (`None`), state changes to `LeftExhausted`, returning `Continue` +/// to proceed with the join process. +/// - From `PullRight` to `PullLeft` or `RightExhausted`: +/// - In `fetch_next_from_right_stream`, when fetching from the right stream: +/// - If a batch is available, state changes to `PullLeft` for processing. +/// - On error, the error is returned without changing the state. +/// - If right stream is exhausted (`None`), state transitions to `RightExhausted`, +/// with a `Continue` result. +/// - Handling `RightExhausted` and `LeftExhausted`: +/// - Methods `handle_right_stream_end` and `handle_left_stream_end` manage scenarios +/// when streams are exhausted: +/// - They attempt to continue processing with the other stream. +/// - If both streams are exhausted, state changes to `BothExhausted { final_result: false }`. +/// - Transition to `BothExhausted { final_result: true }`: +/// - Occurs in `prepare_for_final_results_after_exhaustion` when both streams are +/// exhausted, indicating completion of processing and availability of final results. +#[async_trait] +pub trait EagerJoinStream { + /// Implements the main polling logic for the join stream. + /// + /// This method continuously checks the state of the join stream and + /// acts accordingly by delegating the handling to appropriate sub-methods + /// depending on the current state. + /// + /// # Arguments + /// + /// * `cx` - A context that facilitates cooperative non-blocking execution within a task. + /// + /// # Returns + /// + /// * `Poll>>` - A polled result, either a `RecordBatch` or None. + fn poll_next_impl( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> + where + Self: Send, + { + loop { + return match self.state() { + EagerJoinStreamState::PullRight => { + handle_async_state!(self.fetch_next_from_right_stream(), cx) + } + EagerJoinStreamState::PullLeft => { + handle_async_state!(self.fetch_next_from_left_stream(), cx) + } + EagerJoinStreamState::RightExhausted => { + handle_async_state!(self.handle_right_stream_end(), cx) + } + EagerJoinStreamState::LeftExhausted => { + handle_async_state!(self.handle_left_stream_end(), cx) + } + EagerJoinStreamState::BothExhausted { + final_result: false, + } => { + handle_state!(self.prepare_for_final_results_after_exhaustion()) + } + EagerJoinStreamState::BothExhausted { final_result: true } => { + Poll::Ready(None) + } + }; + } + } + /// Asynchronously pulls the next batch from the right stream. + /// + /// This default implementation checks for the next value in the right stream. + /// If a batch is found, the state is switched to `PullLeft`, and the batch handling + /// is delegated to `process_batch_from_right`. If the stream ends, the state is set to `RightExhausted`. + /// + /// # Returns + /// + /// * `Result>>` - The state result after pulling the batch. + async fn fetch_next_from_right_stream( + &mut self, + ) -> Result>> { + match self.right_stream().next().await { + Some(Ok(batch)) => { + if batch.num_rows() == 0 { + return Ok(StatefulStreamResult::Continue); + } + self.set_state(EagerJoinStreamState::PullLeft); + self.process_batch_from_right(batch) + } + Some(Err(e)) => Err(e), + None => { + self.set_state(EagerJoinStreamState::RightExhausted); + Ok(StatefulStreamResult::Continue) + } + } + } + + /// Asynchronously pulls the next batch from the left stream. + /// + /// This default implementation checks for the next value in the left stream. + /// If a batch is found, the state is switched to `PullRight`, and the batch handling + /// is delegated to `process_batch_from_left`. If the stream ends, the state is set to `LeftExhausted`. + /// + /// # Returns + /// + /// * `Result>>` - The state result after pulling the batch. + async fn fetch_next_from_left_stream( + &mut self, + ) -> Result>> { + match self.left_stream().next().await { + Some(Ok(batch)) => { + if batch.num_rows() == 0 { + return Ok(StatefulStreamResult::Continue); + } + self.set_state(EagerJoinStreamState::PullRight); + self.process_batch_from_left(batch) + } + Some(Err(e)) => Err(e), + None => { + self.set_state(EagerJoinStreamState::LeftExhausted); + Ok(StatefulStreamResult::Continue) + } + } + } + + /// Asynchronously handles the scenario when the right stream is exhausted. + /// + /// In this default implementation, when the right stream is exhausted, it attempts + /// to pull from the left stream. If a batch is found in the left stream, it delegates + /// the handling to `process_batch_from_left`. If both streams are exhausted, the state is set + /// to indicate both streams are exhausted without final results yet. + /// + /// # Returns + /// + /// * `Result>>` - The state result after checking the exhaustion state. + async fn handle_right_stream_end( + &mut self, + ) -> Result>> { + match self.left_stream().next().await { + Some(Ok(batch)) => { + if batch.num_rows() == 0 { + return Ok(StatefulStreamResult::Continue); + } + self.process_batch_after_right_end(batch) + } + Some(Err(e)) => Err(e), + None => { + self.set_state(EagerJoinStreamState::BothExhausted { + final_result: false, + }); + Ok(StatefulStreamResult::Continue) + } + } + } + + /// Asynchronously handles the scenario when the left stream is exhausted. + /// + /// When the left stream is exhausted, this default + /// implementation tries to pull from the right stream and delegates the batch + /// handling to `process_batch_after_left_end`. If both streams are exhausted, the state + /// is updated to indicate so. + /// + /// # Returns + /// + /// * `Result>>` - The state result after checking the exhaustion state. + async fn handle_left_stream_end( + &mut self, + ) -> Result>> { + match self.right_stream().next().await { + Some(Ok(batch)) => { + if batch.num_rows() == 0 { + return Ok(StatefulStreamResult::Continue); + } + self.process_batch_after_left_end(batch) + } + Some(Err(e)) => Err(e), + None => { + self.set_state(EagerJoinStreamState::BothExhausted { + final_result: false, + }); + Ok(StatefulStreamResult::Continue) + } + } + } + + /// Handles the state when both streams are exhausted and final results are yet to be produced. + /// + /// This default implementation switches the state to indicate both streams are + /// exhausted with final results and then invokes the handling for this specific + /// scenario via `process_batches_before_finalization`. + /// + /// # Returns + /// + /// * `Result>>` - The state result after both streams are exhausted. + fn prepare_for_final_results_after_exhaustion( + &mut self, + ) -> Result>> { + self.set_state(EagerJoinStreamState::BothExhausted { final_result: true }); + self.process_batches_before_finalization() + } + + /// Handles a pulled batch from the right stream. + /// + /// # Arguments + /// + /// * `batch` - The pulled `RecordBatch` from the right stream. + /// + /// # Returns + /// + /// * `Result>>` - The state result after processing the batch. + fn process_batch_from_right( + &mut self, + batch: RecordBatch, + ) -> Result>>; + + /// Handles a pulled batch from the left stream. + /// + /// # Arguments + /// + /// * `batch` - The pulled `RecordBatch` from the left stream. + /// + /// # Returns + /// + /// * `Result>>` - The state result after processing the batch. + fn process_batch_from_left( + &mut self, + batch: RecordBatch, + ) -> Result>>; + + /// Handles the situation when only the left stream is exhausted. + /// + /// # Arguments + /// + /// * `right_batch` - The `RecordBatch` from the right stream. + /// + /// # Returns + /// + /// * `Result>>` - The state result after the left stream is exhausted. + fn process_batch_after_left_end( + &mut self, + right_batch: RecordBatch, + ) -> Result>>; + + /// Handles the situation when only the right stream is exhausted. + /// + /// # Arguments + /// + /// * `left_batch` - The `RecordBatch` from the left stream. + /// + /// # Returns + /// + /// * `Result>>` - The state result after the right stream is exhausted. + fn process_batch_after_right_end( + &mut self, + left_batch: RecordBatch, + ) -> Result>>; + + /// Handles the final state after both streams are exhausted. + /// + /// # Returns + /// + /// * `Result>>` - The final state result after processing. + fn process_batches_before_finalization( + &mut self, + ) -> Result>>; + + /// Provides mutable access to the right stream. + /// + /// # Returns + /// + /// * `&mut SendableRecordBatchStream` - Returns a mutable reference to the right stream. + fn right_stream(&mut self) -> &mut SendableRecordBatchStream; + + /// Provides mutable access to the left stream. + /// + /// # Returns + /// + /// * `&mut SendableRecordBatchStream` - Returns a mutable reference to the left stream. + fn left_stream(&mut self) -> &mut SendableRecordBatchStream; + + /// Sets the current state of the join stream. + /// + /// # Arguments + /// + /// * `state` - The new state to be set. + fn set_state(&mut self, state: EagerJoinStreamState); + + /// Fetches the current state of the join stream. + /// + /// # Returns + /// + /// * `EagerJoinStreamState` - The current state of the join stream. + fn state(&mut self) -> EagerJoinStreamState; +} + +#[derive(Debug)] +pub struct StreamJoinSideMetrics { + /// Number of batches consumed by this operator + pub(crate) input_batches: metrics::Count, + /// Number of rows consumed by this operator + pub(crate) input_rows: metrics::Count, +} + +/// Metrics for HashJoinExec +#[derive(Debug)] +pub struct StreamJoinMetrics { + /// Number of left batches/rows consumed by this operator + pub(crate) left: StreamJoinSideMetrics, + /// Number of right batches/rows consumed by this operator + pub(crate) right: StreamJoinSideMetrics, + /// Memory used by sides in bytes + pub(crate) stream_memory_usage: metrics::Gauge, + /// Number of batches produced by this operator + pub(crate) output_batches: metrics::Count, + /// Number of rows produced by this operator + pub(crate) output_rows: metrics::Count, +} + +impl StreamJoinMetrics { + pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { + let input_batches = + MetricBuilder::new(metrics).counter("input_batches", partition); + let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + let left = StreamJoinSideMetrics { + input_batches, + input_rows, + }; + + let input_batches = + MetricBuilder::new(metrics).counter("input_batches", partition); + let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + let right = StreamJoinSideMetrics { + input_batches, + input_rows, + }; + + let stream_memory_usage = + MetricBuilder::new(metrics).gauge("stream_memory_usage", partition); + + let output_batches = + MetricBuilder::new(metrics).counter("output_batches", partition); + + let output_rows = MetricBuilder::new(metrics).output_rows(partition); + + Self { + left, + right, + output_batches, + stream_memory_usage, + output_rows, + } + } +} + +/// Updates sorted filter expressions with corresponding node indices from the +/// expression interval graph. +/// +/// This function iterates through the provided sorted filter expressions, +/// gathers the corresponding node indices from the expression interval graph, +/// and then updates the sorted expressions with these indices. It ensures +/// that these sorted expressions are aligned with the structure of the graph. +fn update_sorted_exprs_with_node_indices( + graph: &mut ExprIntervalGraph, + sorted_exprs: &mut [SortedFilterExpr], +) { + // Extract filter expressions from the sorted expressions: + let filter_exprs = sorted_exprs + .iter() + .map(|expr| expr.filter_expr().clone()) + .collect::>(); + + // Gather corresponding node indices for the extracted filter expressions from the graph: + let child_node_indices = graph.gather_node_indices(&filter_exprs); + + // Iterate through the sorted expressions and the gathered node indices: + for (sorted_expr, (_, index)) in sorted_exprs.iter_mut().zip(child_node_indices) { + // Update each sorted expression with the corresponding node index: + sorted_expr.set_node_index(index); + } +} + +/// Prepares and sorts expressions based on a given filter, left and right execution plans, and sort expressions. +/// +/// # Arguments +/// +/// * `filter` - The join filter to base the sorting on. +/// * `left` - The left execution plan. +/// * `right` - The right execution plan. +/// * `left_sort_exprs` - The expressions to sort on the left side. +/// * `right_sort_exprs` - The expressions to sort on the right side. +/// +/// # Returns +/// +/// * A tuple consisting of the sorted filter expression for the left and right sides, and an expression interval graph. +pub fn prepare_sorted_exprs( + filter: &JoinFilter, + left: &Arc, + right: &Arc, + left_sort_exprs: &[PhysicalSortExpr], + right_sort_exprs: &[PhysicalSortExpr], +) -> Result<(SortedFilterExpr, SortedFilterExpr, ExprIntervalGraph)> { + // Build the filter order for the left side + let err = || plan_datafusion_err!("Filter does not include the child order"); + + let left_temp_sorted_filter_expr = build_filter_input_order( + JoinSide::Left, + filter, + &left.schema(), + &left_sort_exprs[0], + )? + .ok_or_else(err)?; + + // Build the filter order for the right side + let right_temp_sorted_filter_expr = build_filter_input_order( + JoinSide::Right, + filter, + &right.schema(), + &right_sort_exprs[0], + )? + .ok_or_else(err)?; + + // Collect the sorted expressions + let mut sorted_exprs = + vec![left_temp_sorted_filter_expr, right_temp_sorted_filter_expr]; + + // Build the expression interval graph + let mut graph = + ExprIntervalGraph::try_new(filter.expression().clone(), filter.schema())?; + + // Update sorted expressions with node indices + update_sorted_exprs_with_node_indices(&mut graph, &mut sorted_exprs); + + // Swap and remove to get the final sorted filter expressions + let right_sorted_filter_expr = sorted_exprs.swap_remove(1); + let left_sorted_filter_expr = sorted_exprs.swap_remove(0); + + Ok((left_sorted_filter_expr, right_sorted_filter_expr, graph)) +} + #[cfg(test)] pub mod tests { + use std::sync::Arc; + use super::*; + use crate::joins::stream_join_utils::{ + build_filter_input_order, check_filter_expr_contains_sort_information, + convert_sort_expr_with_filter_schema, PruningJoinHashMap, + }; use crate::{ - expressions::Column, - expressions::PhysicalSortExpr, + expressions::{Column, PhysicalSortExpr}, + joins::test_utils::complicated_filter, joins::utils::{ColumnIndex, JoinFilter}, }; + use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::ScalarValue; + use datafusion_common::JoinSide; use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::{binary, cast, col, lit}; - use std::sync::Arc; - - /// Filter expr for a + b > c + 10 AND a + b < c + 100 - pub(crate) fn complicated_filter( - filter_schema: &Schema, - ) -> Result> { - let left_expr = binary( - cast( - binary( - col("0", filter_schema)?, - Operator::Plus, - col("1", filter_schema)?, - filter_schema, - )?, - filter_schema, - DataType::Int64, - )?, - Operator::Gt, - binary( - cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?, - Operator::Plus, - lit(ScalarValue::Int64(Some(10))), - filter_schema, - )?, - filter_schema, - )?; - - let right_expr = binary( - cast( - binary( - col("0", filter_schema)?, - Operator::Plus, - col("1", filter_schema)?, - filter_schema, - )?, - filter_schema, - DataType::Int64, - )?, - Operator::Lt, - binary( - cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?, - Operator::Plus, - lit(ScalarValue::Int64(Some(100))), - filter_schema, - )?, - filter_schema, - )?; - binary(left_expr, Operator::And, right_expr, filter_schema) - } + use datafusion_physical_expr::expressions::{binary, cast, col}; #[test] fn test_column_exchange() -> Result<()> { diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index a7e0877537cf2..cfeb80708e685 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -25,32 +25,30 @@ //! This plan uses the [`OneSideHashJoiner`] object to facilitate join calculations //! for both its children. -use std::fmt; -use std::fmt::Debug; +use std::any::Any; +use std::fmt::{self, Debug}; use std::sync::Arc; use std::task::Poll; -use std::vec; -use std::{any::Any, usize}; +use std::{usize, vec}; use crate::common::SharedMemoryReservation; -use crate::joins::hash_join::{ - build_equal_condition_join_indices, update_hash, HashJoinOutputState, -}; -use crate::joins::hash_join_utils::{ +use crate::joins::hash_join::{build_equal_condition_join_indices, update_hash}; +use crate::joins::stream_join_utils::{ calculate_filter_expr_intervals, combine_two_batches, convert_sort_expr_with_filter_schema, get_pruning_anti_indices, - get_pruning_semi_indices, record_visited_indices, PruningJoinHashMap, - SortedFilterExpr, + get_pruning_semi_indices, prepare_sorted_exprs, record_visited_indices, + EagerJoinStream, EagerJoinStreamState, PruningJoinHashMap, SortedFilterExpr, + StreamJoinMetrics, }; use crate::joins::utils::{ build_batch_from_indices, build_join_schema, check_join_is_valid, - partitioned_join_output_partitioning, prepare_sorted_exprs, ColumnIndex, JoinFilter, - JoinOn, + partitioned_join_output_partitioning, ColumnIndex, JoinFilter, JoinOn, + StatefulStreamResult, }; use crate::{ expressions::{Column, PhysicalSortExpr}, joins::StreamJoinPartitionMode, - metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, + metrics::{ExecutionPlanMetricsSet, MetricsSet}, DisplayAs, DisplayFormatType, Distribution, EquivalenceProperties, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, }; @@ -65,12 +63,12 @@ use datafusion_common::{ }; use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::TaskContext; +use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::equivalence::join_equivalence_properties; -use datafusion_physical_expr::intervals::ExprIntervalGraph; +use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; use ahash::RandomState; -use futures::stream::{select, BoxStream}; -use futures::{Stream, StreamExt}; +use futures::Stream; use hashbrown::HashSet; use parking_lot::Mutex; @@ -187,65 +185,6 @@ pub struct SymmetricHashJoinExec { mode: StreamJoinPartitionMode, } -#[derive(Debug)] -struct SymmetricHashJoinSideMetrics { - /// Number of batches consumed by this operator - input_batches: metrics::Count, - /// Number of rows consumed by this operator - input_rows: metrics::Count, -} - -/// Metrics for HashJoinExec -#[derive(Debug)] -struct SymmetricHashJoinMetrics { - /// Number of left batches/rows consumed by this operator - left: SymmetricHashJoinSideMetrics, - /// Number of right batches/rows consumed by this operator - right: SymmetricHashJoinSideMetrics, - /// Memory used by sides in bytes - pub(crate) stream_memory_usage: metrics::Gauge, - /// Number of batches produced by this operator - output_batches: metrics::Count, - /// Number of rows produced by this operator - output_rows: metrics::Count, -} - -impl SymmetricHashJoinMetrics { - pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { - let input_batches = - MetricBuilder::new(metrics).counter("input_batches", partition); - let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); - let left = SymmetricHashJoinSideMetrics { - input_batches, - input_rows, - }; - - let input_batches = - MetricBuilder::new(metrics).counter("input_batches", partition); - let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); - let right = SymmetricHashJoinSideMetrics { - input_batches, - input_rows, - }; - - let stream_memory_usage = - MetricBuilder::new(metrics).gauge("stream_memory_usage", partition); - - let output_batches = - MetricBuilder::new(metrics).counter("output_batches", partition); - - let output_rows = MetricBuilder::new(metrics).output_rows(partition); - - Self { - left, - right, - output_batches, - stream_memory_usage, - output_rows, - } - } -} - impl SymmetricHashJoinExec { /// Tries to create a new [SymmetricHashJoinExec]. /// # Error @@ -327,6 +266,11 @@ impl SymmetricHashJoinExec { self.null_equals_null } + /// Get partition mode + pub fn partition_mode(&self) -> StreamJoinPartitionMode { + self.mode + } + /// Check if order information covers every column in the filter expression. pub fn check_if_order_information_available(&self) -> Result { if let Some(filter) = self.filter() { @@ -513,21 +457,9 @@ impl ExecutionPlan for SymmetricHashJoinExec { let right_side_joiner = OneSideHashJoiner::new(JoinSide::Right, on_right, self.right.schema()); - let left_stream = self - .left - .execute(partition, context.clone())? - .map(|val| (JoinSide::Left, val)); - - let right_stream = self - .right - .execute(partition, context.clone())? - .map(|val| (JoinSide::Right, val)); - // This function will attempt to pull items from both streams. - // Each stream will be polled in a round-robin fashion, and whenever a stream is - // ready to yield an item that item is yielded. - // After one of the two input streams completes, the remaining one will be polled exclusively. - // The returned stream completes when both input streams have completed. - let input_stream = select(left_stream, right_stream).boxed(); + let left_stream = self.left.execute(partition, context.clone())?; + + let right_stream = self.right.execute(partition, context.clone())?; let reservation = Arc::new(Mutex::new( MemoryConsumer::new(format!("SymmetricHashJoinStream[{partition}]")) @@ -538,7 +470,8 @@ impl ExecutionPlan for SymmetricHashJoinExec { } Ok(Box::pin(SymmetricHashJoinStream { - input_stream, + left_stream, + right_stream, schema: self.schema(), filter: self.filter.clone(), join_type: self.join_type, @@ -546,12 +479,12 @@ impl ExecutionPlan for SymmetricHashJoinExec { left: left_side_joiner, right: right_side_joiner, column_indices: self.column_indices.clone(), - metrics: SymmetricHashJoinMetrics::new(partition, &self.metrics), + metrics: StreamJoinMetrics::new(partition, &self.metrics), graph, left_sorted_filter_expr, right_sorted_filter_expr, null_equals_null: self.null_equals_null, - final_result: false, + state: EagerJoinStreamState::PullRight, reservation, output_state: HashJoinOutputState::default(), })) @@ -560,8 +493,9 @@ impl ExecutionPlan for SymmetricHashJoinExec { /// A stream that issues [RecordBatch]es as they arrive from the right of the join. struct SymmetricHashJoinStream { - /// Input stream - input_stream: BoxStream<'static, (JoinSide, Result)>, + /// Input streams + left_stream: SendableRecordBatchStream, + right_stream: SendableRecordBatchStream, /// Input schema schema: Arc, /// join filter @@ -585,13 +519,11 @@ struct SymmetricHashJoinStream { /// If null_equals_null is true, null == null else null != null null_equals_null: bool, /// Metrics - metrics: SymmetricHashJoinMetrics, + metrics: StreamJoinMetrics, /// Memory reservation reservation: SharedMemoryReservation, - /// Flag indicating whether there is nothing to process anymore - final_result: bool, - /// Stream state for compatibility with HashJoinExec - output_state: HashJoinOutputState, + /// State machine for input execution + state: EagerJoinStreamState, } impl RecordBatchStream for SymmetricHashJoinStream { @@ -626,7 +558,9 @@ impl Stream for SymmetricHashJoinStream { /// # Returns /// /// A [Result] object that contains the pruning length. The function will return -/// an error if there is an issue evaluating the build side filter expression. +/// an error if +/// - there is an issue evaluating the build side filter expression; +/// - there is an issue converting the build side filter expression into an array fn determine_prune_length( buffer: &RecordBatch, build_side_filter_expr: &SortedFilterExpr, @@ -637,13 +571,13 @@ fn determine_prune_length( let batch_arr = origin_sorted_expr .expr .evaluate(buffer)? - .into_array(buffer.num_rows()); + .into_array(buffer.num_rows())?; // Get the lower or upper interval based on the sort direction let target = if origin_sorted_expr.options.descending { - interval.upper.value.clone() + interval.upper().clone() } else { - interval.lower.value.clone() + interval.lower().clone() }; // Perform binary search on the array to determine the length of the record batch to be pruned @@ -761,7 +695,9 @@ pub(crate) fn build_side_determined_results( column_indices: &[ColumnIndex], ) -> Result> { // Check if we need to produce a result in the final output: - if need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) { + if prune_length > 0 + && need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) + { // Calculate the indices for build and probe sides based on join type and build side: let (build_indices, probe_indices) = calculate_indices_by_join_type( build_hash_joiner.build_side, @@ -837,8 +773,7 @@ pub(crate) fn join_with_probe_batch( filter, build_hash_joiner.build_side, Some(build_hash_joiner.deleted_offset), - usize::MAX, - output_state, + false, )?; // Resetting state to avoid potential overflows @@ -955,31 +890,22 @@ impl OneSideHashJoiner { random_state, &mut self.hashes_buffer, self.deleted_offset, + false, )?; Ok(()) } - /// Prunes the internal buffer. - /// - /// Argument `probe_batch` is used to update the intervals of the sorted - /// filter expressions. The updated build interval determines the new length - /// of the build side. If there are rows to prune, they are removed from the - /// internal buffer. + /// Calculate prune length. /// /// # Arguments /// - /// * `schema` - The schema of the final output record batch - /// * `probe_batch` - Incoming RecordBatch of the probe side. + /// * `build_side_sorted_filter_expr` - Build side mutable sorted filter expression.. /// * `probe_side_sorted_filter_expr` - Probe side mutable sorted filter expression. - /// * `join_type` - The type of join (e.g. inner, left, right, etc.). - /// * `column_indices` - A vector of column indices that specifies which columns from the - /// build side should be included in the output. /// * `graph` - A mutable reference to the physical expression graph. /// /// # Returns /// - /// If there are rows to prune, returns the pruned build side record batch wrapped in an `Ok` variant. - /// Otherwise, returns `Ok(None)`. + /// A Result object that contains the pruning length. pub(crate) fn calculate_prune_length_with_probe_batch( &mut self, build_side_sorted_filter_expr: &mut SortedFilterExpr, @@ -1000,7 +926,7 @@ impl OneSideHashJoiner { filter_intervals.push((expr.node_index(), expr.interval().clone())) } // Update the physical expression graph using the join filter intervals: - graph.update_ranges(&mut filter_intervals)?; + graph.update_ranges(&mut filter_intervals, Interval::CERTAINLY_TRUE)?; // Extract the new join filter interval for the build side: let calculated_build_side_interval = filter_intervals.remove(0).1; // If the intervals have not changed, return early without pruning: @@ -1019,7 +945,7 @@ impl OneSideHashJoiner { prune_length, self.deleted_offset as u64, HASHMAP_SHRINK_SCALE_FACTOR, - )?; + ); // Remove pruned rows from the visited rows set: for row in self.deleted_offset..(self.deleted_offset + prune_length) { self.visited_rows.remove(&row); @@ -1034,10 +960,104 @@ impl OneSideHashJoiner { } } +impl EagerJoinStream for SymmetricHashJoinStream { + fn process_batch_from_right( + &mut self, + batch: RecordBatch, + ) -> Result>> { + self.perform_join_for_given_side(batch, JoinSide::Right) + .map(|maybe_batch| { + if maybe_batch.is_some() { + StatefulStreamResult::Ready(maybe_batch) + } else { + StatefulStreamResult::Continue + } + }) + } + + fn process_batch_from_left( + &mut self, + batch: RecordBatch, + ) -> Result>> { + self.perform_join_for_given_side(batch, JoinSide::Left) + .map(|maybe_batch| { + if maybe_batch.is_some() { + StatefulStreamResult::Ready(maybe_batch) + } else { + StatefulStreamResult::Continue + } + }) + } + + fn process_batch_after_left_end( + &mut self, + right_batch: RecordBatch, + ) -> Result>> { + self.process_batch_from_right(right_batch) + } + + fn process_batch_after_right_end( + &mut self, + left_batch: RecordBatch, + ) -> Result>> { + self.process_batch_from_left(left_batch) + } + + fn process_batches_before_finalization( + &mut self, + ) -> Result>> { + // Get the left side results: + let left_result = build_side_determined_results( + &self.left, + &self.schema, + self.left.input_buffer.num_rows(), + self.right.input_buffer.schema(), + self.join_type, + &self.column_indices, + )?; + // Get the right side results: + let right_result = build_side_determined_results( + &self.right, + &self.schema, + self.right.input_buffer.num_rows(), + self.left.input_buffer.schema(), + self.join_type, + &self.column_indices, + )?; + + // Combine the left and right results: + let result = combine_two_batches(&self.schema, left_result, right_result)?; + + // Update the metrics and return the result: + if let Some(batch) = &result { + // Update the metrics: + self.metrics.output_batches.add(1); + self.metrics.output_rows.add(batch.num_rows()); + return Ok(StatefulStreamResult::Ready(result)); + } + Ok(StatefulStreamResult::Continue) + } + + fn right_stream(&mut self) -> &mut SendableRecordBatchStream { + &mut self.right_stream + } + + fn left_stream(&mut self) -> &mut SendableRecordBatchStream { + &mut self.left_stream + } + + fn set_state(&mut self, state: EagerJoinStreamState) { + self.state = state; + } + + fn state(&mut self) -> EagerJoinStreamState { + self.state.clone() + } +} + impl SymmetricHashJoinStream { fn size(&self) -> usize { let mut size = 0; - size += std::mem::size_of_val(&self.input_stream); size += std::mem::size_of_val(&self.schema); size += std::mem::size_of_val(&self.filter); size += std::mem::size_of_val(&self.join_type); @@ -1050,166 +1070,111 @@ impl SymmetricHashJoinStream { size += std::mem::size_of_val(&self.random_state); size += std::mem::size_of_val(&self.null_equals_null); size += std::mem::size_of_val(&self.metrics); - size += std::mem::size_of_val(&self.final_result); size } - /// Polls the next result of the join operation. - /// - /// If the result of the join is ready, it returns the next record batch. - /// If the join has completed and there are no more results, it returns - /// `Poll::Ready(None)`. If the join operation is not complete, but the - /// current stream is not ready yet, it returns `Poll::Pending`. - fn poll_next_impl( + + /// Performs a join operation for the specified `probe_side` (either left or right). + /// This function: + /// 1. Determines which side is the probe and which is the build side. + /// 2. Updates metrics based on the batch that was polled. + /// 3. Executes the join with the given `probe_batch`. + /// 4. Optionally computes anti-join results if all conditions are met. + /// 5. Combines the results and returns a combined batch or `None` if no batch was produced. + fn perform_join_for_given_side( &mut self, - cx: &mut std::task::Context<'_>, - ) -> Poll>> { - loop { - // Poll the next batch from `input_stream`: - match self.input_stream.poll_next_unpin(cx) { - // Batch is available - Poll::Ready(Some((side, Ok(probe_batch)))) => { - // Determine which stream should be polled next. The side the - // RecordBatch comes from becomes the probe side. - let ( - probe_hash_joiner, - build_hash_joiner, - probe_side_sorted_filter_expr, - build_side_sorted_filter_expr, - probe_side_metrics, - ) = if side.eq(&JoinSide::Left) { - ( - &mut self.left, - &mut self.right, - &mut self.left_sorted_filter_expr, - &mut self.right_sorted_filter_expr, - &mut self.metrics.left, - ) - } else { - ( - &mut self.right, - &mut self.left, - &mut self.right_sorted_filter_expr, - &mut self.left_sorted_filter_expr, - &mut self.metrics.right, - ) - }; - // Update the metrics for the stream that was polled: - probe_side_metrics.input_batches.add(1); - probe_side_metrics.input_rows.add(probe_batch.num_rows()); - // Update the internal state of the hash joiner for the build side: - probe_hash_joiner - .update_internal_state(&probe_batch, &self.random_state)?; - // Join the two sides: - let equal_result = join_with_probe_batch( - build_hash_joiner, - probe_hash_joiner, - &self.schema, - self.join_type, - self.filter.as_ref(), - &probe_batch, - &self.column_indices, - &self.random_state, - self.null_equals_null, - &mut self.output_state, - )?; - // Increment the offset for the probe hash joiner: - probe_hash_joiner.offset += probe_batch.num_rows(); - - let anti_result = if let ( - Some(build_side_sorted_filter_expr), - Some(probe_side_sorted_filter_expr), - Some(graph), - ) = ( - build_side_sorted_filter_expr.as_mut(), - probe_side_sorted_filter_expr.as_mut(), - self.graph.as_mut(), - ) { - // Calculate filter intervals: - calculate_filter_expr_intervals( - &build_hash_joiner.input_buffer, - build_side_sorted_filter_expr, - &probe_batch, - probe_side_sorted_filter_expr, - )?; - let prune_length = build_hash_joiner - .calculate_prune_length_with_probe_batch( - build_side_sorted_filter_expr, - probe_side_sorted_filter_expr, - graph, - )?; - - if prune_length > 0 { - let res = build_side_determined_results( - build_hash_joiner, - &self.schema, - prune_length, - probe_batch.schema(), - self.join_type, - &self.column_indices, - )?; - build_hash_joiner.prune_internal_state(prune_length)?; - res - } else { - None - } - } else { - None - }; - - // Combine results: - let result = - combine_two_batches(&self.schema, equal_result, anti_result)?; - let capacity = self.size(); - self.metrics.stream_memory_usage.set(capacity); - self.reservation.lock().try_resize(capacity)?; - // Update the metrics if we have a batch; otherwise, continue the loop. - if let Some(batch) = &result { - self.metrics.output_batches.add(1); - self.metrics.output_rows.add(batch.num_rows()); - return Poll::Ready(Ok(result).transpose()); - } - } - Poll::Ready(Some((_, Err(e)))) => return Poll::Ready(Some(Err(e))), - Poll::Ready(None) => { - // If the final result has already been obtained, return `Poll::Ready(None)`: - if self.final_result { - return Poll::Ready(None); - } - self.final_result = true; - // Get the left side results: - let left_result = build_side_determined_results( - &self.left, - &self.schema, - self.left.input_buffer.num_rows(), - self.right.input_buffer.schema(), - self.join_type, - &self.column_indices, - )?; - // Get the right side results: - let right_result = build_side_determined_results( - &self.right, - &self.schema, - self.right.input_buffer.num_rows(), - self.left.input_buffer.schema(), - self.join_type, - &self.column_indices, - )?; - - // Combine the left and right results: - let result = - combine_two_batches(&self.schema, left_result, right_result)?; - - // Update the metrics and return the result: - if let Some(batch) = &result { - // Update the metrics: - self.metrics.output_batches.add(1); - self.metrics.output_rows.add(batch.num_rows()); - return Poll::Ready(Ok(result).transpose()); - } - } - Poll::Pending => return Poll::Pending, - } + probe_batch: RecordBatch, + probe_side: JoinSide, + ) -> Result> { + let ( + probe_hash_joiner, + build_hash_joiner, + probe_side_sorted_filter_expr, + build_side_sorted_filter_expr, + probe_side_metrics, + ) = if probe_side.eq(&JoinSide::Left) { + ( + &mut self.left, + &mut self.right, + &mut self.left_sorted_filter_expr, + &mut self.right_sorted_filter_expr, + &mut self.metrics.left, + ) + } else { + ( + &mut self.right, + &mut self.left, + &mut self.right_sorted_filter_expr, + &mut self.left_sorted_filter_expr, + &mut self.metrics.right, + ) + }; + // Update the metrics for the stream that was polled: + probe_side_metrics.input_batches.add(1); + probe_side_metrics.input_rows.add(probe_batch.num_rows()); + // Update the internal state of the hash joiner for the build side: + probe_hash_joiner.update_internal_state(&probe_batch, &self.random_state)?; + // Join the two sides: + let equal_result = join_with_probe_batch( + build_hash_joiner, + probe_hash_joiner, + &self.schema, + self.join_type, + self.filter.as_ref(), + &probe_batch, + &self.column_indices, + &self.random_state, + self.null_equals_null, + )?; + // Increment the offset for the probe hash joiner: + probe_hash_joiner.offset += probe_batch.num_rows(); + + let anti_result = if let ( + Some(build_side_sorted_filter_expr), + Some(probe_side_sorted_filter_expr), + Some(graph), + ) = ( + build_side_sorted_filter_expr.as_mut(), + probe_side_sorted_filter_expr.as_mut(), + self.graph.as_mut(), + ) { + // Calculate filter intervals: + calculate_filter_expr_intervals( + &build_hash_joiner.input_buffer, + build_side_sorted_filter_expr, + &probe_batch, + probe_side_sorted_filter_expr, + )?; + let prune_length = build_hash_joiner + .calculate_prune_length_with_probe_batch( + build_side_sorted_filter_expr, + probe_side_sorted_filter_expr, + graph, + )?; + let result = build_side_determined_results( + build_hash_joiner, + &self.schema, + prune_length, + probe_batch.schema(), + self.join_type, + &self.column_indices, + )?; + build_hash_joiner.prune_internal_state(prune_length)?; + result + } else { + None + }; + + // Combine results: + let result = combine_two_batches(&self.schema, equal_result, anti_result)?; + let capacity = self.size(); + self.metrics.stream_memory_usage.set(capacity); + self.reservation.lock().try_resize(capacity)?; + // Update the metrics if we have a batch; otherwise, continue the loop. + if let Some(batch) = &result { + self.metrics.output_batches.add(1); + self.metrics.output_rows.add(batch.num_rows()); } + Ok(result) } } @@ -1219,10 +1184,9 @@ mod tests { use std::sync::Mutex; use super::*; - use crate::joins::hash_join_utils::tests::complicated_filter; use crate::joins::test_utils::{ - build_sides_record_batches, compare_batches, create_memory_table, - join_expr_tests_fixture_f64, join_expr_tests_fixture_i32, + build_sides_record_batches, compare_batches, complicated_filter, + create_memory_table, join_expr_tests_fixture_f64, join_expr_tests_fixture_i32, join_expr_tests_fixture_temporal, partitioned_hash_join_with_filter, partitioned_sym_join_with_filter, split_record_batches, }; @@ -1849,6 +1813,73 @@ mod tests { Ok(()) } + #[tokio::test(flavor = "multi_thread")] + async fn complex_join_all_one_ascending_equivalence() -> Result<()> { + let cardinality = (3, 4); + let join_type = JoinType::Full; + + // a + b > c + 10 AND a + b < c + 100 + let config = SessionConfig::new().with_repartition_joins(false); + // let session_ctx = SessionContext::with_config(config); + // let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default().with_session_config(config)); + let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?; + let left_schema = &left_partition[0].schema(); + let right_schema = &right_partition[0].schema(); + let left_sorted = vec![ + vec![PhysicalSortExpr { + expr: col("la1", left_schema)?, + options: SortOptions::default(), + }], + vec![PhysicalSortExpr { + expr: col("la2", left_schema)?, + options: SortOptions::default(), + }], + ]; + + let right_sorted = vec![PhysicalSortExpr { + expr: col("ra1", right_schema)?, + options: SortOptions::default(), + }]; + + let (left, right) = create_memory_table( + left_partition, + right_partition, + left_sorted, + vec![right_sorted], + )?; + + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + + let intermediate_schema = Schema::new(vec![ + Field::new("0", DataType::Int32, true), + Field::new("1", DataType::Int32, true), + Field::new("2", DataType::Int32, true), + ]); + let filter_expr = complicated_filter(&intermediate_schema)?; + let column_indices = vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 4, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + experiment(left, right, Some(filter), join_type, on, task_ctx).await?; + Ok(()) + } + #[rstest] #[tokio::test(flavor = "multi_thread")] async fn testing_with_temporal_columns( @@ -1868,7 +1899,7 @@ mod tests { (12, 17), )] cardinality: (i32, i32), - #[values(0, 1)] case_expr: usize, + #[values(0, 1, 2)] case_expr: usize, ) -> Result<()> { let session_config = SessionConfig::new().with_repartition_joins(false); let task_ctx = TaskContext::default().with_session_config(session_config); @@ -1933,6 +1964,7 @@ mod tests { experiment(left, right, Some(filter), join_type, on, task_ctx).await?; Ok(()) } + #[rstest] #[tokio::test(flavor = "multi_thread")] async fn test_with_interval_columns( diff --git a/datafusion/physical-plan/src/joins/test_utils.rs b/datafusion/physical-plan/src/joins/test_utils.rs index bb4a86199112e..fbd52ddf0c704 100644 --- a/datafusion/physical-plan/src/joins/test_utils.rs +++ b/datafusion/physical-plan/src/joins/test_utils.rs @@ -17,6 +17,9 @@ //! This file has test utils for hash joins +use std::sync::Arc; +use std::usize; + use crate::joins::utils::{JoinFilter, JoinOn}; use crate::joins::{ HashJoinExec, PartitionMode, StreamJoinPartitionMode, SymmetricHashJoinExec, @@ -24,24 +27,24 @@ use crate::joins::{ use crate::memory::MemoryExec; use crate::repartition::RepartitionExec; use crate::{common, ExecutionPlan, Partitioning}; + use arrow::util::pretty::pretty_format_batches; use arrow_array::{ ArrayRef, Float64Array, Int32Array, IntervalDayTimeArray, RecordBatch, TimestampMillisecondArray, }; -use arrow_schema::Schema; -use datafusion_common::Result; -use datafusion_common::ScalarValue; +use arrow_schema::{DataType, Schema}; +use datafusion_common::{Result, ScalarValue}; use datafusion_execution::TaskContext; use datafusion_expr::{JoinType, Operator}; +use datafusion_physical_expr::expressions::{binary, cast, col, lit}; use datafusion_physical_expr::intervals::test_utils::{ gen_conjunctive_numerical_expr, gen_conjunctive_temporal_expr, }; use datafusion_physical_expr::{LexOrdering, PhysicalExpr}; + use rand::prelude::StdRng; use rand::{Rng, SeedableRng}; -use std::sync::Arc; -use std::usize; pub fn compare_batches(collected_1: &[RecordBatch], collected_2: &[RecordBatch]) { // compare @@ -240,6 +243,20 @@ pub fn join_expr_tests_fixture_temporal( ScalarValue::TimestampMillisecond(Some(1672574402000), None), // 2023-01-01:12.00.02 schema, ), + // constructs ((left_col - DURATION '3 secs') > (right_col - DURATION '2 secs')) AND ((left_col - DURATION '5 secs') < (right_col - DURATION '4 secs')) + 2 => gen_conjunctive_temporal_expr( + left_col, + right_col, + Operator::Minus, + Operator::Minus, + Operator::Minus, + Operator::Minus, + ScalarValue::DurationMillisecond(Some(3000)), // 3 secs + ScalarValue::DurationMillisecond(Some(2000)), // 2 secs + ScalarValue::DurationMillisecond(Some(5000)), // 5 secs + ScalarValue::DurationMillisecond(Some(4000)), // 4 secs + schema, + ), _ => unreachable!(), } } @@ -500,3 +517,51 @@ pub fn create_memory_table( .with_sort_information(right_sorted); Ok((Arc::new(left), Arc::new(right))) } + +/// Filter expr for a + b > c + 10 AND a + b < c + 100 +pub(crate) fn complicated_filter( + filter_schema: &Schema, +) -> Result> { + let left_expr = binary( + cast( + binary( + col("0", filter_schema)?, + Operator::Plus, + col("1", filter_schema)?, + filter_schema, + )?, + filter_schema, + DataType::Int64, + )?, + Operator::Gt, + binary( + cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?, + Operator::Plus, + lit(ScalarValue::Int64(Some(10))), + filter_schema, + )?, + filter_schema, + )?; + + let right_expr = binary( + cast( + binary( + col("0", filter_schema)?, + Operator::Plus, + col("1", filter_schema)?, + filter_schema, + )?, + filter_schema, + DataType::Int64, + )?, + Operator::Lt, + binary( + cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?, + Operator::Plus, + lit(ScalarValue::Int64(Some(100))), + filter_schema, + )?, + filter_schema, + )?; + binary(left_expr, Operator::And, right_expr, filter_schema) +} diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 53c762ff9511c..36c7143ee0d8c 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -18,19 +18,19 @@ //! Join related functionality used both on logical and physical plans use std::collections::HashSet; +use std::fmt::{self, Debug}; use std::future::Future; -use std::ops::Range; +use std::ops::{IndexMut, Range}; use std::sync::Arc; use std::task::{Context, Poll}; use std::usize; -use crate::joins::hash_join_utils::{build_filter_input_order, SortedFilterExpr}; use crate::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder}; use crate::{ColumnStatistics, ExecutionPlan, Partitioning, Statistics}; use arrow::array::{ downcast_array, new_null_array, Array, BooleanBufferBuilder, UInt32Array, - UInt32Builder, UInt64Array, + UInt32BufferBuilder, UInt32Builder, UInt64Array, UInt64BufferBuilder, }; use arrow::compute; use arrow::datatypes::{Field, Schema, SchemaBuilder}; @@ -40,12 +40,11 @@ use arrow_buffer::ArrowNativeType; use datafusion_common::cast::as_boolean_array; use datafusion_common::stats::Precision; use datafusion_common::{ - plan_datafusion_err, plan_err, DataFusionError, JoinSide, JoinType, Result, - SharedResult, + plan_err, DataFusionError, JoinSide, JoinType, Result, SharedResult, }; +use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::equivalence::add_offset_to_expr; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::intervals::{ExprIntervalGraph, Interval, IntervalBound}; use datafusion_physical_expr::utils::merge_vectors; use datafusion_physical_expr::{ LexOrdering, LexOrderingRef, PhysicalExpr, PhysicalSortExpr, @@ -53,8 +52,211 @@ use datafusion_physical_expr::{ use futures::future::{BoxFuture, Shared}; use futures::{ready, FutureExt}; +use hashbrown::raw::RawTable; use parking_lot::Mutex; +/// Maps a `u64` hash value based on the build side ["on" values] to a list of indices with this key's value. +/// +/// By allocating a `HashMap` with capacity for *at least* the number of rows for entries at the build side, +/// we make sure that we don't have to re-hash the hashmap, which needs access to the key (the hash in this case) value. +/// +/// E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and 8 for hash value 1 +/// As the key is a hash value, we need to check possible hash collisions in the probe stage +/// During this stage it might be the case that a row is contained the same hashmap value, +/// but the values don't match. Those are checked in the [`equal_rows_arr`](crate::joins::hash_join::equal_rows_arr) method. +/// +/// The indices (values) are stored in a separate chained list stored in the `Vec`. +/// +/// The first value (+1) is stored in the hashmap, whereas the next value is stored in array at the position value. +/// +/// The chain can be followed until the value "0" has been reached, meaning the end of the list. +/// Also see chapter 5.3 of [Balancing vectorized query execution with bandwidth-optimized storage](https://dare.uva.nl/search?identifier=5ccbb60a-38b8-4eeb-858a-e7735dd37487) +/// +/// # Example +/// +/// ``` text +/// See the example below: +/// +/// Insert (10,1) <-- insert hash value 10 with row index 1 +/// map: +/// ---------- +/// | 10 | 2 | +/// ---------- +/// next: +/// --------------------- +/// | 0 | 0 | 0 | 0 | 0 | +/// --------------------- +/// Insert (20,2) +/// map: +/// ---------- +/// | 10 | 2 | +/// | 20 | 3 | +/// ---------- +/// next: +/// --------------------- +/// | 0 | 0 | 0 | 0 | 0 | +/// --------------------- +/// Insert (10,3) <-- collision! row index 3 has a hash value of 10 as well +/// map: +/// ---------- +/// | 10 | 4 | +/// | 20 | 3 | +/// ---------- +/// next: +/// --------------------- +/// | 0 | 0 | 0 | 2 | 0 | <--- hash value 10 maps to 4,2 (which means indices values 3,1) +/// --------------------- +/// Insert (10,4) <-- another collision! row index 4 ALSO has a hash value of 10 +/// map: +/// --------- +/// | 10 | 5 | +/// | 20 | 3 | +/// --------- +/// next: +/// --------------------- +/// | 0 | 0 | 0 | 2 | 4 | <--- hash value 10 maps to 5,4,2 (which means indices values 4,3,1) +/// --------------------- +/// ``` +pub struct JoinHashMap { + // Stores hash value to last row index + map: RawTable<(u64, u64)>, + // Stores indices in chained list data structure + next: Vec, +} + +impl JoinHashMap { + #[cfg(test)] + pub(crate) fn new(map: RawTable<(u64, u64)>, next: Vec) -> Self { + Self { map, next } + } + + pub(crate) fn with_capacity(capacity: usize) -> Self { + JoinHashMap { + map: RawTable::with_capacity(capacity), + next: vec![0; capacity], + } + } +} + +// Trait defining methods that must be implemented by a hash map type to be used for joins. +pub trait JoinHashMapType { + /// The type of list used to store the next list + type NextType: IndexMut; + /// Extend with zero + fn extend_zero(&mut self, len: usize); + /// Returns mutable references to the hash map and the next. + fn get_mut(&mut self) -> (&mut RawTable<(u64, u64)>, &mut Self::NextType); + /// Returns a reference to the hash map. + fn get_map(&self) -> &RawTable<(u64, u64)>; + /// Returns a reference to the next. + fn get_list(&self) -> &Self::NextType; + + /// Updates hashmap from iterator of row indices & row hashes pairs. + fn update_from_iter<'a>( + &mut self, + iter: impl Iterator, + deleted_offset: usize, + ) { + let (mut_map, mut_list) = self.get_mut(); + for (row, hash_value) in iter { + let item = mut_map.get_mut(*hash_value, |(hash, _)| *hash_value == *hash); + if let Some((_, index)) = item { + // Already exists: add index to next array + let prev_index = *index; + // Store new value inside hashmap + *index = (row + 1) as u64; + // Update chained Vec at `row` with previous value + mut_list[row - deleted_offset] = prev_index; + } else { + mut_map.insert( + *hash_value, + // store the value + 1 as 0 value reserved for end of list + (*hash_value, (row + 1) as u64), + |(hash, _)| *hash, + ); + // chained list at `row` is already initialized with 0 + // meaning end of list + } + } + } + + /// Returns all pairs of row indices matched by hash. + /// + /// This method only compares hashes, so additional further check for actual values + /// equality may be required. + fn get_matched_indices<'a>( + &self, + iter: impl Iterator, + deleted_offset: Option, + ) -> (UInt32BufferBuilder, UInt64BufferBuilder) { + let mut input_indices = UInt32BufferBuilder::new(0); + let mut match_indices = UInt64BufferBuilder::new(0); + + let hash_map = self.get_map(); + let next_chain = self.get_list(); + for (row_idx, hash_value) in iter { + // Get the hash and find it in the index + if let Some((_, index)) = + hash_map.get(*hash_value, |(hash, _)| *hash_value == *hash) + { + let mut i = *index - 1; + loop { + let match_row_idx = if let Some(offset) = deleted_offset { + // This arguments means that we prune the next index way before here. + if i < offset as u64 { + // End of the list due to pruning + break; + } + i - offset as u64 + } else { + i + }; + match_indices.append(match_row_idx); + input_indices.append(row_idx as u32); + // Follow the chain to get the next index value + let next = next_chain[match_row_idx as usize]; + if next == 0 { + // end of list + break; + } + i = next - 1; + } + } + } + + (input_indices, match_indices) + } +} + +/// Implementation of `JoinHashMapType` for `JoinHashMap`. +impl JoinHashMapType for JoinHashMap { + type NextType = Vec; + + // Void implementation + fn extend_zero(&mut self, _: usize) {} + + /// Get mutable references to the hash map and the next. + fn get_mut(&mut self) -> (&mut RawTable<(u64, u64)>, &mut Self::NextType) { + (&mut self.map, &mut self.next) + } + + /// Get a reference to the hash map. + fn get_map(&self) -> &RawTable<(u64, u64)> { + &self.map + } + + /// Get a reference to the next. + fn get_list(&self) -> &Self::NextType { + &self.next + } +} + +impl fmt::Debug for JoinHashMap { + fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { + Ok(()) + } +} + /// The on clause of the join, as vector of (left, right) columns. pub type JoinOn = Vec<(Column, Column)>; /// Reference for JoinOn. @@ -222,7 +424,7 @@ pub fn calculate_join_output_ordering( } /// Information about the index and placement (left or right) of the columns -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct ColumnIndex { /// Index of the column pub index: usize, @@ -587,8 +789,8 @@ fn estimate_inner_join_cardinality( ); } - let left_max_distinct = max_distinct_count(&left_stats.num_rows, left_stat)?; - let right_max_distinct = max_distinct_count(&right_stats.num_rows, right_stat)?; + let left_max_distinct = max_distinct_count(&left_stats.num_rows, left_stat); + let right_max_distinct = max_distinct_count(&right_stats.num_rows, right_stat); let max_distinct = left_max_distinct.max(&right_max_distinct); if max_distinct.get_value().is_some() { // Seems like there are a few implementations of this algorithm that implement @@ -619,48 +821,60 @@ fn estimate_inner_join_cardinality( } /// Estimate the number of maximum distinct values that can be present in the -/// given column from its statistics. -/// -/// If distinct_count is available, uses it directly. If the column numeric, and -/// has min/max values, then they might be used as a fallback option. Otherwise, -/// returns None. +/// given column from its statistics. If distinct_count is available, uses it +/// directly. Otherwise, if the column is numeric and has min/max values, it +/// estimates the maximum distinct count from those. fn max_distinct_count( num_rows: &Precision, stats: &ColumnStatistics, -) -> Option> { - match ( - &stats.distinct_count, - stats.max_value.get_value(), - stats.min_value.get_value(), - ) { - (Precision::Exact(_), _, _) | (Precision::Inexact(_), _, _) => { - Some(stats.distinct_count.clone()) - } - (_, Some(max), Some(min)) => { - let numeric_range = Interval::new( - IntervalBound::new(min.clone(), false), - IntervalBound::new(max.clone(), false), - ) - .cardinality() - .ok() - .flatten()? as usize; - - // The number can never be greater than the number of rows we have (minus - // the nulls, since they don't count as distinct values). - let ceiling = - num_rows.get_value()? - stats.null_count.get_value().unwrap_or(&0); - Some( - if num_rows.is_exact().unwrap_or(false) - && stats.max_value.is_exact().unwrap_or(false) - && stats.min_value.is_exact().unwrap_or(false) +) -> Precision { + match &stats.distinct_count { + dc @ (Precision::Exact(_) | Precision::Inexact(_)) => dc.clone(), + _ => { + // The number can never be greater than the number of rows we have + // minus the nulls (since they don't count as distinct values). + let result = match num_rows { + Precision::Absent => Precision::Absent, + Precision::Inexact(count) => { + Precision::Inexact(count - stats.null_count.get_value().unwrap_or(&0)) + } + Precision::Exact(count) => { + let count = count - stats.null_count.get_value().unwrap_or(&0); + if stats.null_count.is_exact().unwrap_or(false) { + Precision::Exact(count) + } else { + Precision::Inexact(count) + } + } + }; + // Cap the estimate using the number of possible values: + if let (Some(min), Some(max)) = + (stats.min_value.get_value(), stats.max_value.get_value()) + { + if let Some(range_dc) = Interval::try_new(min.clone(), max.clone()) + .ok() + .and_then(|e| e.cardinality()) { - Precision::Exact(numeric_range.min(ceiling)) - } else { - Precision::Inexact(numeric_range.min(ceiling)) - }, - ) + let range_dc = range_dc as usize; + // Note that the `unwrap` calls in the below statement are safe. + return if matches!(result, Precision::Absent) + || &range_dc < result.get_value().unwrap() + { + if stats.min_value.is_exact().unwrap() + && stats.max_value.is_exact().unwrap() + { + Precision::Exact(range_dc) + } else { + Precision::Inexact(range_dc) + } + } else { + result + }; + } + } + + result } - _ => None, } } @@ -710,6 +924,22 @@ impl OnceFut { ), } } + + /// Get shared reference to the result of the computation if it is ready, without consuming it + pub(crate) fn get_shared(&mut self, cx: &mut Context<'_>) -> Poll>> { + if let OnceFutState::Pending(fut) = &mut self.state { + let r = ready!(fut.poll_unpin(cx)); + self.state = OnceFutState::Ready(r); + } + + match &self.state { + OnceFutState::Pending(_) => unreachable!(), + OnceFutState::Ready(r) => Poll::Ready( + r.clone() + .map_err(|e| DataFusionError::External(Box::new(e))), + ), + } + } } /// Some type `join_type` of join need to maintain the matched indices bit map for the left side, and @@ -781,7 +1011,7 @@ pub(crate) fn apply_join_filter_to_indices( let filter_result = filter .expression() .evaluate(&intermediate_batch)? - .into_array(intermediate_batch.num_rows()); + .into_array(intermediate_batch.num_rows())?; let mask = as_boolean_array(&filter_result)?; let left_filtered = compute::filter(&build_indices, mask)?; @@ -1042,88 +1272,71 @@ impl BuildProbeJoinMetrics { } } -/// Updates sorted filter expressions with corresponding node indices from the -/// expression interval graph. +/// The `handle_state` macro is designed to process the result of a state-changing +/// operation, encountered e.g. in implementations of `EagerJoinStream`. It +/// operates on a `StatefulStreamResult` by matching its variants and executing +/// corresponding actions. This macro is used to streamline code that deals with +/// state transitions, reducing boilerplate and improving readability. /// -/// This function iterates through the provided sorted filter expressions, -/// gathers the corresponding node indices from the expression interval graph, -/// and then updates the sorted expressions with these indices. It ensures -/// that these sorted expressions are aligned with the structure of the graph. -fn update_sorted_exprs_with_node_indices( - graph: &mut ExprIntervalGraph, - sorted_exprs: &mut [SortedFilterExpr], -) { - // Extract filter expressions from the sorted expressions: - let filter_exprs = sorted_exprs - .iter() - .map(|expr| expr.filter_expr().clone()) - .collect::>(); - - // Gather corresponding node indices for the extracted filter expressions from the graph: - let child_node_indices = graph.gather_node_indices(&filter_exprs); - - // Iterate through the sorted expressions and the gathered node indices: - for (sorted_expr, (_, index)) in sorted_exprs.iter_mut().zip(child_node_indices) { - // Update each sorted expression with the corresponding node index: - sorted_expr.set_node_index(index); - } +/// # Cases +/// +/// - `Ok(StatefulStreamResult::Continue)`: Continues the loop, indicating the +/// stream join operation should proceed to the next step. +/// - `Ok(StatefulStreamResult::Ready(result))`: Returns a `Poll::Ready` with the +/// result, either yielding a value or indicating the stream is awaiting more +/// data. +/// - `Err(e)`: Returns a `Poll::Ready` containing an error, signaling an issue +/// during the stream join operation. +/// +/// # Arguments +/// +/// * `$match_case`: An expression that evaluates to a `Result>`. +#[macro_export] +macro_rules! handle_state { + ($match_case:expr) => { + match $match_case { + Ok(StatefulStreamResult::Continue) => continue, + Ok(StatefulStreamResult::Ready(result)) => { + Poll::Ready(Ok(result).transpose()) + } + Err(e) => Poll::Ready(Some(Err(e))), + } + }; } -/// Prepares and sorts expressions based on a given filter, left and right execution plans, and sort expressions. +/// The `handle_async_state` macro adapts the `handle_state` macro for use in +/// asynchronous operations, particularly when dealing with `Poll` results within +/// async traits like `EagerJoinStream`. It polls the asynchronous state-changing +/// function using `poll_unpin` and then passes the result to `handle_state` for +/// further processing. /// /// # Arguments /// -/// * `filter` - The join filter to base the sorting on. -/// * `left` - The left execution plan. -/// * `right` - The right execution plan. -/// * `left_sort_exprs` - The expressions to sort on the left side. -/// * `right_sort_exprs` - The expressions to sort on the right side. +/// * `$state_func`: An async function or future that returns a +/// `Result>`. +/// * `$cx`: The context to be passed for polling, usually of type `&mut Context`. /// -/// # Returns +#[macro_export] +macro_rules! handle_async_state { + ($state_func:expr, $cx:expr) => { + $crate::handle_state!(ready!($state_func.poll_unpin($cx))) + }; +} + +/// Represents the result of an operation on stateful join stream. /// -/// * A tuple consisting of the sorted filter expression for the left and right sides, and an expression interval graph. -pub fn prepare_sorted_exprs( - filter: &JoinFilter, - left: &Arc, - right: &Arc, - left_sort_exprs: &[PhysicalSortExpr], - right_sort_exprs: &[PhysicalSortExpr], -) -> Result<(SortedFilterExpr, SortedFilterExpr, ExprIntervalGraph)> { - // Build the filter order for the left side - let err = || plan_datafusion_err!("Filter does not include the child order"); - - let left_temp_sorted_filter_expr = build_filter_input_order( - JoinSide::Left, - filter, - &left.schema(), - &left_sort_exprs[0], - )? - .ok_or_else(err)?; - - // Build the filter order for the right side - let right_temp_sorted_filter_expr = build_filter_input_order( - JoinSide::Right, - filter, - &right.schema(), - &right_sort_exprs[0], - )? - .ok_or_else(err)?; - - // Collect the sorted expressions - let mut sorted_exprs = - vec![left_temp_sorted_filter_expr, right_temp_sorted_filter_expr]; - - // Build the expression interval graph - let mut graph = ExprIntervalGraph::try_new(filter.expression().clone())?; - - // Update sorted expressions with node indices - update_sorted_exprs_with_node_indices(&mut graph, &mut sorted_exprs); - - // Swap and remove to get the final sorted filter expressions - let right_sorted_filter_expr = sorted_exprs.swap_remove(1); - let left_sorted_filter_expr = sorted_exprs.swap_remove(0); - - Ok((left_sorted_filter_expr, right_sorted_filter_expr, graph)) +/// This enumueration indicates whether the state produced a result that is +/// ready for use (`Ready`) or if the operation requires continuation (`Continue`). +/// +/// Variants: +/// - `Ready(T)`: Indicates that the operation is complete with a result of type `T`. +/// - `Continue`: Indicates that the operation is not yet complete and requires further +/// processing or more data. When this variant is returned, it typically means that the +/// current invocation of the state did not produce a final result, and the operation +/// should be invoked again later with more data and possibly with a different state. +pub enum StatefulStreamResult { + Ready(T), + Continue, } #[cfg(test)] @@ -1136,7 +1349,7 @@ mod tests { use arrow::error::{ArrowError, Result as ArrowResult}; use arrow_schema::SortOptions; - use datafusion_common::ScalarValue; + use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue}; fn check(left: &[Column], right: &[Column], on: &[(Column, Column)]) -> Result<()> { let left = left @@ -1172,9 +1385,7 @@ mod tests { #[tokio::test] async fn check_error_nesting() { let once_fut = OnceFut::<()>::new(async { - Err(DataFusionError::ArrowError(ArrowError::CsvError( - "some error".to_string(), - ))) + arrow_err!(ArrowError::CsvError("some error".to_string())) }); struct TestFut(OnceFut<()>); @@ -1198,10 +1409,10 @@ mod tests { let wrapped_err = DataFusionError::from(arrow_err_from_fut); let root_err = wrapped_err.find_root(); - assert!(matches!( - root_err, - DataFusionError::ArrowError(ArrowError::CsvError(_)) - )) + let _expected = + arrow_datafusion_err!(ArrowError::CsvError("some error".to_owned())); + + assert!(matches!(root_err, _expected)) } #[test] @@ -1560,7 +1771,7 @@ mod tests { column_statistics: right_col_stats, }, ), - None + Some(Precision::Inexact(100)) ); Ok(()) } diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 081916f4f42d0..cae48c627f688 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -58,6 +58,8 @@ pub mod joins; pub mod limit; pub mod memory; pub mod metrics; +mod ordering; +pub mod placeholder_row; pub mod projection; pub mod repartition; pub mod sorts; @@ -72,6 +74,7 @@ pub mod windows; pub use crate::display::{DefaultDisplay, DisplayAs, DisplayFormatType, VerboseDisplay}; pub use crate::metrics::Metric; +pub use crate::ordering::InputOrderMode; pub use crate::topk::TopK; pub use crate::visitor::{accept, visit_execution_plan, ExecutionPlanVisitor}; @@ -203,7 +206,23 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { .collect() } - /// Get the [`EquivalenceProperties`] within the plan + /// Get the [`EquivalenceProperties`] within the plan. + /// + /// Equivalence properties tell DataFusion what columns are known to be + /// equal, during various optimization passes. By default, this returns "no + /// known equivalences" which is always correct, but may cause DataFusion to + /// unnecessarily resort data. + /// + /// If this ExecutionPlan makes no changes to the schema of the rows flowing + /// through it or how columns within each row relate to each other, it + /// should return the equivalence properties of its input. For + /// example, since `FilterExec` may remove rows from its input, but does not + /// otherwise modify them, it preserves its input equivalence properties. + /// However, since `ProjectionExec` may calculate derived expressions, it + /// needs special handling. + /// + /// See also [`Self::maintains_input_order`] and [`Self::output_ordering`] + /// for related concepts. fn equivalence_properties(&self) -> EquivalenceProperties { EquivalenceProperties::new(self.schema()) } @@ -554,6 +573,13 @@ pub fn unbounded_output(plan: &Arc) -> bool { .unwrap_or(true) } +/// Utility function yielding a string representation of the given [`ExecutionPlan`]. +pub fn get_plan_string(plan: &Arc) -> Vec { + let formatted = displayable(plan.as_ref()).indent(true).to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + actual.iter().map(|elem| elem.to_string()).collect() +} + #[cfg(test)] #[allow(clippy::single_component_path_imports)] use rstest_reuse; diff --git a/datafusion/physical-plan/src/limit.rs b/datafusion/physical-plan/src/limit.rs index c8427f9bc2c66..37e8ffd761598 100644 --- a/datafusion/physical-plan/src/limit.rs +++ b/datafusion/physical-plan/src/limit.rs @@ -229,7 +229,7 @@ impl ExecutionPlan for GlobalLimitExec { let remaining_rows: usize = nr - skip; let mut skip_some_rows_stats = Statistics { num_rows: Precision::Exact(remaining_rows), - column_statistics: col_stats.clone(), + column_statistics: col_stats, total_byte_size: Precision::Absent, }; if !input_stats.num_rows.is_exact().unwrap_or(false) { @@ -878,7 +878,6 @@ mod tests { build_group_by(&csv.schema().clone(), vec!["i".to_string()]), vec![], vec![None], - vec![None], csv.clone(), csv.schema().clone(), )?; diff --git a/datafusion/physical-plan/src/memory.rs b/datafusion/physical-plan/src/memory.rs index 5f1660a225b98..7de474fda11c3 100644 --- a/datafusion/physical-plan/src/memory.rs +++ b/datafusion/physical-plan/src/memory.rs @@ -55,7 +55,7 @@ impl fmt::Debug for MemoryExec { write!(f, "partitions: [...]")?; write!(f, "schema: {:?}", self.projected_schema)?; write!(f, "projection: {:?}", self.projection)?; - if let Some(sort_info) = &self.sort_information.get(0) { + if let Some(sort_info) = &self.sort_information.first() { write!(f, ", output_ordering: {:?}", sort_info)?; } Ok(()) @@ -177,6 +177,14 @@ impl MemoryExec { }) } + pub fn partitions(&self) -> &[Vec] { + &self.partitions + } + + pub fn projection(&self) -> &Option> { + &self.projection + } + /// A memory table can be ordered by multiple expressions simultaneously. /// [`EquivalenceProperties`] keeps track of expressions that describe the /// global ordering of the schema. These columns are not necessarily same; e.g. @@ -197,6 +205,10 @@ impl MemoryExec { self.sort_information = sort_information; self } + + pub fn original_schema(&self) -> SchemaRef { + self.schema.clone() + } } /// Iterator over batches diff --git a/datafusion/physical-plan/src/ordering.rs b/datafusion/physical-plan/src/ordering.rs new file mode 100644 index 0000000000000..047f89eef1932 --- /dev/null +++ b/datafusion/physical-plan/src/ordering.rs @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Specifies how the input to an aggregation or window operator is ordered +/// relative to their `GROUP BY` or `PARTITION BY` expressions. +/// +/// For example, if the existing ordering is `[a ASC, b ASC, c ASC]` +/// +/// ## Window Functions +/// - A `PARTITION BY b` clause can use `Linear` mode. +/// - A `PARTITION BY a, c` or a `PARTITION BY c, a` can use +/// `PartiallySorted([0])` or `PartiallySorted([1])` modes, respectively. +/// (The vector stores the index of `a` in the respective PARTITION BY expression.) +/// - A `PARTITION BY a, b` or a `PARTITION BY b, a` can use `Sorted` mode. +/// +/// ## Aggregations +/// - A `GROUP BY b` clause can use `Linear` mode. +/// - A `GROUP BY a, c` or a `GROUP BY BY c, a` can use +/// `PartiallySorted([0])` or `PartiallySorted([1])` modes, respectively. +/// (The vector stores the index of `a` in the respective PARTITION BY expression.) +/// - A `GROUP BY a, b` or a `GROUP BY b, a` can use `Sorted` mode. +/// +/// Note these are the same examples as above, but with `GROUP BY` instead of +/// `PARTITION BY` to make the examples easier to read. +#[derive(Debug, Clone, PartialEq)] +pub enum InputOrderMode { + /// There is no partial permutation of the expressions satisfying the + /// existing ordering. + Linear, + /// There is a partial permutation of the expressions satisfying the + /// existing ordering. Indices describing the longest partial permutation + /// are stored in the vector. + PartiallySorted(Vec), + /// There is a (full) permutation of the expressions satisfying the + /// existing ordering. + Sorted, +} diff --git a/datafusion/physical-plan/src/placeholder_row.rs b/datafusion/physical-plan/src/placeholder_row.rs new file mode 100644 index 0000000000000..3ab3de62f37a7 --- /dev/null +++ b/datafusion/physical-plan/src/placeholder_row.rs @@ -0,0 +1,230 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! EmptyRelation produce_one_row=true execution plan + +use std::any::Any; +use std::sync::Arc; + +use super::expressions::PhysicalSortExpr; +use super::{common, DisplayAs, SendableRecordBatchStream, Statistics}; +use crate::{memory::MemoryStream, DisplayFormatType, ExecutionPlan, Partitioning}; + +use arrow::array::{ArrayRef, NullArray}; +use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use arrow_array::RecordBatchOptions; +use datafusion_common::{internal_err, DataFusionError, Result}; +use datafusion_execution::TaskContext; + +use log::trace; + +/// Execution plan for empty relation with produce_one_row=true +#[derive(Debug)] +pub struct PlaceholderRowExec { + /// The schema for the produced row + schema: SchemaRef, + /// Number of partitions + partitions: usize, +} + +impl PlaceholderRowExec { + /// Create a new PlaceholderRowExec + pub fn new(schema: SchemaRef) -> Self { + PlaceholderRowExec { + schema, + partitions: 1, + } + } + + /// Create a new PlaceholderRowExecPlaceholderRowExec with specified partition number + pub fn with_partitions(mut self, partitions: usize) -> Self { + self.partitions = partitions; + self + } + + fn data(&self) -> Result> { + Ok({ + let n_field = self.schema.fields.len(); + vec![RecordBatch::try_new_with_options( + Arc::new(Schema::new( + (0..n_field) + .map(|i| { + Field::new(format!("placeholder_{i}"), DataType::Null, true) + }) + .collect::(), + )), + (0..n_field) + .map(|_i| { + let ret: ArrayRef = Arc::new(NullArray::new(1)); + ret + }) + .collect(), + // Even if column number is empty we can generate single row. + &RecordBatchOptions::new().with_row_count(Some(1)), + )?] + }) + } +} + +impl DisplayAs for PlaceholderRowExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "PlaceholderRowExec") + } + } + } +} + +impl ExecutionPlan for PlaceholderRowExec { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn children(&self) -> Vec> { + vec![] + } + + /// Get the output partitioning of this plan + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(self.partitions) + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + Ok(Arc::new(PlaceholderRowExec::new(self.schema.clone()))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + trace!("Start PlaceholderRowExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); + + if partition >= self.partitions { + return internal_err!( + "PlaceholderRowExec invalid partition {} (expected less than {})", + partition, + self.partitions + ); + } + + Ok(Box::pin(MemoryStream::try_new( + self.data()?, + self.schema.clone(), + None, + )?)) + } + + fn statistics(&self) -> Result { + let batch = self + .data() + .expect("Create single row placeholder RecordBatch should not fail"); + Ok(common::compute_record_batch_statistics( + &[batch], + &self.schema, + None, + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::with_new_children_if_necessary; + use crate::{common, test}; + + #[test] + fn with_new_children() -> Result<()> { + let schema = test::aggr_test_schema(); + + let placeholder = Arc::new(PlaceholderRowExec::new(schema)); + + let placeholder_2 = + with_new_children_if_necessary(placeholder.clone(), vec![])?.into(); + assert_eq!(placeholder.schema(), placeholder_2.schema()); + + let too_many_kids = vec![placeholder_2]; + assert!( + with_new_children_if_necessary(placeholder, too_many_kids).is_err(), + "expected error when providing list of kids" + ); + Ok(()) + } + + #[tokio::test] + async fn invalid_execute() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let schema = test::aggr_test_schema(); + let placeholder = PlaceholderRowExec::new(schema); + + // ask for the wrong partition + assert!(placeholder.execute(1, task_ctx.clone()).is_err()); + assert!(placeholder.execute(20, task_ctx).is_err()); + Ok(()) + } + + #[tokio::test] + async fn produce_one_row() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let schema = test::aggr_test_schema(); + let placeholder = PlaceholderRowExec::new(schema); + + let iter = placeholder.execute(0, task_ctx)?; + let batches = common::collect(iter).await?; + + // should have one item + assert_eq!(batches.len(), 1); + + Ok(()) + } + + #[tokio::test] + async fn produce_one_row_multiple_partition() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let schema = test::aggr_test_schema(); + let partitions = 3; + let placeholder = PlaceholderRowExec::new(schema).with_partitions(partitions); + + for n in 0..partitions { + let iter = placeholder.execute(n, task_ctx.clone())?; + let batches = common::collect(iter).await?; + + // should have one item + assert_eq!(batches.len(), 1); + } + + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index c5d94b08e0e17..cc2ab62049ed5 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -38,15 +38,15 @@ use arrow::record_batch::{RecordBatch, RecordBatchOptions}; use datafusion_common::stats::Precision; use datafusion_common::Result; use datafusion_execution::TaskContext; +use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::expressions::{Literal, UnKnownColumn}; use datafusion_physical_expr::EquivalenceProperties; -use datafusion_physical_expr::equivalence::ProjectionMapping; use futures::stream::{Stream, StreamExt}; use log::trace; /// Execution plan for a projection -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ProjectionExec { /// The projection expressions stored as tuples of (expression, output column name) pub(crate) expr: Vec<(Arc, String)>, @@ -257,16 +257,12 @@ fn get_field_metadata( e: &Arc, input_schema: &Schema, ) -> Option> { - let name = if let Some(column) = e.as_any().downcast_ref::() { - column.name() - } else { - return None; - }; - - input_schema - .field_with_name(name) - .ok() - .map(|f| f.metadata().clone()) + // Look up field by index in schema (not NAME as there can be more than one + // column with the same name) + e.as_any() + .downcast_ref::() + .map(|column| input_schema.field(column.index()).metadata()) + .cloned() } fn stats_projection( @@ -310,8 +306,10 @@ impl ProjectionStream { let arrays = self .expr .iter() - .map(|expr| expr.evaluate(batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .map(|expr| { + expr.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) .collect::>>()?; if arrays.is_empty() { @@ -399,12 +397,8 @@ mod tests { }, ColumnStatistics { distinct_count: Precision::Exact(1), - max_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "x", - )))), - min_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "a", - )))), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), null_count: Precision::Exact(3), }, ColumnStatistics { @@ -441,12 +435,8 @@ mod tests { column_statistics: vec![ ColumnStatistics { distinct_count: Precision::Exact(1), - max_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "x", - )))), - min_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "a", - )))), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), null_count: Precision::Exact(3), }, ColumnStatistics { diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 66f7037e5c2df..07693f747feec 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -15,15 +15,30 @@ // specific language governing permissions and limitations // under the License. -//! The repartition operator maps N input partitions to M output partitions based on a -//! partitioning scheme (according to flag `preserve_order` ordering can be preserved during -//! repartitioning if its input is ordered). +//! This file implements the [`RepartitionExec`] operator, which maps N input +//! partitions to M output partitions based on a partitioning scheme, optionally +//! maintaining the order of the input rows in the output. use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use std::{any::Any, vec}; +use arrow::array::{ArrayRef, UInt64Builder}; +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use futures::stream::Stream; +use futures::{FutureExt, StreamExt}; +use hashbrown::HashMap; +use log::trace; +use parking_lot::Mutex; +use tokio::task::JoinHandle; + +use datafusion_common::{arrow_datafusion_err, not_impl_err, DataFusionError, Result}; +use datafusion_execution::memory_pool::MemoryConsumer; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr}; + use crate::common::transpose; use crate::hash_utils::create_hashes; use crate::metrics::BaselineMetrics; @@ -31,27 +46,12 @@ use crate::repartition::distributor_channels::{channels, partition_aware_channel use crate::sorts::streaming_merge; use crate::{DisplayFormatType, ExecutionPlan, Partitioning, Statistics}; -use self::distributor_channels::{DistributionReceiver, DistributionSender}; - use super::common::{AbortOnDropMany, AbortOnDropSingle, SharedMemoryReservation}; use super::expressions::PhysicalSortExpr; use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; use super::{DisplayAs, RecordBatchStream, SendableRecordBatchStream}; -use arrow::array::{ArrayRef, UInt64Builder}; -use arrow::datatypes::SchemaRef; -use arrow::record_batch::RecordBatch; -use datafusion_common::{not_impl_err, DataFusionError, Result}; -use datafusion_execution::memory_pool::MemoryConsumer; -use datafusion_execution::TaskContext; -use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr}; - -use futures::stream::Stream; -use futures::{FutureExt, StreamExt}; -use hashbrown::HashMap; -use log::trace; -use parking_lot::Mutex; -use tokio::task::JoinHandle; +use self::distributor_channels::{DistributionReceiver, DistributionSender}; mod distributor_channels; @@ -169,9 +169,7 @@ impl BatchPartitioner { let arrays = exprs .iter() - .map(|expr| { - Ok(expr.evaluate(&batch)?.into_array(batch.num_rows())) - }) + .map(|expr| expr.evaluate(&batch)?.into_array(batch.num_rows())) .collect::>>()?; hash_buffer.clear(); @@ -202,7 +200,7 @@ impl BatchPartitioner { .iter() .map(|c| { arrow::compute::take(c.as_ref(), &indices, None) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) }) .collect::>>()?; @@ -281,8 +279,9 @@ impl BatchPartitioner { /// /// # Output Ordering /// -/// No guarantees are made about the order of the resulting -/// partitions unless `preserve_order` is set. +/// If more than one stream is being repartitioned, the output will be some +/// arbitrary interleaving (and thus unordered) unless +/// [`Self::with_preserve_order`] specifies otherwise. /// /// # Footnote /// @@ -371,11 +370,7 @@ impl RepartitionExec { /// Get name used to display this Exec pub fn name(&self) -> &str { - if self.preserve_order { - "SortPreservingRepartitionExec" - } else { - "RepartitionExec" - } + "RepartitionExec" } } @@ -395,6 +390,10 @@ impl DisplayAs for RepartitionExec { self.input.output_partitioning().partition_count() )?; + if self.preserve_order { + write!(f, ", preserve_order=true")?; + } + if let Some(sort_exprs) = self.sort_exprs() { write!( f, @@ -427,9 +426,12 @@ impl ExecutionPlan for RepartitionExec { self: Arc, mut children: Vec>, ) -> Result> { - let repartition = - RepartitionExec::try_new(children.swap_remove(0), self.partitioning.clone()); - repartition.map(|r| Arc::new(r.with_preserve_order(self.preserve_order)) as _) + let mut repartition = + RepartitionExec::try_new(children.swap_remove(0), self.partitioning.clone())?; + if self.preserve_order { + repartition = repartition.with_preserve_order(); + } + Ok(Arc::new(repartition)) } /// Specifies whether this plan generates an infinite stream of records. @@ -470,9 +472,6 @@ impl ExecutionPlan for RepartitionExec { if !self.maintains_input_order()[0] { result.clear_orderings(); } - if self.preserve_order { - result = result.with_reorder(self.sort_exprs().unwrap_or_default().to_vec()) - } result } @@ -627,7 +626,9 @@ impl ExecutionPlan for RepartitionExec { } impl RepartitionExec { - /// Create a new RepartitionExec + /// Create a new RepartitionExec, that produces output `partitioning`, and + /// does not preserve the order of the input (see [`Self::with_preserve_order`] + /// for more details) pub fn try_new( input: Arc, partitioning: Partitioning, @@ -644,16 +645,20 @@ impl RepartitionExec { }) } - /// Set Order preserving flag - pub fn with_preserve_order(mut self, preserve_order: bool) -> Self { - // Set "preserve order" mode only if the input partition count is larger than 1 - // Because in these cases naive `RepartitionExec` cannot maintain ordering. Using - // `SortPreservingRepartitionExec` is necessity. However, when input partition number - // is 1, `RepartitionExec` can maintain ordering. In this case, we don't need to use - // `SortPreservingRepartitionExec` variant to maintain ordering. - if self.input.output_partitioning().partition_count() > 1 { - self.preserve_order = preserve_order - } + /// Specify if this reparititoning operation should preserve the order of + /// rows from its input when producing output. Preserving order is more + /// expensive at runtime, so should only be set if the output of this + /// operator can take advantage of it. + /// + /// If the input is not ordered, or has only one partition, this is a no op, + /// and the node remains a `RepartitionExec`. + pub fn with_preserve_order(mut self) -> Self { + self.preserve_order = + // If the input isn't ordered, there is no ordering to preserve + self.input.output_ordering().is_some() && + // if there is only one input partition, merging is not required + // to maintain order + self.input.output_partitioning().partition_count() > 1; self } @@ -913,7 +918,19 @@ impl RecordBatchStream for PerPartitionStream { #[cfg(test)] mod tests { - use super::*; + use std::collections::HashSet; + + use arrow::array::{ArrayRef, StringArray}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use arrow_array::UInt32Array; + use futures::FutureExt; + use tokio::task::JoinHandle; + + use datafusion_common::cast::as_string_array; + use datafusion_common::{assert_batches_sorted_eq, exec_err}; + use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use crate::{ test::{ assert_is_pending, @@ -924,16 +941,8 @@ mod tests { }, {collect, expressions::col, memory::MemoryExec}, }; - use arrow::array::{ArrayRef, StringArray}; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow::record_batch::RecordBatch; - use arrow_array::UInt32Array; - use datafusion_common::cast::as_string_array; - use datafusion_common::{assert_batches_sorted_eq, exec_err}; - use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; - use futures::FutureExt; - use std::collections::HashSet; - use tokio::task::JoinHandle; + + use super::*; #[tokio::test] async fn one_to_many_round_robin() -> Result<()> { @@ -1405,9 +1414,8 @@ mod tests { // pull partitions for i in 0..exec.partitioning.partition_count() { let mut stream = exec.execute(i, task_ctx.clone())?; - let err = DataFusionError::ArrowError( - stream.next().await.unwrap().unwrap_err().into(), - ); + let err = + arrow_datafusion_err!(stream.next().await.unwrap().unwrap_err().into()); let err = err.find_root(); assert!( matches!(err, DataFusionError::ResourcesExhausted(_)), @@ -1434,3 +1442,129 @@ mod tests { .unwrap() } } + +#[cfg(test)] +mod test { + use arrow_schema::{DataType, Field, Schema, SortOptions}; + + use datafusion_physical_expr::expressions::col; + + use crate::memory::MemoryExec; + use crate::union::UnionExec; + + use super::*; + + /// Asserts that the plan is as expected + /// + /// `$EXPECTED_PLAN_LINES`: input plan + /// `$PLAN`: the plan to optimized + /// + macro_rules! assert_plan { + ($EXPECTED_PLAN_LINES: expr, $PLAN: expr) => { + let physical_plan = $PLAN; + let formatted = crate::displayable(&physical_plan).indent(true).to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + + let expected_plan_lines: Vec<&str> = $EXPECTED_PLAN_LINES + .iter().map(|s| *s).collect(); + + assert_eq!( + expected_plan_lines, actual, + "\n**Original Plan Mismatch\n\nexpected:\n\n{expected_plan_lines:#?}\nactual:\n\n{actual:#?}\n\n" + ); + }; + } + + #[tokio::test] + async fn test_preserve_order() -> Result<()> { + let schema = test_schema(); + let sort_exprs = sort_exprs(&schema); + let source1 = sorted_memory_exec(&schema, sort_exprs.clone()); + let source2 = sorted_memory_exec(&schema, sort_exprs); + // output has multiple partitions, and is sorted + let union = UnionExec::new(vec![source1, source2]); + let exec = + RepartitionExec::try_new(Arc::new(union), Partitioning::RoundRobinBatch(10)) + .unwrap() + .with_preserve_order(); + + // Repartition should preserve order + let expected_plan = [ + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=c0@0 ASC", + " UnionExec", + " MemoryExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC", + " MemoryExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC", + ]; + assert_plan!(expected_plan, exec); + Ok(()) + } + + #[tokio::test] + async fn test_preserve_order_one_partition() -> Result<()> { + let schema = test_schema(); + let sort_exprs = sort_exprs(&schema); + let source = sorted_memory_exec(&schema, sort_exprs); + // output is sorted, but has only a single partition, so no need to sort + let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10)) + .unwrap() + .with_preserve_order(); + + // Repartition should not preserve order + let expected_plan = [ + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " MemoryExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC", + ]; + assert_plan!(expected_plan, exec); + Ok(()) + } + + #[tokio::test] + async fn test_preserve_order_input_not_sorted() -> Result<()> { + let schema = test_schema(); + let source1 = memory_exec(&schema); + let source2 = memory_exec(&schema); + // output has multiple partitions, but is not sorted + let union = UnionExec::new(vec![source1, source2]); + let exec = + RepartitionExec::try_new(Arc::new(union), Partitioning::RoundRobinBatch(10)) + .unwrap() + .with_preserve_order(); + + // Repartition should not preserve order, as there is no order to preserve + let expected_plan = [ + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", + " UnionExec", + " MemoryExec: partitions=1, partition_sizes=[0]", + " MemoryExec: partitions=1, partition_sizes=[0]", + ]; + assert_plan!(expected_plan, exec); + Ok(()) + } + + fn test_schema() -> Arc { + Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)])) + } + + fn sort_exprs(schema: &Schema) -> Vec { + let options = SortOptions::default(); + vec![PhysicalSortExpr { + expr: col("c0", schema).unwrap(), + options, + }] + } + + fn memory_exec(schema: &SchemaRef) -> Arc { + Arc::new(MemoryExec::try_new(&[vec![]], schema.clone(), None).unwrap()) + } + + fn sorted_memory_exec( + schema: &SchemaRef, + sort_exprs: Vec, + ) -> Arc { + Arc::new( + MemoryExec::try_new(&[vec![]], schema.clone(), None) + .unwrap() + .with_sort_information(vec![sort_exprs]), + ) + } +} diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index 65cd8e41480e1..f4b57e8bfb45c 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -174,8 +174,7 @@ impl ExecutionPlan for SortPreservingMergeExec { } fn equivalence_properties(&self) -> EquivalenceProperties { - let output_oeq = self.input.equivalence_properties(); - output_oeq.with_reorder(self.expr.to_vec()) + self.input.equivalence_properties() } fn children(&self) -> Vec> { diff --git a/datafusion/physical-plan/src/sorts/stream.rs b/datafusion/physical-plan/src/sorts/stream.rs index 4cabdc6e178c1..135b4fbdece49 100644 --- a/datafusion/physical-plan/src/sorts/stream.rs +++ b/datafusion/physical-plan/src/sorts/stream.rs @@ -118,7 +118,7 @@ impl RowCursorStream { let cols = self .column_expressions .iter() - .map(|expr| Ok(expr.evaluate(batch)?.into_array(batch.num_rows()))) + .map(|expr| expr.evaluate(batch)?.into_array(batch.num_rows())) .collect::>>()?; let rows = self.converter.convert_columns(&cols)?; @@ -181,7 +181,7 @@ impl FieldCursorStream { fn convert_batch(&mut self, batch: &RecordBatch) -> Result> { let value = self.sort.expr.evaluate(batch)?; - let array = value.into_array(batch.num_rows()); + let array = value.into_array(batch.num_rows())?; let array = array.as_any().downcast_ref::().expect("field values"); Ok(ArrayValues::new(self.sort.options, array)) } diff --git a/datafusion/physical-plan/src/streaming.rs b/datafusion/physical-plan/src/streaming.rs index 1923a5f3abad1..59819c6921fb4 100644 --- a/datafusion/physical-plan/src/streaming.rs +++ b/datafusion/physical-plan/src/streaming.rs @@ -26,6 +26,7 @@ use crate::stream::RecordBatchStreamAdapter; use crate::{ExecutionPlan, Partitioning, SendableRecordBatchStream}; use arrow::datatypes::SchemaRef; +use arrow_schema::Schema; use datafusion_common::{internal_err, plan_err, DataFusionError, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, LexOrdering, PhysicalSortExpr}; @@ -55,7 +56,7 @@ pub struct StreamingTableExec { partitions: Vec>, projection: Option>, projected_schema: SchemaRef, - projected_output_ordering: Option, + projected_output_ordering: Vec, infinite: bool, } @@ -65,14 +66,14 @@ impl StreamingTableExec { schema: SchemaRef, partitions: Vec>, projection: Option<&Vec>, - projected_output_ordering: Option, + projected_output_ordering: impl IntoIterator, infinite: bool, ) -> Result { for x in partitions.iter() { let partition_schema = x.schema(); - if !schema.contains(partition_schema) { + if !schema.eq(partition_schema) { debug!( - "target schema does not contain partition schema. \ + "Target schema does not match with partition schema. \ Target_schema: {schema:?}. Partiton Schema: {partition_schema:?}" ); return plan_err!("Mismatch between schema and batches"); @@ -88,10 +89,34 @@ impl StreamingTableExec { partitions, projected_schema, projection: projection.cloned().map(Into::into), - projected_output_ordering, + projected_output_ordering: projected_output_ordering.into_iter().collect(), infinite, }) } + + pub fn partitions(&self) -> &Vec> { + &self.partitions + } + + pub fn partition_schema(&self) -> &SchemaRef { + self.partitions[0].schema() + } + + pub fn projection(&self) -> &Option> { + &self.projection + } + + pub fn projected_schema(&self) -> &Schema { + &self.projected_schema + } + + pub fn projected_output_ordering(&self) -> impl IntoIterator { + self.projected_output_ordering.clone() + } + + pub fn is_infinite(&self) -> bool { + self.infinite + } } impl std::fmt::Debug for StreamingTableExec { @@ -125,7 +150,7 @@ impl DisplayAs for StreamingTableExec { } self.projected_output_ordering - .as_deref() + .first() .map_or(Ok(()), |ordering| { if !ordering.is_empty() { write!( @@ -160,15 +185,16 @@ impl ExecutionPlan for StreamingTableExec { } fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - self.projected_output_ordering.as_deref() + self.projected_output_ordering + .first() + .map(|ordering| ordering.as_slice()) } fn equivalence_properties(&self) -> EquivalenceProperties { - let mut result = EquivalenceProperties::new(self.schema()); - if let Some(ordering) = &self.projected_output_ordering { - result.add_new_orderings([ordering.clone()]) - } - result + EquivalenceProperties::new_with_orderings( + self.schema(), + &self.projected_output_ordering, + ) } fn children(&self) -> Vec> { diff --git a/datafusion/physical-plan/src/test/exec.rs b/datafusion/physical-plan/src/test/exec.rs index fcc0cf6b7af88..5a8ef2db77c28 100644 --- a/datafusion/physical-plan/src/test/exec.rs +++ b/datafusion/physical-plan/src/test/exec.rs @@ -790,7 +790,7 @@ impl Stream for PanicStream { } else { self.ready = true; // get called again - cx.waker().clone().wake(); + cx.waker().wake_by_ref(); return Poll::Pending; } } diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 4638c0dcf2646..9120566273d35 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -153,7 +153,7 @@ impl TopK { .iter() .map(|expr| { let value = expr.expr.evaluate(&batch)?; - Ok(value.into_array(batch.num_rows())) + value.into_array(batch.num_rows()) }) .collect::>>()?; diff --git a/datafusion/physical-plan/src/udaf.rs b/datafusion/physical-plan/src/udaf.rs index 7cc3cc7d59fed..94017efe97aa1 100644 --- a/datafusion/physical-plan/src/udaf.rs +++ b/datafusion/physical-plan/src/udaf.rs @@ -50,7 +50,7 @@ pub fn create_aggregate_expr( Ok(Arc::new(AggregateFunctionExpr { fun: fun.clone(), args: input_phy_exprs.to_vec(), - data_type: (fun.return_type)(&input_exprs_types)?.as_ref().clone(), + data_type: fun.return_type(&input_exprs_types)?, name: name.into(), })) } @@ -83,7 +83,9 @@ impl AggregateExpr for AggregateFunctionExpr { } fn state_fields(&self) -> Result> { - let fields = (self.fun.state_type)(&self.data_type)? + let fields = self + .fun + .state_type(&self.data_type)? .iter() .enumerate() .map(|(i, data_type)| { @@ -103,11 +105,11 @@ impl AggregateExpr for AggregateFunctionExpr { } fn create_accumulator(&self) -> Result> { - (self.fun.accumulator)(&self.data_type) + self.fun.accumulator(&self.data_type) } fn create_sliding_accumulator(&self) -> Result> { - let accumulator = (self.fun.accumulator)(&self.data_type)?; + let accumulator = self.fun.accumulator(&self.data_type)?; // Accumulators that have window frame startings different // than `UNBOUNDED PRECEDING`, such as `1 PRECEEDING`, need to diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 9700605ce406f..d01ea55074498 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -21,6 +21,7 @@ //! The Union operator combines multiple inputs with the same schema +use std::borrow::Borrow; use std::pin::Pin; use std::task::{Context, Poll}; use std::{any::Any, sync::Arc}; @@ -38,7 +39,7 @@ use crate::stream::ObservedStream; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::stats::Precision; -use datafusion_common::{exec_err, internal_err, DFSchemaRef, DataFusionError, Result}; +use datafusion_common::{exec_err, internal_err, DataFusionError, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::EquivalenceProperties; @@ -95,38 +96,6 @@ pub struct UnionExec { } impl UnionExec { - /// Create a new UnionExec with specified schema. - /// The `schema` should always be a subset of the schema of `inputs`, - /// otherwise, an error will be returned. - pub fn try_new_with_schema( - inputs: Vec>, - schema: DFSchemaRef, - ) -> Result { - let mut exec = Self::new(inputs); - let exec_schema = exec.schema(); - let fields = schema - .fields() - .iter() - .map(|dff| { - exec_schema - .field_with_name(dff.name()) - .cloned() - .map_err(|_| { - DataFusionError::Internal(format!( - "Cannot find the field {:?} in child schema", - dff.name() - )) - }) - }) - .collect::>>()?; - let schema = Arc::new(Schema::new_with_metadata( - fields, - exec.schema().metadata().clone(), - )); - exec.schema = schema; - Ok(exec) - } - /// Create a new UnionExec pub fn new(inputs: Vec>) -> Self { let schema = union_schema(&inputs); @@ -368,7 +337,7 @@ impl InterleaveExec { pub fn try_new(inputs: Vec>) -> Result { let schema = union_schema(&inputs); - if !can_interleave(&inputs) { + if !can_interleave(inputs.iter()) { return internal_err!( "Not all InterleaveExec children have a consistent hash partitioning" ); @@ -506,17 +475,18 @@ impl ExecutionPlan for InterleaveExec { /// It might be too strict here in the case that the input partition specs are compatible but not exactly the same. /// For example one input partition has the partition spec Hash('a','b','c') and /// other has the partition spec Hash('a'), It is safe to derive the out partition with the spec Hash('a','b','c'). -pub fn can_interleave(inputs: &[Arc]) -> bool { - if inputs.is_empty() { +pub fn can_interleave>>( + mut inputs: impl Iterator, +) -> bool { + let Some(first) = inputs.next() else { return false; - } + }; - let first_input_partition = inputs[0].output_partitioning(); - matches!(first_input_partition, Partitioning::Hash(_, _)) + let reference = first.borrow().output_partitioning(); + matches!(reference, Partitioning::Hash(_, _)) && inputs - .iter() - .map(|plan| plan.output_partitioning()) - .all(|partition| partition == first_input_partition) + .map(|plan| plan.borrow().output_partitioning()) + .all(|partition| partition == reference) } fn union_schema(inputs: &[Arc]) -> SchemaRef { @@ -706,12 +676,8 @@ mod tests { }, ColumnStatistics { distinct_count: Precision::Exact(1), - max_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "x", - )))), - min_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "a", - )))), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), null_count: Precision::Exact(3), }, ColumnStatistics { @@ -735,12 +701,8 @@ mod tests { }, ColumnStatistics { distinct_count: Precision::Absent, - max_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "c", - )))), - min_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "b", - )))), + max_value: Precision::Exact(ScalarValue::from("c")), + min_value: Precision::Exact(ScalarValue::from("b")), null_count: Precision::Absent, }, ColumnStatistics { @@ -765,12 +727,8 @@ mod tests { }, ColumnStatistics { distinct_count: Precision::Absent, - max_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "x", - )))), - min_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "a", - )))), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), null_count: Precision::Absent, }, ColumnStatistics { diff --git a/datafusion/physical-plan/src/unnest.rs b/datafusion/physical-plan/src/unnest.rs index c9f3fb76c2e54..b9e732c317af4 100644 --- a/datafusion/physical-plan/src/unnest.rs +++ b/datafusion/physical-plan/src/unnest.rs @@ -17,8 +17,6 @@ //! Defines the unnest column plan for unnesting values in a column that contains a list //! type, conceptually is like joining each row with all the values in the list column. - -use std::time::Instant; use std::{any::Any, sync::Arc}; use super::DisplayAs; @@ -44,6 +42,8 @@ use async_trait::async_trait; use futures::{Stream, StreamExt}; use log::trace; +use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; + /// Unnest the given column by joining the row with each value in the /// nested type. /// @@ -58,6 +58,8 @@ pub struct UnnestExec { column: Column, /// Options options: UnnestOptions, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, } impl UnnestExec { @@ -73,6 +75,7 @@ impl UnnestExec { schema, column, options, + metrics: Default::default(), } } } @@ -141,19 +144,58 @@ impl ExecutionPlan for UnnestExec { context: Arc, ) -> Result { let input = self.input.execute(partition, context)?; + let metrics = UnnestMetrics::new(partition, &self.metrics); Ok(Box::pin(UnnestStream { input, schema: self.schema.clone(), column: self.column.clone(), options: self.options.clone(), - num_input_batches: 0, - num_input_rows: 0, - num_output_batches: 0, - num_output_rows: 0, - unnest_time: 0, + metrics, })) } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } +} + +#[derive(Clone, Debug)] +struct UnnestMetrics { + /// total time for column unnesting + elapsed_compute: metrics::Time, + /// Number of batches consumed + input_batches: metrics::Count, + /// Number of rows consumed + input_rows: metrics::Count, + /// Number of batches produced + output_batches: metrics::Count, + /// Number of rows produced by this operator + output_rows: metrics::Count, +} + +impl UnnestMetrics { + fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { + let elapsed_compute = MetricBuilder::new(metrics).elapsed_compute(partition); + + let input_batches = + MetricBuilder::new(metrics).counter("input_batches", partition); + + let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + + let output_batches = + MetricBuilder::new(metrics).counter("output_batches", partition); + + let output_rows = MetricBuilder::new(metrics).output_rows(partition); + + Self { + input_batches, + input_rows, + output_batches, + output_rows, + elapsed_compute, + } + } } /// A stream that issues [RecordBatch]es with unnested column data. @@ -166,16 +208,8 @@ struct UnnestStream { column: Column, /// Options options: UnnestOptions, - /// number of input batches - num_input_batches: usize, - /// number of input rows - num_input_rows: usize, - /// number of batches produced - num_output_batches: usize, - /// number of rows produced - num_output_rows: usize, - /// total time for column unnesting, in ms - unnest_time: usize, + /// Metrics + metrics: UnnestMetrics, } impl RecordBatchStream for UnnestStream { @@ -207,15 +241,15 @@ impl UnnestStream { .poll_next_unpin(cx) .map(|maybe_batch| match maybe_batch { Some(Ok(batch)) => { - let start = Instant::now(); + let timer = self.metrics.elapsed_compute.timer(); let result = build_batch(&batch, &self.schema, &self.column, &self.options); - self.num_input_batches += 1; - self.num_input_rows += batch.num_rows(); + self.metrics.input_batches.add(1); + self.metrics.input_rows.add(batch.num_rows()); if let Ok(ref batch) = result { - self.unnest_time += start.elapsed().as_millis() as usize; - self.num_output_batches += 1; - self.num_output_rows += batch.num_rows(); + timer.done(); + self.metrics.output_batches.add(1); + self.metrics.output_rows.add(batch.num_rows()); } Some(result) @@ -223,12 +257,12 @@ impl UnnestStream { other => { trace!( "Processed {} probe-side input batches containing {} rows and \ - produced {} output batches containing {} rows in {} ms", - self.num_input_batches, - self.num_input_rows, - self.num_output_batches, - self.num_output_rows, - self.unnest_time, + produced {} output batches containing {} rows in {}", + self.metrics.input_batches, + self.metrics.input_rows, + self.metrics.output_batches, + self.metrics.output_rows, + self.metrics.elapsed_compute, ); other } @@ -242,7 +276,7 @@ fn build_batch( column: &Column, options: &UnnestOptions, ) -> Result { - let list_array = column.evaluate(batch)?.into_array(batch.num_rows()); + let list_array = column.evaluate(batch)?.into_array(batch.num_rows())?; match list_array.data_type() { DataType::List(_) => { let list_array = list_array.as_any().downcast_ref::().unwrap(); diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index fb679b013863f..0871ec0d7ff3a 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -31,15 +31,16 @@ use crate::expressions::PhysicalSortExpr; use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use crate::windows::{ calc_requirements, get_ordered_partition_by_indices, get_partition_by_sort_exprs, - window_equivalence_properties, PartitionSearchMode, + window_equivalence_properties, }; use crate::{ ColumnStatistics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, - Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, WindowExpr, + InputOrderMode, Partitioning, RecordBatchStream, SendableRecordBatchStream, + Statistics, WindowExpr, }; use arrow::{ - array::{Array, ArrayRef, UInt32Builder}, + array::{Array, ArrayRef, RecordBatchOptions, UInt32Builder}, compute::{concat, concat_batches, sort_to_indices}, datatypes::{Schema, SchemaBuilder, SchemaRef}, record_batch::RecordBatch, @@ -50,7 +51,7 @@ use datafusion_common::utils::{ evaluate_partition_ranges, get_arrayref_at_indices, get_at_indices, get_record_batch_at_indices, get_row_at_idx, }; -use datafusion_common::{exec_err, plan_err, DataFusionError, Result}; +use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result}; use datafusion_execution::TaskContext; use datafusion_expr::window_state::{PartitionBatchState, WindowAggState}; use datafusion_expr::ColumnarValue; @@ -81,8 +82,8 @@ pub struct BoundedWindowAggExec { pub partition_keys: Vec>, /// Execution metrics metrics: ExecutionPlanMetricsSet, - /// Partition by search mode - pub partition_search_mode: PartitionSearchMode, + /// Describes how the input is ordered relative to the partition keys + pub input_order_mode: InputOrderMode, /// Partition by indices that define ordering // For example, if input ordering is ORDER BY a, b and window expression // contains PARTITION BY b, a; `ordered_partition_by_indices` would be 1, 0. @@ -98,13 +99,13 @@ impl BoundedWindowAggExec { window_expr: Vec>, input: Arc, partition_keys: Vec>, - partition_search_mode: PartitionSearchMode, + input_order_mode: InputOrderMode, ) -> Result { let schema = create_schema(&input.schema(), &window_expr)?; let schema = Arc::new(schema); let partition_by_exprs = window_expr[0].partition_by(); - let ordered_partition_by_indices = match &partition_search_mode { - PartitionSearchMode::Sorted => { + let ordered_partition_by_indices = match &input_order_mode { + InputOrderMode::Sorted => { let indices = get_ordered_partition_by_indices( window_expr[0].partition_by(), &input, @@ -115,10 +116,8 @@ impl BoundedWindowAggExec { (0..partition_by_exprs.len()).collect::>() } } - PartitionSearchMode::PartiallySorted(ordered_indices) => { - ordered_indices.clone() - } - PartitionSearchMode::Linear => { + InputOrderMode::PartiallySorted(ordered_indices) => ordered_indices.clone(), + InputOrderMode::Linear => { vec![] } }; @@ -128,7 +127,7 @@ impl BoundedWindowAggExec { schema, partition_keys, metrics: ExecutionPlanMetricsSet::new(), - partition_search_mode, + input_order_mode, ordered_partition_by_indices, }) } @@ -162,8 +161,8 @@ impl BoundedWindowAggExec { fn get_search_algo(&self) -> Result> { let partition_by_sort_keys = self.partition_by_sort_keys()?; let ordered_partition_by_indices = self.ordered_partition_by_indices.clone(); - Ok(match &self.partition_search_mode { - PartitionSearchMode::Sorted => { + Ok(match &self.input_order_mode { + InputOrderMode::Sorted => { // In Sorted mode, all partition by columns should be ordered. if self.window_expr()[0].partition_by().len() != ordered_partition_by_indices.len() @@ -175,7 +174,7 @@ impl BoundedWindowAggExec { ordered_partition_by_indices, }) } - PartitionSearchMode::Linear | PartitionSearchMode::PartiallySorted(_) => { + InputOrderMode::Linear | InputOrderMode::PartiallySorted(_) => { Box::new(LinearSearch::new(ordered_partition_by_indices)) } }) @@ -203,7 +202,7 @@ impl DisplayAs for BoundedWindowAggExec { ) }) .collect(); - let mode = &self.partition_search_mode; + let mode = &self.input_order_mode; write!(f, "wdw=[{}], mode=[{:?}]", g.join(", "), mode)?; } } @@ -244,7 +243,7 @@ impl ExecutionPlan for BoundedWindowAggExec { fn required_input_ordering(&self) -> Vec>> { let partition_bys = self.window_expr()[0].partition_by(); let order_keys = self.window_expr()[0].order_by(); - if self.partition_search_mode != PartitionSearchMode::Sorted + if self.input_order_mode != InputOrderMode::Sorted || self.ordered_partition_by_indices.len() >= partition_bys.len() { let partition_bys = self @@ -283,7 +282,7 @@ impl ExecutionPlan for BoundedWindowAggExec { self.window_expr.clone(), children[0].clone(), self.partition_keys.clone(), - self.partition_search_mode.clone(), + self.input_order_mode.clone(), )?)) } @@ -500,7 +499,7 @@ impl PartitionSearcher for LinearSearch { .iter() .map(|items| { concat(&items.iter().map(|e| e.as_ref()).collect::>()) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) }) .collect::>>()?; // We should emit columns according to row index ordering. @@ -586,7 +585,7 @@ impl LinearSearch { .map(|item| match item.evaluate(record_batch)? { ColumnarValue::Array(array) => Ok(array), ColumnarValue::Scalar(scalar) => { - plan_err!("Sort operation is not applicable to scalar value {scalar}") + scalar.to_array_of_size(record_batch.num_rows()) } }) .collect() @@ -1027,8 +1026,11 @@ impl BoundedWindowAggStream { .iter() .map(|elem| elem.slice(n_out, n_to_keep)) .collect::>(); - self.input_buffer = - RecordBatch::try_new(self.input_buffer.schema(), batch_to_keep)?; + self.input_buffer = RecordBatch::try_new_with_options( + self.input_buffer.schema(), + batch_to_keep, + &RecordBatchOptions::new().with_row_count(Some(n_to_keep)), + )?; Ok(()) } @@ -1109,3 +1111,131 @@ fn get_aggregate_result_out_column( result .ok_or_else(|| DataFusionError::Execution("Should contain something".to_string())) } + +#[cfg(test)] +mod tests { + use crate::common::collect; + use crate::memory::MemoryExec; + use crate::windows::{BoundedWindowAggExec, InputOrderMode}; + use crate::{get_plan_string, ExecutionPlan}; + use arrow_array::RecordBatch; + use arrow_schema::{DataType, Field, Schema}; + use datafusion_common::{assert_batches_eq, Result, ScalarValue}; + use datafusion_execution::config::SessionConfig; + use datafusion_execution::TaskContext; + use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; + use datafusion_physical_expr::expressions::col; + use datafusion_physical_expr::expressions::NthValue; + use datafusion_physical_expr::window::BuiltInWindowExpr; + use datafusion_physical_expr::window::BuiltInWindowFunctionExpr; + use std::sync::Arc; + + // Tests NTH_VALUE(negative index) with memoize feature. + // To be able to trigger memoize feature for NTH_VALUE we need to + // - feed BoundedWindowAggExec with batch stream data. + // - Window frame should contain UNBOUNDED PRECEDING. + // It hard to ensure these conditions are met, from the sql query. + #[tokio::test] + async fn test_window_nth_value_bounded_memoize() -> Result<()> { + let config = SessionConfig::new().with_target_partitions(1); + let task_ctx = Arc::new(TaskContext::default().with_session_config(config)); + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + // Create a new batch of data to insert into the table + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(arrow_array::Int32Array::from(vec![1, 2, 3]))], + )?; + + let memory_exec = MemoryExec::try_new( + &[vec![batch.clone(), batch.clone(), batch.clone()]], + schema.clone(), + None, + ) + .map(|e| Arc::new(e) as Arc)?; + let col_a = col("a", &schema)?; + let nth_value_func1 = + NthValue::nth("nth_value(-1)", col_a.clone(), DataType::Int32, 1)? + .reverse_expr() + .unwrap(); + let nth_value_func2 = + NthValue::nth("nth_value(-2)", col_a.clone(), DataType::Int32, 2)? + .reverse_expr() + .unwrap(); + let last_value_func = + Arc::new(NthValue::last("last", col_a.clone(), DataType::Int32)) as _; + let window_exprs = vec![ + // LAST_VALUE(a) + Arc::new(BuiltInWindowExpr::new( + last_value_func, + &[], + &[], + Arc::new(WindowFrame { + units: WindowFrameUnits::Rows, + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + end_bound: WindowFrameBound::CurrentRow, + }), + )) as _, + // NTH_VALUE(a, -1) + Arc::new(BuiltInWindowExpr::new( + nth_value_func1, + &[], + &[], + Arc::new(WindowFrame { + units: WindowFrameUnits::Rows, + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + end_bound: WindowFrameBound::CurrentRow, + }), + )) as _, + // NTH_VALUE(a, -2) + Arc::new(BuiltInWindowExpr::new( + nth_value_func2, + &[], + &[], + Arc::new(WindowFrame { + units: WindowFrameUnits::Rows, + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + end_bound: WindowFrameBound::CurrentRow, + }), + )) as _, + ]; + let physical_plan = BoundedWindowAggExec::try_new( + window_exprs, + memory_exec, + vec![], + InputOrderMode::Sorted, + ) + .map(|e| Arc::new(e) as Arc)?; + + let batches = collect(physical_plan.execute(0, task_ctx)?).await?; + + let expected = vec![ + "BoundedWindowAggExec: wdw=[last: Ok(Field { name: \"last\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }, nth_value(-1): Ok(Field { name: \"nth_value(-1)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }, nth_value(-2): Ok(Field { name: \"nth_value(-2)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted]", + " MemoryExec: partitions=1, partition_sizes=[3]", + ]; + // Get string representation of the plan + let actual = get_plan_string(&physical_plan); + assert_eq!( + expected, actual, + "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let expected = [ + "+---+------+---------------+---------------+", + "| a | last | nth_value(-1) | nth_value(-2) |", + "+---+------+---------------+---------------+", + "| 1 | 1 | 1 | |", + "| 2 | 2 | 2 | 1 |", + "| 3 | 3 | 3 | 2 |", + "| 1 | 1 | 1 | 3 |", + "| 2 | 2 | 2 | 1 |", + "| 3 | 3 | 3 | 2 |", + "| 1 | 1 | 1 | 3 |", + "| 2 | 2 | 2 | 1 |", + "| 3 | 3 | 3 | 2 |", + "+---+------+---------------+---------------+", + ]; + assert_batches_eq!(expected, &batches); + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index b6ed6e482ff50..fec168fabf48b 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -27,15 +27,15 @@ use crate::{ cume_dist, dense_rank, lag, lead, percent_rank, rank, Literal, NthValue, Ntile, PhysicalSortExpr, RowNumber, }, - udaf, unbounded_output, ExecutionPlan, PhysicalExpr, + udaf, unbounded_output, ExecutionPlan, InputOrderMode, PhysicalExpr, }; use arrow::datatypes::Schema; use arrow_schema::{DataType, Field, SchemaRef}; use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::{ - window_function::{BuiltInWindowFunction, WindowFunction}, - PartitionEvaluator, WindowFrame, WindowUDF, + BuiltInWindowFunction, PartitionEvaluator, WindowFrame, WindowFunctionDefinition, + WindowUDF, }; use datafusion_physical_expr::equivalence::collapse_lex_req; use datafusion_physical_expr::{ @@ -54,33 +54,9 @@ pub use datafusion_physical_expr::window::{ BuiltInWindowExpr, PlainAggregateWindowExpr, WindowExpr, }; -#[derive(Debug, Clone, PartialEq)] -/// Specifies aggregation grouping and/or window partitioning properties of a -/// set of expressions in terms of the existing ordering. -/// For example, if the existing ordering is `[a ASC, b ASC, c ASC]`: -/// - A `PARTITION BY b` clause will result in `Linear` mode. -/// - A `PARTITION BY a, c` or a `PARTITION BY c, a` clause will result in -/// `PartiallySorted([0])` or `PartiallySorted([1])` modes, respectively. -/// The vector stores the index of `a` in the respective PARTITION BY expression. -/// - A `PARTITION BY a, b` or a `PARTITION BY b, a` clause will result in -/// `Sorted` mode. -/// Note that the examples above are applicable for `GROUP BY` clauses too. -pub enum PartitionSearchMode { - /// There is no partial permutation of the expressions satisfying the - /// existing ordering. - Linear, - /// There is a partial permutation of the expressions satisfying the - /// existing ordering. Indices describing the longest partial permutation - /// are stored in the vector. - PartiallySorted(Vec), - /// There is a (full) permutation of the expressions satisfying the - /// existing ordering. - Sorted, -} - /// Create a physical expression for window function pub fn create_window_expr( - fun: &WindowFunction, + fun: &WindowFunctionDefinition, name: String, args: &[Arc], partition_by: &[Arc], @@ -89,7 +65,7 @@ pub fn create_window_expr( input_schema: &Schema, ) -> Result> { Ok(match fun { - WindowFunction::AggregateFunction(fun) => { + WindowFunctionDefinition::AggregateFunction(fun) => { let aggregate = aggregates::create_aggregate_expr( fun, false, @@ -105,13 +81,15 @@ pub fn create_window_expr( aggregate, ) } - WindowFunction::BuiltInWindowFunction(fun) => Arc::new(BuiltInWindowExpr::new( - create_built_in_window_expr(fun, args, input_schema, name)?, - partition_by, - order_by, - window_frame, - )), - WindowFunction::AggregateUDF(fun) => { + WindowFunctionDefinition::BuiltInWindowFunction(fun) => { + Arc::new(BuiltInWindowExpr::new( + create_built_in_window_expr(fun, args, input_schema, name)?, + partition_by, + order_by, + window_frame, + )) + } + WindowFunctionDefinition::AggregateUDF(fun) => { let aggregate = udaf::create_aggregate_expr(fun.as_ref(), args, input_schema, name)?; window_expr_from_aggregate_expr( @@ -121,7 +99,7 @@ pub fn create_window_expr( aggregate, ) } - WindowFunction::WindowUDF(fun) => Arc::new(BuiltInWindowExpr::new( + WindowFunctionDefinition::WindowUDF(fun) => Arc::new(BuiltInWindowExpr::new( create_udwf_window_expr(fun, args, input_schema, name)?, partition_by, order_by, @@ -189,15 +167,26 @@ fn create_built_in_window_expr( BuiltInWindowFunction::PercentRank => Arc::new(percent_rank(name)), BuiltInWindowFunction::CumeDist => Arc::new(cume_dist(name)), BuiltInWindowFunction::Ntile => { - let n: i64 = get_scalar_value_from_args(args, 0)? - .ok_or_else(|| { - DataFusionError::Execution( - "NTILE requires at least 1 argument".to_string(), - ) - })? - .try_into()?; - let n: u64 = n as u64; - Arc::new(Ntile::new(name, n)) + let n = get_scalar_value_from_args(args, 0)?.ok_or_else(|| { + DataFusionError::Execution( + "NTILE requires a positive integer".to_string(), + ) + })?; + + if n.is_null() { + return exec_err!("NTILE requires a positive integer, but finds NULL"); + } + + if n.is_unsigned() { + let n: u64 = n.try_into()?; + Arc::new(Ntile::new(name, n)) + } else { + let n: i64 = n.try_into()?; + if n <= 0 { + return exec_err!("NTILE requires a positive integer"); + } + Arc::new(Ntile::new(name, n as u64)) + } } BuiltInWindowFunction::Lag => { let arg = args[0].clone(); @@ -255,7 +244,7 @@ fn create_udwf_window_expr( .collect::>()?; // figure out the output type - let data_type = (fun.return_type)(&input_types)?; + let data_type = fun.return_type(&input_types)?; Ok(Arc::new(WindowUDFExpr { fun: Arc::clone(fun), args: args.to_vec(), @@ -272,7 +261,7 @@ struct WindowUDFExpr { /// Display name name: String, /// result type - data_type: Arc, + data_type: DataType, } impl BuiltInWindowFunctionExpr for WindowUDFExpr { @@ -282,11 +271,7 @@ impl BuiltInWindowFunctionExpr for WindowUDFExpr { fn field(&self) -> Result { let nullable = true; - Ok(Field::new( - &self.name, - self.data_type.as_ref().clone(), - nullable, - )) + Ok(Field::new(&self.name, self.data_type.clone(), nullable)) } fn expressions(&self) -> Vec> { @@ -294,7 +279,7 @@ impl BuiltInWindowFunctionExpr for WindowUDFExpr { } fn create_evaluator(&self) -> Result> { - (self.fun.partition_evaluator_factory)() + self.fun.partition_evaluator_factory() } fn name(&self) -> &str { @@ -407,17 +392,17 @@ pub fn get_best_fitting_window( // of the window_exprs are same. let partitionby_exprs = window_exprs[0].partition_by(); let orderby_keys = window_exprs[0].order_by(); - let (should_reverse, partition_search_mode) = - if let Some((should_reverse, partition_search_mode)) = - get_window_mode(partitionby_exprs, orderby_keys, input)? + let (should_reverse, input_order_mode) = + if let Some((should_reverse, input_order_mode)) = + get_window_mode(partitionby_exprs, orderby_keys, input) { - (should_reverse, partition_search_mode) + (should_reverse, input_order_mode) } else { return Ok(None); }; let is_unbounded = unbounded_output(input); - if !is_unbounded && partition_search_mode != PartitionSearchMode::Sorted { - // Executor has bounded input and `partition_search_mode` is not `PartitionSearchMode::Sorted` + if !is_unbounded && input_order_mode != InputOrderMode::Sorted { + // Executor has bounded input and `input_order_mode` is not `InputOrderMode::Sorted` // in this case removing the sort is not helpful, return: return Ok(None); }; @@ -445,13 +430,13 @@ pub fn get_best_fitting_window( window_expr, input.clone(), physical_partition_keys.to_vec(), - partition_search_mode, + input_order_mode, )?) as _)) - } else if partition_search_mode != PartitionSearchMode::Sorted { + } else if input_order_mode != InputOrderMode::Sorted { // For `WindowAggExec` to work correctly PARTITION BY columns should be sorted. - // Hence, if `partition_search_mode` is not `PartitionSearchMode::Sorted` we should convert - // input ordering such that it can work with PartitionSearchMode::Sorted (add `SortExec`). - // Effectively `WindowAggExec` works only in PartitionSearchMode::Sorted mode. + // Hence, if `input_order_mode` is not `Sorted` we should convert + // input ordering such that it can work with `Sorted` (add `SortExec`). + // Effectively `WindowAggExec` works only in `Sorted` mode. Ok(None) } else { Ok(Some(Arc::new(WindowAggExec::try_new( @@ -467,16 +452,16 @@ pub fn get_best_fitting_window( /// is sufficient to run the current window operator. /// - A `None` return value indicates that we can not remove the sort in question /// (input ordering is not sufficient to run current window executor). -/// - A `Some((bool, PartitionSearchMode))` value indicates that the window operator +/// - A `Some((bool, InputOrderMode))` value indicates that the window operator /// can run with existing input ordering, so we can remove `SortExec` before it. /// The `bool` field in the return value represents whether we should reverse window -/// operator to remove `SortExec` before it. The `PartitionSearchMode` field represents -/// the mode this window operator should work in to accomodate the existing ordering. +/// operator to remove `SortExec` before it. The `InputOrderMode` field represents +/// the mode this window operator should work in to accommodate the existing ordering. pub fn get_window_mode( partitionby_exprs: &[Arc], orderby_keys: &[PhysicalSortExpr], input: &Arc, -) -> Result> { +) -> Option<(bool, InputOrderMode)> { let input_eqs = input.equivalence_properties(); let mut partition_by_reqs: Vec = vec![]; let (_, indices) = input_eqs.find_longest_permutation(partitionby_exprs); @@ -497,16 +482,16 @@ pub fn get_window_mode( if partition_by_eqs.ordering_satisfy_requirement(&req) { // Window can be run with existing ordering let mode = if indices.len() == partitionby_exprs.len() { - PartitionSearchMode::Sorted + InputOrderMode::Sorted } else if indices.is_empty() { - PartitionSearchMode::Linear + InputOrderMode::Linear } else { - PartitionSearchMode::PartiallySorted(indices) + InputOrderMode::PartiallySorted(indices) }; - return Ok(Some((should_swap, mode))); + return Some((should_swap, mode)); } } - Ok(None) + None } #[cfg(test)] @@ -525,7 +510,7 @@ mod tests { use futures::FutureExt; - use PartitionSearchMode::{Linear, PartiallySorted, Sorted}; + use InputOrderMode::{Linear, PartiallySorted, Sorted}; fn create_test_schema() -> Result { let nullable_column = Field::new("nullable_col", DataType::Int32, true); @@ -664,7 +649,7 @@ mod tests { let refs = blocking_exec.refs(); let window_agg_exec = Arc::new(WindowAggExec::try_new( vec![create_window_expr( - &WindowFunction::AggregateFunction(AggregateFunction::Count), + &WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), "count".to_owned(), &[col("a", &schema)?], &[], @@ -785,11 +770,11 @@ mod tests { // Second field in the tuple is Vec where each element in the vector represents ORDER BY columns // For instance, vec!["c"], corresponds to ORDER BY c ASC NULLS FIRST, (ordering is default ordering. We do not check // for reversibility in this test). - // Third field in the tuple is Option, which corresponds to expected algorithm mode. + // Third field in the tuple is Option, which corresponds to expected algorithm mode. // None represents that existing ordering is not sufficient to run executor with any one of the algorithms // (We need to add SortExec to be able to run it). - // Some(PartitionSearchMode) represents, we can run algorithm with existing ordering; and algorithm should work in - // PartitionSearchMode. + // Some(InputOrderMode) represents, we can run algorithm with existing ordering; and algorithm should work in + // InputOrderMode. let test_cases = vec![ (vec!["a"], vec!["a"], Some(Sorted)), (vec!["a"], vec!["b"], Some(Sorted)), @@ -873,8 +858,8 @@ mod tests { order_by_exprs.push(PhysicalSortExpr { expr, options }); } let res = - get_window_mode(&partition_by_exprs, &order_by_exprs, &exec_unbounded)?; - // Since reversibility is not important in this test. Convert Option<(bool, PartitionSearchMode)> to Option + get_window_mode(&partition_by_exprs, &order_by_exprs, &exec_unbounded); + // Since reversibility is not important in this test. Convert Option<(bool, InputOrderMode)> to Option let res = res.map(|(_, mode)| mode); assert_eq!( res, *expected, @@ -905,12 +890,12 @@ mod tests { // Second field in the tuple is Vec<(str, bool, bool)> where each element in the vector represents ORDER BY columns // For instance, vec![("c", false, false)], corresponds to ORDER BY c ASC NULLS LAST, // similarly, vec![("c", true, true)], corresponds to ORDER BY c DESC NULLS FIRST, - // Third field in the tuple is Option<(bool, PartitionSearchMode)>, which corresponds to expected result. + // Third field in the tuple is Option<(bool, InputOrderMode)>, which corresponds to expected result. // None represents that existing ordering is not sufficient to run executor with any one of the algorithms // (We need to add SortExec to be able to run it). - // Some((bool, PartitionSearchMode)) represents, we can run algorithm with existing ordering. Algorithm should work in - // PartitionSearchMode, bool field represents whether we should reverse window expressions to run executor with existing ordering. - // For instance, `Some((false, PartitionSearchMode::Sorted))`, represents that we shouldn't reverse window expressions. And algorithm + // Some((bool, InputOrderMode)) represents, we can run algorithm with existing ordering. Algorithm should work in + // InputOrderMode, bool field represents whether we should reverse window expressions to run executor with existing ordering. + // For instance, `Some((false, InputOrderMode::Sorted))`, represents that we shouldn't reverse window expressions. And algorithm // should work in Sorted mode to work with existing ordering. let test_cases = vec![ // PARTITION BY a, b ORDER BY c ASC NULLS LAST @@ -1037,7 +1022,7 @@ mod tests { } assert_eq!( - get_window_mode(&partition_by_exprs, &order_by_exprs, &exec_unbounded)?, + get_window_mode(&partition_by_exprs, &order_by_exprs, &exec_unbounded), *expected, "Unexpected result for in unbounded test case#: {case_idx:?}, case: {test_case:?}" ); diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index ac3439a64ca81..f9f24b28db813 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -43,11 +43,11 @@ parquet = ["datafusion/parquet", "datafusion-common/parquet"] [dependencies] arrow = { workspace = true } chrono = { workspace = true } -datafusion = { path = "../core", version = "33.0.0" } +datafusion = { path = "../core", version = "34.0.0" } datafusion-common = { workspace = true } datafusion-expr = { workspace = true } -object_store = { version = "0.7.0" } -pbjson = { version = "0.5", optional = true } +object_store = { workspace = true } +pbjson = { version = "0.6.0", optional = true } prost = "0.12.0" serde = { version = "1.0", optional = true } serde_json = { workspace = true, optional = true } diff --git a/datafusion/proto/gen/Cargo.toml b/datafusion/proto/gen/Cargo.toml index 37c49666d3d73..8b3f3f98a8a1d 100644 --- a/datafusion/proto/gen/Cargo.toml +++ b/datafusion/proto/gen/Cargo.toml @@ -32,4 +32,4 @@ publish = false [dependencies] # Pin these dependencies so that the generated output is deterministic pbjson-build = "=0.6.2" -prost-build = "=0.12.1" +prost-build = "=0.12.3" diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 9b6a0448f810e..d5f8397aa30cf 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -73,6 +73,8 @@ message LogicalPlanNode { CustomTableScanNode custom_scan = 25; PrepareNode prepare = 26; DropViewNode drop_view = 27; + DistinctOnNode distinct_on = 28; + CopyToNode copy_to = 29; } } @@ -215,6 +217,7 @@ message CreateExternalTableNode { bool unbounded = 14; map options = 11; Constraints constraints = 15; + map column_defaults = 16; } message PrepareNode { @@ -308,6 +311,33 @@ message DistinctNode { LogicalPlanNode input = 1; } +message DistinctOnNode { + repeated LogicalExprNode on_expr = 1; + repeated LogicalExprNode select_expr = 2; + repeated LogicalExprNode sort_expr = 3; + LogicalPlanNode input = 4; +} + +message CopyToNode { + LogicalPlanNode input = 1; + string output_url = 2; + bool single_file_output = 3; + oneof CopyOptions { + SQLOptions sql_options = 4; + FileTypeWriterOptions writer_options = 5; + } + string file_type = 6; +} + +message SQLOptions { + repeated SQLOption option = 1; +} + +message SQLOption { + string key = 1; + string value = 2; +} + message UnionNode { repeated LogicalPlanNode inputs = 1; } @@ -363,7 +393,7 @@ message LogicalExprNode { SortExprNode sort = 12; NegativeNode negative = 13; InListNode in_list = 14; - bool wildcard = 15; + Wildcard wildcard = 15; ScalarFunctionNode scalar_function = 16; TryCastNode try_cast = 17; @@ -399,6 +429,10 @@ message LogicalExprNode { } } +message Wildcard { + string qualifier = 1; +} + message PlaceholderNode { string id = 1; ArrowType data_type = 2; @@ -481,6 +515,7 @@ message Not { message AliasNode { LogicalExprNode expr = 1; string alias = 2; + repeated OwnedTableReference relation = 3; } message BinaryExprNode { @@ -621,6 +656,17 @@ enum ScalarFunction { ArrayPopBack = 116; StringToArray = 117; ToTimestampNanos = 118; + ArrayIntersect = 119; + ArrayUnion = 120; + OverLay = 121; + Range = 122; + ArrayExcept = 123; + ArrayPopFront = 124; + Levenshtein = 125; + SubstrIndex = 126; + FindInSet = 127; + ArraySort = 128; + ArrayDistinct = 129; } message ScalarFunctionNode { @@ -666,6 +712,7 @@ enum AggregateFunction { REGR_SXX = 32; REGR_SYY = 33; REGR_SXY = 34; + STRING_AGG = 35; } message AggregateExprNode { @@ -815,6 +862,8 @@ message Field { // for complex data types like structs, unions repeated Field children = 4; map metadata = 5; + int64 dict_id = 6; + bool dict_ordered = 7; } message FixedSizeBinary{ @@ -963,7 +1012,9 @@ message ScalarValue{ // Literal Date32 value always has a unit of day int32 date_32_value = 14; ScalarTime32Value time32_value = 15; + ScalarListValue large_list_value = 16; ScalarListValue list_value = 17; + ScalarListValue fixed_size_list_value = 18; Decimal128 decimal128_value = 20; Decimal256 decimal256_value = 39; @@ -1070,8 +1121,10 @@ message PlanType { OptimizedLogicalPlanType OptimizedLogicalPlan = 2; EmptyMessage FinalLogicalPlan = 3; EmptyMessage InitialPhysicalPlan = 4; + EmptyMessage InitialPhysicalPlanWithStats = 9; OptimizedPhysicalPlanType OptimizedPhysicalPlan = 5; EmptyMessage FinalPhysicalPlan = 6; + EmptyMessage FinalPhysicalPlanWithStats = 10; } } @@ -1130,9 +1183,119 @@ message PhysicalPlanNode { SortPreservingMergeExecNode sort_preserving_merge = 21; NestedLoopJoinExecNode nested_loop_join = 22; AnalyzeExecNode analyze = 23; + JsonSinkExecNode json_sink = 24; + SymmetricHashJoinExecNode symmetric_hash_join = 25; + InterleaveExecNode interleave = 26; + PlaceholderRowExecNode placeholder_row = 27; + CsvSinkExecNode csv_sink = 28; + ParquetSinkExecNode parquet_sink = 29; + } +} + +enum CompressionTypeVariant { + GZIP = 0; + BZIP2 = 1; + XZ = 2; + ZSTD = 3; + UNCOMPRESSED = 4; +} + +message PartitionColumn { + string name = 1; + ArrowType arrow_type = 2; +} + +message FileTypeWriterOptions { + oneof FileType { + JsonWriterOptions json_options = 1; + ParquetWriterOptions parquet_options = 2; + CsvWriterOptions csv_options = 3; } } +message JsonWriterOptions { + CompressionTypeVariant compression = 1; +} + +message ParquetWriterOptions { + WriterProperties writer_properties = 1; +} + +message CsvWriterOptions { + // Compression type + CompressionTypeVariant compression = 1; + // Optional column delimiter. Defaults to `b','` + string delimiter = 2; + // Whether to write column names as file headers. Defaults to `true` + bool has_header = 3; + // Optional date format for date arrays + string date_format = 4; + // Optional datetime format for datetime arrays + string datetime_format = 5; + // Optional timestamp format for timestamp arrays + string timestamp_format = 6; + // Optional time format for time arrays + string time_format = 7; + // Optional value to represent null + string null_value = 8; +} + +message WriterProperties { + uint64 data_page_size_limit = 1; + uint64 dictionary_page_size_limit = 2; + uint64 data_page_row_count_limit = 3; + uint64 write_batch_size = 4; + uint64 max_row_group_size = 5; + string writer_version = 6; + string created_by = 7; +} + +message FileSinkConfig { + reserved 6; // writer_mode + + string object_store_url = 1; + repeated PartitionedFile file_groups = 2; + repeated string table_paths = 3; + Schema output_schema = 4; + repeated PartitionColumn table_partition_cols = 5; + bool single_file_output = 7; + bool overwrite = 8; + FileTypeWriterOptions file_type_writer_options = 9; +} + +message JsonSink { + FileSinkConfig config = 1; +} + +message JsonSinkExecNode { + PhysicalPlanNode input = 1; + JsonSink sink = 2; + Schema sink_schema = 3; + PhysicalSortExprNodeCollection sort_order = 4; +} + +message CsvSink { + FileSinkConfig config = 1; +} + +message CsvSinkExecNode { + PhysicalPlanNode input = 1; + CsvSink sink = 2; + Schema sink_schema = 3; + PhysicalSortExprNodeCollection sort_order = 4; +} + +message ParquetSink { + FileSinkConfig config = 1; +} + +message ParquetSinkExecNode { + PhysicalPlanNode input = 1; + ParquetSink sink = 2; + Schema sink_schema = 3; + PhysicalSortExprNodeCollection sort_order = 4; +} + message PhysicalExtensionNode { bytes node = 1; repeated PhysicalPlanNode inputs = 2; @@ -1291,6 +1454,7 @@ message PhysicalNegativeNode { message FilterExecNode { PhysicalPlanNode input = 1; PhysicalExprNode expr = 2; + uint32 default_filter_selectivity = 3; } message FileGroup { @@ -1359,6 +1523,25 @@ message HashJoinExecNode { JoinFilter filter = 8; } +enum StreamPartitionMode { + SINGLE_PARTITION = 0; + PARTITIONED_EXEC = 1; +} + +message SymmetricHashJoinExecNode { + PhysicalPlanNode left = 1; + PhysicalPlanNode right = 2; + repeated JoinOn on = 3; + JoinType join_type = 4; + StreamPartitionMode partition_mode = 6; + bool null_equals_null = 7; + JoinFilter filter = 8; +} + +message InterleaveExecNode { + repeated PhysicalPlanNode inputs = 1; +} + message UnionExecNode { repeated PhysicalPlanNode inputs = 1; } @@ -1392,8 +1575,11 @@ message JoinOn { } message EmptyExecNode { - bool produce_one_row = 1; - Schema schema = 2; + Schema schema = 1; +} + +message PlaceholderRowExecNode { + Schema schema = 1; } message ProjectionExecNode { @@ -1410,7 +1596,7 @@ enum AggregateMode { SINGLE_PARTITIONED = 4; } -message PartiallySortedPartitionSearchMode { +message PartiallySortedInputOrderMode { repeated uint64 columns = 6; } @@ -1419,9 +1605,9 @@ message WindowAggExecNode { repeated PhysicalWindowExprNode window_expr = 2; repeated PhysicalExprNode partition_keys = 5; // Set optional to `None` for `BoundedWindowAggExec`. - oneof partition_search_mode { + oneof input_order_mode { EmptyMessage linear = 7; - PartiallySortedPartitionSearchMode partially_sorted = 8; + PartiallySortedInputOrderMode partially_sorted = 8; EmptyMessage sorted = 9; } } @@ -1446,7 +1632,6 @@ message AggregateExecNode { repeated PhysicalExprNode null_expr = 8; repeated bool groups = 9; repeated MaybeFilter filter_expr = 10; - repeated MaybePhysicalSortExprs order_by_expr = 11; } message GlobalLimitExecNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 3eeb060f8d01d..12e834d75adff 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -36,9 +36,6 @@ impl serde::Serialize for AggregateExecNode { if !self.filter_expr.is_empty() { len += 1; } - if !self.order_by_expr.is_empty() { - len += 1; - } let mut struct_ser = serializer.serialize_struct("datafusion.AggregateExecNode", len)?; if !self.group_expr.is_empty() { struct_ser.serialize_field("groupExpr", &self.group_expr)?; @@ -72,9 +69,6 @@ impl serde::Serialize for AggregateExecNode { if !self.filter_expr.is_empty() { struct_ser.serialize_field("filterExpr", &self.filter_expr)?; } - if !self.order_by_expr.is_empty() { - struct_ser.serialize_field("orderByExpr", &self.order_by_expr)?; - } struct_ser.end() } } @@ -102,8 +96,6 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { "groups", "filter_expr", "filterExpr", - "order_by_expr", - "orderByExpr", ]; #[allow(clippy::enum_variant_names)] @@ -118,7 +110,6 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { NullExpr, Groups, FilterExpr, - OrderByExpr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -150,7 +141,6 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { "nullExpr" | "null_expr" => Ok(GeneratedField::NullExpr), "groups" => Ok(GeneratedField::Groups), "filterExpr" | "filter_expr" => Ok(GeneratedField::FilterExpr), - "orderByExpr" | "order_by_expr" => Ok(GeneratedField::OrderByExpr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -180,7 +170,6 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { let mut null_expr__ = None; let mut groups__ = None; let mut filter_expr__ = None; - let mut order_by_expr__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::GroupExpr => { @@ -243,12 +232,6 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { } filter_expr__ = Some(map_.next_value()?); } - GeneratedField::OrderByExpr => { - if order_by_expr__.is_some() { - return Err(serde::de::Error::duplicate_field("orderByExpr")); - } - order_by_expr__ = Some(map_.next_value()?); - } } } Ok(AggregateExecNode { @@ -262,7 +245,6 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { null_expr: null_expr__.unwrap_or_default(), groups: groups__.unwrap_or_default(), filter_expr: filter_expr__.unwrap_or_default(), - order_by_expr: order_by_expr__.unwrap_or_default(), }) } } @@ -474,6 +456,7 @@ impl serde::Serialize for AggregateFunction { Self::RegrSxx => "REGR_SXX", Self::RegrSyy => "REGR_SYY", Self::RegrSxy => "REGR_SXY", + Self::StringAgg => "STRING_AGG", }; serializer.serialize_str(variant) } @@ -520,6 +503,7 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "REGR_SXX", "REGR_SYY", "REGR_SXY", + "STRING_AGG", ]; struct GeneratedVisitor; @@ -595,6 +579,7 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "REGR_SXX" => Ok(AggregateFunction::RegrSxx), "REGR_SYY" => Ok(AggregateFunction::RegrSyy), "REGR_SXY" => Ok(AggregateFunction::RegrSxy), + "STRING_AGG" => Ok(AggregateFunction::StringAgg), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } @@ -967,6 +952,9 @@ impl serde::Serialize for AliasNode { if !self.alias.is_empty() { len += 1; } + if !self.relation.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.AliasNode", len)?; if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; @@ -974,6 +962,9 @@ impl serde::Serialize for AliasNode { if !self.alias.is_empty() { struct_ser.serialize_field("alias", &self.alias)?; } + if !self.relation.is_empty() { + struct_ser.serialize_field("relation", &self.relation)?; + } struct_ser.end() } } @@ -986,12 +977,14 @@ impl<'de> serde::Deserialize<'de> for AliasNode { const FIELDS: &[&str] = &[ "expr", "alias", + "relation", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Expr, Alias, + Relation, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -1015,6 +1008,7 @@ impl<'de> serde::Deserialize<'de> for AliasNode { match value { "expr" => Ok(GeneratedField::Expr), "alias" => Ok(GeneratedField::Alias), + "relation" => Ok(GeneratedField::Relation), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -1036,6 +1030,7 @@ impl<'de> serde::Deserialize<'de> for AliasNode { { let mut expr__ = None; let mut alias__ = None; + let mut relation__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { @@ -1050,11 +1045,18 @@ impl<'de> serde::Deserialize<'de> for AliasNode { } alias__ = Some(map_.next_value()?); } + GeneratedField::Relation => { + if relation__.is_some() { + return Err(serde::de::Error::duplicate_field("relation")); + } + relation__ = Some(map_.next_value()?); + } } } Ok(AliasNode { expr: expr__, alias: alias__.unwrap_or_default(), + relation: relation__.unwrap_or_default(), }) } } @@ -3421,6 +3423,86 @@ impl<'de> serde::Deserialize<'de> for ColumnStats { deserializer.deserialize_struct("datafusion.ColumnStats", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for CompressionTypeVariant { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::Gzip => "GZIP", + Self::Bzip2 => "BZIP2", + Self::Xz => "XZ", + Self::Zstd => "ZSTD", + Self::Uncompressed => "UNCOMPRESSED", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for CompressionTypeVariant { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "GZIP", + "BZIP2", + "XZ", + "ZSTD", + "UNCOMPRESSED", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CompressionTypeVariant; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "GZIP" => Ok(CompressionTypeVariant::Gzip), + "BZIP2" => Ok(CompressionTypeVariant::Bzip2), + "XZ" => Ok(CompressionTypeVariant::Xz), + "ZSTD" => Ok(CompressionTypeVariant::Zstd), + "UNCOMPRESSED" => Ok(CompressionTypeVariant::Uncompressed), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} impl serde::Serialize for Constraint { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -3622,6 +3704,188 @@ impl<'de> serde::Deserialize<'de> for Constraints { deserializer.deserialize_struct("datafusion.Constraints", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for CopyToNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if !self.output_url.is_empty() { + len += 1; + } + if self.single_file_output { + len += 1; + } + if !self.file_type.is_empty() { + len += 1; + } + if self.copy_options.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CopyToNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if !self.output_url.is_empty() { + struct_ser.serialize_field("outputUrl", &self.output_url)?; + } + if self.single_file_output { + struct_ser.serialize_field("singleFileOutput", &self.single_file_output)?; + } + if !self.file_type.is_empty() { + struct_ser.serialize_field("fileType", &self.file_type)?; + } + if let Some(v) = self.copy_options.as_ref() { + match v { + copy_to_node::CopyOptions::SqlOptions(v) => { + struct_ser.serialize_field("sqlOptions", v)?; + } + copy_to_node::CopyOptions::WriterOptions(v) => { + struct_ser.serialize_field("writerOptions", v)?; + } + } + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for CopyToNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "output_url", + "outputUrl", + "single_file_output", + "singleFileOutput", + "file_type", + "fileType", + "sql_options", + "sqlOptions", + "writer_options", + "writerOptions", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + OutputUrl, + SingleFileOutput, + FileType, + SqlOptions, + WriterOptions, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "outputUrl" | "output_url" => Ok(GeneratedField::OutputUrl), + "singleFileOutput" | "single_file_output" => Ok(GeneratedField::SingleFileOutput), + "fileType" | "file_type" => Ok(GeneratedField::FileType), + "sqlOptions" | "sql_options" => Ok(GeneratedField::SqlOptions), + "writerOptions" | "writer_options" => Ok(GeneratedField::WriterOptions), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CopyToNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.CopyToNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut output_url__ = None; + let mut single_file_output__ = None; + let mut file_type__ = None; + let mut copy_options__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::OutputUrl => { + if output_url__.is_some() { + return Err(serde::de::Error::duplicate_field("outputUrl")); + } + output_url__ = Some(map_.next_value()?); + } + GeneratedField::SingleFileOutput => { + if single_file_output__.is_some() { + return Err(serde::de::Error::duplicate_field("singleFileOutput")); + } + single_file_output__ = Some(map_.next_value()?); + } + GeneratedField::FileType => { + if file_type__.is_some() { + return Err(serde::de::Error::duplicate_field("fileType")); + } + file_type__ = Some(map_.next_value()?); + } + GeneratedField::SqlOptions => { + if copy_options__.is_some() { + return Err(serde::de::Error::duplicate_field("sqlOptions")); + } + copy_options__ = map_.next_value::<::std::option::Option<_>>()?.map(copy_to_node::CopyOptions::SqlOptions) +; + } + GeneratedField::WriterOptions => { + if copy_options__.is_some() { + return Err(serde::de::Error::duplicate_field("writerOptions")); + } + copy_options__ = map_.next_value::<::std::option::Option<_>>()?.map(copy_to_node::CopyOptions::WriterOptions) +; + } + } + } + Ok(CopyToNode { + input: input__, + output_url: output_url__.unwrap_or_default(), + single_file_output: single_file_output__.unwrap_or_default(), + file_type: file_type__.unwrap_or_default(), + copy_options: copy_options__, + }) + } + } + deserializer.deserialize_struct("datafusion.CopyToNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for CreateCatalogNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -3926,6 +4190,9 @@ impl serde::Serialize for CreateExternalTableNode { if self.constraints.is_some() { len += 1; } + if !self.column_defaults.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.CreateExternalTableNode", len)?; if let Some(v) = self.name.as_ref() { struct_ser.serialize_field("name", v)?; @@ -3969,6 +4236,9 @@ impl serde::Serialize for CreateExternalTableNode { if let Some(v) = self.constraints.as_ref() { struct_ser.serialize_field("constraints", v)?; } + if !self.column_defaults.is_empty() { + struct_ser.serialize_field("columnDefaults", &self.column_defaults)?; + } struct_ser.end() } } @@ -3999,6 +4269,8 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { "unbounded", "options", "constraints", + "column_defaults", + "columnDefaults", ]; #[allow(clippy::enum_variant_names)] @@ -4017,6 +4289,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { Unbounded, Options, Constraints, + ColumnDefaults, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -4052,6 +4325,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { "unbounded" => Ok(GeneratedField::Unbounded), "options" => Ok(GeneratedField::Options), "constraints" => Ok(GeneratedField::Constraints), + "columnDefaults" | "column_defaults" => Ok(GeneratedField::ColumnDefaults), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -4085,6 +4359,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { let mut unbounded__ = None; let mut options__ = None; let mut constraints__ = None; + let mut column_defaults__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Name => { @@ -4173,6 +4448,14 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { } constraints__ = map_.next_value()?; } + GeneratedField::ColumnDefaults => { + if column_defaults__.is_some() { + return Err(serde::de::Error::duplicate_field("columnDefaults")); + } + column_defaults__ = Some( + map_.next_value::>()? + ); + } } } Ok(CreateExternalTableNode { @@ -4190,6 +4473,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { unbounded: unbounded__.unwrap_or_default(), options: options__.unwrap_or_default(), constraints: constraints__, + column_defaults: column_defaults__.unwrap_or_default(), }) } } @@ -4867,7 +5151,7 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { deserializer.deserialize_struct("datafusion.CsvScanExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CubeNode { +impl serde::Serialize for CsvSink { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -4875,29 +5159,29 @@ impl serde::Serialize for CubeNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.expr.is_empty() { + if self.config.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CubeNode", len)?; - if !self.expr.is_empty() { - struct_ser.serialize_field("expr", &self.expr)?; + let mut struct_ser = serializer.serialize_struct("datafusion.CsvSink", len)?; + if let Some(v) = self.config.as_ref() { + struct_ser.serialize_field("config", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CubeNode { +impl<'de> serde::Deserialize<'de> for CsvSink { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "config", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, + Config, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -4919,7 +5203,7 @@ impl<'de> serde::Deserialize<'de> for CubeNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), + "config" => Ok(GeneratedField::Config), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -4929,36 +5213,36 @@ impl<'de> serde::Deserialize<'de> for CubeNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CubeNode; + type Value = CsvSink; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CubeNode") + formatter.write_str("struct datafusion.CsvSink") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; + let mut config__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Config => { + if config__.is_some() { + return Err(serde::de::Error::duplicate_field("config")); } - expr__ = Some(map_.next_value()?); + config__ = map_.next_value()?; } } } - Ok(CubeNode { - expr: expr__.unwrap_or_default(), + Ok(CsvSink { + config: config__, }) } } - deserializer.deserialize_struct("datafusion.CubeNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CsvSink", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CustomTableScanNode { +impl serde::Serialize for CsvSinkExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -4966,35 +5250,488 @@ impl serde::Serialize for CustomTableScanNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.table_name.is_some() { - len += 1; - } - if self.projection.is_some() { + if self.input.is_some() { len += 1; } - if self.schema.is_some() { + if self.sink.is_some() { len += 1; } - if !self.filters.is_empty() { + if self.sink_schema.is_some() { len += 1; } - if !self.custom_table_data.is_empty() { + if self.sort_order.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CustomTableScanNode", len)?; - if let Some(v) = self.table_name.as_ref() { - struct_ser.serialize_field("tableName", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.CsvSinkExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; } - if let Some(v) = self.projection.as_ref() { - struct_ser.serialize_field("projection", v)?; + if let Some(v) = self.sink.as_ref() { + struct_ser.serialize_field("sink", v)?; } - if let Some(v) = self.schema.as_ref() { - struct_ser.serialize_field("schema", v)?; + if let Some(v) = self.sink_schema.as_ref() { + struct_ser.serialize_field("sinkSchema", v)?; } - if !self.filters.is_empty() { - struct_ser.serialize_field("filters", &self.filters)?; + if let Some(v) = self.sort_order.as_ref() { + struct_ser.serialize_field("sortOrder", v)?; } - if !self.custom_table_data.is_empty() { + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for CsvSinkExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "sink", + "sink_schema", + "sinkSchema", + "sort_order", + "sortOrder", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + Sink, + SinkSchema, + SortOrder, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "sink" => Ok(GeneratedField::Sink), + "sinkSchema" | "sink_schema" => Ok(GeneratedField::SinkSchema), + "sortOrder" | "sort_order" => Ok(GeneratedField::SortOrder), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CsvSinkExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.CsvSinkExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut sink__ = None; + let mut sink_schema__ = None; + let mut sort_order__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::Sink => { + if sink__.is_some() { + return Err(serde::de::Error::duplicate_field("sink")); + } + sink__ = map_.next_value()?; + } + GeneratedField::SinkSchema => { + if sink_schema__.is_some() { + return Err(serde::de::Error::duplicate_field("sinkSchema")); + } + sink_schema__ = map_.next_value()?; + } + GeneratedField::SortOrder => { + if sort_order__.is_some() { + return Err(serde::de::Error::duplicate_field("sortOrder")); + } + sort_order__ = map_.next_value()?; + } + } + } + Ok(CsvSinkExecNode { + input: input__, + sink: sink__, + sink_schema: sink_schema__, + sort_order: sort_order__, + }) + } + } + deserializer.deserialize_struct("datafusion.CsvSinkExecNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for CsvWriterOptions { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.compression != 0 { + len += 1; + } + if !self.delimiter.is_empty() { + len += 1; + } + if self.has_header { + len += 1; + } + if !self.date_format.is_empty() { + len += 1; + } + if !self.datetime_format.is_empty() { + len += 1; + } + if !self.timestamp_format.is_empty() { + len += 1; + } + if !self.time_format.is_empty() { + len += 1; + } + if !self.null_value.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CsvWriterOptions", len)?; + if self.compression != 0 { + let v = CompressionTypeVariant::try_from(self.compression) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.compression)))?; + struct_ser.serialize_field("compression", &v)?; + } + if !self.delimiter.is_empty() { + struct_ser.serialize_field("delimiter", &self.delimiter)?; + } + if self.has_header { + struct_ser.serialize_field("hasHeader", &self.has_header)?; + } + if !self.date_format.is_empty() { + struct_ser.serialize_field("dateFormat", &self.date_format)?; + } + if !self.datetime_format.is_empty() { + struct_ser.serialize_field("datetimeFormat", &self.datetime_format)?; + } + if !self.timestamp_format.is_empty() { + struct_ser.serialize_field("timestampFormat", &self.timestamp_format)?; + } + if !self.time_format.is_empty() { + struct_ser.serialize_field("timeFormat", &self.time_format)?; + } + if !self.null_value.is_empty() { + struct_ser.serialize_field("nullValue", &self.null_value)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for CsvWriterOptions { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "compression", + "delimiter", + "has_header", + "hasHeader", + "date_format", + "dateFormat", + "datetime_format", + "datetimeFormat", + "timestamp_format", + "timestampFormat", + "time_format", + "timeFormat", + "null_value", + "nullValue", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Compression, + Delimiter, + HasHeader, + DateFormat, + DatetimeFormat, + TimestampFormat, + TimeFormat, + NullValue, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "compression" => Ok(GeneratedField::Compression), + "delimiter" => Ok(GeneratedField::Delimiter), + "hasHeader" | "has_header" => Ok(GeneratedField::HasHeader), + "dateFormat" | "date_format" => Ok(GeneratedField::DateFormat), + "datetimeFormat" | "datetime_format" => Ok(GeneratedField::DatetimeFormat), + "timestampFormat" | "timestamp_format" => Ok(GeneratedField::TimestampFormat), + "timeFormat" | "time_format" => Ok(GeneratedField::TimeFormat), + "nullValue" | "null_value" => Ok(GeneratedField::NullValue), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CsvWriterOptions; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.CsvWriterOptions") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut compression__ = None; + let mut delimiter__ = None; + let mut has_header__ = None; + let mut date_format__ = None; + let mut datetime_format__ = None; + let mut timestamp_format__ = None; + let mut time_format__ = None; + let mut null_value__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Compression => { + if compression__.is_some() { + return Err(serde::de::Error::duplicate_field("compression")); + } + compression__ = Some(map_.next_value::()? as i32); + } + GeneratedField::Delimiter => { + if delimiter__.is_some() { + return Err(serde::de::Error::duplicate_field("delimiter")); + } + delimiter__ = Some(map_.next_value()?); + } + GeneratedField::HasHeader => { + if has_header__.is_some() { + return Err(serde::de::Error::duplicate_field("hasHeader")); + } + has_header__ = Some(map_.next_value()?); + } + GeneratedField::DateFormat => { + if date_format__.is_some() { + return Err(serde::de::Error::duplicate_field("dateFormat")); + } + date_format__ = Some(map_.next_value()?); + } + GeneratedField::DatetimeFormat => { + if datetime_format__.is_some() { + return Err(serde::de::Error::duplicate_field("datetimeFormat")); + } + datetime_format__ = Some(map_.next_value()?); + } + GeneratedField::TimestampFormat => { + if timestamp_format__.is_some() { + return Err(serde::de::Error::duplicate_field("timestampFormat")); + } + timestamp_format__ = Some(map_.next_value()?); + } + GeneratedField::TimeFormat => { + if time_format__.is_some() { + return Err(serde::de::Error::duplicate_field("timeFormat")); + } + time_format__ = Some(map_.next_value()?); + } + GeneratedField::NullValue => { + if null_value__.is_some() { + return Err(serde::de::Error::duplicate_field("nullValue")); + } + null_value__ = Some(map_.next_value()?); + } + } + } + Ok(CsvWriterOptions { + compression: compression__.unwrap_or_default(), + delimiter: delimiter__.unwrap_or_default(), + has_header: has_header__.unwrap_or_default(), + date_format: date_format__.unwrap_or_default(), + datetime_format: datetime_format__.unwrap_or_default(), + timestamp_format: timestamp_format__.unwrap_or_default(), + time_format: time_format__.unwrap_or_default(), + null_value: null_value__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.CsvWriterOptions", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for CubeNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.expr.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CubeNode", len)?; + if !self.expr.is_empty() { + struct_ser.serialize_field("expr", &self.expr)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for CubeNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "expr", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Expr, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "expr" => Ok(GeneratedField::Expr), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CubeNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.CubeNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut expr__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); + } + expr__ = Some(map_.next_value()?); + } + } + } + Ok(CubeNode { + expr: expr__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.CubeNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for CustomTableScanNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.table_name.is_some() { + len += 1; + } + if self.projection.is_some() { + len += 1; + } + if self.schema.is_some() { + len += 1; + } + if !self.filters.is_empty() { + len += 1; + } + if !self.custom_table_data.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CustomTableScanNode", len)?; + if let Some(v) = self.table_name.as_ref() { + struct_ser.serialize_field("tableName", v)?; + } + if let Some(v) = self.projection.as_ref() { + struct_ser.serialize_field("projection", v)?; + } + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; + } + if !self.filters.is_empty() { + struct_ser.serialize_field("filters", &self.filters)?; + } + if !self.custom_table_data.is_empty() { #[allow(clippy::needless_borrow)] struct_ser.serialize_field("customTableData", pbjson::private::base64::encode(&self.custom_table_data).as_str())?; } @@ -5917,18 +6654,136 @@ impl serde::Serialize for DistinctNode { struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for DistinctNode { +impl<'de> serde::Deserialize<'de> for DistinctNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = DistinctNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.DistinctNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + } + } + Ok(DistinctNode { + input: input__, + }) + } + } + deserializer.deserialize_struct("datafusion.DistinctNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for DistinctOnNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.on_expr.is_empty() { + len += 1; + } + if !self.select_expr.is_empty() { + len += 1; + } + if !self.sort_expr.is_empty() { + len += 1; + } + if self.input.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.DistinctOnNode", len)?; + if !self.on_expr.is_empty() { + struct_ser.serialize_field("onExpr", &self.on_expr)?; + } + if !self.select_expr.is_empty() { + struct_ser.serialize_field("selectExpr", &self.select_expr)?; + } + if !self.sort_expr.is_empty() { + struct_ser.serialize_field("sortExpr", &self.sort_expr)?; + } + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for DistinctOnNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ + "on_expr", + "onExpr", + "select_expr", + "selectExpr", + "sort_expr", + "sortExpr", "input", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { + OnExpr, + SelectExpr, + SortExpr, Input, } impl<'de> serde::Deserialize<'de> for GeneratedField { @@ -5951,6 +6806,9 @@ impl<'de> serde::Deserialize<'de> for DistinctNode { E: serde::de::Error, { match value { + "onExpr" | "on_expr" => Ok(GeneratedField::OnExpr), + "selectExpr" | "select_expr" => Ok(GeneratedField::SelectExpr), + "sortExpr" | "sort_expr" => Ok(GeneratedField::SortExpr), "input" => Ok(GeneratedField::Input), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } @@ -5961,19 +6819,40 @@ impl<'de> serde::Deserialize<'de> for DistinctNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = DistinctNode; + type Value = DistinctOnNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.DistinctNode") + formatter.write_str("struct datafusion.DistinctOnNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { + let mut on_expr__ = None; + let mut select_expr__ = None; + let mut sort_expr__ = None; let mut input__ = None; while let Some(k) = map_.next_key()? { match k { + GeneratedField::OnExpr => { + if on_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("onExpr")); + } + on_expr__ = Some(map_.next_value()?); + } + GeneratedField::SelectExpr => { + if select_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("selectExpr")); + } + select_expr__ = Some(map_.next_value()?); + } + GeneratedField::SortExpr => { + if sort_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("sortExpr")); + } + sort_expr__ = Some(map_.next_value()?); + } GeneratedField::Input => { if input__.is_some() { return Err(serde::de::Error::duplicate_field("input")); @@ -5982,12 +6861,15 @@ impl<'de> serde::Deserialize<'de> for DistinctNode { } } } - Ok(DistinctNode { + Ok(DistinctOnNode { + on_expr: on_expr__.unwrap_or_default(), + select_expr: select_expr__.unwrap_or_default(), + sort_expr: sort_expr__.unwrap_or_default(), input: input__, }) } } - deserializer.deserialize_struct("datafusion.DistinctNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.DistinctOnNode", FIELDS, GeneratedVisitor) } } impl serde::Serialize for DropViewNode { @@ -6124,16 +7006,10 @@ impl serde::Serialize for EmptyExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.produce_one_row { - len += 1; - } if self.schema.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.EmptyExecNode", len)?; - if self.produce_one_row { - struct_ser.serialize_field("produceOneRow", &self.produce_one_row)?; - } if let Some(v) = self.schema.as_ref() { struct_ser.serialize_field("schema", v)?; } @@ -6147,14 +7023,11 @@ impl<'de> serde::Deserialize<'de> for EmptyExecNode { D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "produce_one_row", - "produceOneRow", "schema", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - ProduceOneRow, Schema, } impl<'de> serde::Deserialize<'de> for GeneratedField { @@ -6177,7 +7050,6 @@ impl<'de> serde::Deserialize<'de> for EmptyExecNode { E: serde::de::Error, { match value { - "produceOneRow" | "produce_one_row" => Ok(GeneratedField::ProduceOneRow), "schema" => Ok(GeneratedField::Schema), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } @@ -6198,16 +7070,9 @@ impl<'de> serde::Deserialize<'de> for EmptyExecNode { where V: serde::de::MapAccess<'de>, { - let mut produce_one_row__ = None; let mut schema__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::ProduceOneRow => { - if produce_one_row__.is_some() { - return Err(serde::de::Error::duplicate_field("produceOneRow")); - } - produce_one_row__ = Some(map_.next_value()?); - } GeneratedField::Schema => { if schema__.is_some() { return Err(serde::de::Error::duplicate_field("schema")); @@ -6217,7 +7082,6 @@ impl<'de> serde::Deserialize<'de> for EmptyExecNode { } } Ok(EmptyExecNode { - produce_one_row: produce_one_row__.unwrap_or_default(), schema: schema__, }) } @@ -6645,6 +7509,12 @@ impl serde::Serialize for Field { if !self.metadata.is_empty() { len += 1; } + if self.dict_id != 0 { + len += 1; + } + if self.dict_ordered { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.Field", len)?; if !self.name.is_empty() { struct_ser.serialize_field("name", &self.name)?; @@ -6661,6 +7531,13 @@ impl serde::Serialize for Field { if !self.metadata.is_empty() { struct_ser.serialize_field("metadata", &self.metadata)?; } + if self.dict_id != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("dictId", ToString::to_string(&self.dict_id).as_str())?; + } + if self.dict_ordered { + struct_ser.serialize_field("dictOrdered", &self.dict_ordered)?; + } struct_ser.end() } } @@ -6677,6 +7554,10 @@ impl<'de> serde::Deserialize<'de> for Field { "nullable", "children", "metadata", + "dict_id", + "dictId", + "dict_ordered", + "dictOrdered", ]; #[allow(clippy::enum_variant_names)] @@ -6686,6 +7567,8 @@ impl<'de> serde::Deserialize<'de> for Field { Nullable, Children, Metadata, + DictId, + DictOrdered, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -6712,6 +7595,8 @@ impl<'de> serde::Deserialize<'de> for Field { "nullable" => Ok(GeneratedField::Nullable), "children" => Ok(GeneratedField::Children), "metadata" => Ok(GeneratedField::Metadata), + "dictId" | "dict_id" => Ok(GeneratedField::DictId), + "dictOrdered" | "dict_ordered" => Ok(GeneratedField::DictOrdered), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -6736,6 +7621,8 @@ impl<'de> serde::Deserialize<'de> for Field { let mut nullable__ = None; let mut children__ = None; let mut metadata__ = None; + let mut dict_id__ = None; + let mut dict_ordered__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Name => { @@ -6770,6 +7657,20 @@ impl<'de> serde::Deserialize<'de> for Field { map_.next_value::>()? ); } + GeneratedField::DictId => { + if dict_id__.is_some() { + return Err(serde::de::Error::duplicate_field("dictId")); + } + dict_id__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::DictOrdered => { + if dict_ordered__.is_some() { + return Err(serde::de::Error::duplicate_field("dictOrdered")); + } + dict_ordered__ = Some(map_.next_value()?); + } } } Ok(Field { @@ -6778,6 +7679,8 @@ impl<'de> serde::Deserialize<'de> for Field { nullable: nullable__.unwrap_or_default(), children: children__.unwrap_or_default(), metadata: metadata__.unwrap_or_default(), + dict_id: dict_id__.unwrap_or_default(), + dict_ordered: dict_ordered__.unwrap_or_default(), }) } } @@ -7040,46 +7943,266 @@ impl serde::Serialize for FileScanExecConf { if !self.table_partition_cols.is_empty() { struct_ser.serialize_field("tablePartitionCols", &self.table_partition_cols)?; } - if !self.object_store_url.is_empty() { - struct_ser.serialize_field("objectStoreUrl", &self.object_store_url)?; + if !self.object_store_url.is_empty() { + struct_ser.serialize_field("objectStoreUrl", &self.object_store_url)?; + } + if !self.output_ordering.is_empty() { + struct_ser.serialize_field("outputOrdering", &self.output_ordering)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for FileScanExecConf { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "file_groups", + "fileGroups", + "schema", + "projection", + "limit", + "statistics", + "table_partition_cols", + "tablePartitionCols", + "object_store_url", + "objectStoreUrl", + "output_ordering", + "outputOrdering", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + FileGroups, + Schema, + Projection, + Limit, + Statistics, + TablePartitionCols, + ObjectStoreUrl, + OutputOrdering, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "fileGroups" | "file_groups" => Ok(GeneratedField::FileGroups), + "schema" => Ok(GeneratedField::Schema), + "projection" => Ok(GeneratedField::Projection), + "limit" => Ok(GeneratedField::Limit), + "statistics" => Ok(GeneratedField::Statistics), + "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), + "objectStoreUrl" | "object_store_url" => Ok(GeneratedField::ObjectStoreUrl), + "outputOrdering" | "output_ordering" => Ok(GeneratedField::OutputOrdering), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = FileScanExecConf; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.FileScanExecConf") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut file_groups__ = None; + let mut schema__ = None; + let mut projection__ = None; + let mut limit__ = None; + let mut statistics__ = None; + let mut table_partition_cols__ = None; + let mut object_store_url__ = None; + let mut output_ordering__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::FileGroups => { + if file_groups__.is_some() { + return Err(serde::de::Error::duplicate_field("fileGroups")); + } + file_groups__ = Some(map_.next_value()?); + } + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); + } + schema__ = map_.next_value()?; + } + GeneratedField::Projection => { + if projection__.is_some() { + return Err(serde::de::Error::duplicate_field("projection")); + } + projection__ = + Some(map_.next_value::>>()? + .into_iter().map(|x| x.0).collect()) + ; + } + GeneratedField::Limit => { + if limit__.is_some() { + return Err(serde::de::Error::duplicate_field("limit")); + } + limit__ = map_.next_value()?; + } + GeneratedField::Statistics => { + if statistics__.is_some() { + return Err(serde::de::Error::duplicate_field("statistics")); + } + statistics__ = map_.next_value()?; + } + GeneratedField::TablePartitionCols => { + if table_partition_cols__.is_some() { + return Err(serde::de::Error::duplicate_field("tablePartitionCols")); + } + table_partition_cols__ = Some(map_.next_value()?); + } + GeneratedField::ObjectStoreUrl => { + if object_store_url__.is_some() { + return Err(serde::de::Error::duplicate_field("objectStoreUrl")); + } + object_store_url__ = Some(map_.next_value()?); + } + GeneratedField::OutputOrdering => { + if output_ordering__.is_some() { + return Err(serde::de::Error::duplicate_field("outputOrdering")); + } + output_ordering__ = Some(map_.next_value()?); + } + } + } + Ok(FileScanExecConf { + file_groups: file_groups__.unwrap_or_default(), + schema: schema__, + projection: projection__.unwrap_or_default(), + limit: limit__, + statistics: statistics__, + table_partition_cols: table_partition_cols__.unwrap_or_default(), + object_store_url: object_store_url__.unwrap_or_default(), + output_ordering: output_ordering__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.FileScanExecConf", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for FileSinkConfig { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.object_store_url.is_empty() { + len += 1; + } + if !self.file_groups.is_empty() { + len += 1; + } + if !self.table_paths.is_empty() { + len += 1; + } + if self.output_schema.is_some() { + len += 1; + } + if !self.table_partition_cols.is_empty() { + len += 1; + } + if self.single_file_output { + len += 1; + } + if self.overwrite { + len += 1; + } + if self.file_type_writer_options.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.FileSinkConfig", len)?; + if !self.object_store_url.is_empty() { + struct_ser.serialize_field("objectStoreUrl", &self.object_store_url)?; + } + if !self.file_groups.is_empty() { + struct_ser.serialize_field("fileGroups", &self.file_groups)?; + } + if !self.table_paths.is_empty() { + struct_ser.serialize_field("tablePaths", &self.table_paths)?; + } + if let Some(v) = self.output_schema.as_ref() { + struct_ser.serialize_field("outputSchema", v)?; + } + if !self.table_partition_cols.is_empty() { + struct_ser.serialize_field("tablePartitionCols", &self.table_partition_cols)?; + } + if self.single_file_output { + struct_ser.serialize_field("singleFileOutput", &self.single_file_output)?; } - if !self.output_ordering.is_empty() { - struct_ser.serialize_field("outputOrdering", &self.output_ordering)?; + if self.overwrite { + struct_ser.serialize_field("overwrite", &self.overwrite)?; + } + if let Some(v) = self.file_type_writer_options.as_ref() { + struct_ser.serialize_field("fileTypeWriterOptions", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for FileScanExecConf { +impl<'de> serde::Deserialize<'de> for FileSinkConfig { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ + "object_store_url", + "objectStoreUrl", "file_groups", "fileGroups", - "schema", - "projection", - "limit", - "statistics", + "table_paths", + "tablePaths", + "output_schema", + "outputSchema", "table_partition_cols", "tablePartitionCols", - "object_store_url", - "objectStoreUrl", - "output_ordering", - "outputOrdering", + "single_file_output", + "singleFileOutput", + "overwrite", + "file_type_writer_options", + "fileTypeWriterOptions", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { + ObjectStoreUrl, FileGroups, - Schema, - Projection, - Limit, - Statistics, + TablePaths, + OutputSchema, TablePartitionCols, - ObjectStoreUrl, - OutputOrdering, + SingleFileOutput, + Overwrite, + FileTypeWriterOptions, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7101,14 +8224,14 @@ impl<'de> serde::Deserialize<'de> for FileScanExecConf { E: serde::de::Error, { match value { + "objectStoreUrl" | "object_store_url" => Ok(GeneratedField::ObjectStoreUrl), "fileGroups" | "file_groups" => Ok(GeneratedField::FileGroups), - "schema" => Ok(GeneratedField::Schema), - "projection" => Ok(GeneratedField::Projection), - "limit" => Ok(GeneratedField::Limit), - "statistics" => Ok(GeneratedField::Statistics), + "tablePaths" | "table_paths" => Ok(GeneratedField::TablePaths), + "outputSchema" | "output_schema" => Ok(GeneratedField::OutputSchema), "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), - "objectStoreUrl" | "object_store_url" => Ok(GeneratedField::ObjectStoreUrl), - "outputOrdering" | "output_ordering" => Ok(GeneratedField::OutputOrdering), + "singleFileOutput" | "single_file_output" => Ok(GeneratedField::SingleFileOutput), + "overwrite" => Ok(GeneratedField::Overwrite), + "fileTypeWriterOptions" | "file_type_writer_options" => Ok(GeneratedField::FileTypeWriterOptions), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7118,58 +8241,49 @@ impl<'de> serde::Deserialize<'de> for FileScanExecConf { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = FileScanExecConf; + type Value = FileSinkConfig; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.FileScanExecConf") + formatter.write_str("struct datafusion.FileSinkConfig") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { + let mut object_store_url__ = None; let mut file_groups__ = None; - let mut schema__ = None; - let mut projection__ = None; - let mut limit__ = None; - let mut statistics__ = None; + let mut table_paths__ = None; + let mut output_schema__ = None; let mut table_partition_cols__ = None; - let mut object_store_url__ = None; - let mut output_ordering__ = None; + let mut single_file_output__ = None; + let mut overwrite__ = None; + let mut file_type_writer_options__ = None; while let Some(k) = map_.next_key()? { match k { + GeneratedField::ObjectStoreUrl => { + if object_store_url__.is_some() { + return Err(serde::de::Error::duplicate_field("objectStoreUrl")); + } + object_store_url__ = Some(map_.next_value()?); + } GeneratedField::FileGroups => { if file_groups__.is_some() { return Err(serde::de::Error::duplicate_field("fileGroups")); } file_groups__ = Some(map_.next_value()?); } - GeneratedField::Schema => { - if schema__.is_some() { - return Err(serde::de::Error::duplicate_field("schema")); - } - schema__ = map_.next_value()?; - } - GeneratedField::Projection => { - if projection__.is_some() { - return Err(serde::de::Error::duplicate_field("projection")); - } - projection__ = - Some(map_.next_value::>>()? - .into_iter().map(|x| x.0).collect()) - ; - } - GeneratedField::Limit => { - if limit__.is_some() { - return Err(serde::de::Error::duplicate_field("limit")); + GeneratedField::TablePaths => { + if table_paths__.is_some() { + return Err(serde::de::Error::duplicate_field("tablePaths")); } - limit__ = map_.next_value()?; + table_paths__ = Some(map_.next_value()?); } - GeneratedField::Statistics => { - if statistics__.is_some() { - return Err(serde::de::Error::duplicate_field("statistics")); + GeneratedField::OutputSchema => { + if output_schema__.is_some() { + return Err(serde::de::Error::duplicate_field("outputSchema")); } - statistics__ = map_.next_value()?; + output_schema__ = map_.next_value()?; } GeneratedField::TablePartitionCols => { if table_partition_cols__.is_some() { @@ -7177,33 +8291,164 @@ impl<'de> serde::Deserialize<'de> for FileScanExecConf { } table_partition_cols__ = Some(map_.next_value()?); } - GeneratedField::ObjectStoreUrl => { - if object_store_url__.is_some() { - return Err(serde::de::Error::duplicate_field("objectStoreUrl")); + GeneratedField::SingleFileOutput => { + if single_file_output__.is_some() { + return Err(serde::de::Error::duplicate_field("singleFileOutput")); } - object_store_url__ = Some(map_.next_value()?); + single_file_output__ = Some(map_.next_value()?); } - GeneratedField::OutputOrdering => { - if output_ordering__.is_some() { - return Err(serde::de::Error::duplicate_field("outputOrdering")); + GeneratedField::Overwrite => { + if overwrite__.is_some() { + return Err(serde::de::Error::duplicate_field("overwrite")); } - output_ordering__ = Some(map_.next_value()?); + overwrite__ = Some(map_.next_value()?); + } + GeneratedField::FileTypeWriterOptions => { + if file_type_writer_options__.is_some() { + return Err(serde::de::Error::duplicate_field("fileTypeWriterOptions")); + } + file_type_writer_options__ = map_.next_value()?; } } } - Ok(FileScanExecConf { + Ok(FileSinkConfig { + object_store_url: object_store_url__.unwrap_or_default(), file_groups: file_groups__.unwrap_or_default(), - schema: schema__, - projection: projection__.unwrap_or_default(), - limit: limit__, - statistics: statistics__, + table_paths: table_paths__.unwrap_or_default(), + output_schema: output_schema__, table_partition_cols: table_partition_cols__.unwrap_or_default(), - object_store_url: object_store_url__.unwrap_or_default(), - output_ordering: output_ordering__.unwrap_or_default(), + single_file_output: single_file_output__.unwrap_or_default(), + overwrite: overwrite__.unwrap_or_default(), + file_type_writer_options: file_type_writer_options__, }) } } - deserializer.deserialize_struct("datafusion.FileScanExecConf", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.FileSinkConfig", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for FileTypeWriterOptions { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.file_type.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.FileTypeWriterOptions", len)?; + if let Some(v) = self.file_type.as_ref() { + match v { + file_type_writer_options::FileType::JsonOptions(v) => { + struct_ser.serialize_field("jsonOptions", v)?; + } + file_type_writer_options::FileType::ParquetOptions(v) => { + struct_ser.serialize_field("parquetOptions", v)?; + } + file_type_writer_options::FileType::CsvOptions(v) => { + struct_ser.serialize_field("csvOptions", v)?; + } + } + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for FileTypeWriterOptions { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "json_options", + "jsonOptions", + "parquet_options", + "parquetOptions", + "csv_options", + "csvOptions", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + JsonOptions, + ParquetOptions, + CsvOptions, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "jsonOptions" | "json_options" => Ok(GeneratedField::JsonOptions), + "parquetOptions" | "parquet_options" => Ok(GeneratedField::ParquetOptions), + "csvOptions" | "csv_options" => Ok(GeneratedField::CsvOptions), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = FileTypeWriterOptions; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.FileTypeWriterOptions") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut file_type__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::JsonOptions => { + if file_type__.is_some() { + return Err(serde::de::Error::duplicate_field("jsonOptions")); + } + file_type__ = map_.next_value::<::std::option::Option<_>>()?.map(file_type_writer_options::FileType::JsonOptions) +; + } + GeneratedField::ParquetOptions => { + if file_type__.is_some() { + return Err(serde::de::Error::duplicate_field("parquetOptions")); + } + file_type__ = map_.next_value::<::std::option::Option<_>>()?.map(file_type_writer_options::FileType::ParquetOptions) +; + } + GeneratedField::CsvOptions => { + if file_type__.is_some() { + return Err(serde::de::Error::duplicate_field("csvOptions")); + } + file_type__ = map_.next_value::<::std::option::Option<_>>()?.map(file_type_writer_options::FileType::CsvOptions) +; + } + } + } + Ok(FileTypeWriterOptions { + file_type: file_type__, + }) + } + } + deserializer.deserialize_struct("datafusion.FileTypeWriterOptions", FIELDS, GeneratedVisitor) } } impl serde::Serialize for FilterExecNode { @@ -7220,6 +8465,9 @@ impl serde::Serialize for FilterExecNode { if self.expr.is_some() { len += 1; } + if self.default_filter_selectivity != 0 { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.FilterExecNode", len)?; if let Some(v) = self.input.as_ref() { struct_ser.serialize_field("input", v)?; @@ -7227,6 +8475,9 @@ impl serde::Serialize for FilterExecNode { if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; } + if self.default_filter_selectivity != 0 { + struct_ser.serialize_field("defaultFilterSelectivity", &self.default_filter_selectivity)?; + } struct_ser.end() } } @@ -7239,12 +8490,15 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { const FIELDS: &[&str] = &[ "input", "expr", + "default_filter_selectivity", + "defaultFilterSelectivity", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Input, Expr, + DefaultFilterSelectivity, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7268,6 +8522,7 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { match value { "input" => Ok(GeneratedField::Input), "expr" => Ok(GeneratedField::Expr), + "defaultFilterSelectivity" | "default_filter_selectivity" => Ok(GeneratedField::DefaultFilterSelectivity), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7289,6 +8544,7 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { { let mut input__ = None; let mut expr__ = None; + let mut default_filter_selectivity__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { @@ -7303,11 +8559,20 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { } expr__ = map_.next_value()?; } + GeneratedField::DefaultFilterSelectivity => { + if default_filter_selectivity__.is_some() { + return Err(serde::de::Error::duplicate_field("defaultFilterSelectivity")); + } + default_filter_selectivity__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } } } Ok(FilterExecNode { input: input__, expr: expr__, + default_filter_selectivity: default_filter_selectivity__.unwrap_or_default(), }) } } @@ -8574,18 +9839,109 @@ impl<'de> serde::Deserialize<'de> for InListNode { if negated__.is_some() { return Err(serde::de::Error::duplicate_field("negated")); } - negated__ = Some(map_.next_value()?); + negated__ = Some(map_.next_value()?); + } + } + } + Ok(InListNode { + expr: expr__, + list: list__.unwrap_or_default(), + negated: negated__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.InListNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for InterleaveExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.inputs.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.InterleaveExecNode", len)?; + if !self.inputs.is_empty() { + struct_ser.serialize_field("inputs", &self.inputs)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for InterleaveExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "inputs", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Inputs, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "inputs" => Ok(GeneratedField::Inputs), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = InterleaveExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.InterleaveExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut inputs__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Inputs => { + if inputs__.is_some() { + return Err(serde::de::Error::duplicate_field("inputs")); + } + inputs__ = Some(map_.next_value()?); } } } - Ok(InListNode { - expr: expr__, - list: list__.unwrap_or_default(), - negated: negated__.unwrap_or_default(), + Ok(InterleaveExecNode { + inputs: inputs__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.InListNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.InterleaveExecNode", FIELDS, GeneratedVisitor) } } impl serde::Serialize for IntervalMonthDayNanoValue { @@ -10203,7 +11559,335 @@ impl<'de> serde::Deserialize<'de> for JoinType { } } } - deserializer.deserialize_any(GeneratedVisitor) + deserializer.deserialize_any(GeneratedVisitor) + } +} +impl serde::Serialize for JsonSink { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.config.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.JsonSink", len)?; + if let Some(v) = self.config.as_ref() { + struct_ser.serialize_field("config", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for JsonSink { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "config", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Config, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "config" => Ok(GeneratedField::Config), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = JsonSink; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.JsonSink") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut config__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Config => { + if config__.is_some() { + return Err(serde::de::Error::duplicate_field("config")); + } + config__ = map_.next_value()?; + } + } + } + Ok(JsonSink { + config: config__, + }) + } + } + deserializer.deserialize_struct("datafusion.JsonSink", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for JsonSinkExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if self.sink.is_some() { + len += 1; + } + if self.sink_schema.is_some() { + len += 1; + } + if self.sort_order.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.JsonSinkExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if let Some(v) = self.sink.as_ref() { + struct_ser.serialize_field("sink", v)?; + } + if let Some(v) = self.sink_schema.as_ref() { + struct_ser.serialize_field("sinkSchema", v)?; + } + if let Some(v) = self.sort_order.as_ref() { + struct_ser.serialize_field("sortOrder", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for JsonSinkExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "sink", + "sink_schema", + "sinkSchema", + "sort_order", + "sortOrder", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + Sink, + SinkSchema, + SortOrder, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "sink" => Ok(GeneratedField::Sink), + "sinkSchema" | "sink_schema" => Ok(GeneratedField::SinkSchema), + "sortOrder" | "sort_order" => Ok(GeneratedField::SortOrder), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = JsonSinkExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.JsonSinkExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut sink__ = None; + let mut sink_schema__ = None; + let mut sort_order__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::Sink => { + if sink__.is_some() { + return Err(serde::de::Error::duplicate_field("sink")); + } + sink__ = map_.next_value()?; + } + GeneratedField::SinkSchema => { + if sink_schema__.is_some() { + return Err(serde::de::Error::duplicate_field("sinkSchema")); + } + sink_schema__ = map_.next_value()?; + } + GeneratedField::SortOrder => { + if sort_order__.is_some() { + return Err(serde::de::Error::duplicate_field("sortOrder")); + } + sort_order__ = map_.next_value()?; + } + } + } + Ok(JsonSinkExecNode { + input: input__, + sink: sink__, + sink_schema: sink_schema__, + sort_order: sort_order__, + }) + } + } + deserializer.deserialize_struct("datafusion.JsonSinkExecNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for JsonWriterOptions { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.compression != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.JsonWriterOptions", len)?; + if self.compression != 0 { + let v = CompressionTypeVariant::try_from(self.compression) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.compression)))?; + struct_ser.serialize_field("compression", &v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for JsonWriterOptions { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "compression", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Compression, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "compression" => Ok(GeneratedField::Compression), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = JsonWriterOptions; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.JsonWriterOptions") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut compression__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Compression => { + if compression__.is_some() { + return Err(serde::de::Error::duplicate_field("compression")); + } + compression__ = Some(map_.next_value::()? as i32); + } + } + } + Ok(JsonWriterOptions { + compression: compression__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.JsonWriterOptions", FIELDS, GeneratedVisitor) } } impl serde::Serialize for LikeNode { @@ -11871,7 +13555,8 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("wildcard")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Wildcard); + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Wildcard) +; } GeneratedField::ScalarFunction => { if expr_type__.is_some() { @@ -12311,6 +13996,12 @@ impl serde::Serialize for LogicalPlanNode { logical_plan_node::LogicalPlanType::DropView(v) => { struct_ser.serialize_field("dropView", v)?; } + logical_plan_node::LogicalPlanType::DistinctOn(v) => { + struct_ser.serialize_field("distinctOn", v)?; + } + logical_plan_node::LogicalPlanType::CopyTo(v) => { + struct_ser.serialize_field("copyTo", v)?; + } } } struct_ser.end() @@ -12360,6 +14051,10 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { "prepare", "drop_view", "dropView", + "distinct_on", + "distinctOn", + "copy_to", + "copyTo", ]; #[allow(clippy::enum_variant_names)] @@ -12390,6 +14085,8 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { CustomScan, Prepare, DropView, + DistinctOn, + CopyTo, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -12437,6 +14134,8 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { "customScan" | "custom_scan" => Ok(GeneratedField::CustomScan), "prepare" => Ok(GeneratedField::Prepare), "dropView" | "drop_view" => Ok(GeneratedField::DropView), + "distinctOn" | "distinct_on" => Ok(GeneratedField::DistinctOn), + "copyTo" | "copy_to" => Ok(GeneratedField::CopyTo), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -12639,6 +14338,20 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { return Err(serde::de::Error::duplicate_field("dropView")); } logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::DropView) +; + } + GeneratedField::DistinctOn => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("distinctOn")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::DistinctOn) +; + } + GeneratedField::CopyTo => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("copyTo")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CopyTo) ; } } @@ -13926,17 +15639,344 @@ impl<'de> serde::Deserialize<'de> for ParquetScanExecNode { if predicate__.is_some() { return Err(serde::de::Error::duplicate_field("predicate")); } - predicate__ = map_.next_value()?; + predicate__ = map_.next_value()?; + } + } + } + Ok(ParquetScanExecNode { + base_conf: base_conf__, + predicate: predicate__, + }) + } + } + deserializer.deserialize_struct("datafusion.ParquetScanExecNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ParquetSink { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.config.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ParquetSink", len)?; + if let Some(v) = self.config.as_ref() { + struct_ser.serialize_field("config", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ParquetSink { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "config", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Config, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "config" => Ok(GeneratedField::Config), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ParquetSink; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ParquetSink") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut config__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Config => { + if config__.is_some() { + return Err(serde::de::Error::duplicate_field("config")); + } + config__ = map_.next_value()?; + } + } + } + Ok(ParquetSink { + config: config__, + }) + } + } + deserializer.deserialize_struct("datafusion.ParquetSink", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ParquetSinkExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if self.sink.is_some() { + len += 1; + } + if self.sink_schema.is_some() { + len += 1; + } + if self.sort_order.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ParquetSinkExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if let Some(v) = self.sink.as_ref() { + struct_ser.serialize_field("sink", v)?; + } + if let Some(v) = self.sink_schema.as_ref() { + struct_ser.serialize_field("sinkSchema", v)?; + } + if let Some(v) = self.sort_order.as_ref() { + struct_ser.serialize_field("sortOrder", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ParquetSinkExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "sink", + "sink_schema", + "sinkSchema", + "sort_order", + "sortOrder", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + Sink, + SinkSchema, + SortOrder, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "sink" => Ok(GeneratedField::Sink), + "sinkSchema" | "sink_schema" => Ok(GeneratedField::SinkSchema), + "sortOrder" | "sort_order" => Ok(GeneratedField::SortOrder), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ParquetSinkExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ParquetSinkExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut sink__ = None; + let mut sink_schema__ = None; + let mut sort_order__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::Sink => { + if sink__.is_some() { + return Err(serde::de::Error::duplicate_field("sink")); + } + sink__ = map_.next_value()?; + } + GeneratedField::SinkSchema => { + if sink_schema__.is_some() { + return Err(serde::de::Error::duplicate_field("sinkSchema")); + } + sink_schema__ = map_.next_value()?; + } + GeneratedField::SortOrder => { + if sort_order__.is_some() { + return Err(serde::de::Error::duplicate_field("sortOrder")); + } + sort_order__ = map_.next_value()?; + } + } + } + Ok(ParquetSinkExecNode { + input: input__, + sink: sink__, + sink_schema: sink_schema__, + sort_order: sort_order__, + }) + } + } + deserializer.deserialize_struct("datafusion.ParquetSinkExecNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ParquetWriterOptions { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.writer_properties.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ParquetWriterOptions", len)?; + if let Some(v) = self.writer_properties.as_ref() { + struct_ser.serialize_field("writerProperties", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ParquetWriterOptions { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "writer_properties", + "writerProperties", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + WriterProperties, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "writerProperties" | "writer_properties" => Ok(GeneratedField::WriterProperties), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ParquetWriterOptions; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ParquetWriterOptions") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut writer_properties__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::WriterProperties => { + if writer_properties__.is_some() { + return Err(serde::de::Error::duplicate_field("writerProperties")); + } + writer_properties__ = map_.next_value()?; } } } - Ok(ParquetScanExecNode { - base_conf: base_conf__, - predicate: predicate__, + Ok(ParquetWriterOptions { + writer_properties: writer_properties__, }) } } - deserializer.deserialize_struct("datafusion.ParquetScanExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ParquetWriterOptions", FIELDS, GeneratedVisitor) } } impl serde::Serialize for PartialTableReference { @@ -14047,7 +16087,7 @@ impl<'de> serde::Deserialize<'de> for PartialTableReference { deserializer.deserialize_struct("datafusion.PartialTableReference", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PartiallySortedPartitionSearchMode { +impl serde::Serialize for PartiallySortedInputOrderMode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -14058,14 +16098,14 @@ impl serde::Serialize for PartiallySortedPartitionSearchMode { if !self.columns.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PartiallySortedPartitionSearchMode", len)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PartiallySortedInputOrderMode", len)?; if !self.columns.is_empty() { struct_ser.serialize_field("columns", &self.columns.iter().map(ToString::to_string).collect::>())?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PartiallySortedPartitionSearchMode { +impl<'de> serde::Deserialize<'de> for PartiallySortedInputOrderMode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where @@ -14109,13 +16149,13 @@ impl<'de> serde::Deserialize<'de> for PartiallySortedPartitionSearchMode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PartiallySortedPartitionSearchMode; + type Value = PartiallySortedInputOrderMode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PartiallySortedPartitionSearchMode") + formatter.write_str("struct datafusion.PartiallySortedInputOrderMode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -14133,12 +16173,121 @@ impl<'de> serde::Deserialize<'de> for PartiallySortedPartitionSearchMode { } } } - Ok(PartiallySortedPartitionSearchMode { + Ok(PartiallySortedInputOrderMode { columns: columns__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.PartiallySortedPartitionSearchMode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PartiallySortedInputOrderMode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for PartitionColumn { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.name.is_empty() { + len += 1; + } + if self.arrow_type.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PartitionColumn", len)?; + if !self.name.is_empty() { + struct_ser.serialize_field("name", &self.name)?; + } + if let Some(v) = self.arrow_type.as_ref() { + struct_ser.serialize_field("arrowType", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PartitionColumn { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "name", + "arrow_type", + "arrowType", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Name, + ArrowType, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "name" => Ok(GeneratedField::Name), + "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PartitionColumn; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PartitionColumn") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut name__ = None; + let mut arrow_type__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); + } + name__ = Some(map_.next_value()?); + } + GeneratedField::ArrowType => { + if arrow_type__.is_some() { + return Err(serde::de::Error::duplicate_field("arrowType")); + } + arrow_type__ = map_.next_value()?; + } + } + } + Ok(PartitionColumn { + name: name__.unwrap_or_default(), + arrow_type: arrow_type__, + }) + } + } + deserializer.deserialize_struct("datafusion.PartitionColumn", FIELDS, GeneratedVisitor) } } impl serde::Serialize for PartitionMode { @@ -16812,6 +18961,24 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::Analyze(v) => { struct_ser.serialize_field("analyze", v)?; } + physical_plan_node::PhysicalPlanType::JsonSink(v) => { + struct_ser.serialize_field("jsonSink", v)?; + } + physical_plan_node::PhysicalPlanType::SymmetricHashJoin(v) => { + struct_ser.serialize_field("symmetricHashJoin", v)?; + } + physical_plan_node::PhysicalPlanType::Interleave(v) => { + struct_ser.serialize_field("interleave", v)?; + } + physical_plan_node::PhysicalPlanType::PlaceholderRow(v) => { + struct_ser.serialize_field("placeholderRow", v)?; + } + physical_plan_node::PhysicalPlanType::CsvSink(v) => { + struct_ser.serialize_field("csvSink", v)?; + } + physical_plan_node::PhysicalPlanType::ParquetSink(v) => { + struct_ser.serialize_field("parquetSink", v)?; + } } } struct_ser.end() @@ -16856,6 +19023,17 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "nested_loop_join", "nestedLoopJoin", "analyze", + "json_sink", + "jsonSink", + "symmetric_hash_join", + "symmetricHashJoin", + "interleave", + "placeholder_row", + "placeholderRow", + "csv_sink", + "csvSink", + "parquet_sink", + "parquetSink", ]; #[allow(clippy::enum_variant_names)] @@ -16882,6 +19060,12 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { SortPreservingMerge, NestedLoopJoin, Analyze, + JsonSink, + SymmetricHashJoin, + Interleave, + PlaceholderRow, + CsvSink, + ParquetSink, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -16925,6 +19109,12 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "sortPreservingMerge" | "sort_preserving_merge" => Ok(GeneratedField::SortPreservingMerge), "nestedLoopJoin" | "nested_loop_join" => Ok(GeneratedField::NestedLoopJoin), "analyze" => Ok(GeneratedField::Analyze), + "jsonSink" | "json_sink" => Ok(GeneratedField::JsonSink), + "symmetricHashJoin" | "symmetric_hash_join" => Ok(GeneratedField::SymmetricHashJoin), + "interleave" => Ok(GeneratedField::Interleave), + "placeholderRow" | "placeholder_row" => Ok(GeneratedField::PlaceholderRow), + "csvSink" | "csv_sink" => Ok(GeneratedField::CsvSink), + "parquetSink" | "parquet_sink" => Ok(GeneratedField::ParquetSink), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -17099,6 +19289,48 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { return Err(serde::de::Error::duplicate_field("analyze")); } physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Analyze) +; + } + GeneratedField::JsonSink => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("jsonSink")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::JsonSink) +; + } + GeneratedField::SymmetricHashJoin => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("symmetricHashJoin")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::SymmetricHashJoin) +; + } + GeneratedField::Interleave => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("interleave")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Interleave) +; + } + GeneratedField::PlaceholderRow => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("placeholderRow")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::PlaceholderRow) +; + } + GeneratedField::CsvSink => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("csvSink")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::CsvSink) +; + } + GeneratedField::ParquetSink => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("parquetSink")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::ParquetSink) ; } } @@ -18007,20 +20239,129 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { } } } - Ok(PhysicalWindowExprNode { - args: args__.unwrap_or_default(), - partition_by: partition_by__.unwrap_or_default(), - order_by: order_by__.unwrap_or_default(), - window_frame: window_frame__, - name: name__.unwrap_or_default(), - window_function: window_function__, + Ok(PhysicalWindowExprNode { + args: args__.unwrap_or_default(), + partition_by: partition_by__.unwrap_or_default(), + order_by: order_by__.unwrap_or_default(), + window_frame: window_frame__, + name: name__.unwrap_or_default(), + window_function: window_function__, + }) + } + } + deserializer.deserialize_struct("datafusion.PhysicalWindowExprNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for PlaceholderNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.id.is_empty() { + len += 1; + } + if self.data_type.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PlaceholderNode", len)?; + if !self.id.is_empty() { + struct_ser.serialize_field("id", &self.id)?; + } + if let Some(v) = self.data_type.as_ref() { + struct_ser.serialize_field("dataType", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PlaceholderNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "id", + "data_type", + "dataType", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Id, + DataType, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "id" => Ok(GeneratedField::Id), + "dataType" | "data_type" => Ok(GeneratedField::DataType), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PlaceholderNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PlaceholderNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut id__ = None; + let mut data_type__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Id => { + if id__.is_some() { + return Err(serde::de::Error::duplicate_field("id")); + } + id__ = Some(map_.next_value()?); + } + GeneratedField::DataType => { + if data_type__.is_some() { + return Err(serde::de::Error::duplicate_field("dataType")); + } + data_type__ = map_.next_value()?; + } + } + } + Ok(PlaceholderNode { + id: id__.unwrap_or_default(), + data_type: data_type__, }) } } - deserializer.deserialize_struct("datafusion.PhysicalWindowExprNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PlaceholderNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PlaceholderNode { +impl serde::Serialize for PlaceholderRowExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -18028,38 +20369,29 @@ impl serde::Serialize for PlaceholderNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.id.is_empty() { - len += 1; - } - if self.data_type.is_some() { + if self.schema.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PlaceholderNode", len)?; - if !self.id.is_empty() { - struct_ser.serialize_field("id", &self.id)?; - } - if let Some(v) = self.data_type.as_ref() { - struct_ser.serialize_field("dataType", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PlaceholderRowExecNode", len)?; + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PlaceholderNode { +impl<'de> serde::Deserialize<'de> for PlaceholderRowExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "id", - "data_type", - "dataType", + "schema", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Id, - DataType, + Schema, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -18081,8 +20413,7 @@ impl<'de> serde::Deserialize<'de> for PlaceholderNode { E: serde::de::Error, { match value { - "id" => Ok(GeneratedField::Id), - "dataType" | "data_type" => Ok(GeneratedField::DataType), + "schema" => Ok(GeneratedField::Schema), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18092,41 +20423,33 @@ impl<'de> serde::Deserialize<'de> for PlaceholderNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PlaceholderNode; + type Value = PlaceholderRowExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PlaceholderNode") + formatter.write_str("struct datafusion.PlaceholderRowExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut id__ = None; - let mut data_type__ = None; + let mut schema__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Id => { - if id__.is_some() { - return Err(serde::de::Error::duplicate_field("id")); - } - id__ = Some(map_.next_value()?); - } - GeneratedField::DataType => { - if data_type__.is_some() { - return Err(serde::de::Error::duplicate_field("dataType")); + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); } - data_type__ = map_.next_value()?; + schema__ = map_.next_value()?; } } } - Ok(PlaceholderNode { - id: id__.unwrap_or_default(), - data_type: data_type__, + Ok(PlaceholderRowExecNode { + schema: schema__, }) } } - deserializer.deserialize_struct("datafusion.PlaceholderNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PlaceholderRowExecNode", FIELDS, GeneratedVisitor) } } impl serde::Serialize for PlanType { @@ -18161,12 +20484,18 @@ impl serde::Serialize for PlanType { plan_type::PlanTypeEnum::InitialPhysicalPlan(v) => { struct_ser.serialize_field("InitialPhysicalPlan", v)?; } + plan_type::PlanTypeEnum::InitialPhysicalPlanWithStats(v) => { + struct_ser.serialize_field("InitialPhysicalPlanWithStats", v)?; + } plan_type::PlanTypeEnum::OptimizedPhysicalPlan(v) => { struct_ser.serialize_field("OptimizedPhysicalPlan", v)?; } plan_type::PlanTypeEnum::FinalPhysicalPlan(v) => { struct_ser.serialize_field("FinalPhysicalPlan", v)?; } + plan_type::PlanTypeEnum::FinalPhysicalPlanWithStats(v) => { + struct_ser.serialize_field("FinalPhysicalPlanWithStats", v)?; + } } } struct_ser.end() @@ -18185,8 +20514,10 @@ impl<'de> serde::Deserialize<'de> for PlanType { "OptimizedLogicalPlan", "FinalLogicalPlan", "InitialPhysicalPlan", + "InitialPhysicalPlanWithStats", "OptimizedPhysicalPlan", "FinalPhysicalPlan", + "FinalPhysicalPlanWithStats", ]; #[allow(clippy::enum_variant_names)] @@ -18197,8 +20528,10 @@ impl<'de> serde::Deserialize<'de> for PlanType { OptimizedLogicalPlan, FinalLogicalPlan, InitialPhysicalPlan, + InitialPhysicalPlanWithStats, OptimizedPhysicalPlan, FinalPhysicalPlan, + FinalPhysicalPlanWithStats, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -18226,8 +20559,10 @@ impl<'de> serde::Deserialize<'de> for PlanType { "OptimizedLogicalPlan" => Ok(GeneratedField::OptimizedLogicalPlan), "FinalLogicalPlan" => Ok(GeneratedField::FinalLogicalPlan), "InitialPhysicalPlan" => Ok(GeneratedField::InitialPhysicalPlan), + "InitialPhysicalPlanWithStats" => Ok(GeneratedField::InitialPhysicalPlanWithStats), "OptimizedPhysicalPlan" => Ok(GeneratedField::OptimizedPhysicalPlan), "FinalPhysicalPlan" => Ok(GeneratedField::FinalPhysicalPlan), + "FinalPhysicalPlanWithStats" => Ok(GeneratedField::FinalPhysicalPlanWithStats), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18290,6 +20625,13 @@ impl<'de> serde::Deserialize<'de> for PlanType { return Err(serde::de::Error::duplicate_field("InitialPhysicalPlan")); } plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::InitialPhysicalPlan) +; + } + GeneratedField::InitialPhysicalPlanWithStats => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("InitialPhysicalPlanWithStats")); + } + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::InitialPhysicalPlanWithStats) ; } GeneratedField::OptimizedPhysicalPlan => { @@ -18304,6 +20646,13 @@ impl<'de> serde::Deserialize<'de> for PlanType { return Err(serde::de::Error::duplicate_field("FinalPhysicalPlan")); } plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::FinalPhysicalPlan) +; + } + GeneratedField::FinalPhysicalPlanWithStats => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("FinalPhysicalPlanWithStats")); + } + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::FinalPhysicalPlanWithStats) ; } } @@ -19320,21 +21669,220 @@ impl<'de> serde::Deserialize<'de> for RepartitionNode { if partition_method__.is_some() { return Err(serde::de::Error::duplicate_field("hash")); } - partition_method__ = map_.next_value::<::std::option::Option<_>>()?.map(repartition_node::PartitionMethod::Hash) -; + partition_method__ = map_.next_value::<::std::option::Option<_>>()?.map(repartition_node::PartitionMethod::Hash) +; + } + } + } + Ok(RepartitionNode { + input: input__, + partition_method: partition_method__, + }) + } + } + deserializer.deserialize_struct("datafusion.RepartitionNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for RollupNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.expr.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.RollupNode", len)?; + if !self.expr.is_empty() { + struct_ser.serialize_field("expr", &self.expr)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for RollupNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "expr", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Expr, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "expr" => Ok(GeneratedField::Expr), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = RollupNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.RollupNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut expr__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); + } + expr__ = Some(map_.next_value()?); + } + } + } + Ok(RollupNode { + expr: expr__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.RollupNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for SqlOption { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.key.is_empty() { + len += 1; + } + if !self.value.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.SQLOption", len)?; + if !self.key.is_empty() { + struct_ser.serialize_field("key", &self.key)?; + } + if !self.value.is_empty() { + struct_ser.serialize_field("value", &self.value)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for SqlOption { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "key", + "value", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Key, + Value, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "key" => Ok(GeneratedField::Key), + "value" => Ok(GeneratedField::Value), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = SqlOption; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.SQLOption") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut key__ = None; + let mut value__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Key => { + if key__.is_some() { + return Err(serde::de::Error::duplicate_field("key")); + } + key__ = Some(map_.next_value()?); + } + GeneratedField::Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("value")); + } + value__ = Some(map_.next_value()?); } } } - Ok(RepartitionNode { - input: input__, - partition_method: partition_method__, + Ok(SqlOption { + key: key__.unwrap_or_default(), + value: value__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.RepartitionNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.SQLOption", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for RollupNode { +impl serde::Serialize for SqlOptions { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -19342,29 +21890,29 @@ impl serde::Serialize for RollupNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.expr.is_empty() { + if !self.option.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.RollupNode", len)?; - if !self.expr.is_empty() { - struct_ser.serialize_field("expr", &self.expr)?; + let mut struct_ser = serializer.serialize_struct("datafusion.SQLOptions", len)?; + if !self.option.is_empty() { + struct_ser.serialize_field("option", &self.option)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for RollupNode { +impl<'de> serde::Deserialize<'de> for SqlOptions { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "option", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, + Option, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -19386,7 +21934,7 @@ impl<'de> serde::Deserialize<'de> for RollupNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), + "option" => Ok(GeneratedField::Option), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -19396,33 +21944,33 @@ impl<'de> serde::Deserialize<'de> for RollupNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = RollupNode; + type Value = SqlOptions; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.RollupNode") + formatter.write_str("struct datafusion.SQLOptions") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; + let mut option__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Option => { + if option__.is_some() { + return Err(serde::de::Error::duplicate_field("option")); } - expr__ = Some(map_.next_value()?); + option__ = Some(map_.next_value()?); } } } - Ok(RollupNode { - expr: expr__.unwrap_or_default(), + Ok(SqlOptions { + option: option__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.RollupNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.SQLOptions", FIELDS, GeneratedVisitor) } } impl serde::Serialize for ScalarDictionaryValue { @@ -19773,6 +22321,17 @@ impl serde::Serialize for ScalarFunction { Self::ArrayPopBack => "ArrayPopBack", Self::StringToArray => "StringToArray", Self::ToTimestampNanos => "ToTimestampNanos", + Self::ArrayIntersect => "ArrayIntersect", + Self::ArrayUnion => "ArrayUnion", + Self::OverLay => "OverLay", + Self::Range => "Range", + Self::ArrayExcept => "ArrayExcept", + Self::ArrayPopFront => "ArrayPopFront", + Self::Levenshtein => "Levenshtein", + Self::SubstrIndex => "SubstrIndex", + Self::FindInSet => "FindInSet", + Self::ArraySort => "ArraySort", + Self::ArrayDistinct => "ArrayDistinct", }; serializer.serialize_str(variant) } @@ -19903,6 +22462,17 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayPopBack", "StringToArray", "ToTimestampNanos", + "ArrayIntersect", + "ArrayUnion", + "OverLay", + "Range", + "ArrayExcept", + "ArrayPopFront", + "Levenshtein", + "SubstrIndex", + "FindInSet", + "ArraySort", + "ArrayDistinct", ]; struct GeneratedVisitor; @@ -20062,6 +22632,17 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayPopBack" => Ok(ScalarFunction::ArrayPopBack), "StringToArray" => Ok(ScalarFunction::StringToArray), "ToTimestampNanos" => Ok(ScalarFunction::ToTimestampNanos), + "ArrayIntersect" => Ok(ScalarFunction::ArrayIntersect), + "ArrayUnion" => Ok(ScalarFunction::ArrayUnion), + "OverLay" => Ok(ScalarFunction::OverLay), + "Range" => Ok(ScalarFunction::Range), + "ArrayExcept" => Ok(ScalarFunction::ArrayExcept), + "ArrayPopFront" => Ok(ScalarFunction::ArrayPopFront), + "Levenshtein" => Ok(ScalarFunction::Levenshtein), + "SubstrIndex" => Ok(ScalarFunction::SubstrIndex), + "FindInSet" => Ok(ScalarFunction::FindInSet), + "ArraySort" => Ok(ScalarFunction::ArraySort), + "ArrayDistinct" => Ok(ScalarFunction::ArrayDistinct), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } @@ -20861,9 +23442,15 @@ impl serde::Serialize for ScalarValue { scalar_value::Value::Time32Value(v) => { struct_ser.serialize_field("time32Value", v)?; } + scalar_value::Value::LargeListValue(v) => { + struct_ser.serialize_field("largeListValue", v)?; + } scalar_value::Value::ListValue(v) => { struct_ser.serialize_field("listValue", v)?; } + scalar_value::Value::FixedSizeListValue(v) => { + struct_ser.serialize_field("fixedSizeListValue", v)?; + } scalar_value::Value::Decimal128Value(v) => { struct_ser.serialize_field("decimal128Value", v)?; } @@ -20967,8 +23554,12 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "date32Value", "time32_value", "time32Value", + "large_list_value", + "largeListValue", "list_value", "listValue", + "fixed_size_list_value", + "fixedSizeListValue", "decimal128_value", "decimal128Value", "decimal256_value", @@ -21023,7 +23614,9 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { Float64Value, Date32Value, Time32Value, + LargeListValue, ListValue, + FixedSizeListValue, Decimal128Value, Decimal256Value, Date64Value, @@ -21078,7 +23671,9 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "float64Value" | "float64_value" => Ok(GeneratedField::Float64Value), "date32Value" | "date_32_value" => Ok(GeneratedField::Date32Value), "time32Value" | "time32_value" => Ok(GeneratedField::Time32Value), + "largeListValue" | "large_list_value" => Ok(GeneratedField::LargeListValue), "listValue" | "list_value" => Ok(GeneratedField::ListValue), + "fixedSizeListValue" | "fixed_size_list_value" => Ok(GeneratedField::FixedSizeListValue), "decimal128Value" | "decimal128_value" => Ok(GeneratedField::Decimal128Value), "decimal256Value" | "decimal256_value" => Ok(GeneratedField::Decimal256Value), "date64Value" | "date_64_value" => Ok(GeneratedField::Date64Value), @@ -21214,6 +23809,13 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { return Err(serde::de::Error::duplicate_field("time32Value")); } value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Time32Value) +; + } + GeneratedField::LargeListValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("largeListValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::LargeListValue) ; } GeneratedField::ListValue => { @@ -21221,6 +23823,13 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { return Err(serde::de::Error::duplicate_field("listValue")); } value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::ListValue) +; + } + GeneratedField::FixedSizeListValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("fixedSizeListValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::FixedSizeListValue) ; } GeneratedField::Decimal128Value => { @@ -22544,6 +25153,77 @@ impl<'de> serde::Deserialize<'de> for Statistics { deserializer.deserialize_struct("datafusion.Statistics", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for StreamPartitionMode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::SinglePartition => "SINGLE_PARTITION", + Self::PartitionedExec => "PARTITIONED_EXEC", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for StreamPartitionMode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "SINGLE_PARTITION", + "PARTITIONED_EXEC", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = StreamPartitionMode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "SINGLE_PARTITION" => Ok(StreamPartitionMode::SinglePartition), + "PARTITIONED_EXEC" => Ok(StreamPartitionMode::PartitionedExec), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} impl serde::Serialize for StringifiedPlan { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -22939,27 +25619,227 @@ impl<'de> serde::Deserialize<'de> for SubqueryAliasNode { let mut alias__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::Alias => { + if alias__.is_some() { + return Err(serde::de::Error::duplicate_field("alias")); + } + alias__ = map_.next_value()?; + } + } + } + Ok(SubqueryAliasNode { + input: input__, + alias: alias__, + }) + } + } + deserializer.deserialize_struct("datafusion.SubqueryAliasNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for SymmetricHashJoinExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.left.is_some() { + len += 1; + } + if self.right.is_some() { + len += 1; + } + if !self.on.is_empty() { + len += 1; + } + if self.join_type != 0 { + len += 1; + } + if self.partition_mode != 0 { + len += 1; + } + if self.null_equals_null { + len += 1; + } + if self.filter.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.SymmetricHashJoinExecNode", len)?; + if let Some(v) = self.left.as_ref() { + struct_ser.serialize_field("left", v)?; + } + if let Some(v) = self.right.as_ref() { + struct_ser.serialize_field("right", v)?; + } + if !self.on.is_empty() { + struct_ser.serialize_field("on", &self.on)?; + } + if self.join_type != 0 { + let v = JoinType::try_from(self.join_type) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_type)))?; + struct_ser.serialize_field("joinType", &v)?; + } + if self.partition_mode != 0 { + let v = StreamPartitionMode::try_from(self.partition_mode) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.partition_mode)))?; + struct_ser.serialize_field("partitionMode", &v)?; + } + if self.null_equals_null { + struct_ser.serialize_field("nullEqualsNull", &self.null_equals_null)?; + } + if let Some(v) = self.filter.as_ref() { + struct_ser.serialize_field("filter", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "left", + "right", + "on", + "join_type", + "joinType", + "partition_mode", + "partitionMode", + "null_equals_null", + "nullEqualsNull", + "filter", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Left, + Right, + On, + JoinType, + PartitionMode, + NullEqualsNull, + Filter, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "left" => Ok(GeneratedField::Left), + "right" => Ok(GeneratedField::Right), + "on" => Ok(GeneratedField::On), + "joinType" | "join_type" => Ok(GeneratedField::JoinType), + "partitionMode" | "partition_mode" => Ok(GeneratedField::PartitionMode), + "nullEqualsNull" | "null_equals_null" => Ok(GeneratedField::NullEqualsNull), + "filter" => Ok(GeneratedField::Filter), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = SymmetricHashJoinExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.SymmetricHashJoinExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut left__ = None; + let mut right__ = None; + let mut on__ = None; + let mut join_type__ = None; + let mut partition_mode__ = None; + let mut null_equals_null__ = None; + let mut filter__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Left => { + if left__.is_some() { + return Err(serde::de::Error::duplicate_field("left")); } - input__ = map_.next_value()?; + left__ = map_.next_value()?; } - GeneratedField::Alias => { - if alias__.is_some() { - return Err(serde::de::Error::duplicate_field("alias")); + GeneratedField::Right => { + if right__.is_some() { + return Err(serde::de::Error::duplicate_field("right")); } - alias__ = map_.next_value()?; + right__ = map_.next_value()?; + } + GeneratedField::On => { + if on__.is_some() { + return Err(serde::de::Error::duplicate_field("on")); + } + on__ = Some(map_.next_value()?); + } + GeneratedField::JoinType => { + if join_type__.is_some() { + return Err(serde::de::Error::duplicate_field("joinType")); + } + join_type__ = Some(map_.next_value::()? as i32); + } + GeneratedField::PartitionMode => { + if partition_mode__.is_some() { + return Err(serde::de::Error::duplicate_field("partitionMode")); + } + partition_mode__ = Some(map_.next_value::()? as i32); + } + GeneratedField::NullEqualsNull => { + if null_equals_null__.is_some() { + return Err(serde::de::Error::duplicate_field("nullEqualsNull")); + } + null_equals_null__ = Some(map_.next_value()?); + } + GeneratedField::Filter => { + if filter__.is_some() { + return Err(serde::de::Error::duplicate_field("filter")); + } + filter__ = map_.next_value()?; } } } - Ok(SubqueryAliasNode { - input: input__, - alias: alias__, + Ok(SymmetricHashJoinExecNode { + left: left__, + right: right__, + on: on__.unwrap_or_default(), + join_type: join_type__.unwrap_or_default(), + partition_mode: partition_mode__.unwrap_or_default(), + null_equals_null: null_equals_null__.unwrap_or_default(), + filter: filter__, }) } } - deserializer.deserialize_struct("datafusion.SubqueryAliasNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.SymmetricHashJoinExecNode", FIELDS, GeneratedVisitor) } } impl serde::Serialize for TimeUnit { @@ -24122,6 +27002,97 @@ impl<'de> serde::Deserialize<'de> for WhenThen { deserializer.deserialize_struct("datafusion.WhenThen", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for Wildcard { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.qualifier.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.Wildcard", len)?; + if !self.qualifier.is_empty() { + struct_ser.serialize_field("qualifier", &self.qualifier)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Wildcard { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "qualifier", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Qualifier, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "qualifier" => Ok(GeneratedField::Qualifier), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Wildcard; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.Wildcard") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut qualifier__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Qualifier => { + if qualifier__.is_some() { + return Err(serde::de::Error::duplicate_field("qualifier")); + } + qualifier__ = Some(map_.next_value()?); + } + } + } + Ok(Wildcard { + qualifier: qualifier__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.Wildcard", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for WindowAggExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -24139,7 +27110,7 @@ impl serde::Serialize for WindowAggExecNode { if !self.partition_keys.is_empty() { len += 1; } - if self.partition_search_mode.is_some() { + if self.input_order_mode.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.WindowAggExecNode", len)?; @@ -24152,15 +27123,15 @@ impl serde::Serialize for WindowAggExecNode { if !self.partition_keys.is_empty() { struct_ser.serialize_field("partitionKeys", &self.partition_keys)?; } - if let Some(v) = self.partition_search_mode.as_ref() { + if let Some(v) = self.input_order_mode.as_ref() { match v { - window_agg_exec_node::PartitionSearchMode::Linear(v) => { + window_agg_exec_node::InputOrderMode::Linear(v) => { struct_ser.serialize_field("linear", v)?; } - window_agg_exec_node::PartitionSearchMode::PartiallySorted(v) => { + window_agg_exec_node::InputOrderMode::PartiallySorted(v) => { struct_ser.serialize_field("partiallySorted", v)?; } - window_agg_exec_node::PartitionSearchMode::Sorted(v) => { + window_agg_exec_node::InputOrderMode::Sorted(v) => { struct_ser.serialize_field("sorted", v)?; } } @@ -24243,7 +27214,7 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { let mut input__ = None; let mut window_expr__ = None; let mut partition_keys__ = None; - let mut partition_search_mode__ = None; + let mut input_order_mode__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { @@ -24265,24 +27236,24 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { partition_keys__ = Some(map_.next_value()?); } GeneratedField::Linear => { - if partition_search_mode__.is_some() { + if input_order_mode__.is_some() { return Err(serde::de::Error::duplicate_field("linear")); } - partition_search_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::PartitionSearchMode::Linear) + input_order_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::InputOrderMode::Linear) ; } GeneratedField::PartiallySorted => { - if partition_search_mode__.is_some() { + if input_order_mode__.is_some() { return Err(serde::de::Error::duplicate_field("partiallySorted")); } - partition_search_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::PartitionSearchMode::PartiallySorted) + input_order_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::InputOrderMode::PartiallySorted) ; } GeneratedField::Sorted => { - if partition_search_mode__.is_some() { + if input_order_mode__.is_some() { return Err(serde::de::Error::duplicate_field("sorted")); } - partition_search_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::PartitionSearchMode::Sorted) + input_order_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::InputOrderMode::Sorted) ; } } @@ -24291,7 +27262,7 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { input: input__, window_expr: window_expr__.unwrap_or_default(), partition_keys: partition_keys__.unwrap_or_default(), - partition_search_mode: partition_search_mode__, + input_order_mode: input_order_mode__, }) } } @@ -25009,3 +27980,218 @@ impl<'de> serde::Deserialize<'de> for WindowNode { deserializer.deserialize_struct("datafusion.WindowNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for WriterProperties { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.data_page_size_limit != 0 { + len += 1; + } + if self.dictionary_page_size_limit != 0 { + len += 1; + } + if self.data_page_row_count_limit != 0 { + len += 1; + } + if self.write_batch_size != 0 { + len += 1; + } + if self.max_row_group_size != 0 { + len += 1; + } + if !self.writer_version.is_empty() { + len += 1; + } + if !self.created_by.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.WriterProperties", len)?; + if self.data_page_size_limit != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("dataPageSizeLimit", ToString::to_string(&self.data_page_size_limit).as_str())?; + } + if self.dictionary_page_size_limit != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("dictionaryPageSizeLimit", ToString::to_string(&self.dictionary_page_size_limit).as_str())?; + } + if self.data_page_row_count_limit != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("dataPageRowCountLimit", ToString::to_string(&self.data_page_row_count_limit).as_str())?; + } + if self.write_batch_size != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("writeBatchSize", ToString::to_string(&self.write_batch_size).as_str())?; + } + if self.max_row_group_size != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("maxRowGroupSize", ToString::to_string(&self.max_row_group_size).as_str())?; + } + if !self.writer_version.is_empty() { + struct_ser.serialize_field("writerVersion", &self.writer_version)?; + } + if !self.created_by.is_empty() { + struct_ser.serialize_field("createdBy", &self.created_by)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for WriterProperties { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "data_page_size_limit", + "dataPageSizeLimit", + "dictionary_page_size_limit", + "dictionaryPageSizeLimit", + "data_page_row_count_limit", + "dataPageRowCountLimit", + "write_batch_size", + "writeBatchSize", + "max_row_group_size", + "maxRowGroupSize", + "writer_version", + "writerVersion", + "created_by", + "createdBy", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + DataPageSizeLimit, + DictionaryPageSizeLimit, + DataPageRowCountLimit, + WriteBatchSize, + MaxRowGroupSize, + WriterVersion, + CreatedBy, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "dataPageSizeLimit" | "data_page_size_limit" => Ok(GeneratedField::DataPageSizeLimit), + "dictionaryPageSizeLimit" | "dictionary_page_size_limit" => Ok(GeneratedField::DictionaryPageSizeLimit), + "dataPageRowCountLimit" | "data_page_row_count_limit" => Ok(GeneratedField::DataPageRowCountLimit), + "writeBatchSize" | "write_batch_size" => Ok(GeneratedField::WriteBatchSize), + "maxRowGroupSize" | "max_row_group_size" => Ok(GeneratedField::MaxRowGroupSize), + "writerVersion" | "writer_version" => Ok(GeneratedField::WriterVersion), + "createdBy" | "created_by" => Ok(GeneratedField::CreatedBy), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = WriterProperties; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.WriterProperties") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut data_page_size_limit__ = None; + let mut dictionary_page_size_limit__ = None; + let mut data_page_row_count_limit__ = None; + let mut write_batch_size__ = None; + let mut max_row_group_size__ = None; + let mut writer_version__ = None; + let mut created_by__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::DataPageSizeLimit => { + if data_page_size_limit__.is_some() { + return Err(serde::de::Error::duplicate_field("dataPageSizeLimit")); + } + data_page_size_limit__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::DictionaryPageSizeLimit => { + if dictionary_page_size_limit__.is_some() { + return Err(serde::de::Error::duplicate_field("dictionaryPageSizeLimit")); + } + dictionary_page_size_limit__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::DataPageRowCountLimit => { + if data_page_row_count_limit__.is_some() { + return Err(serde::de::Error::duplicate_field("dataPageRowCountLimit")); + } + data_page_row_count_limit__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::WriteBatchSize => { + if write_batch_size__.is_some() { + return Err(serde::de::Error::duplicate_field("writeBatchSize")); + } + write_batch_size__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::MaxRowGroupSize => { + if max_row_group_size__.is_some() { + return Err(serde::de::Error::duplicate_field("maxRowGroupSize")); + } + max_row_group_size__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::WriterVersion => { + if writer_version__.is_some() { + return Err(serde::de::Error::duplicate_field("writerVersion")); + } + writer_version__ = Some(map_.next_value()?); + } + GeneratedField::CreatedBy => { + if created_by__.is_some() { + return Err(serde::de::Error::duplicate_field("createdBy")); + } + created_by__ = Some(map_.next_value()?); + } + } + } + Ok(WriterProperties { + data_page_size_limit: data_page_size_limit__.unwrap_or_default(), + dictionary_page_size_limit: dictionary_page_size_limit__.unwrap_or_default(), + data_page_row_count_limit: data_page_row_count_limit__.unwrap_or_default(), + write_batch_size: write_batch_size__.unwrap_or_default(), + max_row_group_size: max_row_group_size__.unwrap_or_default(), + writer_version: writer_version__.unwrap_or_default(), + created_by: created_by__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.WriterProperties", FIELDS, GeneratedVisitor) + } +} diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index d18bacfb3bcc8..4ee0b70325ca7 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -38,7 +38,7 @@ pub struct DfSchema { pub struct LogicalPlanNode { #[prost( oneof = "logical_plan_node::LogicalPlanType", - tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27" + tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29" )] pub logical_plan_type: ::core::option::Option, } @@ -99,6 +99,10 @@ pub mod logical_plan_node { Prepare(::prost::alloc::boxed::Box), #[prost(message, tag = "27")] DropView(super::DropViewNode), + #[prost(message, tag = "28")] + DistinctOn(::prost::alloc::boxed::Box), + #[prost(message, tag = "29")] + CopyTo(::prost::alloc::boxed::Box), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -358,6 +362,11 @@ pub struct CreateExternalTableNode { >, #[prost(message, optional, tag = "15")] pub constraints: ::core::option::Option, + #[prost(map = "string, message", tag = "16")] + pub column_defaults: ::std::collections::HashMap< + ::prost::alloc::string::String, + LogicalExprNode, + >, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -483,6 +492,57 @@ pub struct DistinctNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct DistinctOnNode { + #[prost(message, repeated, tag = "1")] + pub on_expr: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "2")] + pub select_expr: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "3")] + pub sort_expr: ::prost::alloc::vec::Vec, + #[prost(message, optional, boxed, tag = "4")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CopyToNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(string, tag = "2")] + pub output_url: ::prost::alloc::string::String, + #[prost(bool, tag = "3")] + pub single_file_output: bool, + #[prost(string, tag = "6")] + pub file_type: ::prost::alloc::string::String, + #[prost(oneof = "copy_to_node::CopyOptions", tags = "4, 5")] + pub copy_options: ::core::option::Option, +} +/// Nested message and enum types in `CopyToNode`. +pub mod copy_to_node { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum CopyOptions { + #[prost(message, tag = "4")] + SqlOptions(super::SqlOptions), + #[prost(message, tag = "5")] + WriterOptions(super::FileTypeWriterOptions), + } +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct SqlOptions { + #[prost(message, repeated, tag = "1")] + pub option: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct SqlOption { + #[prost(string, tag = "1")] + pub key: ::prost::alloc::string::String, + #[prost(string, tag = "2")] + pub value: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct UnionNode { #[prost(message, repeated, tag = "1")] pub inputs: ::prost::alloc::vec::Vec, @@ -569,8 +629,8 @@ pub mod logical_expr_node { Negative(::prost::alloc::boxed::Box), #[prost(message, tag = "14")] InList(::prost::alloc::boxed::Box), - #[prost(bool, tag = "15")] - Wildcard(bool), + #[prost(message, tag = "15")] + Wildcard(super::Wildcard), #[prost(message, tag = "16")] ScalarFunction(super::ScalarFunctionNode), #[prost(message, tag = "17")] @@ -616,6 +676,12 @@ pub mod logical_expr_node { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct Wildcard { + #[prost(string, tag = "1")] + pub qualifier: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct PlaceholderNode { #[prost(string, tag = "1")] pub id: ::prost::alloc::string::String, @@ -748,6 +814,8 @@ pub struct AliasNode { pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(string, tag = "2")] pub alias: ::prost::alloc::string::String, + #[prost(message, repeated, tag = "3")] + pub relation: ::prost::alloc::vec::Vec, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -1000,6 +1068,10 @@ pub struct Field { ::prost::alloc::string::String, ::prost::alloc::string::String, >, + #[prost(int64, tag = "6")] + pub dict_id: i64, + #[prost(bool, tag = "7")] + pub dict_ordered: bool, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -1178,7 +1250,7 @@ pub struct ScalarFixedSizeBinary { pub struct ScalarValue { #[prost( oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 32, 34" + tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 32, 34" )] pub value: ::core::option::Option, } @@ -1222,8 +1294,12 @@ pub mod scalar_value { Date32Value(i32), #[prost(message, tag = "15")] Time32Value(super::ScalarTime32Value), + #[prost(message, tag = "16")] + LargeListValue(super::ScalarListValue), #[prost(message, tag = "17")] ListValue(super::ScalarListValue), + #[prost(message, tag = "18")] + FixedSizeListValue(super::ScalarListValue), #[prost(message, tag = "20")] Decimal128Value(super::Decimal128), #[prost(message, tag = "39")] @@ -1401,7 +1477,7 @@ pub struct OptimizedPhysicalPlanType { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PlanType { - #[prost(oneof = "plan_type::PlanTypeEnum", tags = "1, 7, 8, 2, 3, 4, 5, 6")] + #[prost(oneof = "plan_type::PlanTypeEnum", tags = "1, 7, 8, 2, 3, 4, 9, 5, 6, 10")] pub plan_type_enum: ::core::option::Option, } /// Nested message and enum types in `PlanType`. @@ -1421,10 +1497,14 @@ pub mod plan_type { FinalLogicalPlan(super::EmptyMessage), #[prost(message, tag = "4")] InitialPhysicalPlan(super::EmptyMessage), + #[prost(message, tag = "9")] + InitialPhysicalPlanWithStats(super::EmptyMessage), #[prost(message, tag = "5")] OptimizedPhysicalPlan(super::OptimizedPhysicalPlanType), #[prost(message, tag = "6")] FinalPhysicalPlan(super::EmptyMessage), + #[prost(message, tag = "10")] + FinalPhysicalPlanWithStats(super::EmptyMessage), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -1486,7 +1566,7 @@ pub mod owned_table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29" )] pub physical_plan_type: ::core::option::Option, } @@ -1541,10 +1621,181 @@ pub mod physical_plan_node { NestedLoopJoin(::prost::alloc::boxed::Box), #[prost(message, tag = "23")] Analyze(::prost::alloc::boxed::Box), + #[prost(message, tag = "24")] + JsonSink(::prost::alloc::boxed::Box), + #[prost(message, tag = "25")] + SymmetricHashJoin(::prost::alloc::boxed::Box), + #[prost(message, tag = "26")] + Interleave(super::InterleaveExecNode), + #[prost(message, tag = "27")] + PlaceholderRow(super::PlaceholderRowExecNode), + #[prost(message, tag = "28")] + CsvSink(::prost::alloc::boxed::Box), + #[prost(message, tag = "29")] + ParquetSink(::prost::alloc::boxed::Box), } } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct PartitionColumn { + #[prost(string, tag = "1")] + pub name: ::prost::alloc::string::String, + #[prost(message, optional, tag = "2")] + pub arrow_type: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FileTypeWriterOptions { + #[prost(oneof = "file_type_writer_options::FileType", tags = "1, 2, 3")] + pub file_type: ::core::option::Option, +} +/// Nested message and enum types in `FileTypeWriterOptions`. +pub mod file_type_writer_options { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum FileType { + #[prost(message, tag = "1")] + JsonOptions(super::JsonWriterOptions), + #[prost(message, tag = "2")] + ParquetOptions(super::ParquetWriterOptions), + #[prost(message, tag = "3")] + CsvOptions(super::CsvWriterOptions), + } +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct JsonWriterOptions { + #[prost(enumeration = "CompressionTypeVariant", tag = "1")] + pub compression: i32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ParquetWriterOptions { + #[prost(message, optional, tag = "1")] + pub writer_properties: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CsvWriterOptions { + /// Compression type + #[prost(enumeration = "CompressionTypeVariant", tag = "1")] + pub compression: i32, + /// Optional column delimiter. Defaults to `b','` + #[prost(string, tag = "2")] + pub delimiter: ::prost::alloc::string::String, + /// Whether to write column names as file headers. Defaults to `true` + #[prost(bool, tag = "3")] + pub has_header: bool, + /// Optional date format for date arrays + #[prost(string, tag = "4")] + pub date_format: ::prost::alloc::string::String, + /// Optional datetime format for datetime arrays + #[prost(string, tag = "5")] + pub datetime_format: ::prost::alloc::string::String, + /// Optional timestamp format for timestamp arrays + #[prost(string, tag = "6")] + pub timestamp_format: ::prost::alloc::string::String, + /// Optional time format for time arrays + #[prost(string, tag = "7")] + pub time_format: ::prost::alloc::string::String, + /// Optional value to represent null + #[prost(string, tag = "8")] + pub null_value: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct WriterProperties { + #[prost(uint64, tag = "1")] + pub data_page_size_limit: u64, + #[prost(uint64, tag = "2")] + pub dictionary_page_size_limit: u64, + #[prost(uint64, tag = "3")] + pub data_page_row_count_limit: u64, + #[prost(uint64, tag = "4")] + pub write_batch_size: u64, + #[prost(uint64, tag = "5")] + pub max_row_group_size: u64, + #[prost(string, tag = "6")] + pub writer_version: ::prost::alloc::string::String, + #[prost(string, tag = "7")] + pub created_by: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FileSinkConfig { + #[prost(string, tag = "1")] + pub object_store_url: ::prost::alloc::string::String, + #[prost(message, repeated, tag = "2")] + pub file_groups: ::prost::alloc::vec::Vec, + #[prost(string, repeated, tag = "3")] + pub table_paths: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + #[prost(message, optional, tag = "4")] + pub output_schema: ::core::option::Option, + #[prost(message, repeated, tag = "5")] + pub table_partition_cols: ::prost::alloc::vec::Vec, + #[prost(bool, tag = "7")] + pub single_file_output: bool, + #[prost(bool, tag = "8")] + pub overwrite: bool, + #[prost(message, optional, tag = "9")] + pub file_type_writer_options: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct JsonSink { + #[prost(message, optional, tag = "1")] + pub config: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct JsonSinkExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, tag = "2")] + pub sink: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub sink_schema: ::core::option::Option, + #[prost(message, optional, tag = "4")] + pub sort_order: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CsvSink { + #[prost(message, optional, tag = "1")] + pub config: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CsvSinkExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, tag = "2")] + pub sink: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub sink_schema: ::core::option::Option, + #[prost(message, optional, tag = "4")] + pub sort_order: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ParquetSink { + #[prost(message, optional, tag = "1")] + pub config: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ParquetSinkExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, tag = "2")] + pub sink: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub sink_schema: ::core::option::Option, + #[prost(message, optional, tag = "4")] + pub sort_order: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalExtensionNode { #[prost(bytes = "vec", tag = "1")] pub node: ::prost::alloc::vec::Vec, @@ -1813,6 +2064,8 @@ pub struct FilterExecNode { pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, optional, tag = "2")] pub expr: ::core::option::Option, + #[prost(uint32, tag = "3")] + pub default_filter_selectivity: u32, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -1910,6 +2163,30 @@ pub struct HashJoinExecNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct SymmetricHashJoinExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub left: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "2")] + pub right: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "3")] + pub on: ::prost::alloc::vec::Vec, + #[prost(enumeration = "JoinType", tag = "4")] + pub join_type: i32, + #[prost(enumeration = "StreamPartitionMode", tag = "6")] + pub partition_mode: i32, + #[prost(bool, tag = "7")] + pub null_equals_null: bool, + #[prost(message, optional, tag = "8")] + pub filter: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct InterleaveExecNode { + #[prost(message, repeated, tag = "1")] + pub inputs: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct UnionExecNode { #[prost(message, repeated, tag = "1")] pub inputs: ::prost::alloc::vec::Vec, @@ -1963,9 +2240,13 @@ pub struct JoinOn { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct EmptyExecNode { - #[prost(bool, tag = "1")] - pub produce_one_row: bool, - #[prost(message, optional, tag = "2")] + #[prost(message, optional, tag = "1")] + pub schema: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PlaceholderRowExecNode { + #[prost(message, optional, tag = "1")] pub schema: ::core::option::Option, } #[allow(clippy::derive_partial_eq_without_eq)] @@ -1980,7 +2261,7 @@ pub struct ProjectionExecNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] -pub struct PartiallySortedPartitionSearchMode { +pub struct PartiallySortedInputOrderMode { #[prost(uint64, repeated, tag = "6")] pub columns: ::prost::alloc::vec::Vec, } @@ -1994,21 +2275,19 @@ pub struct WindowAggExecNode { #[prost(message, repeated, tag = "5")] pub partition_keys: ::prost::alloc::vec::Vec, /// Set optional to `None` for `BoundedWindowAggExec`. - #[prost(oneof = "window_agg_exec_node::PartitionSearchMode", tags = "7, 8, 9")] - pub partition_search_mode: ::core::option::Option< - window_agg_exec_node::PartitionSearchMode, - >, + #[prost(oneof = "window_agg_exec_node::InputOrderMode", tags = "7, 8, 9")] + pub input_order_mode: ::core::option::Option, } /// Nested message and enum types in `WindowAggExecNode`. pub mod window_agg_exec_node { /// Set optional to `None` for `BoundedWindowAggExec`. #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum PartitionSearchMode { + pub enum InputOrderMode { #[prost(message, tag = "7")] Linear(super::EmptyMessage), #[prost(message, tag = "8")] - PartiallySorted(super::PartiallySortedPartitionSearchMode), + PartiallySorted(super::PartiallySortedInputOrderMode), #[prost(message, tag = "9")] Sorted(super::EmptyMessage), } @@ -2049,8 +2328,6 @@ pub struct AggregateExecNode { pub groups: ::prost::alloc::vec::Vec, #[prost(message, repeated, tag = "10")] pub filter_expr: ::prost::alloc::vec::Vec, - #[prost(message, repeated, tag = "11")] - pub order_by_expr: ::prost::alloc::vec::Vec, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -2466,6 +2743,17 @@ pub enum ScalarFunction { ArrayPopBack = 116, StringToArray = 117, ToTimestampNanos = 118, + ArrayIntersect = 119, + ArrayUnion = 120, + OverLay = 121, + Range = 122, + ArrayExcept = 123, + ArrayPopFront = 124, + Levenshtein = 125, + SubstrIndex = 126, + FindInSet = 127, + ArraySort = 128, + ArrayDistinct = 129, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2593,6 +2881,17 @@ impl ScalarFunction { ScalarFunction::ArrayPopBack => "ArrayPopBack", ScalarFunction::StringToArray => "StringToArray", ScalarFunction::ToTimestampNanos => "ToTimestampNanos", + ScalarFunction::ArrayIntersect => "ArrayIntersect", + ScalarFunction::ArrayUnion => "ArrayUnion", + ScalarFunction::OverLay => "OverLay", + ScalarFunction::Range => "Range", + ScalarFunction::ArrayExcept => "ArrayExcept", + ScalarFunction::ArrayPopFront => "ArrayPopFront", + ScalarFunction::Levenshtein => "Levenshtein", + ScalarFunction::SubstrIndex => "SubstrIndex", + ScalarFunction::FindInSet => "FindInSet", + ScalarFunction::ArraySort => "ArraySort", + ScalarFunction::ArrayDistinct => "ArrayDistinct", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2717,6 +3016,17 @@ impl ScalarFunction { "ArrayPopBack" => Some(Self::ArrayPopBack), "StringToArray" => Some(Self::StringToArray), "ToTimestampNanos" => Some(Self::ToTimestampNanos), + "ArrayIntersect" => Some(Self::ArrayIntersect), + "ArrayUnion" => Some(Self::ArrayUnion), + "OverLay" => Some(Self::OverLay), + "Range" => Some(Self::Range), + "ArrayExcept" => Some(Self::ArrayExcept), + "ArrayPopFront" => Some(Self::ArrayPopFront), + "Levenshtein" => Some(Self::Levenshtein), + "SubstrIndex" => Some(Self::SubstrIndex), + "FindInSet" => Some(Self::FindInSet), + "ArraySort" => Some(Self::ArraySort), + "ArrayDistinct" => Some(Self::ArrayDistinct), _ => None, } } @@ -2761,6 +3071,7 @@ pub enum AggregateFunction { RegrSxx = 32, RegrSyy = 33, RegrSxy = 34, + StringAgg = 35, } impl AggregateFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2806,6 +3117,7 @@ impl AggregateFunction { AggregateFunction::RegrSxx => "REGR_SXX", AggregateFunction::RegrSyy => "REGR_SYY", AggregateFunction::RegrSxy => "REGR_SXY", + AggregateFunction::StringAgg => "STRING_AGG", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2848,6 +3160,7 @@ impl AggregateFunction { "REGR_SXX" => Some(Self::RegrSxx), "REGR_SYY" => Some(Self::RegrSyy), "REGR_SXY" => Some(Self::RegrSxy), + "STRING_AGG" => Some(Self::StringAgg), _ => None, } } @@ -3078,6 +3391,41 @@ impl UnionMode { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] +pub enum CompressionTypeVariant { + Gzip = 0, + Bzip2 = 1, + Xz = 2, + Zstd = 3, + Uncompressed = 4, +} +impl CompressionTypeVariant { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + CompressionTypeVariant::Gzip => "GZIP", + CompressionTypeVariant::Bzip2 => "BZIP2", + CompressionTypeVariant::Xz => "XZ", + CompressionTypeVariant::Zstd => "ZSTD", + CompressionTypeVariant::Uncompressed => "UNCOMPRESSED", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "GZIP" => Some(Self::Gzip), + "BZIP2" => Some(Self::Bzip2), + "XZ" => Some(Self::Xz), + "ZSTD" => Some(Self::Zstd), + "UNCOMPRESSED" => Some(Self::Uncompressed), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] pub enum PartitionMode { CollectLeft = 0, Partitioned = 1, @@ -3107,6 +3455,32 @@ impl PartitionMode { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] +pub enum StreamPartitionMode { + SinglePartition = 0, + PartitionedExec = 1, +} +impl StreamPartitionMode { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + StreamPartitionMode::SinglePartition => "SINGLE_PARTITION", + StreamPartitionMode::PartitionedExec => "PARTITIONED_EXEC", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SINGLE_PARTITION" => Some(Self::SinglePartition), + "PARTITIONED_EXEC" => Some(Self::PartitionedExec), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] pub enum AggregateMode { Partial = 0, Final = 1, diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 26bd0163d0a31..36c5b44f00b9c 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -19,7 +19,8 @@ use crate::protobuf::{ self, plan_type::PlanTypeEnum::{ AnalyzedLogicalPlan, FinalAnalyzedLogicalPlan, FinalLogicalPlan, - FinalPhysicalPlan, InitialLogicalPlan, InitialPhysicalPlan, OptimizedLogicalPlan, + FinalPhysicalPlan, FinalPhysicalPlanWithStats, InitialLogicalPlan, + InitialPhysicalPlan, InitialPhysicalPlanWithStats, OptimizedLogicalPlan, OptimizedPhysicalPlan, }, AnalyzedLogicalPlanType, CubeNode, GroupingSetNode, OptimizedLogicalPlanType, @@ -35,28 +36,31 @@ use arrow::{ }; use datafusion::execution::registry::FunctionRegistry; use datafusion_common::{ - internal_err, plan_datafusion_err, Column, Constraint, Constraints, DFField, - DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, Result, ScalarValue, + arrow_datafusion_err, internal_err, plan_datafusion_err, Column, Constraint, + Constraints, DFField, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, + Result, ScalarValue, }; +use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - abs, acos, acosh, array, array_append, array_concat, array_dims, array_element, - array_has, array_has_all, array_has_any, array_length, array_ndims, array_position, - array_positions, array_prepend, array_remove, array_remove_all, array_remove_n, - array_repeat, array_replace, array_replace_all, array_replace_n, array_slice, - array_to_string, ascii, asin, asinh, atan, atan2, atanh, bit_length, btrim, - cardinality, cbrt, ceil, character_length, chr, coalesce, concat_expr, + abs, acos, acosh, array, array_append, array_concat, array_dims, array_distinct, + array_element, array_except, array_has, array_has_all, array_has_any, + array_intersect, array_length, array_ndims, array_position, array_positions, + array_prepend, array_remove, array_remove_all, array_remove_n, array_repeat, + array_replace, array_replace_all, array_replace_n, array_slice, array_sort, + array_to_string, arrow_typeof, ascii, asin, asinh, atan, atan2, atanh, bit_length, + btrim, cardinality, cbrt, ceil, character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, current_date, current_time, date_bin, date_part, - date_trunc, degrees, digest, exp, + date_trunc, decode, degrees, digest, encode, exp, expr::{self, InList, Sort, WindowFunction}, - factorial, floor, from_unixtime, gcd, isnan, iszero, lcm, left, ln, log, log10, log2, + factorial, find_in_set, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero, + lcm, left, levenshtein, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, - lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, pi, power, radians, - random, regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, - rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, sqrt, - starts_with, strpos, substr, substring, tan, tanh, to_hex, to_timestamp_micros, - to_timestamp_millis, to_timestamp_nanos, to_timestamp_seconds, translate, trim, - trunc, upper, uuid, - window_frame::regularize, + lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, power, + radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right, + round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, + sqrt, starts_with, string_to_array, strpos, struct_fun, substr, substr_index, + substring, tan, tanh, to_hex, to_timestamp_micros, to_timestamp_millis, + to_timestamp_nanos, to_timestamp_seconds, translate, trim, trunc, upper, uuid, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, GroupingSet::GroupingSets, @@ -64,7 +68,7 @@ use datafusion_expr::{ WindowFrameUnits, }; use datafusion_expr::{ - array_empty, array_pop_back, + array_empty, array_pop_back, array_pop_front, expr::{Alias, Placeholder}, }; use std::sync::Arc; @@ -374,8 +378,20 @@ impl TryFrom<&protobuf::Field> for Field { type Error = Error; fn try_from(field: &protobuf::Field) -> Result { let datatype = field.arrow_type.as_deref().required("arrow_type")?; - Ok(Self::new(field.name.as_str(), datatype, field.nullable) - .with_metadata(field.metadata.clone())) + let field = if field.dict_id != 0 { + Self::new_dict( + field.name.as_str(), + datatype, + field.nullable, + field.dict_id, + field.dict_ordered, + ) + .with_metadata(field.metadata.clone()) + } else { + Self::new(field.name.as_str(), datatype, field.nullable) + .with_metadata(field.metadata.clone()) + }; + Ok(field) } } @@ -405,12 +421,14 @@ impl From<&protobuf::StringifiedPlan> for StringifiedPlan { } FinalLogicalPlan(_) => PlanType::FinalLogicalPlan, InitialPhysicalPlan(_) => PlanType::InitialPhysicalPlan, + InitialPhysicalPlanWithStats(_) => PlanType::InitialPhysicalPlanWithStats, OptimizedPhysicalPlan(OptimizedPhysicalPlanType { optimizer_name }) => { PlanType::OptimizedPhysicalPlan { optimizer_name: optimizer_name.clone(), } } FinalPhysicalPlan(_) => PlanType::FinalPhysicalPlan, + FinalPhysicalPlanWithStats(_) => PlanType::FinalPhysicalPlanWithStats, }, plan: Arc::new(stringified_plan.plan.clone()), } @@ -459,16 +477,20 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Rtrim => Self::Rtrim, ScalarFunction::ToTimestamp => Self::ToTimestamp, ScalarFunction::ArrayAppend => Self::ArrayAppend, + ScalarFunction::ArraySort => Self::ArraySort, ScalarFunction::ArrayConcat => Self::ArrayConcat, ScalarFunction::ArrayEmpty => Self::ArrayEmpty, + ScalarFunction::ArrayExcept => Self::ArrayExcept, ScalarFunction::ArrayHasAll => Self::ArrayHasAll, ScalarFunction::ArrayHasAny => Self::ArrayHasAny, ScalarFunction::ArrayHas => Self::ArrayHas, ScalarFunction::ArrayDims => Self::ArrayDims, + ScalarFunction::ArrayDistinct => Self::ArrayDistinct, ScalarFunction::ArrayElement => Self::ArrayElement, ScalarFunction::Flatten => Self::Flatten, ScalarFunction::ArrayLength => Self::ArrayLength, ScalarFunction::ArrayNdims => Self::ArrayNdims, + ScalarFunction::ArrayPopFront => Self::ArrayPopFront, ScalarFunction::ArrayPopBack => Self::ArrayPopBack, ScalarFunction::ArrayPosition => Self::ArrayPosition, ScalarFunction::ArrayPositions => Self::ArrayPositions, @@ -482,6 +504,9 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll, ScalarFunction::ArraySlice => Self::ArraySlice, ScalarFunction::ArrayToString => Self::ArrayToString, + ScalarFunction::ArrayIntersect => Self::ArrayIntersect, + ScalarFunction::ArrayUnion => Self::ArrayUnion, + ScalarFunction::Range => Self::Range, ScalarFunction::Cardinality => Self::Cardinality, ScalarFunction::Array => Self::MakeArray, ScalarFunction::NullIf => Self::NullIf, @@ -540,6 +565,10 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Isnan => Self::Isnan, ScalarFunction::Iszero => Self::Iszero, ScalarFunction::ArrowTypeof => Self::ArrowTypeof, + ScalarFunction::OverLay => Self::OverLay, + ScalarFunction::Levenshtein => Self::Levenshtein, + ScalarFunction::SubstrIndex => Self::SubstrIndex, + ScalarFunction::FindInSet => Self::FindInSet, } } } @@ -586,6 +615,7 @@ impl From for AggregateFunction { protobuf::AggregateFunction::Median => Self::Median, protobuf::AggregateFunction::FirstValueAgg => Self::FirstValue, protobuf::AggregateFunction::LastValueAgg => Self::LastValue, + protobuf::AggregateFunction::StringAgg => Self::StringAgg, } } } @@ -648,7 +678,9 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Value::Float64Value(v) => Self::Float64(Some(*v)), Value::Date32Value(v) => Self::Date32(Some(*v)), // ScalarValue::List is serialized using arrow IPC format - Value::ListValue(scalar_list) => { + Value::ListValue(scalar_list) + | Value::FixedSizeListValue(scalar_list) + | Value::LargeListValue(scalar_list) => { let protobuf::ScalarListValue { ipc_message, arrow_data, @@ -686,10 +718,15 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { None, &message.version(), ) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) .map_err(|e| e.context("Decoding ScalarValue::List Value"))?; let arr = record_batch.column(0); - Self::List(arr.to_owned()) + match value { + Value::ListValue(_) => Self::List(arr.to_owned()), + Value::LargeListValue(_) => Self::LargeList(arr.to_owned()), + Value::FixedSizeListValue(_) => Self::FixedSizeList(arr.to_owned()), + _ => unreachable!(), + } } Value::NullValue(v) => { let null_type: DataType = v.try_into()?; @@ -1049,7 +1086,7 @@ pub fn parse_expr( .iter() .map(|e| parse_expr(e, registry)) .collect::, _>>()?; - let order_by = expr + let mut order_by = expr .order_by .iter() .map(|e| parse_expr(e, registry)) @@ -1059,7 +1096,8 @@ pub fn parse_expr( .as_ref() .map::, _>(|window_frame| { let window_frame = window_frame.clone().try_into()?; - regularize(window_frame, order_by.len()) + check_window_frame(&window_frame, order_by.len()) + .map(|_| window_frame) }) .transpose()? .ok_or_else(|| { @@ -1067,13 +1105,14 @@ pub fn parse_expr( "missing window frame during deserialization".to_string(), ) })?; + regularize_window_order_by(&window_frame, &mut order_by)?; match window_function { window_expr_node::WindowFunction::AggrFunction(i) => { let aggr_function = parse_i32_to_aggregate_function(i)?; Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::window_function::WindowFunction::AggregateFunction( + datafusion_expr::expr::WindowFunctionDefinition::AggregateFunction( aggr_function, ), vec![parse_required_expr(expr.expr.as_deref(), registry, "expr")?], @@ -1092,7 +1131,7 @@ pub fn parse_expr( .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::window_function::WindowFunction::BuiltInWindowFunction( + datafusion_expr::expr::WindowFunctionDefinition::BuiltInWindowFunction( built_in_function, ), args, @@ -1107,7 +1146,7 @@ pub fn parse_expr( .map(|e| vec![e]) .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::window_function::WindowFunction::AggregateUDF( + datafusion_expr::expr::WindowFunctionDefinition::AggregateUDF( udaf_function, ), args, @@ -1122,7 +1161,7 @@ pub fn parse_expr( .map(|e| vec![e]) .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::window_function::WindowFunction::WindowUDF( + datafusion_expr::expr::WindowFunctionDefinition::WindowUDF( udwf_function, ), args, @@ -1149,6 +1188,11 @@ pub fn parse_expr( } ExprType::Alias(alias) => Ok(Expr::Alias(Alias::new( parse_required_expr(alias.expr.as_deref(), registry, "expr")?, + alias + .relation + .first() + .map(|r| OwnedTableReference::try_from(r.clone())) + .transpose()?, alias.alias.clone(), ))), ExprType::IsNullExpr(is_null) => Ok(Expr::IsNull(Box::new(parse_required_expr( @@ -1294,7 +1338,13 @@ pub fn parse_expr( .collect::, _>>()?, in_list.negated, ))), - ExprType::Wildcard(_) => Ok(Expr::Wildcard), + ExprType::Wildcard(protobuf::Wildcard { qualifier }) => Ok(Expr::Wildcard { + qualifier: if qualifier.is_empty() { + None + } else { + Some(qualifier.clone()) + }, + }), ExprType::ScalarFunction(expr) => { let scalar_function = protobuf::ScalarFunction::try_from(expr.fun) .map_err(|_| Error::unknown("ScalarFunction", expr.fun))?; @@ -1315,6 +1365,14 @@ pub fn parse_expr( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), + ScalarFunction::ArraySort => Ok(array_sort( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + parse_expr(&args[2], registry)?, + )), + ScalarFunction::ArrayPopFront => { + Ok(array_pop_front(parse_expr(&args[0], registry)?)) + } ScalarFunction::ArrayPopBack => { Ok(array_pop_back(parse_expr(&args[0], registry)?)) } @@ -1328,6 +1386,10 @@ pub fn parse_expr( .map(|expr| parse_expr(expr, registry)) .collect::, _>>()?, )), + ScalarFunction::ArrayExcept => Ok(array_except( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::ArrayHasAll => Ok(array_has_all( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, @@ -1340,6 +1402,10 @@ pub fn parse_expr( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), + ScalarFunction::ArrayIntersect => Ok(array_intersect( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::ArrayPosition => Ok(array_position( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, @@ -1391,6 +1457,12 @@ pub fn parse_expr( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), + ScalarFunction::Range => Ok(gen_range( + args.to_owned() + .iter() + .map(|expr| parse_expr(expr, registry)) + .collect::, _>>()?, + )), ScalarFunction::Cardinality => { Ok(cardinality(parse_expr(&args[0], registry)?)) } @@ -1401,6 +1473,9 @@ pub fn parse_expr( ScalarFunction::ArrayDims => { Ok(array_dims(parse_expr(&args[0], registry)?)) } + ScalarFunction::ArrayDistinct => { + Ok(array_distinct(parse_expr(&args[0], registry)?)) + } ScalarFunction::ArrayElement => Ok(array_element( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, @@ -1411,6 +1486,12 @@ pub fn parse_expr( ScalarFunction::ArrayNdims => { Ok(array_ndims(parse_expr(&args[0], registry)?)) } + ScalarFunction::ArrayUnion => Ok(array( + args.to_owned() + .iter() + .map(|expr| parse_expr(expr, registry)) + .collect::, _>>()?, + )), ScalarFunction::Sqrt => Ok(sqrt(parse_expr(&args[0], registry)?)), ScalarFunction::Cbrt => Ok(cbrt(parse_expr(&args[0], registry)?)), ScalarFunction::Sin => Ok(sin(parse_expr(&args[0], registry)?)), @@ -1472,6 +1553,14 @@ pub fn parse_expr( ScalarFunction::Sha384 => Ok(sha384(parse_expr(&args[0], registry)?)), ScalarFunction::Sha512 => Ok(sha512(parse_expr(&args[0], registry)?)), ScalarFunction::Md5 => Ok(md5(parse_expr(&args[0], registry)?)), + ScalarFunction::Encode => Ok(encode( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), + ScalarFunction::Decode => Ok(decode( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::NullIf => Ok(nullif( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, @@ -1587,6 +1676,10 @@ pub fn parse_expr( )) } } + ScalarFunction::Levenshtein => Ok(levenshtein( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::ToHex => Ok(to_hex(parse_expr(&args[0], registry)?)), ScalarFunction::ToTimestampMillis => { Ok(to_timestamp_millis(parse_expr(&args[0], registry)?)) @@ -1637,14 +1730,41 @@ pub fn parse_expr( )), ScalarFunction::Isnan => Ok(isnan(parse_expr(&args[0], registry)?)), ScalarFunction::Iszero => Ok(iszero(parse_expr(&args[0], registry)?)), - _ => Err(proto_error( - "Protobuf deserialization error: Unsupported scalar function", + ScalarFunction::ArrowTypeof => { + Ok(arrow_typeof(parse_expr(&args[0], registry)?)) + } + ScalarFunction::ToTimestamp => { + Ok(to_timestamp_seconds(parse_expr(&args[0], registry)?)) + } + ScalarFunction::Flatten => Ok(flatten(parse_expr(&args[0], registry)?)), + ScalarFunction::StringToArray => Ok(string_to_array( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + parse_expr(&args[2], registry)?, )), + ScalarFunction::OverLay => Ok(overlay( + args.to_owned() + .iter() + .map(|expr| parse_expr(expr, registry)) + .collect::, _>>()?, + )), + ScalarFunction::SubstrIndex => Ok(substr_index( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + parse_expr(&args[2], registry)?, + )), + ScalarFunction::FindInSet => Ok(find_in_set( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), + ScalarFunction::StructFun => { + Ok(struct_fun(parse_expr(&args[0], registry)?)) + } } } ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { fun_name, args }) => { let scalar_fn = registry.udf(fun_name.as_str())?; - Ok(Expr::ScalarUDF(expr::ScalarUDF::new( + Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( scalar_fn, args.iter() .map(|expr| parse_expr(expr, registry)) @@ -1654,12 +1774,13 @@ pub fn parse_expr( ExprType::AggregateUdfExpr(pb) => { let agg_fn = registry.udaf(pb.fun_name.as_str())?; - Ok(Expr::AggregateUDF(expr::AggregateUDF::new( + Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( agg_fn, pb.args .iter() .map(|expr| parse_expr(expr, registry)) .collect::, Error>>()?, + false, parse_optional_expr(pb.filter.as_deref(), registry)?.map(Box::new), parse_vec_expr(&pb.order_by, registry)?, ))) diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index e426c598523e3..e8a38784481b0 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -15,13 +15,18 @@ // specific language governing permissions and limitations // under the License. +use arrow::csv::WriterBuilder; +use std::collections::HashMap; use std::fmt::Debug; use std::str::FromStr; use std::sync::Arc; use crate::common::{byte_to_string, proto_error, str_to_byte}; use crate::protobuf::logical_plan_node::LogicalPlanType::CustomScan; -use crate::protobuf::{CustomTableScanNode, LogicalExprNodeCollection}; +use crate::protobuf::{ + copy_to_node, file_type_writer_options, CustomTableScanNode, + LogicalExprNodeCollection, SqlOption, +}; use crate::{ convert_required, protobuf::{ @@ -43,21 +48,26 @@ use datafusion::{ datasource::{provider_as_source, source_as_provider}, prelude::SessionContext, }; -use datafusion_common::plan_datafusion_err; use datafusion_common::{ - context, internal_err, not_impl_err, parsers::CompressionTypeVariant, - DataFusionError, OwnedTableReference, Result, + context, file_options::StatementOptions, internal_err, not_impl_err, + parsers::CompressionTypeVariant, plan_datafusion_err, DataFusionError, FileType, + FileTypeWriterOptions, OwnedTableReference, Result, }; use datafusion_expr::{ + dml, logical_plan::{ builder::project, Aggregate, CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateView, CrossJoin, DdlStatement, Distinct, EmptyRelation, Extension, Join, JoinConstraint, Limit, Prepare, Projection, Repartition, Sort, SubqueryAlias, TableScan, Values, Window, }, - DropView, Expr, LogicalPlan, LogicalPlanBuilder, + DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, }; +use datafusion::parquet::file::properties::{WriterProperties, WriterVersion}; +use datafusion_common::file_options::csv_writer::CsvWriterOptions; +use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; +use datafusion_expr::dml::CopyOptions; use prost::bytes::BufMut; use prost::Message; @@ -252,7 +262,7 @@ impl AsLogicalPlan for LogicalPlanNode { Some(a) => match a { protobuf::projection_node::OptionalAlias::Alias(alias) => { Ok(LogicalPlan::SubqueryAlias(SubqueryAlias::try_new( - new_proj, + Arc::new(new_proj), alias.clone(), )?)) } @@ -521,6 +531,13 @@ impl AsLogicalPlan for LogicalPlanNode { order_exprs.push(order_expr) } + let mut column_defaults = + HashMap::with_capacity(create_extern_table.column_defaults.len()); + for (col_name, expr) in &create_extern_table.column_defaults { + let expr = from_proto::parse_expr(expr, ctx)?; + column_defaults.insert(col_name.clone(), expr); + } + Ok(LogicalPlan::Ddl(DdlStatement::CreateExternalTable(CreateExternalTable { schema: pb_schema.try_into()?, name: from_owned_table_reference(create_extern_table.name.as_ref(), "CreateExternalTable")?, @@ -540,6 +557,7 @@ impl AsLogicalPlan for LogicalPlanNode { unbounded: create_extern_table.unbounded, options: create_extern_table.options.clone(), constraints: constraints.into(), + column_defaults, }))) } LogicalPlanType::CreateView(create_view) => { @@ -734,6 +752,33 @@ impl AsLogicalPlan for LogicalPlanNode { into_logical_plan!(distinct.input, ctx, extension_codec)?; LogicalPlanBuilder::from(input).distinct()?.build() } + LogicalPlanType::DistinctOn(distinct_on) => { + let input: LogicalPlan = + into_logical_plan!(distinct_on.input, ctx, extension_codec)?; + let on_expr = distinct_on + .on_expr + .iter() + .map(|expr| from_proto::parse_expr(expr, ctx)) + .collect::, _>>()?; + let select_expr = distinct_on + .select_expr + .iter() + .map(|expr| from_proto::parse_expr(expr, ctx)) + .collect::, _>>()?; + let sort_expr = match distinct_on.sort_expr.len() { + 0 => None, + _ => Some( + distinct_on + .sort_expr + .iter() + .map(|expr| from_proto::parse_expr(expr, ctx)) + .collect::, _>>()?, + ), + }; + LogicalPlanBuilder::from(input) + .distinct_on(on_expr, select_expr, sort_expr)? + .build() + } LogicalPlanType::ViewScan(scan) => { let schema: Schema = convert_required!(scan.schema)?; @@ -787,6 +832,79 @@ impl AsLogicalPlan for LogicalPlanNode { schema: Arc::new(convert_required!(dropview.schema)?), }), )), + LogicalPlanType::CopyTo(copy) => { + let input: LogicalPlan = + into_logical_plan!(copy.input, ctx, extension_codec)?; + + let copy_options = match ©.copy_options { + Some(copy_to_node::CopyOptions::SqlOptions(opt)) => { + let options = opt + .option + .iter() + .map(|o| (o.key.clone(), o.value.clone())) + .collect(); + CopyOptions::SQLOptions(StatementOptions::from(&options)) + } + Some(copy_to_node::CopyOptions::WriterOptions(opt)) => { + match &opt.file_type { + Some(ft) => match ft { + file_type_writer_options::FileType::CsvOptions( + writer_options, + ) => { + let writer_builder = + csv_writer_options_from_proto(writer_options)?; + CopyOptions::WriterOptions(Box::new( + FileTypeWriterOptions::CSV( + CsvWriterOptions::new( + writer_builder, + CompressionTypeVariant::UNCOMPRESSED, + ), + ), + )) + } + file_type_writer_options::FileType::ParquetOptions( + writer_options, + ) => { + let writer_properties = + match &writer_options.writer_properties { + Some(serialized_writer_options) => { + writer_properties_from_proto( + serialized_writer_options, + )? + } + _ => WriterProperties::default(), + }; + CopyOptions::WriterOptions(Box::new( + FileTypeWriterOptions::Parquet( + ParquetWriterOptions::new(writer_properties), + ), + )) + } + _ => { + return Err(proto_error( + "WriterOptions unsupported file_type", + )) + } + }, + None => { + return Err(proto_error( + "WriterOptions missing file_type", + )) + } + } + } + None => return Err(proto_error("CopyTo missing CopyOptions")), + }; + Ok(datafusion_expr::LogicalPlan::Copy( + datafusion_expr::dml::CopyTo { + input: Arc::new(input), + output_url: copy.output_url.clone(), + file_format: FileType::from_str(©.file_type)?, + single_file_output: copy.single_file_output, + copy_options, + }, + )) + } } } @@ -1005,7 +1123,7 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Distinct(Distinct { input }) => { + LogicalPlan::Distinct(Distinct::All(input)) => { let input: protobuf::LogicalPlanNode = protobuf::LogicalPlanNode::try_from_logical_plan( input.as_ref(), @@ -1019,6 +1137,42 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } + LogicalPlan::Distinct(Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + input, + .. + })) => { + let input: protobuf::LogicalPlanNode = + protobuf::LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; + let sort_expr = match sort_expr { + None => vec![], + Some(sort_expr) => sort_expr + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + }; + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::DistinctOn(Box::new( + protobuf::DistinctOnNode { + on_expr: on_expr + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + select_expr: select_expr + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + sort_expr, + input: Some(Box::new(input)), + }, + ))), + }) + } LogicalPlan::Window(Window { input, window_expr, .. }) => { @@ -1235,6 +1389,7 @@ impl AsLogicalPlan for LogicalPlanNode { unbounded, options, constraints, + column_defaults, }, )) => { let mut converted_order_exprs: Vec = vec![]; @@ -1249,6 +1404,12 @@ impl AsLogicalPlan for LogicalPlanNode { converted_order_exprs.push(temp); } + let mut converted_column_defaults = + HashMap::with_capacity(column_defaults.len()); + for (col_name, expr) in column_defaults { + converted_column_defaults.insert(col_name.clone(), expr.try_into()?); + } + Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateExternalTable( protobuf::CreateExternalTableNode { @@ -1266,6 +1427,7 @@ impl AsLogicalPlan for LogicalPlanNode { unbounded: *unbounded, options: options.clone(), constraints: Some(constraints.clone().into()), + column_defaults: converted_column_defaults, }, )), }) @@ -1454,12 +1616,163 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlan::Dml(_) => Err(proto_error( "LogicalPlan serde is not yet implemented for Dml", )), - LogicalPlan::Copy(_) => Err(proto_error( - "LogicalPlan serde is not yet implemented for Copy", - )), + LogicalPlan::Copy(dml::CopyTo { + input, + output_url, + single_file_output, + file_format, + copy_options, + }) => { + let input = protobuf::LogicalPlanNode::try_from_logical_plan( + input, + extension_codec, + )?; + + let copy_options_proto: Option = + match copy_options { + CopyOptions::SQLOptions(opt) => { + let options: Vec = opt + .clone() + .into_inner() + .iter() + .map(|(k, v)| SqlOption { + key: k.to_string(), + value: v.to_string(), + }) + .collect(); + Some(copy_to_node::CopyOptions::SqlOptions( + protobuf::SqlOptions { option: options }, + )) + } + CopyOptions::WriterOptions(opt) => { + match opt.as_ref() { + FileTypeWriterOptions::CSV(csv_opts) => { + let csv_options = &csv_opts.writer_options; + let csv_writer_options = csv_writer_options_to_proto( + csv_options, + &csv_opts.compression, + ); + let csv_options = + file_type_writer_options::FileType::CsvOptions( + csv_writer_options, + ); + Some(copy_to_node::CopyOptions::WriterOptions( + protobuf::FileTypeWriterOptions { + file_type: Some(csv_options), + }, + )) + } + FileTypeWriterOptions::Parquet(parquet_opts) => { + let parquet_writer_options = + protobuf::ParquetWriterOptions { + writer_properties: Some( + writer_properties_to_proto( + &parquet_opts.writer_options, + ), + ), + }; + let parquet_options = file_type_writer_options::FileType::ParquetOptions(parquet_writer_options); + Some(copy_to_node::CopyOptions::WriterOptions( + protobuf::FileTypeWriterOptions { + file_type: Some(parquet_options), + }, + )) + } + _ => { + return Err(proto_error( + "Unsupported FileTypeWriterOptions in CopyTo", + )) + } + } + } + }; + + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::CopyTo(Box::new( + protobuf::CopyToNode { + input: Some(Box::new(input)), + single_file_output: *single_file_output, + output_url: output_url.to_string(), + file_type: file_format.to_string(), + copy_options: copy_options_proto, + }, + ))), + }) + } LogicalPlan::DescribeTable(_) => Err(proto_error( "LogicalPlan serde is not yet implemented for DescribeTable", )), } } } + +pub(crate) fn csv_writer_options_to_proto( + csv_options: &WriterBuilder, + compression: &CompressionTypeVariant, +) -> protobuf::CsvWriterOptions { + let compression: protobuf::CompressionTypeVariant = compression.into(); + protobuf::CsvWriterOptions { + compression: compression.into(), + delimiter: (csv_options.delimiter() as char).to_string(), + has_header: csv_options.header(), + date_format: csv_options.date_format().unwrap_or("").to_owned(), + datetime_format: csv_options.datetime_format().unwrap_or("").to_owned(), + timestamp_format: csv_options.timestamp_format().unwrap_or("").to_owned(), + time_format: csv_options.time_format().unwrap_or("").to_owned(), + null_value: csv_options.null().to_owned(), + } +} + +pub(crate) fn csv_writer_options_from_proto( + writer_options: &protobuf::CsvWriterOptions, +) -> Result { + let mut builder = WriterBuilder::new(); + if !writer_options.delimiter.is_empty() { + if let Some(delimiter) = writer_options.delimiter.chars().next() { + if delimiter.is_ascii() { + builder = builder.with_delimiter(delimiter as u8); + } else { + return Err(proto_error("CSV Delimiter is not ASCII")); + } + } else { + return Err(proto_error("Error parsing CSV Delimiter")); + } + } + Ok(builder + .with_header(writer_options.has_header) + .with_date_format(writer_options.date_format.clone()) + .with_datetime_format(writer_options.datetime_format.clone()) + .with_timestamp_format(writer_options.timestamp_format.clone()) + .with_time_format(writer_options.time_format.clone()) + .with_null(writer_options.null_value.clone())) +} + +pub(crate) fn writer_properties_to_proto( + props: &WriterProperties, +) -> protobuf::WriterProperties { + protobuf::WriterProperties { + data_page_size_limit: props.data_page_size_limit() as u64, + dictionary_page_size_limit: props.dictionary_page_size_limit() as u64, + data_page_row_count_limit: props.data_page_row_count_limit() as u64, + write_batch_size: props.write_batch_size() as u64, + max_row_group_size: props.max_row_group_size() as u64, + writer_version: format!("{:?}", props.writer_version()), + created_by: props.created_by().to_string(), + } +} + +pub(crate) fn writer_properties_from_proto( + props: &protobuf::WriterProperties, +) -> Result { + let writer_version = + WriterVersion::from_str(&props.writer_version).map_err(proto_error)?; + Ok(WriterProperties::builder() + .set_created_by(props.created_by.clone()) + .set_writer_version(writer_version) + .set_dictionary_page_size_limit(props.dictionary_page_size_limit as usize) + .set_data_page_row_count_limit(props.data_page_row_count_limit as usize) + .set_data_page_size_limit(props.data_page_size_limit as usize) + .set_write_batch_size(props.write_batch_size as usize) + .set_max_row_group_size(props.max_row_group_size as usize) + .build()) +} diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 687b73cfc886f..a162b2389cd1a 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -24,7 +24,8 @@ use crate::protobuf::{ arrow_type::ArrowTypeEnum, plan_type::PlanTypeEnum::{ AnalyzedLogicalPlan, FinalAnalyzedLogicalPlan, FinalLogicalPlan, - FinalPhysicalPlan, InitialLogicalPlan, InitialPhysicalPlan, OptimizedLogicalPlan, + FinalPhysicalPlan, FinalPhysicalPlanWithStats, InitialLogicalPlan, + InitialPhysicalPlan, InitialPhysicalPlanWithStats, OptimizedLogicalPlan, OptimizedPhysicalPlan, }, AnalyzedLogicalPlanType, CubeNode, EmptyMessage, GroupingSetNode, LogicalExprList, @@ -43,13 +44,14 @@ use datafusion_common::{ ScalarValue, }; use datafusion_expr::expr::{ - self, Alias, Between, BinaryExpr, Cast, GetFieldAccess, GetIndexedField, GroupingSet, - InList, Like, Placeholder, ScalarFunction, ScalarUDF, Sort, + self, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Cast, GetFieldAccess, + GetIndexedField, GroupingSet, InList, Like, Placeholder, ScalarFunction, + ScalarFunctionDefinition, Sort, }; use datafusion_expr::{ logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction, BuiltInWindowFunction, BuiltinScalarFunction, Expr, JoinConstraint, JoinType, - TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, + TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; #[derive(Debug)] @@ -106,6 +108,8 @@ impl TryFrom<&Field> for protobuf::Field { nullable: field.is_nullable(), children: Vec::new(), metadata: field.metadata().clone(), + dict_id: field.dict_id().unwrap_or(0), + dict_ordered: field.dict_is_ordered().unwrap_or(false), }) } } @@ -352,6 +356,12 @@ impl From<&StringifiedPlan> for protobuf::StringifiedPlan { PlanType::FinalPhysicalPlan => Some(protobuf::PlanType { plan_type_enum: Some(FinalPhysicalPlan(EmptyMessage {})), }), + PlanType::InitialPhysicalPlanWithStats => Some(protobuf::PlanType { + plan_type_enum: Some(InitialPhysicalPlanWithStats(EmptyMessage {})), + }), + PlanType::FinalPhysicalPlanWithStats => Some(protobuf::PlanType { + plan_type_enum: Some(FinalPhysicalPlanWithStats(EmptyMessage {})), + }), }, plan: stringified_plan.plan.to_string(), } @@ -398,6 +408,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Median => Self::Median, AggregateFunction::FirstValue => Self::FirstValueAgg, AggregateFunction::LastValue => Self::LastValueAgg, + AggregateFunction::StringAgg => Self::StringAgg, } } } @@ -476,9 +487,17 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { Expr::Column(c) => Self { expr_type: Some(ExprType::Column(c.into())), }, - Expr::Alias(Alias { expr, name, .. }) => { + Expr::Alias(Alias { + expr, + relation, + name, + }) => { let alias = Box::new(protobuf::AliasNode { expr: Some(Box::new(expr.as_ref().try_into()?)), + relation: relation + .to_owned() + .map(|r| vec![r.into()]) + .unwrap_or(vec![]), alias: name.to_owned(), }); Self { @@ -586,24 +605,24 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { ref window_frame, }) => { let window_function = match fun { - WindowFunction::AggregateFunction(fun) => { + WindowFunctionDefinition::AggregateFunction(fun) => { protobuf::window_expr_node::WindowFunction::AggrFunction( protobuf::AggregateFunction::from(fun).into(), ) } - WindowFunction::BuiltInWindowFunction(fun) => { + WindowFunctionDefinition::BuiltInWindowFunction(fun) => { protobuf::window_expr_node::WindowFunction::BuiltInFunction( protobuf::BuiltInWindowFunction::from(fun).into(), ) } - WindowFunction::AggregateUDF(aggr_udf) => { + WindowFunctionDefinition::AggregateUDF(aggr_udf) => { protobuf::window_expr_node::WindowFunction::Udaf( - aggr_udf.name.clone(), + aggr_udf.name().to_string(), ) } - WindowFunction::WindowUDF(window_udf) => { + WindowFunctionDefinition::WindowUDF(window_udf) => { protobuf::window_expr_node::WindowFunction::Udwf( - window_udf.name.clone(), + window_udf.name().to_string(), ) } }; @@ -636,159 +655,178 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { } } Expr::AggregateFunction(expr::AggregateFunction { - ref fun, + ref func_def, ref args, ref distinct, ref filter, ref order_by, }) => { - let aggr_function = match fun { - AggregateFunction::ApproxDistinct => { - protobuf::AggregateFunction::ApproxDistinct - } - AggregateFunction::ApproxPercentileCont => { - protobuf::AggregateFunction::ApproxPercentileCont - } - AggregateFunction::ApproxPercentileContWithWeight => { - protobuf::AggregateFunction::ApproxPercentileContWithWeight - } - AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, - AggregateFunction::Min => protobuf::AggregateFunction::Min, - AggregateFunction::Max => protobuf::AggregateFunction::Max, - AggregateFunction::Sum => protobuf::AggregateFunction::Sum, - AggregateFunction::BitAnd => protobuf::AggregateFunction::BitAnd, - AggregateFunction::BitOr => protobuf::AggregateFunction::BitOr, - AggregateFunction::BitXor => protobuf::AggregateFunction::BitXor, - AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd, - AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, - AggregateFunction::Avg => protobuf::AggregateFunction::Avg, - AggregateFunction::Count => protobuf::AggregateFunction::Count, - AggregateFunction::Variance => protobuf::AggregateFunction::Variance, - AggregateFunction::VariancePop => { - protobuf::AggregateFunction::VariancePop - } - AggregateFunction::Covariance => { - protobuf::AggregateFunction::Covariance - } - AggregateFunction::CovariancePop => { - protobuf::AggregateFunction::CovariancePop - } - AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev, - AggregateFunction::StddevPop => { - protobuf::AggregateFunction::StddevPop - } - AggregateFunction::Correlation => { - protobuf::AggregateFunction::Correlation - } - AggregateFunction::RegrSlope => { - protobuf::AggregateFunction::RegrSlope - } - AggregateFunction::RegrIntercept => { - protobuf::AggregateFunction::RegrIntercept - } - AggregateFunction::RegrR2 => protobuf::AggregateFunction::RegrR2, - AggregateFunction::RegrAvgx => protobuf::AggregateFunction::RegrAvgx, - AggregateFunction::RegrAvgy => protobuf::AggregateFunction::RegrAvgy, - AggregateFunction::RegrCount => { - protobuf::AggregateFunction::RegrCount - } - AggregateFunction::RegrSXX => protobuf::AggregateFunction::RegrSxx, - AggregateFunction::RegrSYY => protobuf::AggregateFunction::RegrSyy, - AggregateFunction::RegrSXY => protobuf::AggregateFunction::RegrSxy, - AggregateFunction::ApproxMedian => { - protobuf::AggregateFunction::ApproxMedian - } - AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, - AggregateFunction::Median => protobuf::AggregateFunction::Median, - AggregateFunction::FirstValue => { - protobuf::AggregateFunction::FirstValueAgg - } - AggregateFunction::LastValue => { - protobuf::AggregateFunction::LastValueAgg + match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + let aggr_function = match fun { + AggregateFunction::ApproxDistinct => { + protobuf::AggregateFunction::ApproxDistinct + } + AggregateFunction::ApproxPercentileCont => { + protobuf::AggregateFunction::ApproxPercentileCont + } + AggregateFunction::ApproxPercentileContWithWeight => { + protobuf::AggregateFunction::ApproxPercentileContWithWeight + } + AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, + AggregateFunction::Min => protobuf::AggregateFunction::Min, + AggregateFunction::Max => protobuf::AggregateFunction::Max, + AggregateFunction::Sum => protobuf::AggregateFunction::Sum, + AggregateFunction::BitAnd => protobuf::AggregateFunction::BitAnd, + AggregateFunction::BitOr => protobuf::AggregateFunction::BitOr, + AggregateFunction::BitXor => protobuf::AggregateFunction::BitXor, + AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd, + AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, + AggregateFunction::Avg => protobuf::AggregateFunction::Avg, + AggregateFunction::Count => protobuf::AggregateFunction::Count, + AggregateFunction::Variance => protobuf::AggregateFunction::Variance, + AggregateFunction::VariancePop => { + protobuf::AggregateFunction::VariancePop + } + AggregateFunction::Covariance => { + protobuf::AggregateFunction::Covariance + } + AggregateFunction::CovariancePop => { + protobuf::AggregateFunction::CovariancePop + } + AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev, + AggregateFunction::StddevPop => { + protobuf::AggregateFunction::StddevPop + } + AggregateFunction::Correlation => { + protobuf::AggregateFunction::Correlation + } + AggregateFunction::RegrSlope => { + protobuf::AggregateFunction::RegrSlope + } + AggregateFunction::RegrIntercept => { + protobuf::AggregateFunction::RegrIntercept + } + AggregateFunction::RegrR2 => protobuf::AggregateFunction::RegrR2, + AggregateFunction::RegrAvgx => protobuf::AggregateFunction::RegrAvgx, + AggregateFunction::RegrAvgy => protobuf::AggregateFunction::RegrAvgy, + AggregateFunction::RegrCount => { + protobuf::AggregateFunction::RegrCount + } + AggregateFunction::RegrSXX => protobuf::AggregateFunction::RegrSxx, + AggregateFunction::RegrSYY => protobuf::AggregateFunction::RegrSyy, + AggregateFunction::RegrSXY => protobuf::AggregateFunction::RegrSxy, + AggregateFunction::ApproxMedian => { + protobuf::AggregateFunction::ApproxMedian + } + AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, + AggregateFunction::Median => protobuf::AggregateFunction::Median, + AggregateFunction::FirstValue => { + protobuf::AggregateFunction::FirstValueAgg + } + AggregateFunction::LastValue => { + protobuf::AggregateFunction::LastValueAgg + } + AggregateFunction::StringAgg => { + protobuf::AggregateFunction::StringAgg + } + }; + + let aggregate_expr = protobuf::AggregateExprNode { + aggr_function: aggr_function.into(), + expr: args + .iter() + .map(|v| v.try_into()) + .collect::, _>>()?, + distinct: *distinct, + filter: match filter { + Some(e) => Some(Box::new(e.as_ref().try_into()?)), + None => None, + }, + order_by: match order_by { + Some(e) => e + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + None => vec![], + }, + }; + Self { + expr_type: Some(ExprType::AggregateExpr(Box::new( + aggregate_expr, + ))), + } } - }; - - let aggregate_expr = protobuf::AggregateExprNode { - aggr_function: aggr_function.into(), - expr: args - .iter() - .map(|v| v.try_into()) - .collect::, _>>()?, - distinct: *distinct, - filter: match filter { - Some(e) => Some(Box::new(e.as_ref().try_into()?)), - None => None, - }, - order_by: match order_by { - Some(e) => e - .iter() - .map(|expr| expr.try_into()) - .collect::, _>>()?, - None => vec![], + AggregateFunctionDefinition::UDF(fun) => Self { + expr_type: Some(ExprType::AggregateUdfExpr(Box::new( + protobuf::AggregateUdfExprNode { + fun_name: fun.name().to_string(), + args: args + .iter() + .map(|expr| expr.try_into()) + .collect::, Error>>()?, + filter: match filter { + Some(e) => Some(Box::new(e.as_ref().try_into()?)), + None => None, + }, + order_by: match order_by { + Some(e) => e + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + None => vec![], + }, + }, + ))), }, - }; - Self { - expr_type: Some(ExprType::AggregateExpr(Box::new(aggregate_expr))), + AggregateFunctionDefinition::Name(_) => { + return Err(Error::NotImplemented( + "Proto serialization error: Trying to serialize a unresolved function" + .to_string(), + )); + } } } + Expr::ScalarVariable(_, _) => { return Err(Error::General( "Proto serialization error: Scalar Variable not supported" .to_string(), )) } - Expr::ScalarFunction(ScalarFunction { fun, args }) => { - let fun: protobuf::ScalarFunction = fun.try_into()?; - let args: Vec = args + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + let args = args .iter() - .map(|e| e.try_into()) - .collect::, Error>>()?; - Self { - expr_type: Some(ExprType::ScalarFunction( - protobuf::ScalarFunctionNode { - fun: fun.into(), - args, - }, - )), + .map(|expr| expr.try_into()) + .collect::, Error>>()?; + match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + let fun: protobuf::ScalarFunction = fun.try_into()?; + Self { + expr_type: Some(ExprType::ScalarFunction( + protobuf::ScalarFunctionNode { + fun: fun.into(), + args, + }, + )), + } + } + ScalarFunctionDefinition::UDF(fun) => Self { + expr_type: Some(ExprType::ScalarUdfExpr( + protobuf::ScalarUdfExprNode { + fun_name: fun.name().to_string(), + args, + }, + )), + }, + ScalarFunctionDefinition::Name(_) => { + return Err(Error::NotImplemented( + "Proto serialization error: Trying to serialize a unresolved function" + .to_string(), + )); + } } } - Expr::ScalarUDF(ScalarUDF { fun, args }) => Self { - expr_type: Some(ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { - fun_name: fun.name.clone(), - args: args - .iter() - .map(|expr| expr.try_into()) - .collect::, Error>>()?, - })), - }, - Expr::AggregateUDF(expr::AggregateUDF { - fun, - args, - filter, - order_by, - }) => Self { - expr_type: Some(ExprType::AggregateUdfExpr(Box::new( - protobuf::AggregateUdfExprNode { - fun_name: fun.name.clone(), - args: args.iter().map(|expr| expr.try_into()).collect::, - Error, - >>( - )?, - filter: match filter { - Some(e) => Some(Box::new(e.as_ref().try_into()?)), - None => None, - }, - order_by: match order_by { - Some(e) => e - .iter() - .map(|expr| expr.try_into()) - .collect::, _>>()?, - None => vec![], - }, - }, - ))), - }, Expr::Not(expr) => { let expr = Box::new(protobuf::Not { expr: Some(Box::new(expr.as_ref().try_into()?)), @@ -960,8 +998,10 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { expr_type: Some(ExprType::InList(expr)), } } - Expr::Wildcard => Self { - expr_type: Some(ExprType::Wildcard(true)), + Expr::Wildcard { qualifier } => Self { + expr_type: Some(ExprType::Wildcard(protobuf::Wildcard { + qualifier: qualifier.clone().unwrap_or("".to_string()), + })), }, Expr::ScalarSubquery(_) | Expr::InSubquery(_) @@ -1052,11 +1092,6 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { })), } } - - Expr::QualifiedWildcard { .. } => return Err(Error::General( - "Proto serialization error: Expr::QualifiedWildcard { .. } not supported" - .to_string(), - )), }; Ok(expr_node) @@ -1122,13 +1157,11 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { Value::LargeUtf8Value(s.to_owned()) }) } - ScalarValue::Fixedsizelist(..) => Err(Error::General( - "Proto serialization error: ScalarValue::Fixedsizelist not supported" - .to_string(), - )), - // ScalarValue::List is serialized using Arrow IPC messages. - // as a single column RecordBatch - ScalarValue::List(arr) => { + // ScalarValue::List and ScalarValue::FixedSizeList are serialized using + // Arrow IPC messages as a single column RecordBatch + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => { // Wrap in a "field_name" column let batch = RecordBatch::try_from_iter(vec![( "field_name", @@ -1156,11 +1189,24 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { schema: Some(schema), }; - Ok(protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::ListValue( - scalar_list_value, - )), - }) + match val { + ScalarValue::List(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::ListValue( + scalar_list_value, + )), + }), + ScalarValue::LargeList(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::LargeListValue( + scalar_list_value, + )), + }), + ScalarValue::FixedSizeList(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::FixedSizeListValue( + scalar_list_value, + )), + }), + _ => unreachable!(), + } } ScalarValue::Date32(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Date32Value(*s)) @@ -1458,16 +1504,20 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Rtrim => Self::Rtrim, BuiltinScalarFunction::ToTimestamp => Self::ToTimestamp, BuiltinScalarFunction::ArrayAppend => Self::ArrayAppend, + BuiltinScalarFunction::ArraySort => Self::ArraySort, BuiltinScalarFunction::ArrayConcat => Self::ArrayConcat, BuiltinScalarFunction::ArrayEmpty => Self::ArrayEmpty, + BuiltinScalarFunction::ArrayExcept => Self::ArrayExcept, BuiltinScalarFunction::ArrayHasAll => Self::ArrayHasAll, BuiltinScalarFunction::ArrayHasAny => Self::ArrayHasAny, BuiltinScalarFunction::ArrayHas => Self::ArrayHas, BuiltinScalarFunction::ArrayDims => Self::ArrayDims, + BuiltinScalarFunction::ArrayDistinct => Self::ArrayDistinct, BuiltinScalarFunction::ArrayElement => Self::ArrayElement, BuiltinScalarFunction::Flatten => Self::Flatten, BuiltinScalarFunction::ArrayLength => Self::ArrayLength, BuiltinScalarFunction::ArrayNdims => Self::ArrayNdims, + BuiltinScalarFunction::ArrayPopFront => Self::ArrayPopFront, BuiltinScalarFunction::ArrayPopBack => Self::ArrayPopBack, BuiltinScalarFunction::ArrayPosition => Self::ArrayPosition, BuiltinScalarFunction::ArrayPositions => Self::ArrayPositions, @@ -1481,6 +1531,9 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll, BuiltinScalarFunction::ArraySlice => Self::ArraySlice, BuiltinScalarFunction::ArrayToString => Self::ArrayToString, + BuiltinScalarFunction::ArrayIntersect => Self::ArrayIntersect, + BuiltinScalarFunction::ArrayUnion => Self::ArrayUnion, + BuiltinScalarFunction::Range => Self::Range, BuiltinScalarFunction::Cardinality => Self::Cardinality, BuiltinScalarFunction::MakeArray => Self::Array, BuiltinScalarFunction::NullIf => Self::NullIf, @@ -1539,6 +1592,10 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Isnan => Self::Isnan, BuiltinScalarFunction::Iszero => Self::Iszero, BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof, + BuiltinScalarFunction::OverLay => Self::OverLay, + BuiltinScalarFunction::Levenshtein => Self::Levenshtein, + BuiltinScalarFunction::SubstrIndex => Self::SubstrIndex, + BuiltinScalarFunction::FindInSet => Self::FindInSet, }; Ok(scalar_function) diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index a956eded9032c..23ab813ca7397 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -18,17 +18,20 @@ //! Serde code to convert from protocol buffers to Rust data structures. use std::convert::{TryFrom, TryInto}; -use std::ops::Deref; use std::sync::Arc; use arrow::compute::SortOptions; use datafusion::arrow::datatypes::Schema; -use datafusion::datasource::listing::{FileRange, PartitionedFile}; +use datafusion::datasource::file_format::csv::CsvSink; +use datafusion::datasource::file_format::json::JsonSink; +#[cfg(feature = "parquet")] +use datafusion::datasource::file_format::parquet::ParquetSink; +use datafusion::datasource::listing::{FileRange, ListingTableUrl, PartitionedFile}; use datafusion::datasource::object_store::ObjectStoreUrl; -use datafusion::datasource::physical_plan::FileScanConfig; +use datafusion::datasource::physical_plan::{FileScanConfig, FileSinkConfig}; use datafusion::execution::context::ExecutionProps; use datafusion::execution::FunctionRegistry; -use datafusion::logical_expr::window_function::WindowFunction; +use datafusion::logical_expr::WindowFunctionDefinition; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ in_list, BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, LikeExpr, @@ -39,8 +42,14 @@ use datafusion::physical_plan::windows::create_window_expr; use datafusion::physical_plan::{ functions, ColumnStatistics, Partitioning, PhysicalExpr, Statistics, WindowExpr, }; +use datafusion_common::file_options::csv_writer::CsvWriterOptions; +use datafusion_common::file_options::json_writer::JsonWriterOptions; +use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; +use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; -use datafusion_common::{not_impl_err, DataFusionError, JoinSide, Result, ScalarValue}; +use datafusion_common::{ + not_impl_err, DataFusionError, FileTypeWriterOptions, JoinSide, Result, ScalarValue, +}; use crate::common::proto_error; use crate::convert_required; @@ -48,6 +57,7 @@ use crate::logical_plan; use crate::protobuf; use crate::protobuf::physical_expr_node::ExprType; +use crate::logical_plan::{csv_writer_options_from_proto, writer_properties_from_proto}; use chrono::{TimeZone, Utc}; use object_store::path::Path; use object_store::ObjectMeta; @@ -308,12 +318,12 @@ pub fn parse_physical_expr( &e.name, fun_expr, args, - &convert_required!(e.return_type)?, + convert_required!(e.return_type)?, None, )) } ExprType::ScalarUdf(e) => { - let scalar_fun = registry.udf(e.name.as_str())?.deref().clone().fun; + let scalar_fun = registry.udf(e.name.as_str())?.fun().clone(); let args = e .args @@ -325,7 +335,7 @@ pub fn parse_physical_expr( e.name.as_str(), scalar_fun, args, - &convert_required!(e.return_type)?, + convert_required!(e.return_type)?, None, )) } @@ -404,7 +414,9 @@ fn parse_required_physical_expr( }) } -impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> for WindowFunction { +impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> + for WindowFunctionDefinition +{ type Error = DataFusionError; fn try_from( @@ -418,7 +430,7 @@ impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> for WindowFun )) })?; - Ok(WindowFunction::AggregateFunction(f.into())) + Ok(WindowFunctionDefinition::AggregateFunction(f.into())) } protobuf::physical_window_expr_node::WindowFunction::BuiltInFunction(n) => { let f = protobuf::BuiltInWindowFunction::try_from(*n).map_err(|_| { @@ -427,7 +439,7 @@ impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> for WindowFun )) })?; - Ok(WindowFunction::BuiltInWindowFunction(f.into())) + Ok(WindowFunctionDefinition::BuiltInWindowFunction(f.into())) } } } @@ -522,7 +534,6 @@ pub fn parse_protobuf_file_scan_config( limit: proto.limit.as_ref().map(|sl| sl.limit as usize), table_partition_cols, output_ordering, - infinite_source: false, }) } @@ -536,6 +547,7 @@ impl TryFrom<&protobuf::PartitionedFile> for PartitionedFile { last_modified: Utc.timestamp_nanos(val.last_modified_ns as i64), size: val.size as usize, e_tag: None, + version: None, }, partition_values: val .partition_values @@ -697,3 +709,103 @@ impl TryFrom<&protobuf::Statistics> for Statistics { }) } } + +impl TryFrom<&protobuf::JsonSink> for JsonSink { + type Error = DataFusionError; + + fn try_from(value: &protobuf::JsonSink) -> Result { + Ok(Self::new(convert_required!(value.config)?)) + } +} + +#[cfg(feature = "parquet")] +impl TryFrom<&protobuf::ParquetSink> for ParquetSink { + type Error = DataFusionError; + + fn try_from(value: &protobuf::ParquetSink) -> Result { + Ok(Self::new(convert_required!(value.config)?)) + } +} + +impl TryFrom<&protobuf::CsvSink> for CsvSink { + type Error = DataFusionError; + + fn try_from(value: &protobuf::CsvSink) -> Result { + Ok(Self::new(convert_required!(value.config)?)) + } +} + +impl TryFrom<&protobuf::FileSinkConfig> for FileSinkConfig { + type Error = DataFusionError; + + fn try_from(conf: &protobuf::FileSinkConfig) -> Result { + let file_groups = conf + .file_groups + .iter() + .map(TryInto::try_into) + .collect::>>()?; + let table_paths = conf + .table_paths + .iter() + .map(ListingTableUrl::parse) + .collect::>>()?; + let table_partition_cols = conf + .table_partition_cols + .iter() + .map(|protobuf::PartitionColumn { name, arrow_type }| { + let data_type = convert_required!(arrow_type)?; + Ok((name.clone(), data_type)) + }) + .collect::>>()?; + Ok(Self { + object_store_url: ObjectStoreUrl::parse(&conf.object_store_url)?, + file_groups, + table_paths, + output_schema: Arc::new(convert_required!(conf.output_schema)?), + table_partition_cols, + single_file_output: conf.single_file_output, + overwrite: conf.overwrite, + file_type_writer_options: convert_required!(conf.file_type_writer_options)?, + }) + } +} + +impl From for CompressionTypeVariant { + fn from(value: protobuf::CompressionTypeVariant) -> Self { + match value { + protobuf::CompressionTypeVariant::Gzip => Self::GZIP, + protobuf::CompressionTypeVariant::Bzip2 => Self::BZIP2, + protobuf::CompressionTypeVariant::Xz => Self::XZ, + protobuf::CompressionTypeVariant::Zstd => Self::ZSTD, + protobuf::CompressionTypeVariant::Uncompressed => Self::UNCOMPRESSED, + } + } +} + +impl TryFrom<&protobuf::FileTypeWriterOptions> for FileTypeWriterOptions { + type Error = DataFusionError; + + fn try_from(value: &protobuf::FileTypeWriterOptions) -> Result { + let file_type = value + .file_type + .as_ref() + .ok_or_else(|| proto_error("Missing required file_type field in protobuf"))?; + + match file_type { + protobuf::file_type_writer_options::FileType::JsonOptions(opts) => { + let compression: CompressionTypeVariant = opts.compression().into(); + Ok(Self::JSON(JsonWriterOptions::new(compression))) + } + protobuf::file_type_writer_options::FileType::CsvOptions(opts) => { + let write_options = csv_writer_options_from_proto(opts)?; + let compression: CompressionTypeVariant = opts.compression().into(); + Ok(Self::CSV(CsvWriterOptions::new(write_options, compression))) + } + protobuf::file_type_writer_options::FileType::ParquetOptions(opt) => { + let props = opt.writer_properties.clone().unwrap_or_default(); + let writer_properties = writer_properties_from_proto(&props)?; + Ok(Self::Parquet(ParquetWriterOptions::new(writer_properties))) + } + } + } +} diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 431b8e42cdaf1..95becb3fe4b3a 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -21,7 +21,11 @@ use std::sync::Arc; use datafusion::arrow::compute::SortOptions; use datafusion::arrow::datatypes::SchemaRef; +use datafusion::datasource::file_format::csv::CsvSink; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; +use datafusion::datasource::file_format::json::JsonSink; +#[cfg(feature = "parquet")] +use datafusion::datasource::file_format::parquet::ParquetSink; #[cfg(feature = "parquet")] use datafusion::datasource::physical_plan::ParquetExec; use datafusion::datasource::physical_plan::{AvroExec, CsvExec}; @@ -36,20 +40,23 @@ use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::explain::ExplainExec; use datafusion::physical_plan::expressions::{Column, PhysicalSortExpr}; use datafusion::physical_plan::filter::FilterExec; +use datafusion::physical_plan::insert::FileSinkExec; use datafusion::physical_plan::joins::utils::{ColumnIndex, JoinFilter}; -use datafusion::physical_plan::joins::{CrossJoinExec, NestedLoopJoinExec}; +use datafusion::physical_plan::joins::{ + CrossJoinExec, NestedLoopJoinExec, StreamJoinPartitionMode, SymmetricHashJoinExec, +}; use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode}; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; -use datafusion::physical_plan::union::UnionExec; -use datafusion::physical_plan::windows::{ - BoundedWindowAggExec, PartitionSearchMode, WindowAggExec, -}; +use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; +use datafusion::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use datafusion::physical_plan::{ - udaf, AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, WindowExpr, + udaf, AggregateExpr, ExecutionPlan, InputOrderMode, Partitioning, PhysicalExpr, + WindowExpr, }; use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; use prost::bytes::BufMut; @@ -64,7 +71,9 @@ use crate::protobuf::physical_aggregate_expr_node::AggregateFunction; use crate::protobuf::physical_expr_node::ExprType; use crate::protobuf::physical_plan_node::PhysicalPlanType; use crate::protobuf::repartition_exec_node::PartitionMethod; -use crate::protobuf::{self, window_agg_exec_node, PhysicalPlanNode}; +use crate::protobuf::{ + self, window_agg_exec_node, PhysicalPlanNode, PhysicalSortExprNodeCollection, +}; use crate::{convert_required, into_required}; use self::from_proto::parse_physical_window_expr; @@ -153,7 +162,16 @@ impl AsExecutionPlan for PhysicalPlanNode { .to_owned(), ) })?; - Ok(Arc::new(FilterExec::try_new(predicate, input)?)) + let filter_selectivity = filter.default_filter_selectivity.try_into(); + let filter = FilterExec::try_new(predicate, input)?; + match filter_selectivity { + Ok(filter_selectivity) => Ok(Arc::new( + filter.with_default_selectivity(filter_selectivity)?, + )), + Err(_) => Err(DataFusionError::Internal( + "filter_selectivity in PhysicalPlanNode is invalid ".to_owned(), + )), + } } PhysicalPlanType::CsvScan(scan) => Ok(Arc::new(CsvExec::new( parse_protobuf_file_scan_config( @@ -307,20 +325,18 @@ impl AsExecutionPlan for PhysicalPlanNode { }) .collect::>>>()?; - if let Some(partition_search_mode) = - window_agg.partition_search_mode.as_ref() - { - let partition_search_mode = match partition_search_mode { - window_agg_exec_node::PartitionSearchMode::Linear(_) => { - PartitionSearchMode::Linear + if let Some(input_order_mode) = window_agg.input_order_mode.as_ref() { + let input_order_mode = match input_order_mode { + window_agg_exec_node::InputOrderMode::Linear(_) => { + InputOrderMode::Linear } - window_agg_exec_node::PartitionSearchMode::PartiallySorted( - protobuf::PartiallySortedPartitionSearchMode { columns }, - ) => PartitionSearchMode::PartiallySorted( + window_agg_exec_node::InputOrderMode::PartiallySorted( + protobuf::PartiallySortedInputOrderMode { columns }, + ) => InputOrderMode::PartiallySorted( columns.iter().map(|c| *c as usize).collect(), ), - window_agg_exec_node::PartitionSearchMode::Sorted(_) => { - PartitionSearchMode::Sorted + window_agg_exec_node::InputOrderMode::Sorted(_) => { + InputOrderMode::Sorted } }; @@ -328,7 +344,7 @@ impl AsExecutionPlan for PhysicalPlanNode { physical_window_expr, input, partition_keys, - partition_search_mode, + input_order_mode, )?)) } else { Ok(Arc::new(WindowAggExec::try_new( @@ -414,19 +430,6 @@ impl AsExecutionPlan for PhysicalPlanNode { .transpose() }) .collect::, _>>()?; - let physical_order_by_expr = hash_agg - .order_by_expr - .iter() - .map(|expr| { - expr.sort_expr - .iter() - .map(|e| { - parse_physical_sort_expr(e, registry, &physical_schema) - }) - .collect::>>() - .map(|exprs| (!exprs.is_empty()).then_some(exprs)) - }) - .collect::>>()?; let physical_aggr_expr: Vec> = hash_agg .aggr_expr @@ -485,9 +488,8 @@ impl AsExecutionPlan for PhysicalPlanNode { PhysicalGroupBy::new(group_expr, null_expr, groups), physical_aggr_expr, physical_filter_expr, - physical_order_by_expr, input, - Arc::new(input_schema.try_into()?), + physical_schema, )?)) } PhysicalPlanType::HashJoin(hashjoin) => { @@ -533,7 +535,7 @@ impl AsExecutionPlan for PhysicalPlanNode { f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - registry, &schema + registry, &schema, )?; let column_indices = f.column_indices .iter() @@ -544,7 +546,7 @@ impl AsExecutionPlan for PhysicalPlanNode { i.side)) )?; - Ok(ColumnIndex{ + Ok(ColumnIndex { index: i.index as usize, side: side.into(), }) @@ -579,6 +581,97 @@ impl AsExecutionPlan for PhysicalPlanNode { hashjoin.null_equals_null, )?)) } + PhysicalPlanType::SymmetricHashJoin(sym_join) => { + let left = into_physical_plan( + &sym_join.left, + registry, + runtime, + extension_codec, + )?; + let right = into_physical_plan( + &sym_join.right, + registry, + runtime, + extension_codec, + )?; + let on = sym_join + .on + .iter() + .map(|col| { + let left = into_required!(col.left)?; + let right = into_required!(col.right)?; + Ok((left, right)) + }) + .collect::>()?; + let join_type = protobuf::JoinType::try_from(sym_join.join_type) + .map_err(|_| { + proto_error(format!( + "Received a SymmetricHashJoin message with unknown JoinType {}", + sym_join.join_type + )) + })?; + let filter = sym_join + .filter + .as_ref() + .map(|f| { + let schema = f + .schema + .as_ref() + .ok_or_else(|| proto_error("Missing JoinFilter schema"))? + .try_into()?; + + let expression = parse_physical_expr( + f.expression.as_ref().ok_or_else(|| { + proto_error("Unexpected empty filter expression") + })?, + registry, &schema, + )?; + let column_indices = f.column_indices + .iter() + .map(|i| { + let side = protobuf::JoinSide::try_from(i.side) + .map_err(|_| proto_error(format!( + "Received a HashJoinNode message with JoinSide in Filter {}", + i.side)) + )?; + + Ok(ColumnIndex { + index: i.index as usize, + side: side.into(), + }) + }) + .collect::>()?; + + Ok(JoinFilter::new(expression, column_indices, schema)) + }) + .map_or(Ok(None), |v: Result| v.map(Some))?; + + let partition_mode = + protobuf::StreamPartitionMode::try_from(sym_join.partition_mode).map_err(|_| { + proto_error(format!( + "Received a SymmetricHashJoin message with unknown PartitionMode {}", + sym_join.partition_mode + )) + })?; + let partition_mode = match partition_mode { + protobuf::StreamPartitionMode::SinglePartition => { + StreamJoinPartitionMode::SinglePartition + } + protobuf::StreamPartitionMode::PartitionedExec => { + StreamJoinPartitionMode::Partitioned + } + }; + SymmetricHashJoinExec::try_new( + left, + right, + on, + filter, + &join_type.into(), + sym_join.null_equals_null, + partition_mode, + ) + .map(|e| Arc::new(e) as _) + } PhysicalPlanType::Union(union) => { let mut inputs: Vec> = vec![]; for input in &union.inputs { @@ -590,6 +683,17 @@ impl AsExecutionPlan for PhysicalPlanNode { } Ok(Arc::new(UnionExec::new(inputs))) } + PhysicalPlanType::Interleave(interleave) => { + let mut inputs: Vec> = vec![]; + for input in &interleave.inputs { + inputs.push(input.try_into_physical_plan( + registry, + runtime, + extension_codec, + )?); + } + Ok(Arc::new(InterleaveExec::try_new(inputs)?)) + } PhysicalPlanType::CrossJoin(crossjoin) => { let left: Arc = into_physical_plan( &crossjoin.left, @@ -607,7 +711,11 @@ impl AsExecutionPlan for PhysicalPlanNode { } PhysicalPlanType::Empty(empty) => { let schema = Arc::new(convert_required!(empty.schema)?); - Ok(Arc::new(EmptyExec::new(empty.produce_one_row, schema))) + Ok(Arc::new(EmptyExec::new(schema))) + } + PhysicalPlanType::PlaceholderRow(placeholder) => { + let schema = Arc::new(convert_required!(placeholder.schema)?); + Ok(Arc::new(PlaceholderRowExec::new(schema))) } PhysicalPlanType::Sort(sort) => { let input: Arc = @@ -632,7 +740,7 @@ impl AsExecutionPlan for PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr,registry, input.schema().as_ref())?, + expr: parse_physical_expr(expr, registry, input.schema().as_ref())?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -679,7 +787,7 @@ impl AsExecutionPlan for PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr,registry, input.schema().as_ref())?, + expr: parse_physical_expr(expr, registry, input.schema().as_ref())?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -742,7 +850,7 @@ impl AsExecutionPlan for PhysicalPlanNode { f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - registry, &schema + registry, &schema, )?; let column_indices = f.column_indices .iter() @@ -753,7 +861,7 @@ impl AsExecutionPlan for PhysicalPlanNode { i.side)) )?; - Ok(ColumnIndex{ + Ok(ColumnIndex { index: i.index as usize, side: side.into(), }) @@ -782,7 +890,100 @@ impl AsExecutionPlan for PhysicalPlanNode { analyze.verbose, analyze.show_statistics, input, - Arc::new(analyze.schema.as_ref().unwrap().try_into()?), + Arc::new(convert_required!(analyze.schema)?), + ))) + } + PhysicalPlanType::JsonSink(sink) => { + let input = + into_physical_plan(&sink.input, registry, runtime, extension_codec)?; + + let data_sink: JsonSink = sink + .sink + .as_ref() + .ok_or_else(|| proto_error("Missing required field in protobuf"))? + .try_into()?; + let sink_schema = convert_required!(sink.sink_schema)?; + let sort_order = sink + .sort_order + .as_ref() + .map(|collection| { + collection + .physical_sort_expr_nodes + .iter() + .map(|proto| { + parse_physical_sort_expr(proto, registry, &sink_schema) + .map(Into::into) + }) + .collect::>>() + }) + .transpose()?; + Ok(Arc::new(FileSinkExec::new( + input, + Arc::new(data_sink), + Arc::new(sink_schema), + sort_order, + ))) + } + PhysicalPlanType::CsvSink(sink) => { + let input = + into_physical_plan(&sink.input, registry, runtime, extension_codec)?; + + let data_sink: CsvSink = sink + .sink + .as_ref() + .ok_or_else(|| proto_error("Missing required field in protobuf"))? + .try_into()?; + let sink_schema = convert_required!(sink.sink_schema)?; + let sort_order = sink + .sort_order + .as_ref() + .map(|collection| { + collection + .physical_sort_expr_nodes + .iter() + .map(|proto| { + parse_physical_sort_expr(proto, registry, &sink_schema) + .map(Into::into) + }) + .collect::>>() + }) + .transpose()?; + Ok(Arc::new(FileSinkExec::new( + input, + Arc::new(data_sink), + Arc::new(sink_schema), + sort_order, + ))) + } + PhysicalPlanType::ParquetSink(sink) => { + let input = + into_physical_plan(&sink.input, registry, runtime, extension_codec)?; + + let data_sink: ParquetSink = sink + .sink + .as_ref() + .ok_or_else(|| proto_error("Missing required field in protobuf"))? + .try_into()?; + let sink_schema = convert_required!(sink.sink_schema)?; + let sort_order = sink + .sort_order + .as_ref() + .map(|collection| { + collection + .physical_sort_expr_nodes + .iter() + .map(|proto| { + parse_physical_sort_expr(proto, registry, &sink_schema) + .map(Into::into) + }) + .collect::>>() + }) + .transpose()?; + Ok(Arc::new(FileSinkExec::new( + input, + Arc::new(data_sink), + Arc::new(sink_schema), + sort_order, ))) } } @@ -863,6 +1064,7 @@ impl AsExecutionPlan for PhysicalPlanNode { protobuf::FilterExecNode { input: Some(Box::new(input)), expr: Some(exec.predicate().clone().try_into()?), + default_filter_selectivity: exec.default_selectivity() as u32, }, ))), }); @@ -973,6 +1175,79 @@ impl AsExecutionPlan for PhysicalPlanNode { }); } + if let Some(exec) = plan.downcast_ref::() { + let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.left().to_owned(), + extension_codec, + )?; + let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.right().to_owned(), + extension_codec, + )?; + let on = exec + .on() + .iter() + .map(|tuple| protobuf::JoinOn { + left: Some(protobuf::PhysicalColumn { + name: tuple.0.name().to_string(), + index: tuple.0.index() as u32, + }), + right: Some(protobuf::PhysicalColumn { + name: tuple.1.name().to_string(), + index: tuple.1.index() as u32, + }), + }) + .collect(); + let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); + let filter = exec + .filter() + .as_ref() + .map(|f| { + let expression = f.expression().to_owned().try_into()?; + let column_indices = f + .column_indices() + .iter() + .map(|i| { + let side: protobuf::JoinSide = i.side.to_owned().into(); + protobuf::ColumnIndex { + index: i.index as u32, + side: side.into(), + } + }) + .collect(); + let schema = f.schema().try_into()?; + Ok(protobuf::JoinFilter { + expression: Some(expression), + column_indices, + schema: Some(schema), + }) + }) + .map_or(Ok(None), |v: Result| v.map(Some))?; + + let partition_mode = match exec.partition_mode() { + StreamJoinPartitionMode::SinglePartition => { + protobuf::StreamPartitionMode::SinglePartition + } + StreamJoinPartitionMode::Partitioned => { + protobuf::StreamPartitionMode::PartitionedExec + } + }; + + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::SymmetricHashJoin(Box::new( + protobuf::SymmetricHashJoinExecNode { + left: Some(Box::new(left)), + right: Some(Box::new(right)), + on, + join_type: join_type.into(), + partition_mode: partition_mode.into(), + null_equals_null: exec.null_equals_null(), + filter, + }, + ))), + }); + } + if let Some(exec) = plan.downcast_ref::() { let left = protobuf::PhysicalPlanNode::try_from_physical_plan( exec.left().to_owned(), @@ -1013,12 +1288,6 @@ impl AsExecutionPlan for PhysicalPlanNode { .map(|expr| expr.to_owned().try_into()) .collect::>>()?; - let order_by = exec - .order_by_expr() - .iter() - .map(|expr| expr.to_owned().try_into()) - .collect::>>()?; - let agg = exec .aggr_expr() .iter() @@ -1071,7 +1340,6 @@ impl AsExecutionPlan for PhysicalPlanNode { group_expr_name: group_names, aggr_expr: agg, filter_expr: filter, - order_by_expr: order_by, aggr_expr_name: agg_names, mode: agg_mode as i32, input: Some(Box::new(input)), @@ -1088,7 +1356,17 @@ impl AsExecutionPlan for PhysicalPlanNode { return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Empty( protobuf::EmptyExecNode { - produce_one_row: empty.produce_one_row(), + schema: Some(schema), + }, + )), + }); + } + + if let Some(empty) = plan.downcast_ref::() { + let schema = empty.schema().as_ref().try_into()?; + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::PlaceholderRow( + protobuf::PlaceholderRowExecNode { schema: Some(schema), }, )), @@ -1255,6 +1533,21 @@ impl AsExecutionPlan for PhysicalPlanNode { }); } + if let Some(interleave) = plan.downcast_ref::() { + let mut inputs: Vec = vec![]; + for input in interleave.inputs() { + inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan( + input.to_owned(), + extension_codec, + )?); + } + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Interleave( + protobuf::InterleaveExecNode { inputs }, + )), + }); + } + if let Some(exec) = plan.downcast_ref::() { let input = protobuf::PhysicalPlanNode::try_from_physical_plan( exec.input().to_owned(), @@ -1359,7 +1652,7 @@ impl AsExecutionPlan for PhysicalPlanNode { input: Some(Box::new(input)), window_expr, partition_keys, - partition_search_mode: None, + input_order_mode: None, }, ))), }); @@ -1383,24 +1676,20 @@ impl AsExecutionPlan for PhysicalPlanNode { .map(|e| e.clone().try_into()) .collect::>>()?; - let partition_search_mode = match &exec.partition_search_mode { - PartitionSearchMode::Linear => { - window_agg_exec_node::PartitionSearchMode::Linear( - protobuf::EmptyMessage {}, - ) - } - PartitionSearchMode::PartiallySorted(columns) => { - window_agg_exec_node::PartitionSearchMode::PartiallySorted( - protobuf::PartiallySortedPartitionSearchMode { + let input_order_mode = match &exec.input_order_mode { + InputOrderMode::Linear => window_agg_exec_node::InputOrderMode::Linear( + protobuf::EmptyMessage {}, + ), + InputOrderMode::PartiallySorted(columns) => { + window_agg_exec_node::InputOrderMode::PartiallySorted( + protobuf::PartiallySortedInputOrderMode { columns: columns.iter().map(|c| *c as u64).collect(), }, ) } - PartitionSearchMode::Sorted => { - window_agg_exec_node::PartitionSearchMode::Sorted( - protobuf::EmptyMessage {}, - ) - } + InputOrderMode::Sorted => window_agg_exec_node::InputOrderMode::Sorted( + protobuf::EmptyMessage {}, + ), }; return Ok(protobuf::PhysicalPlanNode { @@ -1409,12 +1698,80 @@ impl AsExecutionPlan for PhysicalPlanNode { input: Some(Box::new(input)), window_expr, partition_keys, - partition_search_mode: Some(partition_search_mode), + input_order_mode: Some(input_order_mode), }, ))), }); } + if let Some(exec) = plan.downcast_ref::() { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.input().to_owned(), + extension_codec, + )?; + let sort_order = match exec.sort_order() { + Some(requirements) => { + let expr = requirements + .iter() + .map(|requirement| { + let expr: PhysicalSortExpr = requirement.to_owned().into(); + let sort_expr = protobuf::PhysicalSortExprNode { + expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + asc: !expr.options.descending, + nulls_first: expr.options.nulls_first, + }; + Ok(sort_expr) + }) + .collect::>>()?; + Some(PhysicalSortExprNodeCollection { + physical_sort_expr_nodes: expr, + }) + } + None => None, + }; + + if let Some(sink) = exec.sink().as_any().downcast_ref::() { + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::JsonSink(Box::new( + protobuf::JsonSinkExecNode { + input: Some(Box::new(input)), + sink: Some(sink.try_into()?), + sink_schema: Some(exec.schema().as_ref().try_into()?), + sort_order, + }, + ))), + }); + } + + if let Some(sink) = exec.sink().as_any().downcast_ref::() { + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::CsvSink(Box::new( + protobuf::CsvSinkExecNode { + input: Some(Box::new(input)), + sink: Some(sink.try_into()?), + sink_schema: Some(exec.schema().as_ref().try_into()?), + sort_order, + }, + ))), + }); + } + + if let Some(sink) = exec.sink().as_any().downcast_ref::() { + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::ParquetSink(Box::new( + protobuf::ParquetSinkExecNode { + input: Some(Box::new(input)), + sink: Some(sink.try_into()?), + sink_schema: Some(exec.schema().as_ref().try_into()?), + sort_order, + }, + ))), + }); + } + + // If unknown DataSink then let extension handle it + } + let mut buf: Vec = vec![]; match extension_codec.try_encode(plan_clone.clone(), &mut buf) { Ok(_) => { diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 114baab6ccc49..f4e3f9e4dca7f 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -28,8 +28,17 @@ use crate::protobuf::{ ScalarValue, }; -use datafusion::datasource::listing::{FileRange, PartitionedFile}; -use datafusion::datasource::physical_plan::FileScanConfig; +#[cfg(feature = "parquet")] +use datafusion::datasource::file_format::parquet::ParquetSink; + +use crate::logical_plan::{csv_writer_options_to_proto, writer_properties_to_proto}; +use datafusion::datasource::{ + file_format::csv::CsvSink, + file_format::json::JsonSink, + listing::{FileRange, PartitionedFile}, + physical_plan::FileScanConfig, + physical_plan::FileSinkConfig, +}; use datafusion::logical_expr::BuiltinScalarFunction; use datafusion::physical_expr::expressions::{GetFieldAccessExpr, GetIndexedFieldExpr}; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; @@ -50,7 +59,15 @@ use datafusion::physical_plan::{ AggregateExpr, ColumnStatistics, PhysicalExpr, Statistics, WindowExpr, }; use datafusion_common::{ - internal_err, not_impl_err, stats::Precision, DataFusionError, JoinSide, Result, + file_options::{ + arrow_writer::ArrowWriterOptions, avro_writer::AvroWriterOptions, + csv_writer::CsvWriterOptions, json_writer::JsonWriterOptions, + parquet_writer::ParquetWriterOptions, + }, + internal_err, not_impl_err, + parsers::CompressionTypeVariant, + stats::Precision, + DataFusionError, FileTypeWriterOptions, JoinSide, Result, }; impl TryFrom> for protobuf::PhysicalExprNode { @@ -71,10 +88,11 @@ impl TryFrom> for protobuf::PhysicalExprNode { .collect::>>()?; if let Some(a) = a.as_any().downcast_ref::() { + let name = a.fun().name().to_string(); return Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( protobuf::PhysicalAggregateExprNode { - aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(a.fun().name.clone())), + aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(name)), expr: expressions, ordering_req, distinct: false, @@ -167,7 +185,7 @@ impl TryFrom> for protobuf::PhysicalWindowExprNode { args.insert( 1, Arc::new(Literal::new( - datafusion_common::ScalarValue::Int64(Some(n as i64)), + datafusion_common::ScalarValue::Int64(Some(n)), )), ); protobuf::BuiltInWindowFunction::NthValue @@ -790,3 +808,126 @@ impl TryFrom for protobuf::PhysicalSortExprNode { }) } } + +impl TryFrom<&JsonSink> for protobuf::JsonSink { + type Error = DataFusionError; + + fn try_from(value: &JsonSink) -> Result { + Ok(Self { + config: Some(value.config().try_into()?), + }) + } +} + +impl TryFrom<&CsvSink> for protobuf::CsvSink { + type Error = DataFusionError; + + fn try_from(value: &CsvSink) -> Result { + Ok(Self { + config: Some(value.config().try_into()?), + }) + } +} + +#[cfg(feature = "parquet")] +impl TryFrom<&ParquetSink> for protobuf::ParquetSink { + type Error = DataFusionError; + + fn try_from(value: &ParquetSink) -> Result { + Ok(Self { + config: Some(value.config().try_into()?), + }) + } +} + +impl TryFrom<&FileSinkConfig> for protobuf::FileSinkConfig { + type Error = DataFusionError; + + fn try_from(conf: &FileSinkConfig) -> Result { + let file_groups = conf + .file_groups + .iter() + .map(TryInto::try_into) + .collect::>>()?; + let table_paths = conf + .table_paths + .iter() + .map(ToString::to_string) + .collect::>(); + let table_partition_cols = conf + .table_partition_cols + .iter() + .map(|(name, data_type)| { + Ok(protobuf::PartitionColumn { + name: name.to_owned(), + arrow_type: Some(data_type.try_into()?), + }) + }) + .collect::>>()?; + let file_type_writer_options = &conf.file_type_writer_options; + Ok(Self { + object_store_url: conf.object_store_url.to_string(), + file_groups, + table_paths, + output_schema: Some(conf.output_schema.as_ref().try_into()?), + table_partition_cols, + single_file_output: conf.single_file_output, + overwrite: conf.overwrite, + file_type_writer_options: Some(file_type_writer_options.try_into()?), + }) + } +} + +impl From<&CompressionTypeVariant> for protobuf::CompressionTypeVariant { + fn from(value: &CompressionTypeVariant) -> Self { + match value { + CompressionTypeVariant::GZIP => Self::Gzip, + CompressionTypeVariant::BZIP2 => Self::Bzip2, + CompressionTypeVariant::XZ => Self::Xz, + CompressionTypeVariant::ZSTD => Self::Zstd, + CompressionTypeVariant::UNCOMPRESSED => Self::Uncompressed, + } + } +} + +impl TryFrom<&FileTypeWriterOptions> for protobuf::FileTypeWriterOptions { + type Error = DataFusionError; + + fn try_from(opts: &FileTypeWriterOptions) -> Result { + let file_type = match opts { + #[cfg(feature = "parquet")] + FileTypeWriterOptions::Parquet(ParquetWriterOptions { writer_options }) => { + protobuf::file_type_writer_options::FileType::ParquetOptions( + protobuf::ParquetWriterOptions { + writer_properties: Some(writer_properties_to_proto( + writer_options, + )), + }, + ) + } + FileTypeWriterOptions::CSV(CsvWriterOptions { + writer_options, + compression, + }) => protobuf::file_type_writer_options::FileType::CsvOptions( + csv_writer_options_to_proto(writer_options, compression), + ), + FileTypeWriterOptions::JSON(JsonWriterOptions { compression }) => { + let compression: protobuf::CompressionTypeVariant = compression.into(); + protobuf::file_type_writer_options::FileType::JsonOptions( + protobuf::JsonWriterOptions { + compression: compression.into(), + }, + ) + } + FileTypeWriterOptions::Avro(AvroWriterOptions {}) => { + return not_impl_err!("Avro file sink protobuf serialization") + } + FileTypeWriterOptions::Arrow(ArrowWriterOptions {}) => { + return not_impl_err!("Arrow file sink protobuf serialization") + } + }; + Ok(Self { + file_type: Some(file_type), + }) + } +} diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index ca801df337f14..402781e17e6f2 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -15,14 +15,16 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; use std::collections::HashMap; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; -use arrow::array::ArrayRef; +use arrow::array::{ArrayRef, FixedSizeListArray}; +use arrow::csv::WriterBuilder; use arrow::datatypes::{ - DataType, Field, Fields, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, - Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, + DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, + IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, }; use prost::Message; @@ -31,22 +33,29 @@ use datafusion::datasource::provider::TableProviderFactory; use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::parquet::file::properties::{WriterProperties, WriterVersion}; use datafusion::physical_plan::functions::make_scalar_function; use datafusion::prelude::{create_udf, CsvReadOptions, SessionConfig, SessionContext}; use datafusion::test_util::{TestTableFactory, TestTableProvider}; -use datafusion_common::Result; -use datafusion_common::{internal_err, not_impl_err, plan_err}; +use datafusion_common::file_options::csv_writer::CsvWriterOptions; +use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; +use datafusion_common::file_options::StatementOptions; +use datafusion_common::parsers::CompressionTypeVariant; +use datafusion_common::{internal_err, not_impl_err, plan_err, FileTypeWriterOptions}; use datafusion_common::{DFField, DFSchema, DFSchemaRef, DataFusionError, ScalarValue}; +use datafusion_common::{FileType, Result}; +use datafusion_expr::dml::{CopyOptions, CopyTo}; use datafusion_expr::expr::{ self, Between, BinaryExpr, Case, Cast, GroupingSet, InList, Like, ScalarFunction, - ScalarUDF, Sort, + Sort, }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ col, create_udaf, lit, Accumulator, AggregateFunction, BuiltinScalarFunction::{Sqrt, Substr}, Expr, LogicalPlan, Operator, PartitionEvaluator, Signature, TryCast, Volatility, - WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, WindowUDF, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, + WindowUDFImpl, }; use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, @@ -217,11 +226,10 @@ async fn roundtrip_custom_memory_tables() -> Result<()> { async fn roundtrip_custom_listing_tables() -> Result<()> { let ctx = SessionContext::new(); - // Make sure during round-trip, constraint information is preserved let query = "CREATE EXTERNAL TABLE multiple_ordered_table_with_pk ( a0 INTEGER, - a INTEGER, - b INTEGER, + a INTEGER DEFAULT 1*2 + 3, + b INTEGER DEFAULT NULL, c INTEGER, d INTEGER, primary key(c) @@ -232,11 +240,13 @@ async fn roundtrip_custom_listing_tables() -> Result<()> { WITH ORDER (c ASC) LOCATION '../core/tests/data/window_2.csv';"; - let plan = ctx.sql(query).await?.into_optimized_plan()?; + let plan = ctx.state().create_logical_plan(query).await?; let bytes = logical_plan_to_bytes(&plan)?; let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; - assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + // Use exact matching to verify everything. Make sure during round-trip, + // information like constraints, column defaults, and other aspects of the plan are preserved. + assert_eq!(plan, logical_round_trip); Ok(()) } @@ -300,6 +310,184 @@ async fn roundtrip_logical_plan_aggregation() -> Result<()> { Ok(()) } +#[tokio::test] +async fn roundtrip_logical_plan_copy_to_sql_options() -> Result<()> { + let ctx = SessionContext::new(); + + let input = create_csv_scan(&ctx).await?; + + let mut options = HashMap::new(); + options.insert("foo".to_string(), "bar".to_string()); + + let plan = LogicalPlan::Copy(CopyTo { + input: Arc::new(input), + output_url: "test.csv".to_string(), + file_format: FileType::CSV, + single_file_output: true, + copy_options: CopyOptions::SQLOptions(StatementOptions::from(&options)), + }); + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + Ok(()) +} + +#[tokio::test] +async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { + let ctx = SessionContext::new(); + + let input = create_csv_scan(&ctx).await?; + + let writer_properties = WriterProperties::builder() + .set_bloom_filter_enabled(true) + .set_created_by("DataFusion Test".to_string()) + .set_writer_version(WriterVersion::PARQUET_2_0) + .set_write_batch_size(111) + .set_data_page_size_limit(222) + .set_data_page_row_count_limit(333) + .set_dictionary_page_size_limit(444) + .set_max_row_group_size(555) + .build(); + let plan = LogicalPlan::Copy(CopyTo { + input: Arc::new(input), + output_url: "test.parquet".to_string(), + file_format: FileType::PARQUET, + single_file_output: true, + copy_options: CopyOptions::WriterOptions(Box::new( + FileTypeWriterOptions::Parquet(ParquetWriterOptions::new(writer_properties)), + )), + }); + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + match logical_round_trip { + LogicalPlan::Copy(copy_to) => { + assert_eq!("test.parquet", copy_to.output_url); + assert_eq!(FileType::PARQUET, copy_to.file_format); + assert!(copy_to.single_file_output); + match ©_to.copy_options { + CopyOptions::WriterOptions(y) => match y.as_ref() { + FileTypeWriterOptions::Parquet(p) => { + let props = &p.writer_options; + assert_eq!("DataFusion Test", props.created_by()); + assert_eq!( + "PARQUET_2_0", + format!("{:?}", props.writer_version()) + ); + assert_eq!(111, props.write_batch_size()); + assert_eq!(222, props.data_page_size_limit()); + assert_eq!(333, props.data_page_row_count_limit()); + assert_eq!(444, props.dictionary_page_size_limit()); + assert_eq!(555, props.max_row_group_size()); + } + _ => panic!(), + }, + _ => panic!(), + } + } + _ => panic!(), + } + Ok(()) +} + +#[tokio::test] +async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> { + let ctx = SessionContext::new(); + + let input = create_csv_scan(&ctx).await?; + + let writer_properties = WriterBuilder::new() + .with_delimiter(b'*') + .with_date_format("dd/MM/yyyy".to_string()) + .with_datetime_format("dd/MM/yyyy HH:mm:ss".to_string()) + .with_timestamp_format("HH:mm:ss.SSSSSS".to_string()) + .with_time_format("HH:mm:ss".to_string()) + .with_null("NIL".to_string()); + + let plan = LogicalPlan::Copy(CopyTo { + input: Arc::new(input), + output_url: "test.csv".to_string(), + file_format: FileType::CSV, + single_file_output: true, + copy_options: CopyOptions::WriterOptions(Box::new(FileTypeWriterOptions::CSV( + CsvWriterOptions::new( + writer_properties, + CompressionTypeVariant::UNCOMPRESSED, + ), + ))), + }); + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + match logical_round_trip { + LogicalPlan::Copy(copy_to) => { + assert_eq!("test.csv", copy_to.output_url); + assert_eq!(FileType::CSV, copy_to.file_format); + assert!(copy_to.single_file_output); + match ©_to.copy_options { + CopyOptions::WriterOptions(y) => match y.as_ref() { + FileTypeWriterOptions::CSV(p) => { + let props = &p.writer_options; + assert_eq!(b'*', props.delimiter()); + assert_eq!("dd/MM/yyyy", props.date_format().unwrap()); + assert_eq!( + "dd/MM/yyyy HH:mm:ss", + props.datetime_format().unwrap() + ); + assert_eq!("HH:mm:ss.SSSSSS", props.timestamp_format().unwrap()); + assert_eq!("HH:mm:ss", props.time_format().unwrap()); + assert_eq!("NIL", props.null()); + } + _ => panic!(), + }, + _ => panic!(), + } + } + _ => panic!(), + } + + Ok(()) +} +async fn create_csv_scan(ctx: &SessionContext) -> Result { + ctx.register_csv("t1", "tests/testdata/test.csv", CsvReadOptions::default()) + .await?; + + let input = ctx.table("t1").await?.into_optimized_plan()?; + Ok(input) +} + +#[tokio::test] +async fn roundtrip_logical_plan_distinct_on() -> Result<()> { + let ctx = SessionContext::new(); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Decimal128(15, 2), true), + ]); + + ctx.register_csv( + "t1", + "tests/testdata/test.csv", + CsvReadOptions::default().schema(&schema), + ) + .await?; + + let query = "SELECT DISTINCT ON (a % 2) a, b * 2 FROM t1 ORDER BY a % 2 DESC, b"; + let plan = ctx.sql(query).await?.into_optimized_plan()?; + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + Ok(()) +} + #[tokio::test] async fn roundtrip_single_count_distinct() -> Result<()> { let ctx = SessionContext::new(); @@ -548,6 +736,7 @@ fn round_trip_scalar_values() { ScalarValue::Utf8(None), ScalarValue::LargeUtf8(None), ScalarValue::List(ScalarValue::new_list(&[], &DataType::Boolean)), + ScalarValue::LargeList(ScalarValue::new_large_list(&[], &DataType::Boolean)), ScalarValue::Date32(None), ScalarValue::Boolean(Some(true)), ScalarValue::Boolean(Some(false)), @@ -648,6 +837,16 @@ fn round_trip_scalar_values() { ], &DataType::Float32, )), + ScalarValue::LargeList(ScalarValue::new_large_list( + &[ + ScalarValue::Float32(Some(-213.1)), + ScalarValue::Float32(None), + ScalarValue::Float32(Some(5.5)), + ScalarValue::Float32(Some(2.0)), + ScalarValue::Float32(Some(1.0)), + ], + &DataType::Float32, + )), ScalarValue::List(ScalarValue::new_list( &[ ScalarValue::List(ScalarValue::new_list(&[], &DataType::Float32)), @@ -664,9 +863,36 @@ fn round_trip_scalar_values() { ], &DataType::List(new_arc_field("item", DataType::Float32, true)), )), + ScalarValue::LargeList(ScalarValue::new_large_list( + &[ + ScalarValue::LargeList(ScalarValue::new_large_list( + &[], + &DataType::Float32, + )), + ScalarValue::LargeList(ScalarValue::new_large_list( + &[ + ScalarValue::Float32(Some(-213.1)), + ScalarValue::Float32(None), + ScalarValue::Float32(Some(5.5)), + ScalarValue::Float32(Some(2.0)), + ScalarValue::Float32(Some(1.0)), + ], + &DataType::Float32, + )), + ], + &DataType::LargeList(new_arc_field("item", DataType::Float32, true)), + )), + ScalarValue::FixedSizeList(Arc::new(FixedSizeListArray::from_iter_primitive::< + Int32Type, + _, + _, + >( + vec![Some(vec![Some(1), Some(2), Some(3)])], + 3, + ))), ScalarValue::Dictionary( Box::new(DataType::Int32), - Box::new(ScalarValue::Utf8(Some("foo".into()))), + Box::new(ScalarValue::from("foo")), ), ScalarValue::Dictionary( Box::new(DataType::Int32), @@ -907,6 +1133,45 @@ fn round_trip_datatype() { } } +#[test] +fn roundtrip_dict_id() -> Result<()> { + let dict_id = 42; + let field = Field::new( + "keys", + DataType::List(Arc::new(Field::new_dict( + "item", + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)), + true, + dict_id, + false, + ))), + false, + ); + let schema = Arc::new(Schema::new(vec![field])); + + // encode + let mut buf: Vec = vec![]; + let schema_proto: datafusion_proto::generated::datafusion::Schema = + schema.try_into().unwrap(); + schema_proto.encode(&mut buf).unwrap(); + + // decode + let schema_proto = + datafusion_proto::generated::datafusion::Schema::decode(buf.as_slice()).unwrap(); + let decoded: Schema = (&schema_proto).try_into()?; + + // assert + let keys = decoded.fields().iter().last().unwrap(); + match keys.data_type() { + DataType::List(field) => { + assert_eq!(field.dict_id(), Some(dict_id), "dict_id should be retained"); + } + _ => panic!("Invalid type"), + } + + Ok(()) +} + #[test] fn roundtrip_null_scalar_values() { let test_types = vec![ @@ -1147,7 +1412,17 @@ fn roundtrip_inlist() { #[test] fn roundtrip_wildcard() { - let test_expr = Expr::Wildcard; + let test_expr = Expr::Wildcard { qualifier: None }; + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); +} + +#[test] +fn roundtrip_qualified_wildcard() { + let test_expr = Expr::Wildcard { + qualifier: Some("foo".into()), + }; let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -1301,9 +1576,10 @@ fn roundtrip_aggregate_udf() { Arc::new(vec![DataType::Float64, DataType::UInt32]), ); - let test_expr = Expr::AggregateUDF(expr::AggregateUDF::new( + let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( Arc::new(dummy_agg.clone()), vec![lit(1.0_f64)], + false, Some(Box::new(lit(true))), None, )); @@ -1328,7 +1604,10 @@ fn roundtrip_scalar_udf() { scalar_fn, ); - let test_expr = Expr::ScalarUDF(ScalarUDF::new(Arc::new(udf.clone()), vec![lit("")])); + let test_expr = Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::new(udf.clone()), + vec![lit("")], + )); let ctx = SessionContext::new(); ctx.register_udf(udf); @@ -1386,8 +1665,8 @@ fn roundtrip_window() { // 1. without window_frame let test_expr1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::BuiltInWindowFunction( - datafusion_expr::window_function::BuiltInWindowFunction::Rank, + WindowFunctionDefinition::BuiltInWindowFunction( + datafusion_expr::BuiltInWindowFunction::Rank, ), vec![], vec![col("col1")], @@ -1397,8 +1676,8 @@ fn roundtrip_window() { // 2. with default window_frame let test_expr2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::BuiltInWindowFunction( - datafusion_expr::window_function::BuiltInWindowFunction::Rank, + WindowFunctionDefinition::BuiltInWindowFunction( + datafusion_expr::BuiltInWindowFunction::Rank, ), vec![], vec![col("col1")], @@ -1414,8 +1693,8 @@ fn roundtrip_window() { }; let test_expr3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::BuiltInWindowFunction( - datafusion_expr::window_function::BuiltInWindowFunction::Rank, + WindowFunctionDefinition::BuiltInWindowFunction( + datafusion_expr::BuiltInWindowFunction::Rank, ), vec![], vec![col("col1")], @@ -1431,7 +1710,7 @@ fn roundtrip_window() { }; let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("col1")], vec![col("col1")], vec![col("col2")], @@ -1482,7 +1761,7 @@ fn roundtrip_window() { ); let test_expr5 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateUDF(Arc::new(dummy_agg.clone())), + WindowFunctionDefinition::AggregateUDF(Arc::new(dummy_agg.clone())), vec![col("col1")], vec![col("col1")], vec![col("col2")], @@ -1508,30 +1787,55 @@ fn roundtrip_window() { } } - fn return_type(arg_types: &[DataType]) -> Result> { - if arg_types.len() != 1 { - return plan_err!( - "dummy_udwf expects 1 argument, got {}: {:?}", - arg_types.len(), - arg_types - ); + struct SimpleWindowUDF { + signature: Signature, + } + + impl SimpleWindowUDF { + fn new() -> Self { + let signature = + Signature::exact(vec![DataType::Float64], Volatility::Immutable); + Self { signature } + } + } + + impl WindowUDFImpl for SimpleWindowUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "dummy_udwf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() != 1 { + return plan_err!( + "dummy_udwf expects 1 argument, got {}: {:?}", + arg_types.len(), + arg_types + ); + } + Ok(arg_types[0].clone()) + } + + fn partition_evaluator(&self) -> Result> { + make_partition_evaluator() } - Ok(Arc::new(arg_types[0].clone())) } fn make_partition_evaluator() -> Result> { Ok(Box::new(DummyWindow {})) } - let dummy_window_udf = WindowUDF { - name: String::from("dummy_udwf"), - signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), - return_type: Arc::new(return_type), - partition_evaluator_factory: Arc::new(make_partition_evaluator), - }; + let dummy_window_udf = WindowUDF::from(SimpleWindowUDF::new()); let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::WindowUDF(Arc::new(dummy_window_udf.clone())), + WindowFunctionDefinition::WindowUDF(Arc::new(dummy_window_udf.clone())), vec![col("col1")], vec![col("col1")], vec![col("col2")], diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 01a0916d8cd23..27ac5d122f83f 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -15,21 +15,28 @@ // specific language governing permissions and limitations // under the License. +use arrow::csv::WriterBuilder; use std::ops::Deref; use std::sync::Arc; use datafusion::arrow::array::ArrayRef; use datafusion::arrow::compute::kernels::sort::SortOptions; use datafusion::arrow::datatypes::{DataType, Field, Fields, IntervalUnit, Schema}; -use datafusion::datasource::listing::PartitionedFile; +use datafusion::datasource::file_format::csv::CsvSink; +use datafusion::datasource::file_format::json::JsonSink; +use datafusion::datasource::file_format::parquet::ParquetSink; +use datafusion::datasource::listing::{ListingTableUrl, PartitionedFile}; use datafusion::datasource::object_store::ObjectStoreUrl; -use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec}; +use datafusion::datasource::physical_plan::{ + FileScanConfig, FileSinkConfig, ParquetExec, +}; use datafusion::execution::context::ExecutionProps; use datafusion::logical_expr::{ create_udf, BuiltinScalarFunction, JoinType, Operator, Volatility, }; +use datafusion::parquet::file::properties::WriterProperties; use datafusion::physical_expr::window::SlidingAggregateWindowExpr; -use datafusion::physical_expr::ScalarFunctionExpr; +use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr}; use datafusion::physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; @@ -41,20 +48,30 @@ use datafusion::physical_plan::expressions::{ }; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::functions::make_scalar_function; -use datafusion::physical_plan::joins::{HashJoinExec, NestedLoopJoinExec, PartitionMode}; +use datafusion::physical_plan::insert::FileSinkExec; +use datafusion::physical_plan::joins::{ + HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode, +}; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; use datafusion::physical_plan::projection::ProjectionExec; +use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; +use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; use datafusion::physical_plan::windows::{ BuiltInWindowExpr, PlainAggregateWindowExpr, WindowAggExec, }; use datafusion::physical_plan::{ - functions, udaf, AggregateExpr, ExecutionPlan, PhysicalExpr, Statistics, + functions, udaf, AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, Statistics, }; use datafusion::prelude::SessionContext; use datafusion::scalar::ScalarValue; +use datafusion_common::file_options::csv_writer::CsvWriterOptions; +use datafusion_common::file_options::json_writer::JsonWriterOptions; +use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; +use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; -use datafusion_common::Result; +use datafusion_common::{FileTypeWriterOptions, Result}; use datafusion_expr::{ Accumulator, AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction, Signature, StateTypeFunction, WindowFrame, WindowFrameBound, @@ -62,7 +79,23 @@ use datafusion_expr::{ use datafusion_proto::physical_plan::{AsExecutionPlan, DefaultPhysicalExtensionCodec}; use datafusion_proto::protobuf; +/// Perform a serde roundtrip and assert that the string representation of the before and after plans +/// are identical. Note that this often isn't sufficient to guarantee that no information is +/// lost during serde because the string representation of a plan often only shows a subset of state. fn roundtrip_test(exec_plan: Arc) -> Result<()> { + let _ = roundtrip_test_and_return(exec_plan); + Ok(()) +} + +/// Perform a serde roundtrip and assert that the string representation of the before and after plans +/// are identical. Note that this often isn't sufficient to guarantee that no information is +/// lost during serde because the string representation of a plan often only shows a subset of state. +/// +/// This version of the roundtrip_test method returns the final plan after serde so that it can be inspected +/// farther in tests. +fn roundtrip_test_and_return( + exec_plan: Arc, +) -> Result> { let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; let proto: protobuf::PhysicalPlanNode = @@ -73,9 +106,15 @@ fn roundtrip_test(exec_plan: Arc) -> Result<()> { .try_into_physical_plan(&ctx, runtime.deref(), &codec) .expect("from proto"); assert_eq!(format!("{exec_plan:?}"), format!("{result_exec_plan:?}")); - Ok(()) + Ok(result_exec_plan) } +/// Perform a serde roundtrip and assert that the string representation of the before and after plans +/// are identical. Note that this often isn't sufficient to guarantee that no information is +/// lost during serde because the string representation of a plan often only shows a subset of state. +/// +/// This version of the roundtrip_test function accepts a SessionContext, which is required when +/// performing serde on some plans. fn roundtrip_test_with_context( exec_plan: Arc, ctx: SessionContext, @@ -94,7 +133,7 @@ fn roundtrip_test_with_context( #[test] fn roundtrip_empty() -> Result<()> { - roundtrip_test(Arc::new(EmptyExec::new(false, Arc::new(Schema::empty())))) + roundtrip_test(Arc::new(EmptyExec::new(Arc::new(Schema::empty())))) } #[test] @@ -107,7 +146,7 @@ fn roundtrip_date_time_interval() -> Result<()> { false, ), ]); - let input = Arc::new(EmptyExec::new(false, Arc::new(schema.clone()))); + let input = Arc::new(EmptyExec::new(Arc::new(schema.clone()))); let date_expr = col("some_date", &schema)?; let literal_expr = col("some_interval", &schema)?; let date_time_interval_expr = @@ -122,7 +161,7 @@ fn roundtrip_date_time_interval() -> Result<()> { #[test] fn roundtrip_local_limit() -> Result<()> { roundtrip_test(Arc::new(LocalLimitExec::new( - Arc::new(EmptyExec::new(false, Arc::new(Schema::empty()))), + Arc::new(EmptyExec::new(Arc::new(Schema::empty()))), 25, ))) } @@ -130,7 +169,7 @@ fn roundtrip_local_limit() -> Result<()> { #[test] fn roundtrip_global_limit() -> Result<()> { roundtrip_test(Arc::new(GlobalLimitExec::new( - Arc::new(EmptyExec::new(false, Arc::new(Schema::empty()))), + Arc::new(EmptyExec::new(Arc::new(Schema::empty()))), 0, Some(25), ))) @@ -139,7 +178,7 @@ fn roundtrip_global_limit() -> Result<()> { #[test] fn roundtrip_global_skip_no_limit() -> Result<()> { roundtrip_test(Arc::new(GlobalLimitExec::new( - Arc::new(EmptyExec::new(false, Arc::new(Schema::empty()))), + Arc::new(EmptyExec::new(Arc::new(Schema::empty()))), 10, None, // no limit ))) @@ -169,8 +208,8 @@ fn roundtrip_hash_join() -> Result<()> { ] { for partition_mode in &[PartitionMode::Partitioned, PartitionMode::CollectLeft] { roundtrip_test(Arc::new(HashJoinExec::try_new( - Arc::new(EmptyExec::new(false, schema_left.clone())), - Arc::new(EmptyExec::new(false, schema_right.clone())), + Arc::new(EmptyExec::new(schema_left.clone())), + Arc::new(EmptyExec::new(schema_right.clone())), on.clone(), None, join_type, @@ -201,8 +240,8 @@ fn roundtrip_nested_loop_join() -> Result<()> { JoinType::RightSemi, ] { roundtrip_test(Arc::new(NestedLoopJoinExec::try_new( - Arc::new(EmptyExec::new(false, schema_left.clone())), - Arc::new(EmptyExec::new(false, schema_right.clone())), + Arc::new(EmptyExec::new(schema_left.clone())), + Arc::new(EmptyExec::new(schema_right.clone())), None, join_type, )?))?; @@ -223,21 +262,21 @@ fn roundtrip_window() -> Result<()> { }; let builtin_window_expr = Arc::new(BuiltInWindowExpr::new( - Arc::new(NthValue::first( - "FIRST_VALUE(a) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", - col("a", &schema)?, - DataType::Int64, - )), - &[col("b", &schema)?], - &[PhysicalSortExpr { - expr: col("a", &schema)?, - options: SortOptions { - descending: false, - nulls_first: false, - }, - }], - Arc::new(window_frame), - )); + Arc::new(NthValue::first( + "FIRST_VALUE(a) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", + col("a", &schema)?, + DataType::Int64, + )), + &[col("b", &schema)?], + &[PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }], + Arc::new(window_frame), + )); let plain_aggr_window_expr = Arc::new(PlainAggregateWindowExpr::new( Arc::new(Avg::new( @@ -267,7 +306,7 @@ fn roundtrip_window() -> Result<()> { Arc::new(window_frame), )); - let input = Arc::new(EmptyExec::new(false, schema.clone())); + let input = Arc::new(EmptyExec::new(schema.clone())); roundtrip_test(Arc::new(WindowAggExec::try_new( vec![ @@ -300,8 +339,7 @@ fn rountrip_aggregate() -> Result<()> { PhysicalGroupBy::new_single(groups.clone()), aggregates.clone(), vec![None], - vec![None], - Arc::new(EmptyExec::new(false, schema.clone())), + Arc::new(EmptyExec::new(schema.clone())), schema, )?)) } @@ -368,8 +406,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { PhysicalGroupBy::new_single(groups.clone()), aggregates.clone(), vec![None], - vec![None], - Arc::new(EmptyExec::new(false, schema.clone())), + Arc::new(EmptyExec::new(schema.clone())), schema, )?), ctx, @@ -395,7 +432,7 @@ fn roundtrip_filter_with_not_and_in_list() -> Result<()> { let and = binary(not, Operator::And, in_list, &schema)?; roundtrip_test(Arc::new(FilterExec::try_new( and, - Arc::new(EmptyExec::new(false, schema.clone())), + Arc::new(EmptyExec::new(schema.clone())), )?)) } @@ -422,7 +459,7 @@ fn roundtrip_sort() -> Result<()> { ]; roundtrip_test(Arc::new(SortExec::new( sort_exprs, - Arc::new(EmptyExec::new(false, schema)), + Arc::new(EmptyExec::new(schema)), ))) } @@ -450,11 +487,11 @@ fn roundtrip_sort_preserve_partitioning() -> Result<()> { roundtrip_test(Arc::new(SortExec::new( sort_exprs.clone(), - Arc::new(EmptyExec::new(false, schema.clone())), + Arc::new(EmptyExec::new(schema.clone())), )))?; roundtrip_test(Arc::new( - SortExec::new(sort_exprs, Arc::new(EmptyExec::new(false, schema))) + SortExec::new(sort_exprs, Arc::new(EmptyExec::new(schema))) .with_preserve_partitioning(true), )) } @@ -483,7 +520,6 @@ fn roundtrip_parquet_exec_with_pruning_predicate() -> Result<()> { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }; let predicate = Arc::new(BinaryExpr::new( @@ -504,7 +540,7 @@ fn roundtrip_builtin_scalar_function() -> Result<()> { let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); - let input = Arc::new(EmptyExec::new(false, schema.clone())); + let input = Arc::new(EmptyExec::new(schema.clone())); let execution_props = ExecutionProps::new(); @@ -515,7 +551,7 @@ fn roundtrip_builtin_scalar_function() -> Result<()> { "acos", fun_expr, vec![col("a", &schema)?], - &DataType::Int64, + DataType::Int64, None, ); @@ -531,7 +567,7 @@ fn roundtrip_scalar_udf() -> Result<()> { let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); - let input = Arc::new(EmptyExec::new(false, schema.clone())); + let input = Arc::new(EmptyExec::new(schema.clone())); let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef); @@ -549,7 +585,7 @@ fn roundtrip_scalar_udf() -> Result<()> { "dummy", scalar_fn, vec![col("a", &schema)?], - &DataType::Int64, + DataType::Int64, None, ); @@ -583,8 +619,7 @@ fn roundtrip_distinct_count() -> Result<()> { PhysicalGroupBy::new_single(groups), aggregates.clone(), vec![None], - vec![None], - Arc::new(EmptyExec::new(false, schema.clone())), + Arc::new(EmptyExec::new(schema.clone())), schema, )?)) } @@ -595,7 +630,7 @@ fn roundtrip_like() -> Result<()> { Field::new("a", DataType::Utf8, false), Field::new("b", DataType::Utf8, false), ]); - let input = Arc::new(EmptyExec::new(false, Arc::new(schema.clone()))); + let input = Arc::new(EmptyExec::new(Arc::new(schema.clone()))); let like_expr = like( false, false, @@ -622,13 +657,13 @@ fn roundtrip_get_indexed_field_named_struct_field() -> Result<()> { ]; let schema = Schema::new(fields); - let input = Arc::new(EmptyExec::new(false, Arc::new(schema.clone()))); + let input = Arc::new(EmptyExec::new(Arc::new(schema.clone()))); let col_arg = col("arg", &schema)?; let get_indexed_field_expr = Arc::new(GetIndexedFieldExpr::new( col_arg, GetFieldAccessExpr::NamedStructField { - name: ScalarValue::Utf8(Some(String::from("name"))), + name: ScalarValue::from("name"), }, )); @@ -649,7 +684,7 @@ fn roundtrip_get_indexed_field_list_index() -> Result<()> { ]; let schema = Schema::new(fields); - let input = Arc::new(EmptyExec::new(true, Arc::new(schema.clone()))); + let input = Arc::new(PlaceholderRowExec::new(Arc::new(schema.clone()))); let col_arg = col("arg", &schema)?; let col_key = col("key", &schema)?; @@ -676,7 +711,7 @@ fn roundtrip_get_indexed_field_list_range() -> Result<()> { ]; let schema = Schema::new(fields); - let input = Arc::new(EmptyExec::new(false, Arc::new(schema.clone()))); + let input = Arc::new(EmptyExec::new(Arc::new(schema.clone()))); let col_arg = col("arg", &schema)?; let col_start = col("start", &schema)?; @@ -698,11 +733,11 @@ fn roundtrip_get_indexed_field_list_range() -> Result<()> { } #[test] -fn rountrip_analyze() -> Result<()> { +fn roundtrip_analyze() -> Result<()> { let field_a = Field::new("plan_type", DataType::Utf8, false); let field_b = Field::new("plan", DataType::Utf8, false); let schema = Schema::new(vec![field_a, field_b]); - let input = Arc::new(EmptyExec::new(true, Arc::new(schema.clone()))); + let input = Arc::new(PlaceholderRowExec::new(Arc::new(schema.clone()))); roundtrip_test(Arc::new(AnalyzeExec::new( false, @@ -711,3 +746,207 @@ fn rountrip_analyze() -> Result<()> { Arc::new(schema), ))) } + +#[test] +fn roundtrip_json_sink() -> Result<()> { + let field_a = Field::new("plan_type", DataType::Utf8, false); + let field_b = Field::new("plan", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let input = Arc::new(PlaceholderRowExec::new(schema.clone())); + + let file_sink_config = FileSinkConfig { + object_store_url: ObjectStoreUrl::local_filesystem(), + file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)], + table_paths: vec![ListingTableUrl::parse("file:///")?], + output_schema: schema.clone(), + table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)], + single_file_output: true, + overwrite: true, + file_type_writer_options: FileTypeWriterOptions::JSON(JsonWriterOptions::new( + CompressionTypeVariant::UNCOMPRESSED, + )), + }; + let data_sink = Arc::new(JsonSink::new(file_sink_config)); + let sort_order = vec![PhysicalSortRequirement::new( + Arc::new(Column::new("plan_type", 0)), + Some(SortOptions { + descending: true, + nulls_first: false, + }), + )]; + + roundtrip_test(Arc::new(FileSinkExec::new( + input, + data_sink, + schema.clone(), + Some(sort_order), + ))) +} + +#[test] +fn roundtrip_csv_sink() -> Result<()> { + let field_a = Field::new("plan_type", DataType::Utf8, false); + let field_b = Field::new("plan", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let input = Arc::new(PlaceholderRowExec::new(schema.clone())); + + let file_sink_config = FileSinkConfig { + object_store_url: ObjectStoreUrl::local_filesystem(), + file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)], + table_paths: vec![ListingTableUrl::parse("file:///")?], + output_schema: schema.clone(), + table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)], + single_file_output: true, + overwrite: true, + file_type_writer_options: FileTypeWriterOptions::CSV(CsvWriterOptions::new( + WriterBuilder::default(), + CompressionTypeVariant::ZSTD, + )), + }; + let data_sink = Arc::new(CsvSink::new(file_sink_config)); + let sort_order = vec![PhysicalSortRequirement::new( + Arc::new(Column::new("plan_type", 0)), + Some(SortOptions { + descending: true, + nulls_first: false, + }), + )]; + + let roundtrip_plan = roundtrip_test_and_return(Arc::new(FileSinkExec::new( + input, + data_sink, + schema.clone(), + Some(sort_order), + ))) + .unwrap(); + + let roundtrip_plan = roundtrip_plan + .as_any() + .downcast_ref::() + .unwrap(); + let csv_sink = roundtrip_plan + .sink() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + CompressionTypeVariant::ZSTD, + csv_sink + .config() + .file_type_writer_options + .try_into_csv() + .unwrap() + .compression + ); + + Ok(()) +} + +#[test] +fn roundtrip_parquet_sink() -> Result<()> { + let field_a = Field::new("plan_type", DataType::Utf8, false); + let field_b = Field::new("plan", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let input = Arc::new(PlaceholderRowExec::new(schema.clone())); + + let file_sink_config = FileSinkConfig { + object_store_url: ObjectStoreUrl::local_filesystem(), + file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)], + table_paths: vec![ListingTableUrl::parse("file:///")?], + output_schema: schema.clone(), + table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)], + single_file_output: true, + overwrite: true, + file_type_writer_options: FileTypeWriterOptions::Parquet( + ParquetWriterOptions::new(WriterProperties::default()), + ), + }; + let data_sink = Arc::new(ParquetSink::new(file_sink_config)); + let sort_order = vec![PhysicalSortRequirement::new( + Arc::new(Column::new("plan_type", 0)), + Some(SortOptions { + descending: true, + nulls_first: false, + }), + )]; + + roundtrip_test(Arc::new(FileSinkExec::new( + input, + data_sink, + schema.clone(), + Some(sort_order), + ))) +} + +#[test] +fn roundtrip_sym_hash_join() -> Result<()> { + let field_a = Field::new("col", DataType::Int64, false); + let schema_left = Schema::new(vec![field_a.clone()]); + let schema_right = Schema::new(vec![field_a]); + let on = vec![( + Column::new("col", schema_left.index_of("col")?), + Column::new("col", schema_right.index_of("col")?), + )]; + + let schema_left = Arc::new(schema_left); + let schema_right = Arc::new(schema_right); + for join_type in &[ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftAnti, + JoinType::RightAnti, + JoinType::LeftSemi, + JoinType::RightSemi, + ] { + for partition_mode in &[ + StreamJoinPartitionMode::Partitioned, + StreamJoinPartitionMode::SinglePartition, + ] { + roundtrip_test(Arc::new( + datafusion::physical_plan::joins::SymmetricHashJoinExec::try_new( + Arc::new(EmptyExec::new(schema_left.clone())), + Arc::new(EmptyExec::new(schema_right.clone())), + on.clone(), + None, + join_type, + false, + *partition_mode, + )?, + ))?; + } + } + Ok(()) +} + +#[test] +fn roundtrip_union() -> Result<()> { + let field_a = Field::new("col", DataType::Int64, false); + let schema_left = Schema::new(vec![field_a.clone()]); + let schema_right = Schema::new(vec![field_a]); + let left = EmptyExec::new(Arc::new(schema_left)); + let right = EmptyExec::new(Arc::new(schema_right)); + let inputs: Vec> = vec![Arc::new(left), Arc::new(right)]; + let union = UnionExec::new(inputs); + roundtrip_test(Arc::new(union)) +} + +#[test] +fn roundtrip_interleave() -> Result<()> { + let field_a = Field::new("col", DataType::Int64, false); + let schema_left = Schema::new(vec![field_a.clone()]); + let schema_right = Schema::new(vec![field_a]); + let partition = Partitioning::Hash(vec![], 3); + let left = RepartitionExec::try_new( + Arc::new(EmptyExec::new(Arc::new(schema_left))), + partition.clone(), + )?; + let right = RepartitionExec::try_new( + Arc::new(EmptyExec::new(Arc::new(schema_right))), + partition.clone(), + )?; + let inputs: Vec> = vec![Arc::new(left), Arc::new(right)]; + let interleave = InterleaveExec::try_new(inputs)?; + roundtrip_test(Arc::new(interleave)) +} diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index f32c81527925d..5b890accd81f2 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -128,6 +128,12 @@ fn exact_roundtrip_linearized_binary_expr() { } } +#[test] +fn roundtrip_qualified_alias() { + let qual_alias = col("c1").alias_qualified(Some("my_table"), "my_column"); + assert_eq!(qual_alias, roundtrip_expr(&qual_alias)); +} + #[test] fn roundtrip_deeply_nested_binary_expr() { // We need more stack space so this doesn't overflow in dev builds diff --git a/datafusion/sql/src/expr/arrow_cast.rs b/datafusion/sql/src/expr/arrow_cast.rs index 8c0184b6d1192..ade8b96b5cc21 100644 --- a/datafusion/sql/src/expr/arrow_cast.rs +++ b/datafusion/sql/src/expr/arrow_cast.rs @@ -149,6 +149,7 @@ impl<'a> Parser<'a> { Token::Decimal256 => self.parse_decimal_256(), Token::Dictionary => self.parse_dictionary(), Token::List => self.parse_list(), + Token::LargeList => self.parse_large_list(), tok => Err(make_error( self.val, &format!("finding next type, got unexpected '{tok}'"), @@ -166,6 +167,16 @@ impl<'a> Parser<'a> { )))) } + /// Parses the LargeList type + fn parse_large_list(&mut self) -> Result { + self.expect_token(Token::LParen)?; + let data_type = self.parse_next_type()?; + self.expect_token(Token::RParen)?; + Ok(DataType::LargeList(Arc::new(Field::new( + "item", data_type, true, + )))) + } + /// Parses the next timeunit fn parse_time_unit(&mut self, context: &str) -> Result { match self.next_token()? { @@ -496,6 +507,7 @@ impl<'a> Tokenizer<'a> { "Date64" => Token::SimpleType(DataType::Date64), "List" => Token::List, + "LargeList" => Token::LargeList, "Second" => Token::TimeUnit(TimeUnit::Second), "Millisecond" => Token::TimeUnit(TimeUnit::Millisecond), @@ -585,6 +597,7 @@ enum Token { Integer(i64), DoubleQuotedString(String), List, + LargeList, } impl Display for Token { @@ -592,6 +605,7 @@ impl Display for Token { match self { Token::SimpleType(t) => write!(f, "{t}"), Token::List => write!(f, "List"), + Token::LargeList => write!(f, "LargeList"), Token::Timestamp => write!(f, "Timestamp"), Token::Time32 => write!(f, "Time32"), Token::Time64 => write!(f, "Time64"), diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index c58b8319ceb72..395f10b6f7834 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -19,12 +19,12 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{ not_impl_err, plan_datafusion_err, plan_err, DFSchema, DataFusionError, Result, }; -use datafusion_expr::expr::{ScalarFunction, ScalarUDF}; +use datafusion_expr::expr::ScalarFunction; use datafusion_expr::function::suggest_valid_function; -use datafusion_expr::window_frame::regularize; +use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - expr, window_function, AggregateFunction, BuiltinScalarFunction, Expr, WindowFrame, - WindowFunction, + expr, AggregateFunction, BuiltinScalarFunction, Expr, WindowFrame, + WindowFunctionDefinition, }; use sqlparser::ast::{ Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, WindowType, @@ -66,7 +66,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // user-defined function (UDF) should have precedence in case it has the same name as a scalar built-in function if let Some(fm) = self.context_provider.get_function_meta(&name) { let args = self.function_args_to_expr(args, schema, planner_context)?; - return Ok(Expr::ScalarUDF(ScalarUDF::new(fm, args))); + return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fm, args))); } // next, scalar built-in @@ -90,31 +90,43 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let partition_by = window .partition_by .into_iter() + // ignore window spec PARTITION BY for scalar values + // as they do not change and thus do not generate new partitions + .filter(|e| !matches!(e, sqlparser::ast::Expr::Value { .. },)) .map(|e| self.sql_expr_to_logical_expr(e, schema, planner_context)) .collect::>>()?; - let order_by = - self.order_by_to_sort_expr(&window.order_by, schema, planner_context)?; + let mut order_by = self.order_by_to_sort_expr( + &window.order_by, + schema, + planner_context, + // Numeric literals in window function ORDER BY are treated as constants + false, + )?; let window_frame = window .window_frame .as_ref() .map(|window_frame| { let window_frame = window_frame.clone().try_into()?; - regularize(window_frame, order_by.len()) + check_window_frame(&window_frame, order_by.len()) + .map(|_| window_frame) }) .transpose()?; + let window_frame = if let Some(window_frame) = window_frame { + regularize_window_order_by(&window_frame, &mut order_by)?; window_frame } else { WindowFrame::new(!order_by.is_empty()) }; + if let Ok(fun) = self.find_window_func(&name) { let expr = match fun { - WindowFunction::AggregateFunction(aggregate_fun) => { + WindowFunctionDefinition::AggregateFunction(aggregate_fun) => { let args = self.function_args_to_expr(args, schema, planner_context)?; Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(aggregate_fun), + WindowFunctionDefinition::AggregateFunction(aggregate_fun), args, partition_by, order_by, @@ -135,15 +147,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // User defined aggregate functions (UDAF) have precedence in case it has the same name as a scalar built-in function if let Some(fm) = self.context_provider.get_aggregate_meta(&name) { let args = self.function_args_to_expr(args, schema, planner_context)?; - return Ok(Expr::AggregateUDF(expr::AggregateUDF::new( - fm, args, None, None, + return Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( + fm, args, false, None, None, ))); } // next, aggregate built-ins if let Ok(fun) = AggregateFunction::from_str(&name) { let order_by = - self.order_by_to_sort_expr(&order_by, schema, planner_context)?; + self.order_by_to_sort_expr(&order_by, schema, planner_context, true)?; let order_by = (!order_by.is_empty()).then_some(order_by); let args = self.function_args_to_expr(args, schema, planner_context)?; let filter: Option> = filter @@ -179,19 +191,22 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) } - pub(super) fn find_window_func(&self, name: &str) -> Result { - window_function::find_df_window_func(name) + pub(super) fn find_window_func( + &self, + name: &str, + ) -> Result { + expr::find_df_window_func(name) // next check user defined aggregates .or_else(|| { self.context_provider .get_aggregate_meta(name) - .map(WindowFunction::AggregateUDF) + .map(WindowFunctionDefinition::AggregateUDF) }) // next check user defined window functions .or_else(|| { self.context_provider .get_window_meta(name) - .map(WindowFunction::WindowUDF) + .map(WindowFunctionDefinition::WindowUDF) }) .ok_or_else(|| { plan_datafusion_err!("There is no window function named {name}") @@ -212,11 +227,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { FunctionArg::Named { name: _, arg: FunctionArgExpr::Wildcard, - } => Ok(Expr::Wildcard), + } => Ok(Expr::Wildcard { qualifier: None }), FunctionArg::Unnamed(FunctionArgExpr::Expr(arg)) => { self.sql_expr_to_logical_expr(arg, schema, planner_context) } - FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => Ok(Expr::Wildcard), + FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => { + Ok(Expr::Wildcard { qualifier: None }) + } _ => not_impl_err!("Unsupported qualified wildcard argument: {sql:?}"), } } diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 1cf0fc133f040..27351e10eb34e 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -29,10 +29,12 @@ mod value; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow_schema::DataType; +use arrow_schema::TimeUnit; use datafusion_common::{ internal_err, not_impl_err, plan_err, Column, DFSchema, DataFusionError, Result, ScalarValue, }; +use datafusion_expr::expr::AggregateFunctionDefinition; use datafusion_expr::expr::InList; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ @@ -169,7 +171,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(Expr::ScalarFunction(ScalarFunction::new( BuiltinScalarFunction::DatePart, vec![ - Expr::Literal(ScalarValue::Utf8(Some(format!("{field}")))), + Expr::Literal(ScalarValue::from(format!("{field}"))), self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ], ))) @@ -224,14 +226,27 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Cast { expr, data_type, .. - } => Ok(Expr::Cast(Cast::new( - Box::new(self.sql_expr_to_logical_expr( - *expr, - schema, - planner_context, - )?), - self.convert_data_type(&data_type)?, - ))), + } => { + let dt = self.convert_data_type(&data_type)?; + let expr = + self.sql_expr_to_logical_expr(*expr, schema, planner_context)?; + + // numeric constants are treated as seconds (rather as nanoseconds) + // to align with postgres / duckdb semantics + let expr = match &dt { + DataType::Timestamp(TimeUnit::Nanosecond, tz) + if expr.get_type(schema)? == DataType::Int64 => + { + Expr::Cast(Cast::new( + Box::new(expr), + DataType::Timestamp(TimeUnit::Second, tz.clone()), + )) + } + _ => expr, + }; + + Ok(Expr::Cast(Cast::new(Box::new(expr), dt))) + } SQLExpr::TryCast { expr, data_type, .. @@ -459,7 +474,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema, planner_context, ), - + SQLExpr::Overlay { + expr, + overlay_what, + overlay_from, + overlay_for, + } => self.sql_overlay_to_expr( + *expr, + *overlay_what, + *overlay_from, + overlay_for, + schema, + planner_context, + ), SQLExpr::Nested(e) => { self.sql_expr_to_logical_expr(*e, schema, planner_context) } @@ -528,7 +555,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } = array_agg; let order_by = if let Some(order_by) = order_by { - Some(self.order_by_to_sort_expr(&order_by, input_schema, planner_context)?) + Some(self.order_by_to_sort_expr( + &order_by, + input_schema, + planner_context, + true, + )?) } else { None }; @@ -645,6 +677,32 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) } + fn sql_overlay_to_expr( + &self, + expr: SQLExpr, + overlay_what: SQLExpr, + overlay_from: SQLExpr, + overlay_for: Option>, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let fun = BuiltinScalarFunction::OverLay; + let arg = self.sql_expr_to_logical_expr(expr, schema, planner_context)?; + let what_arg = + self.sql_expr_to_logical_expr(overlay_what, schema, planner_context)?; + let from_arg = + self.sql_expr_to_logical_expr(overlay_from, schema, planner_context)?; + let args = match overlay_for { + Some(for_expr) => { + let for_expr = + self.sql_expr_to_logical_expr(*for_expr, schema, planner_context)?; + vec![arg, what_arg, from_arg, for_expr] + } + None => vec![arg, what_arg, from_arg], + }; + Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) + } + fn sql_agg_with_filter_to_expr( &self, expr: SQLExpr, @@ -654,7 +712,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { match self.sql_expr_to_logical_expr(expr, schema, planner_context)? { Expr::AggregateFunction(expr::AggregateFunction { - fun, + func_def: AggregateFunctionDefinition::BuiltIn(fun), args, distinct, order_by, @@ -686,7 +744,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Value( Value::SingleQuotedString(s) | Value::DoubleQuotedString(s), ) => GetFieldAccess::NamedStructField { - name: ScalarValue::Utf8(Some(s)), + name: ScalarValue::from(s), }, SQLExpr::JsonAccess { left, diff --git a/datafusion/sql/src/expr/order_by.rs b/datafusion/sql/src/expr/order_by.rs index 1dccc2376f0b1..772255bd9773a 100644 --- a/datafusion/sql/src/expr/order_by.rs +++ b/datafusion/sql/src/expr/order_by.rs @@ -24,12 +24,17 @@ use datafusion_expr::Expr; use sqlparser::ast::{Expr as SQLExpr, OrderByExpr, Value}; impl<'a, S: ContextProvider> SqlToRel<'a, S> { - /// convert sql [OrderByExpr] to `Vec` + /// Convert sql [OrderByExpr] to `Vec`. + /// + /// If `literal_to_column` is true, treat any numeric literals (e.g. `2`) as a 1 based index + /// into the SELECT list (e.g. `SELECT a, b FROM table ORDER BY 2`). + /// If false, interpret numeric literals as constant values. pub(crate) fn order_by_to_sort_expr( &self, exprs: &[OrderByExpr], schema: &DFSchema, planner_context: &mut PlannerContext, + literal_to_column: bool, ) -> Result> { let mut expr_vec = vec![]; for e in exprs { @@ -40,7 +45,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } = e; let expr = match expr { - SQLExpr::Value(Value::Number(v, _)) => { + SQLExpr::Value(Value::Number(v, _)) if literal_to_column => { let field_index = v .parse::() .map_err(|err| plan_datafusion_err!("{}", err))?; diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index 3a06fdb158f76..9f88318ab21a6 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -16,20 +16,20 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use arrow::array::new_null_array; use arrow::compute::kernels::cast_utils::parse_interval_month_day_nano; use arrow::datatypes::DECIMAL128_MAX_PRECISION; use arrow_schema::DataType; use datafusion_common::{ not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, }; +use datafusion_expr::expr::ScalarFunction; use datafusion_expr::expr::{BinaryExpr, Placeholder}; use datafusion_expr::{lit, Expr, Operator}; +use datafusion_expr::{BuiltinScalarFunction, ScalarFunctionDefinition}; use log::debug; use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, Value}; use sqlparser::parser::ParserError::ParserError; use std::borrow::Cow; -use std::collections::HashSet; impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(crate) fn parse_value( @@ -108,7 +108,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } Ok(index) => index - 1, Err(_) => { - return plan_err!("Invalid placeholder, not a number: {param}"); + return if param_data_types.is_empty() { + Ok(Expr::Placeholder(Placeholder::new(param, None))) + } else { + // when PREPARE Statement, param_data_types length is always 0 + plan_err!("Invalid placeholder, not a number: {param}") + }; } }; // Check if the placeholder is in the parameter list @@ -138,9 +143,22 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema, &mut PlannerContext::new(), )?; + match value { - Expr::Literal(scalar) => { - values.push(scalar); + Expr::Literal(_) => { + values.push(value); + } + Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + .. + }) => { + if fun == BuiltinScalarFunction::MakeArray { + values.push(value); + } else { + return not_impl_err!( + "ScalarFunctions without MakeArray are not supported: {value}" + ); + } } _ => { return not_impl_err!( @@ -150,18 +168,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - let data_types: HashSet = - values.iter().map(|e| e.data_type()).collect(); - - if data_types.is_empty() { - Ok(lit(ScalarValue::List(new_null_array(&DataType::Null, 0)))) - } else if data_types.len() > 1 { - not_impl_err!("Arrays with different types are not supported: {data_types:?}") - } else { - let data_type = values[0].data_type(); - let arr = ScalarValue::new_list(&values, &data_type); - Ok(lit(ScalarValue::List(arr))) - } + Ok(Expr::ScalarFunction(ScalarFunction::new( + BuiltinScalarFunction::MakeArray, + values, + ))) } /// Convert a SQL interval expression to a DataFusion logical plan @@ -333,6 +343,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // TODO make interval parsing better in arrow-rs / expose `IntervalType` fn has_units(val: &str) -> bool { + let val = val.to_lowercase(); val.ends_with("century") || val.ends_with("centuries") || val.ends_with("decade") diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index 9c104ff18a9b3..dbd72ec5eb7a8 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -213,13 +213,6 @@ impl fmt::Display for CreateExternalTable { } } -/// DataFusion extension DDL for `DESCRIBE TABLE` -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct DescribeTableStmt { - /// Table name - pub table_name: ObjectName, -} - /// DataFusion SQL Statement. /// /// This can either be a [`Statement`] from [`sqlparser`] from a @@ -233,8 +226,6 @@ pub enum Statement { Statement(Box), /// Extension: `CREATE EXTERNAL TABLE` CreateExternalTable(CreateExternalTable), - /// Extension: `DESCRIBE TABLE` - DescribeTableStmt(DescribeTableStmt), /// Extension: `COPY TO` CopyTo(CopyToStatement), /// EXPLAIN for extensions @@ -246,7 +237,6 @@ impl fmt::Display for Statement { match self { Statement::Statement(stmt) => write!(f, "{stmt}"), Statement::CreateExternalTable(stmt) => write!(f, "{stmt}"), - Statement::DescribeTableStmt(_) => write!(f, "DESCRIBE TABLE ..."), Statement::CopyTo(stmt) => write!(f, "{stmt}"), Statement::Explain(stmt) => write!(f, "{stmt}"), } @@ -345,10 +335,6 @@ impl<'a> DFParser<'a> { self.parser.next_token(); // COPY self.parse_copy() } - Keyword::DESCRIBE => { - self.parser.next_token(); // DESCRIBE - self.parse_describe() - } Keyword::EXPLAIN => { // (TODO parse all supported statements) self.parser.next_token(); // EXPLAIN @@ -371,14 +357,6 @@ impl<'a> DFParser<'a> { } } - /// Parse a SQL `DESCRIBE` statement - pub fn parse_describe(&mut self) -> Result { - let table_name = self.parser.parse_object_name()?; - Ok(Statement::DescribeTableStmt(DescribeTableStmt { - table_name, - })) - } - /// Parse a SQL `COPY TO` statement pub fn parse_copy(&mut self) -> Result { // parse as a query diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index ca5e260aee050..c5c30e3a22536 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -21,8 +21,9 @@ use std::sync::Arc; use std::vec; use arrow_schema::*; -use datafusion_common::field_not_found; -use datafusion_common::internal_err; +use datafusion_common::{ + field_not_found, internal_err, plan_datafusion_err, SchemaError, +}; use datafusion_expr::WindowUDF; use sqlparser::ast::TimezoneInfo; use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo}; @@ -51,6 +52,15 @@ pub trait ContextProvider { } /// Getter for a datasource fn get_table_source(&self, name: TableReference) -> Result>; + /// Getter for a table function + fn get_table_function_source( + &self, + _name: &str, + _args: Vec, + ) -> Result> { + not_impl_err!("Table Functions are not supported") + } + /// Getter for a UDF description fn get_function_meta(&self, name: &str) -> Option>; /// Getter for a UDAF description @@ -230,6 +240,42 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(Schema::new(fields)) } + /// Returns a vector of (column_name, default_expr) pairs + pub(super) fn build_column_defaults( + &self, + columns: &Vec, + planner_context: &mut PlannerContext, + ) -> Result> { + let mut column_defaults = vec![]; + // Default expressions are restricted, column references are not allowed + let empty_schema = DFSchema::empty(); + let error_desc = |e: DataFusionError| match e { + DataFusionError::SchemaError(SchemaError::FieldNotFound { .. }) => { + plan_datafusion_err!( + "Column reference is not allowed in the DEFAULT expression : {}", + e + ) + } + _ => e, + }; + + for column in columns { + if let Some(default_sql_expr) = + column.options.iter().find_map(|o| match &o.option { + ColumnOption::Default(expr) => Some(expr), + _ => None, + }) + { + let default_expr = self + .sql_to_expr(default_sql_expr.clone(), &empty_schema, planner_context) + .map_err(error_desc)?; + column_defaults + .push((self.normalizer.normalize(column.name.clone()), default_expr)); + } + } + Ok(column_defaults) + } + /// Apply the given TableAlias to the input plan pub(crate) fn apply_table_alias( &self, diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index fc2a3fb9a57b3..dd4cab126261e 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -23,7 +23,7 @@ use datafusion_common::{ not_impl_err, plan_err, sql_err, Constraints, DataFusionError, Result, ScalarValue, }; use datafusion_expr::{ - CreateMemoryTable, DdlStatement, Expr, LogicalPlan, LogicalPlanBuilder, + CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder, }; use sqlparser::ast::{ Expr as SQLExpr, Offset as SQLOffset, OrderByExpr, Query, SetExpr, Value, @@ -90,6 +90,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { input: Arc::new(plan), if_not_exists: false, or_replace: false, + column_defaults: vec![], })) } _ => plan, @@ -160,7 +161,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } let order_by_rex = - self.order_by_to_sort_expr(&order_by, plan.schema(), planner_context)?; - LogicalPlanBuilder::from(plan).sort(order_by_rex)?.build() + self.order_by_to_sort_expr(&order_by, plan.schema(), planner_context, true)?; + + if let LogicalPlan::Distinct(Distinct::On(ref distinct_on)) = plan { + // In case of `DISTINCT ON` we must capture the sort expressions since during the plan + // optimization we're effectively doing a `first_value` aggregation according to them. + let distinct_on = distinct_on.clone().with_sort_expr(order_by_rex)?; + Ok(LogicalPlan::Distinct(Distinct::On(distinct_on))) + } else { + LogicalPlanBuilder::from(plan).sort(order_by_rex)?.build() + } } } diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index 180743d19b7bd..b233f47a058fb 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -16,9 +16,11 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::{not_impl_err, DataFusionError, Result}; +use datafusion_common::{ + not_impl_err, plan_err, DFSchema, DataFusionError, Result, TableReference, +}; use datafusion_expr::{LogicalPlan, LogicalPlanBuilder}; -use sqlparser::ast::TableFactor; +use sqlparser::ast::{FunctionArg, FunctionArgExpr, TableFactor}; mod join; @@ -30,24 +32,58 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context: &mut PlannerContext, ) -> Result { let (plan, alias) = match relation { - TableFactor::Table { name, alias, .. } => { - // normalize name and alias - let table_ref = self.object_name_to_table_reference(name)?; - let table_name = table_ref.to_string(); - let cte = planner_context.get_cte(&table_name); - ( - match ( - cte, - self.context_provider.get_table_source(table_ref.clone()), - ) { - (Some(cte_plan), _) => Ok(cte_plan.clone()), - (_, Ok(provider)) => { - LogicalPlanBuilder::scan(table_ref, provider, None)?.build() - } - (None, Err(e)) => Err(e), - }?, - alias, - ) + TableFactor::Table { + name, alias, args, .. + } => { + if let Some(func_args) = args { + let tbl_func_name = name.0.first().unwrap().value.to_string(); + let args = func_args + .into_iter() + .flat_map(|arg| { + if let FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) = arg + { + self.sql_expr_to_logical_expr( + expr, + &DFSchema::empty(), + planner_context, + ) + } else { + plan_err!("Unsupported function argument type: {:?}", arg) + } + }) + .collect::>(); + let provider = self + .context_provider + .get_table_function_source(&tbl_func_name, args)?; + let plan = LogicalPlanBuilder::scan( + TableReference::Bare { + table: std::borrow::Cow::Borrowed("tmp_table"), + }, + provider, + None, + )? + .build()?; + (plan, alias) + } else { + // normalize name and alias + let table_ref = self.object_name_to_table_reference(name)?; + let table_name = table_ref.to_string(); + let cte = planner_context.get_cte(&table_name); + ( + match ( + cte, + self.context_provider.get_table_source(table_ref.clone()), + ) { + (Some(cte_plan), _) => Ok(cte_plan.clone()), + (_, Ok(provider)) => { + LogicalPlanBuilder::scan(table_ref, provider, None)? + .build() + } + (None, Err(e)) => Err(e), + }?, + alias, + ) + } } TableFactor::Derived { subquery, alias, .. diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 2062afabfc1a4..a0819e4aaf8e8 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -25,10 +25,7 @@ use crate::utils::{ }; use datafusion_common::Column; -use datafusion_common::{ - get_target_functional_dependencies, not_impl_err, plan_err, DFSchemaRef, - DataFusionError, Result, -}; +use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; use datafusion_expr::expr::Alias; use datafusion_expr::expr_rewriter::{ normalize_col, normalize_col_with_schemas_and_ambiguity_check, @@ -76,7 +73,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_)); // process `where` clause - let plan = self.plan_selection(select.selection, plan, planner_context)?; + let base_plan = self.plan_selection(select.selection, plan, planner_context)?; // handle named windows before processing the projection expression check_conflicting_windows(&select.named_window)?; @@ -84,16 +81,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // process the SELECT expressions, with wildcards expanded. let select_exprs = self.prepare_select_exprs( - &plan, + &base_plan, select.projection, empty_from, planner_context, )?; // having and group by clause may reference aliases defined in select projection - let projected_plan = self.project(plan.clone(), select_exprs.clone())?; + let projected_plan = self.project(base_plan.clone(), select_exprs.clone())?; let mut combined_schema = (**projected_plan.schema()).clone(); - combined_schema.merge(plan.schema()); + combined_schema.merge(base_plan.schema()); // this alias map is resolved and looked up in both having exprs and group by exprs let alias_map = extract_aliases(&select_exprs); @@ -148,7 +145,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { )?; // aliases from the projection can conflict with same-named expressions in the input let mut alias_map = alias_map.clone(); - for f in plan.schema().fields() { + for f in base_plan.schema().fields() { alias_map.remove(f.name()); } let group_by_expr = @@ -158,7 +155,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .unwrap_or(group_by_expr); let group_by_expr = normalize_col(group_by_expr, &projected_plan)?; self.validate_schema_satisfies_exprs( - plan.schema(), + base_plan.schema(), &[group_by_expr.clone()], )?; Ok(group_by_expr) @@ -170,11 +167,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { select_exprs .iter() .filter(|select_expr| match select_expr { - Expr::AggregateFunction(_) | Expr::AggregateUDF(_) => false, - Expr::Alias(Alias { expr, name: _ }) => !matches!( - **expr, - Expr::AggregateFunction(_) | Expr::AggregateUDF(_) - ), + Expr::AggregateFunction(_) => false, + Expr::Alias(Alias { expr, name: _, .. }) => { + !matches!(**expr, Expr::AggregateFunction(_)) + } _ => true, }) .cloned() @@ -187,16 +183,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { || !aggr_exprs.is_empty() { self.aggregate( - plan, + &base_plan, &select_exprs, having_expr_opt.as_ref(), - group_by_exprs, - aggr_exprs, + &group_by_exprs, + &aggr_exprs, )? } else { match having_expr_opt { Some(having_expr) => return plan_err!("HAVING clause references: {having_expr} must appear in the GROUP BY clause or be used in an aggregate function"), - None => (plan, select_exprs, having_expr_opt) + None => (base_plan.clone(), select_exprs.clone(), having_expr_opt) } }; @@ -229,19 +225,31 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let plan = project(plan, select_exprs_post_aggr)?; // process distinct clause - let distinct = select - .distinct - .map(|distinct| match distinct { - Distinct::Distinct => Ok(true), - Distinct::On(_) => not_impl_err!("DISTINCT ON Exprs not supported"), - }) - .transpose()? - .unwrap_or(false); + let plan = match select.distinct { + None => Ok(plan), + Some(Distinct::Distinct) => { + LogicalPlanBuilder::from(plan).distinct()?.build() + } + Some(Distinct::On(on_expr)) => { + if !aggr_exprs.is_empty() + || !group_by_exprs.is_empty() + || !window_func_exprs.is_empty() + { + return not_impl_err!("DISTINCT ON expressions with GROUP BY, aggregation or window functions are not supported "); + } - let plan = if distinct { - LogicalPlanBuilder::from(plan).distinct()?.build() - } else { - Ok(plan) + let on_expr = on_expr + .into_iter() + .map(|e| { + self.sql_expr_to_logical_expr(e, plan.schema(), planner_context) + }) + .collect::>>()?; + + // Build the final plan + return LogicalPlanBuilder::from(base_plan) + .distinct_on(on_expr, select_exprs, None)? + .build(); + } }?; // DISTRIBUTE BY @@ -373,7 +381,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &[&[plan.schema()]], &plan.using_columns()?, )?; - let expr = Expr::Alias(Alias::new(col, self.normalizer.normalize(alias))); + let name = self.normalizer.normalize(alias); + // avoiding adding an alias if the column name is the same. + let expr = match &col { + Expr::Column(column) if column.name.eq(&name) => col, + _ => col.alias(name), + }; Ok(vec![expr]) } SelectItem::Wildcard(options) => { @@ -471,6 +484,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .clone(); *expr = Expr::Alias(Alias { expr: Box::new(new_expr), + relation: None, name: name.clone(), }); } @@ -511,20 +525,23 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// the aggregate fn aggregate( &self, - input: LogicalPlan, + input: &LogicalPlan, select_exprs: &[Expr], having_expr_opt: Option<&Expr>, - group_by_exprs: Vec, - aggr_exprs: Vec, + group_by_exprs: &[Expr], + aggr_exprs: &[Expr], ) -> Result<(LogicalPlan, Vec, Option)> { - let group_by_exprs = - get_updated_group_by_exprs(&group_by_exprs, select_exprs, input.schema())?; - // create the aggregate plan let plan = LogicalPlanBuilder::from(input.clone()) - .aggregate(group_by_exprs.clone(), aggr_exprs.clone())? + .aggregate(group_by_exprs.to_vec(), aggr_exprs.to_vec())? .build()?; + let group_by_exprs = if let LogicalPlan::Aggregate(agg) = &plan { + &agg.group_expr + } else { + unreachable!(); + }; + // in this next section of code we are re-writing the projection to refer to columns // output by the aggregate plan. For example, if the projection contains the expression // `SUM(a)` then we replace that with a reference to a column `SUM(a)` produced by @@ -533,7 +550,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // combine the original grouping and aggregate expressions into one list (note that // we do not add the "having" expression since that is not part of the projection) let mut aggr_projection_exprs = vec![]; - for expr in &group_by_exprs { + for expr in group_by_exprs { match expr { Expr::GroupingSet(GroupingSet::Rollup(exprs)) => { aggr_projection_exprs.extend_from_slice(exprs) @@ -549,25 +566,25 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { _ => aggr_projection_exprs.push(expr.clone()), } } - aggr_projection_exprs.extend_from_slice(&aggr_exprs); + aggr_projection_exprs.extend_from_slice(aggr_exprs); // now attempt to resolve columns and replace with fully-qualified columns let aggr_projection_exprs = aggr_projection_exprs .iter() - .map(|expr| resolve_columns(expr, &input)) + .map(|expr| resolve_columns(expr, input)) .collect::>>()?; // next we replace any expressions that are not a column with a column referencing // an output column from the aggregate schema let column_exprs_post_aggr = aggr_projection_exprs .iter() - .map(|expr| expr_as_column_expr(expr, &input)) + .map(|expr| expr_as_column_expr(expr, input)) .collect::>>()?; // next we re-write the projection let select_exprs_post_aggr = select_exprs .iter() - .map(|expr| rebase_expr(expr, &aggr_projection_exprs, &input)) + .map(|expr| rebase_expr(expr, &aggr_projection_exprs, input)) .collect::>>()?; // finally, we have some validation that the re-written projection can be resolved @@ -582,7 +599,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // aggregation. let having_expr_post_aggr = if let Some(having_expr) = having_expr_opt { let having_expr_post_aggr = - rebase_expr(having_expr, &aggr_projection_exprs, &input)?; + rebase_expr(having_expr, &aggr_projection_exprs, input)?; check_columns_satisfy_exprs( &column_exprs_post_aggr, @@ -642,61 +659,3 @@ fn match_window_definitions( } Ok(()) } - -/// Update group by exprs, according to functional dependencies -/// The query below -/// -/// SELECT sn, amount -/// FROM sales_global -/// GROUP BY sn -/// -/// cannot be calculated, because it has a column(`amount`) which is not -/// part of group by expression. -/// However, if we know that, `sn` is determinant of `amount`. We can -/// safely, determine value of `amount` for each distinct `sn`. For these cases -/// we rewrite the query above as -/// -/// SELECT sn, amount -/// FROM sales_global -/// GROUP BY sn, amount -/// -/// Both queries, are functionally same. \[Because, (`sn`, `amount`) and (`sn`) -/// defines the identical groups. \] -/// This function updates group by expressions such that select expressions that are -/// not in group by expression, are added to the group by expressions if they are dependent -/// of the sub-set of group by expressions. -fn get_updated_group_by_exprs( - group_by_exprs: &[Expr], - select_exprs: &[Expr], - schema: &DFSchemaRef, -) -> Result> { - let mut new_group_by_exprs = group_by_exprs.to_vec(); - let fields = schema.fields(); - let group_by_expr_names = group_by_exprs - .iter() - .map(|group_by_expr| group_by_expr.display_name()) - .collect::>>()?; - // Get targets that can be used in a select, even if they do not occur in aggregation: - if let Some(target_indices) = - get_target_functional_dependencies(schema, &group_by_expr_names) - { - // Calculate dependent fields names with determinant GROUP BY expression: - let associated_field_names = target_indices - .iter() - .map(|idx| fields[*idx].qualified_name()) - .collect::>(); - // Expand GROUP BY expressions with select expressions: If a GROUP - // BY expression is a determinant key, we can use its dependent - // columns in select statements also. - for expr in select_exprs { - let expr_name = format!("{}", expr); - if !new_group_by_exprs.contains(expr) - && associated_field_names.contains(&expr_name) - { - new_group_by_exprs.push(expr.clone()); - } - } - } - - Ok(new_group_by_exprs) -} diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 9d9c55361a5e9..b96553ffbf860 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -19,8 +19,8 @@ use std::collections::{BTreeMap, HashMap, HashSet}; use std::sync::Arc; use crate::parser::{ - CopyToSource, CopyToStatement, CreateExternalTable, DFParser, DescribeTableStmt, - ExplainStatement, LexOrdering, Statement as DFStatement, + CopyToSource, CopyToStatement, CreateExternalTable, DFParser, ExplainStatement, + LexOrdering, Statement as DFStatement, }; use crate::planner::{ object_name_to_qualifier, ContextProvider, PlannerContext, SqlToRel, @@ -31,9 +31,9 @@ use arrow_schema::DataType; use datafusion_common::file_options::StatementOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ - not_impl_err, plan_datafusion_err, plan_err, unqualified_field_not_found, + not_impl_err, plan_datafusion_err, plan_err, unqualified_field_not_found, Column, Constraints, DFField, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, - Result, SchemaReference, TableReference, ToDFSchema, + Result, ScalarValue, SchemaReference, TableReference, ToDFSchema, }; use datafusion_expr::dml::{CopyOptions, CopyTo}; use datafusion_expr::expr_rewriter::normalize_col_with_schemas_and_ambiguity_check; @@ -136,7 +136,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { match statement { DFStatement::CreateExternalTable(s) => self.external_table_to_plan(s), DFStatement::Statement(s) => self.sql_statement_to_plan(*s), - DFStatement::DescribeTableStmt(s) => self.describe_table_to_plan(s), DFStatement::CopyTo(s) => self.copy_to_plan(s), DFStatement::Explain(ExplainStatement { verbose, @@ -170,6 +169,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { let sql = Some(statement.to_string()); match statement { + Statement::ExplainTable { + describe_alias: true, // only parse 'DESCRIBE table_name' and not 'EXPLAIN table_name' + table_name, + } => self.describe_table_to_plan(table_name), Statement::Explain { verbose, statement, @@ -204,6 +207,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut all_constraints = constraints; let inline_constraints = calc_inline_constraints_from_columns(&columns); all_constraints.extend(inline_constraints); + // Build column default values + let column_defaults = + self.build_column_defaults(&columns, planner_context)?; match query { Some(query) => { let plan = self.query_to_plan(*query, planner_context)?; @@ -250,6 +256,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { input: Arc::new(plan), if_not_exists, or_replace, + column_defaults, }, ))) } @@ -272,6 +279,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { input: Arc::new(plan), if_not_exists, or_replace, + column_defaults, }, ))) } @@ -453,6 +461,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if ignore { plan_err!("Insert-ignore clause not supported")?; } + let Some(source) = source else { + plan_err!("Inserts without a source not supported")? + }; let _ = into; // optional keyword doesn't change behavior self.insert_to_plan(table_name, columns, source, overwrite) } @@ -505,7 +516,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Statement::StartTransaction { modes, begin: false, + modifier, } => { + if let Some(modifier) = modifier { + return not_impl_err!( + "Transaction modifier not supported: {modifier}" + ); + } let isolation_level: ast::TransactionIsolationLevel = modes .iter() .filter_map(|m: &ast::TransactionMode| match m { @@ -561,7 +578,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }); Ok(LogicalPlan::Statement(statement)) } - Statement::Rollback { chain } => { + Statement::Rollback { chain, savepoint } => { + if savepoint.is_some() { + plan_err!("Savepoints not supported")?; + } let statement = PlanStatement::TransactionEnd(TransactionEnd { conclusion: TransactionConclusion::Rollback, chain, @@ -618,11 +638,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - fn describe_table_to_plan( - &self, - statement: DescribeTableStmt, - ) -> Result { - let DescribeTableStmt { table_name } = statement; + fn describe_table_to_plan(&self, table_name: ObjectName) -> Result { let table_ref = self.object_name_to_table_reference(table_name)?; let table_source = self.context_provider.get_table_source(table_ref)?; @@ -699,7 +715,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut all_results = vec![]; for expr in order_exprs { // Convert each OrderByExpr to a SortExpr: - let expr_vec = self.order_by_to_sort_expr(&expr, schema, planner_context)?; + let expr_vec = + self.order_by_to_sort_expr(&expr, schema, planner_context, true)?; // Verify that columns of all SortExprs exist in the schema: for expr in expr_vec.iter() { for column in expr.to_columns()?.iter() { @@ -750,11 +767,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { )?; } + let mut planner_context = PlannerContext::new(); + + let column_defaults = self + .build_column_defaults(&columns, &mut planner_context)? + .into_iter() + .collect(); + let schema = self.build_schema(columns)?; let df_schema = schema.to_dfschema_ref()?; let ordered_exprs = - self.build_order_by(order_exprs, &df_schema, &mut PlannerContext::new())?; + self.build_order_by(order_exprs, &df_schema, &mut planner_context)?; // External tables do not support schemas at the moment, so the name is just a table name let name = OwnedTableReference::bare(name); @@ -776,6 +800,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { unbounded, options, constraints, + column_defaults, }, ))) } @@ -970,8 +995,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { from: Option, predicate_expr: Option, ) -> Result { - let table_name = match &table.relation { - TableFactor::Table { name, .. } => name.clone(), + let (table_name, table_alias) = match &table.relation { + TableFactor::Table { name, alias, .. } => (name.clone(), alias.clone()), _ => plan_err!("Cannot update non-table relation!")?, }; @@ -1017,7 +1042,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { expr_to_columns(&filter_expr, &mut using_columns)?; let filter_expr = normalize_col_with_schemas_and_ambiguity_check( filter_expr, - &[&[&scan.schema()]], + &[&[scan.schema()]], &[using_columns], )?; LogicalPlan::Filter(Filter::try_new(filter_expr, Arc::new(scan))?) @@ -1047,7 +1072,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Cast to target column type, if necessary expr.cast_to(field.data_type(), source.schema())? } - None => datafusion_expr::Expr::Column(field.qualified_column()), + None => { + // If the target table has an alias, use it to qualify the column name + if let Some(alias) = &table_alias { + datafusion_expr::Expr::Column(Column::new( + Some(self.normalizer.normalize(alias.name.clone())), + field.name(), + )) + } else { + datafusion_expr::Expr::Column(field.qualified_column()) + } + } }; Ok(expr.alias(field.name())) }) @@ -1077,9 +1112,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let arrow_schema = (*table_source.schema()).clone(); let table_schema = DFSchema::try_from(arrow_schema)?; - // Get insert fields and index_mapping - // The i-th field of the table is `fields[index_mapping[i]]` - let (fields, index_mapping) = if columns.is_empty() { + // Get insert fields and target table's value indices + // + // if value_indices[i] = Some(j), it means that the value of the i-th target table's column is + // derived from the j-th output of the source. + // + // if value_indices[i] = None, it means that the value of the i-th target table's column is + // not provided, and should be filled with a default value later. + let (fields, value_indices) = if columns.is_empty() { // Empty means we're inserting into all columns of the table ( table_schema.fields().clone(), @@ -1088,7 +1128,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .collect::>(), ) } else { - let mut mapping = vec![None; table_schema.fields().len()]; + let mut value_indices = vec![None; table_schema.fields().len()]; let fields = columns .into_iter() .map(|c| self.normalizer.normalize(c)) @@ -1097,19 +1137,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let column_index = table_schema .index_of_column_by_name(None, &c)? .ok_or_else(|| unqualified_field_not_found(&c, &table_schema))?; - if mapping[column_index].is_some() { + if value_indices[column_index].is_some() { return Err(DataFusionError::SchemaError( datafusion_common::SchemaError::DuplicateUnqualifiedField { name: c, }, )); } else { - mapping[column_index] = Some(i); + value_indices[column_index] = Some(i); } Ok(table_schema.field(column_index).clone()) }) .collect::>>()?; - (fields, mapping) + (fields, value_indices) }; // infer types for Values clause... other types should be resolvable the regular way @@ -1144,17 +1184,28 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan_err!("Column count doesn't match insert query!")?; } - let exprs = index_mapping + let exprs = value_indices .into_iter() - .flatten() - .map(|i| { - let target_field = &fields[i]; - let source_field = source.schema().field(i); - let expr = - datafusion_expr::Expr::Column(source_field.unqualified_column()) - .cast_to(target_field.data_type(), source.schema())? - .alias(target_field.name()); - Ok(expr) + .enumerate() + .map(|(i, value_index)| { + let target_field = table_schema.field(i); + let expr = match value_index { + Some(v) => { + let source_field = source.schema().field(v); + datafusion_expr::Expr::Column(source_field.qualified_column()) + .cast_to(target_field.data_type(), source.schema())? + } + // The value is not specified. Fill in the default value for the column. + None => table_source + .get_column_default(target_field.name()) + .cloned() + .unwrap_or_else(|| { + // If there is no default for the column, then the default is NULL + datafusion_expr::Expr::Literal(ScalarValue::Null) + }) + .cast_to(target_field.data_type(), &DFSchema::empty())?, + }; + Ok(expr.alias(target_field.name())) }) .collect::>>()?; let source = project(source, exprs)?; diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index ff6dca7eef2a8..48ba50145308c 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -22,11 +22,11 @@ use std::{sync::Arc, vec}; use arrow_schema::*; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; -use datafusion_common::plan_err; use datafusion_common::{ assert_contains, config::ConfigOptions, DataFusionError, Result, ScalarValue, TableReference, }; +use datafusion_common::{plan_err, ParamValues}; use datafusion_expr::{ logical_plan::{LogicalPlan, Prepare}, AggregateUDF, ScalarUDF, TableSource, WindowUDF, @@ -422,12 +422,11 @@ CopyTo: format=csv output_url=output.csv single_file_output=true options: () fn plan_insert() { let sql = "insert into person (id, first_name, last_name) values (1, 'Alan', 'Turing')"; - let plan = r#" -Dml: op=[Insert Into] table=[person] - Projection: CAST(column1 AS UInt32) AS id, column2 AS first_name, column3 AS last_name - Values: (Int64(1), Utf8("Alan"), Utf8("Turing")) - "# - .trim(); + let plan = "Dml: op=[Insert Into] table=[person]\ + \n Projection: CAST(column1 AS UInt32) AS id, column2 AS first_name, column3 AS last_name, \ + CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, \ + CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀\ + \n Values: (Int64(1), Utf8(\"Alan\"), Utf8(\"Turing\"))"; quick_test(sql, plan); } @@ -472,6 +471,10 @@ Dml: op=[Insert Into] table=[test_decimal] "INSERT INTO person (id, first_name, last_name) VALUES ($2, $4, $6)", "Error during planning: Placeholder type could not be resolved" )] +#[case::placeholder_type_unresolved( + "INSERT INTO person (id, first_name, last_name) VALUES ($id, $first_name, $last_name)", + "Error during planning: Can't parse placeholder: $id" +)] #[test] fn test_insert_schema_errors(#[case] sql: &str, #[case] error: &str) { let err = logical_plan(sql).unwrap_err(); @@ -607,11 +610,9 @@ fn select_compound_filter() { #[test] fn test_timestamp_filter() { let sql = "SELECT state FROM person WHERE birth_date < CAST (158412331400600000 as timestamp)"; - let expected = "Projection: person.state\ - \n Filter: person.birth_date < CAST(Int64(158412331400600000) AS Timestamp(Nanosecond, None))\ + \n Filter: person.birth_date < CAST(CAST(Int64(158412331400600000) AS Timestamp(Second, None)) AS Timestamp(Nanosecond, None))\ \n TableScan: person"; - quick_test(sql, expected); } @@ -1384,18 +1385,6 @@ fn select_interval_out_of_range() { ); } -#[test] -fn select_array_no_common_type() { - let sql = "SELECT [1, true, null]"; - let err = logical_plan(sql).expect_err("query should have failed"); - - // HashSet doesn't guarantee order - assert_contains!( - err.strip_backtrace(), - "This feature is not implemented: Arrays with different types are not supported: " - ); -} - #[test] fn recursive_ctes() { let sql = " @@ -1412,16 +1401,6 @@ fn recursive_ctes() { ); } -#[test] -fn select_array_non_literal_type() { - let sql = "SELECT [now()]"; - let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - "This feature is not implemented: Arrays with elements other than literal are not supported: now()", - err.strip_backtrace() - ); -} - #[test] fn select_simple_aggregate_with_groupby_and_column_is_in_aggregate_and_groupby() { quick_test( @@ -2699,7 +2678,7 @@ fn prepare_stmt_quick_test( fn prepare_stmt_replace_params_quick_test( plan: LogicalPlan, - param_values: Vec, + param_values: impl Into, expected_plan: &str, ) -> LogicalPlan { // replace params @@ -3567,13 +3546,24 @@ fn test_select_unsupported_syntax_errors(#[case] sql: &str, #[case] error: &str) fn select_order_by_with_cast() { let sql = "SELECT first_name AS first_name FROM (SELECT first_name AS first_name FROM person) ORDER BY CAST(first_name as INT)"; - let expected = "Sort: CAST(first_name AS first_name AS Int32) ASC NULLS LAST\ - \n Projection: first_name AS first_name\ - \n Projection: person.first_name AS first_name\ + let expected = "Sort: CAST(person.first_name AS Int32) ASC NULLS LAST\ + \n Projection: person.first_name\ + \n Projection: person.first_name\ \n TableScan: person"; quick_test(sql, expected); } +#[test] +fn test_avoid_add_alias() { + // avoiding adding an alias if the column name is the same. + // plan1 = plan2 + let sql = "select person.id as id from person order by person.id"; + let plan1 = logical_plan(sql).unwrap(); + let sql = "select id from person order by id"; + let plan2 = logical_plan(sql).unwrap(); + assert_eq!(format!("{plan1:?}"), format!("{plan2:?}")); +} + #[test] fn test_duplicated_left_join_key_inner_join() { // person.id * 2 happen twice in left side. @@ -3751,7 +3741,7 @@ fn test_prepare_statement_to_plan_no_param() { /////////////////// // replace params with values - let param_values = vec![]; + let param_values: Vec = vec![]; let expected_plan = "Projection: person.id, person.age\ \n Filter: person.age = Int64(10)\ \n TableScan: person"; @@ -3765,7 +3755,7 @@ fn test_prepare_statement_to_plan_one_param_no_value_panic() { let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; let plan = logical_plan(sql).unwrap(); // declare 1 param but provide 0 - let param_values = vec![]; + let param_values: Vec = vec![]; assert_eq!( plan.with_param_values(param_values) .unwrap_err() @@ -3878,7 +3868,7 @@ Projection: person.id, orders.order_id assert_eq!(actual_types, expected_types); // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; + let param_values = vec![ScalarValue::Int32(Some(10))].into(); let expected_plan = r#" Projection: person.id, orders.order_id Inner Join: Filter: person.id = orders.customer_id AND person.age = Int32(10) @@ -3910,7 +3900,7 @@ Projection: person.id, person.age assert_eq!(actual_types, expected_types); // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; + let param_values = vec![ScalarValue::Int32(Some(10))].into(); let expected_plan = r#" Projection: person.id, person.age Filter: person.age = Int32(10) @@ -3944,7 +3934,8 @@ Projection: person.id, person.age assert_eq!(actual_types, expected_types); // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))]; + let param_values = + vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))].into(); let expected_plan = r#" Projection: person.id, person.age Filter: person.age BETWEEN Int32(10) AND Int32(30) @@ -3980,7 +3971,7 @@ Projection: person.id, person.age assert_eq!(actual_types, expected_types); // replace params with values - let param_values = vec![ScalarValue::UInt32(Some(10))]; + let param_values = vec![ScalarValue::UInt32(Some(10))].into(); let expected_plan = r#" Projection: person.id, person.age Filter: person.age = () @@ -4020,7 +4011,8 @@ Dml: op=[Update] table=[person] assert_eq!(actual_types, expected_types); // replace params with values - let param_values = vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))]; + let param_values = + vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))].into(); let expected_plan = r#" Dml: op=[Update] table=[person] Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, Int32(42) AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 @@ -4037,12 +4029,11 @@ Dml: op=[Update] table=[person] fn test_prepare_statement_insert_infer() { let sql = "insert into person (id, first_name, last_name) values ($1, $2, $3)"; - let expected_plan = r#" -Dml: op=[Insert Into] table=[person] - Projection: column1 AS id, column2 AS first_name, column3 AS last_name - Values: ($1, $2, $3) - "# - .trim(); + let expected_plan = "Dml: op=[Insert Into] table=[person]\ + \n Projection: column1 AS id, column2 AS first_name, column3 AS last_name, \ + CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, \ + CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀\ + \n Values: ($1, $2, $3)"; let expected_dt = "[Int32]"; let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); @@ -4058,15 +4049,15 @@ Dml: op=[Insert Into] table=[person] // replace params with values let param_values = vec![ ScalarValue::UInt32(Some(1)), - ScalarValue::Utf8(Some("Alan".to_string())), - ScalarValue::Utf8(Some("Turing".to_string())), - ]; - let expected_plan = r#" -Dml: op=[Insert Into] table=[person] - Projection: column1 AS id, column2 AS first_name, column3 AS last_name - Values: (UInt32(1), Utf8("Alan"), Utf8("Turing")) - "# - .trim(); + ScalarValue::from("Alan"), + ScalarValue::from("Turing"), + ] + .into(); + let expected_plan = "Dml: op=[Insert Into] table=[person]\ + \n Projection: column1 AS id, column2 AS first_name, column3 AS last_name, \ + CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, \ + CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀\ + \n Values: (UInt32(1), Utf8(\"Alan\"), Utf8(\"Turing\"))"; let plan = plan.replace_params_with_values(¶m_values).unwrap(); prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); @@ -4140,11 +4131,11 @@ fn test_prepare_statement_to_plan_multi_params() { // replace params with values let param_values = vec![ ScalarValue::Int32(Some(10)), - ScalarValue::Utf8(Some("abc".to_string())), + ScalarValue::from("abc"), ScalarValue::Float64(Some(100.0)), ScalarValue::Int32(Some(20)), ScalarValue::Float64(Some(200.0)), - ScalarValue::Utf8(Some("xyz".to_string())), + ScalarValue::from("xyz"), ]; let expected_plan = "Projection: person.id, person.age, Utf8(\"xyz\")\ @@ -4210,8 +4201,8 @@ fn test_prepare_statement_to_plan_value_list() { /////////////////// // replace params with values let param_values = vec![ - ScalarValue::Utf8(Some("a".to_string())), - ScalarValue::Utf8(Some("b".to_string())), + ScalarValue::from("a".to_string()), + ScalarValue::from("b".to_string()), ]; let expected_plan = "Projection: t.num, t.letter\ \n SubqueryAlias: t\ diff --git a/datafusion/sqllogictest/Cargo.toml b/datafusion/sqllogictest/Cargo.toml index d27e88274f8fd..e333dc816f666 100644 --- a/datafusion/sqllogictest/Cargo.toml +++ b/datafusion/sqllogictest/Cargo.toml @@ -36,7 +36,7 @@ async-trait = { workspace = true } bigdecimal = { workspace = true } bytes = { version = "1.4.0", optional = true } chrono = { workspace = true, optional = true } -datafusion = { path = "../core", version = "33.0.0" } +datafusion = { path = "../core", version = "34.0.0" } datafusion-common = { workspace = true } futures = { version = "0.3.28" } half = { workspace = true } @@ -46,7 +46,7 @@ object_store = { workspace = true } postgres-protocol = { version = "0.6.4", optional = true } postgres-types = { version = "0.2.4", optional = true } rust_decimal = { version = "1.27.0" } -sqllogictest = "0.17.0" +sqllogictest = "0.19.0" sqlparser = { workspace = true } tempfile = { workspace = true } thiserror = { workspace = true } diff --git a/datafusion/sqllogictest/README.md b/datafusion/sqllogictest/README.md index 0349ed852f468..bda00a2dce0f8 100644 --- a/datafusion/sqllogictest/README.md +++ b/datafusion/sqllogictest/README.md @@ -240,7 +240,7 @@ query - NULL values are rendered as `NULL`, - empty strings are rendered as `(empty)`, - boolean values are rendered as `true`/`false`, - - this list can be not exhaustive, check the `datafusion/core/tests/sqllogictests/src/engines/conversion.rs` for + - this list can be not exhaustive, check the `datafusion/sqllogictest/src/engines/conversion.rs` for details. - `sort_mode`: If included, it must be one of `nosort` (**default**), `rowsort`, or `valuesort`. In `nosort` mode, the results appear in exactly the order in which they were received from the database engine. The `nosort` mode should diff --git a/datafusion/sqllogictest/bin/sqllogictests.rs b/datafusion/sqllogictest/bin/sqllogictests.rs index 618e3106c6292..aeb1cc4ec9195 100644 --- a/datafusion/sqllogictest/bin/sqllogictests.rs +++ b/datafusion/sqllogictest/bin/sqllogictests.rs @@ -26,7 +26,7 @@ use futures::stream::StreamExt; use log::info; use sqllogictest::strict_column_validator; -use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_common::{exec_datafusion_err, exec_err, DataFusionError, Result}; const TEST_DIRECTORY: &str = "test_files/"; const PG_COMPAT_FILE_PREFIX: &str = "pg_compat_"; @@ -84,7 +84,7 @@ async fn run_tests() -> Result<()> { // Doing so is safe because each slt file runs with its own // `SessionContext` and should not have side effects (like // modifying shared state like `/tmp/`) - let errors: Vec<_> = futures::stream::iter(read_test_files(&options)) + let errors: Vec<_> = futures::stream::iter(read_test_files(&options)?) .map(|test_file| { tokio::task::spawn(async move { println!("Running {:?}", test_file.relative_path); @@ -159,6 +159,7 @@ async fn run_test_file_with_postgres(test_file: TestFile) -> Result<()> { relative_path, } = test_file; info!("Running with Postgres runner: {}", path.display()); + setup_scratch_dir(&relative_path)?; let mut runner = sqllogictest::Runner::new(|| Postgres::connect(relative_path.clone())); runner.with_column_validator(strict_column_validator); @@ -188,6 +189,7 @@ async fn run_complete_file(test_file: TestFile) -> Result<()> { info!("Skipping: {}", path.display()); return Ok(()); }; + setup_scratch_dir(&relative_path)?; let mut runner = sqllogictest::Runner::new(|| async { Ok(DataFusion::new( test_ctx.session_ctx().clone(), @@ -245,30 +247,45 @@ impl TestFile { } } -fn read_test_files<'a>(options: &'a Options) -> Box + 'a> { - Box::new( - read_dir_recursive(TEST_DIRECTORY) +fn read_test_files<'a>( + options: &'a Options, +) -> Result + 'a>> { + Ok(Box::new( + read_dir_recursive(TEST_DIRECTORY)? + .into_iter() .map(TestFile::new) .filter(|f| options.check_test_file(&f.relative_path)) .filter(|f| f.is_slt_file()) .filter(|f| f.check_tpch(options)) .filter(|f| options.check_pg_compat_file(f.path.as_path())), - ) + )) } -fn read_dir_recursive>(path: P) -> Box> { - Box::new( - std::fs::read_dir(path) - .expect("Readable directory") - .map(|path| path.expect("Readable entry").path()) - .flat_map(|path| { - if path.is_dir() { - read_dir_recursive(path) - } else { - Box::new(std::iter::once(path)) - } - }), - ) +fn read_dir_recursive>(path: P) -> Result> { + let mut dst = vec![]; + read_dir_recursive_impl(&mut dst, path.as_ref())?; + Ok(dst) +} + +/// Append all paths recursively to dst +fn read_dir_recursive_impl(dst: &mut Vec, path: &Path) -> Result<()> { + let entries = std::fs::read_dir(path) + .map_err(|e| exec_datafusion_err!("Error reading directory {path:?}: {e}"))?; + for entry in entries { + let path = entry + .map_err(|e| { + exec_datafusion_err!("Error reading entry in directory {path:?}: {e}") + })? + .path(); + + if path.is_dir() { + read_dir_recursive_impl(dst, &path)?; + } else { + dst.push(path); + } + } + + Ok(()) } /// Parsed command line options diff --git a/datafusion/sqllogictest/src/engines/datafusion_engine/mod.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/mod.rs index 663bbdd5a3c7c..8e2bbbfe4f697 100644 --- a/datafusion/sqllogictest/src/engines/datafusion_engine/mod.rs +++ b/datafusion/sqllogictest/src/engines/datafusion_engine/mod.rs @@ -21,5 +21,4 @@ mod normalize; mod runner; pub use error::*; -pub use normalize::*; pub use runner::*; diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index b2314f34f3601..a5ce7ccb9fe08 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -15,30 +15,33 @@ // specific language governing permissions and limitations // under the License. -use async_trait::async_trait; +use std::collections::HashMap; +use std::fs::File; +use std::io::Write; +use std::path::Path; +use std::sync::Arc; + +use arrow::array::{ + ArrayRef, BinaryArray, Float64Array, Int32Array, LargeBinaryArray, LargeStringArray, + StringArray, TimestampNanosecondArray, +}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; +use arrow::record_batch::RecordBatch; use datafusion::execution::context::SessionState; -use datafusion::logical_expr::Expr; +use datafusion::logical_expr::{create_udf, Expr, ScalarUDF, Volatility}; +use datafusion::physical_expr::functions::make_scalar_function; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionConfig; use datafusion::{ - arrow::{ - array::{ - BinaryArray, Float64Array, Int32Array, LargeBinaryArray, LargeStringArray, - StringArray, TimestampNanosecondArray, - }, - datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}, - record_batch::RecordBatch, - }, catalog::{schema::MemorySchemaProvider, CatalogProvider, MemoryCatalogProvider}, datasource::{MemTable, TableProvider, TableType}, prelude::{CsvReadOptions, SessionContext}, }; +use datafusion_common::cast::as_float64_array; use datafusion_common::DataFusionError; + +use async_trait::async_trait; use log::info; -use std::fs::File; -use std::io::Write; -use std::path::Path; -use std::sync::Arc; use tempfile::TempDir; /// Context for running tests @@ -57,8 +60,8 @@ impl TestContext { } } - /// Create a SessionContext, configured for the specific test, if - /// possible. + /// Create a SessionContext, configured for the specific sqllogictest + /// test(.slt file) , if possible. /// /// If `None` is returned (e.g. because some needed feature is not /// enabled), the file should be skipped @@ -67,7 +70,7 @@ impl TestContext { // hardcode target partitions so plans are deterministic .with_target_partitions(4); - let test_ctx = TestContext::new(SessionContext::new_with_config(config)); + let mut test_ctx = TestContext::new(SessionContext::new_with_config(config)); let file_name = relative_path.file_name().unwrap().to_str().unwrap(); match file_name { @@ -83,13 +86,15 @@ impl TestContext { info!("Registering table with many types"); register_table_with_many_types(test_ctx.session_ctx()).await; } + "map.slt" => { + info!("Registering table with map"); + register_table_with_map(test_ctx.session_ctx()).await; + } "avro.slt" => { #[cfg(feature = "avro")] { - let mut test_ctx = test_ctx; info!("Registering avro tables"); register_avro_tables(&mut test_ctx).await; - return Some(test_ctx); } #[cfg(not(feature = "avro"))] { @@ -99,10 +104,13 @@ impl TestContext { } "joins.slt" => { info!("Registering partition table tables"); - - let mut test_ctx = test_ctx; + let example_udf = create_example_udf(); + test_ctx.ctx.register_udf(example_udf); register_partition_table(&mut test_ctx).await; - return Some(test_ctx); + } + "metadata.slt" => { + info!("Registering metadata table tables"); + register_metadata_tables(test_ctx.session_ctx()).await; } _ => { info!("Using default SessionContext"); @@ -268,6 +276,23 @@ pub async fn register_table_with_many_types(ctx: &SessionContext) { .unwrap(); } +pub async fn register_table_with_map(ctx: &SessionContext) { + let key = Field::new("key", DataType::Int64, false); + let value = Field::new("value", DataType::Int64, true); + let map_field = + Field::new("entries", DataType::Struct(vec![key, value].into()), false); + let fields = vec![ + Field::new("int_field", DataType::Int64, true), + Field::new("map_field", DataType::Map(map_field.into(), false), true), + ]; + let schema = Schema::new(fields); + + let memory_table = MemTable::try_new(schema.into(), vec![vec![]]).unwrap(); + + ctx.register_table("table_with_map", Arc::new(memory_table)) + .unwrap(); +} + fn table_with_many_types() -> Arc { let schema = Schema::new(vec![ Field::new("int32_col", DataType::Int32, false), @@ -299,3 +324,58 @@ fn table_with_many_types() -> Arc { let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]]).unwrap(); Arc::new(provider) } + +/// Registers a table_with_metadata that contains both field level and Table level metadata +pub async fn register_metadata_tables(ctx: &SessionContext) { + let id = Field::new("id", DataType::Int32, true).with_metadata(HashMap::from([( + String::from("metadata_key"), + String::from("the id field"), + )])); + let name = Field::new("name", DataType::Utf8, true).with_metadata(HashMap::from([( + String::from("metadata_key"), + String::from("the name field"), + )])); + + let schema = Schema::new(vec![id, name]).with_metadata(HashMap::from([( + String::from("metadata_key"), + String::from("the entire schema"), + )])); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])) as _, + Arc::new(StringArray::from(vec![None, Some("bar"), Some("baz")])) as _, + ], + ) + .unwrap(); + + ctx.register_batch("table_with_metadata", batch).unwrap(); +} + +/// Create a UDF function named "example". See the `sample_udf.rs` example +/// file for an explanation of the API. +fn create_example_udf() -> ScalarUDF { + let adder = make_scalar_function(|args: &[ArrayRef]| { + let lhs = as_float64_array(&args[0]).expect("cast failed"); + let rhs = as_float64_array(&args[1]).expect("cast failed"); + let array = lhs + .iter() + .zip(rhs.iter()) + .map(|(lhs, rhs)| match (lhs, rhs) { + (Some(lhs), Some(rhs)) => Some(lhs + rhs), + _ => None, + }) + .collect::(); + Ok(Arc::new(array) as ArrayRef) + }); + create_udf( + "example", + // Expects two f64 values: + vec![DataType::Float64, DataType::Float64], + // Returns an f64 value: + Arc::new(DataType::Float64), + Volatility::Immutable, + adder, + ) +} diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 6217f12279a94..78575c9dffc51 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -106,6 +106,36 @@ FROM ---- [0VVIHzxWtNOFLtnhjHEKjXaJOSLJfm, 0keZ5G8BffGwgF2RwQD59TFzMStxCB, 0og6hSkhbX8AC1ktFS4kounvTzy8Vo, 1aOcrEGd0cOqZe2I5XBOm0nDcwtBZO, 2T3wSlHdEmASmO0xcXHnndkKEt6bz8] +statement ok +CREATE EXTERNAL TABLE agg_order ( +c1 INT NOT NULL, +c2 INT NOT NULL, +c3 INT NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../core/tests/data/aggregate_agg_multi_order.csv'; + +# test array_agg with order by multiple columns +query ? +select array_agg(c1 order by c2 desc, c3) from agg_order; +---- +[5, 6, 7, 8, 9, 1, 2, 3, 4, 10] + +query TT +explain select array_agg(c1 order by c2 desc, c3) from agg_order; +---- +logical_plan +Aggregate: groupBy=[[]], aggr=[[ARRAY_AGG(agg_order.c1) ORDER BY [agg_order.c2 DESC NULLS FIRST, agg_order.c3 ASC NULLS LAST]]] +--TableScan: agg_order projection=[c1, c2, c3] +physical_plan +AggregateExec: mode=Final, gby=[], aggr=[ARRAY_AGG(agg_order.c1)] +--CoalescePartitionsExec +----AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(agg_order.c1)] +------SortExec: expr=[c2@1 DESC,c3@2 ASC NULLS LAST] +--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/aggregate_agg_multi_order.csv]]}, projection=[c1, c2, c3], has_header=true + statement error This feature is not implemented: LIMIT not supported in ARRAY_AGG: 1 SELECT array_agg(c13 LIMIT 1) FROM aggregate_test_100 @@ -1327,36 +1357,128 @@ select avg(c1), arrow_typeof(avg(c1)) from d_table ---- 5 Decimal128(14, 7) -# FIX: different test table + # aggregate -# query I -# SELECT SUM(c1), SUM(c2) FROM test -# ---- -# 60 220 +query II +SELECT SUM(c1), SUM(c2) FROM test +---- +7 6 -# TODO: aggregate_empty +# aggregate_empty -# TODO: aggregate_avg +query II +SELECT SUM(c1), SUM(c2) FROM test where c1 > 100000 +---- +NULL NULL -# TODO: aggregate_max +# aggregate_avg +query RR +SELECT AVG(c1), AVG(c2) FROM test +---- +1.75 1.5 -# TODO: aggregate_min +# aggregate_max +query II +SELECT MAX(c1), MAX(c2) FROM test +---- +3 2 -# TODO: aggregate_grouped +# aggregate_min +query II +SELECT MIN(c1), MIN(c2) FROM test +---- +0 1 -# TODO: aggregate_grouped_avg +# aggregate_grouped +query II +SELECT c1, SUM(c2) FROM test GROUP BY c1 order by c1 +---- +0 NULL +1 1 +3 4 +NULL 1 -# TODO: aggregate_grouped_empty +# aggregate_grouped_avg +query IR +SELECT c1, AVG(c2) FROM test GROUP BY c1 order by c1 +---- +0 NULL +1 1 +3 2 +NULL 1 -# TODO: aggregate_grouped_max +# aggregate_grouped_empty +query IR +SELECT c1, AVG(c2) FROM test WHERE c1 = 123 GROUP BY c1 +---- -# TODO: aggregate_grouped_min +# aggregate_grouped_max +query II +SELECT c1, MAX(c2) FROM test GROUP BY c1 order by c1 +---- +0 NULL +1 1 +3 2 +NULL 1 -# TODO: aggregate_avg_add +# aggregate_grouped_min +query II +SELECT c1, MIN(c2) FROM test GROUP BY c1 order by c1 +---- +0 NULL +1 1 +3 2 +NULL 1 -# TODO: case_sensitive_identifiers_aggregates +# aggregate_min_max_w_custom_window_frames +query RR +SELECT +MIN(c12) OVER (ORDER BY C12 RANGE BETWEEN 0.3 PRECEDING AND 0.2 FOLLOWING) as min1, +MAX(c12) OVER (ORDER BY C11 RANGE BETWEEN 0.1 PRECEDING AND 0.2 FOLLOWING) as max1 +FROM aggregate_test_100 +ORDER BY C9 +LIMIT 5 +---- +0.014793053078 0.996540038759 +0.014793053078 0.980019341044 +0.014793053078 0.970671228336 +0.266717779508 0.996540038759 +0.360076636233 0.970671228336 -# TODO: count_basic +# aggregate_min_max_with_custom_window_frames_unbounded_start +query RR +SELECT +MIN(c12) OVER (ORDER BY C12 RANGE BETWEEN UNBOUNDED PRECEDING AND 0.2 FOLLOWING) as min1, +MAX(c12) OVER (ORDER BY C11 RANGE BETWEEN UNBOUNDED PRECEDING AND 0.2 FOLLOWING) as max1 +FROM aggregate_test_100 +ORDER BY C9 +LIMIT 5 +---- +0.014793053078 0.996540038759 +0.014793053078 0.980019341044 +0.014793053078 0.980019341044 +0.014793053078 0.996540038759 +0.014793053078 0.980019341044 + +# aggregate_avg_add +query RRRR +SELECT AVG(c1), AVG(c1) + 1, AVG(c1) + 2, 1 + AVG(c1) FROM test +---- +1.75 2.75 3.75 2.75 + +# case_sensitive_identifiers_aggregates +query I +SELECT max(c1) FROM test; +---- +3 + + + +# count_basic +query II +SELECT COUNT(c1), COUNT(c2) FROM test +---- +4 4 # TODO: count_partitioned @@ -1364,9 +1486,59 @@ select avg(c1), arrow_typeof(avg(c1)) from d_table # TODO: count_aggregated_cube -# TODO: simple_avg +# count_multi_expr +query I +SELECT count(c1, c2) FROM test +---- +3 + +# count_null +query III +SELECT count(null), count(null, null), count(distinct null) FROM test +---- +0 0 0 + +# count_multi_expr_group_by +query I +SELECT count(c1, c2) FROM test group by c1 order by c1 +---- +0 +1 +2 +0 + +# count_null_group_by +query III +SELECT count(null), count(null, null), count(distinct null) FROM test group by c1 order by c1 +---- +0 0 0 +0 0 0 +0 0 0 +0 0 0 + +# aggreggte_with_alias +query II +select c1, sum(c2) as `Total Salary` from test group by c1 order by c1 +---- +0 NULL +1 1 +3 4 +NULL 1 + +# simple_avg + +query R +select avg(c1) from test +---- +1.75 + +# simple_mean +query R +select mean(c1) from test +---- +1.75 + -# TODO: simple_mean # query_sum_distinct - 2 different aggregate functions: avg and sum(distinct) query RI @@ -1396,7 +1568,7 @@ SELECT COUNT(DISTINCT c1) FROM test query ? SELECT ARRAY_AGG([]) ---- -[] +[[]] # array_agg_one query ? @@ -1419,7 +1591,7 @@ e 4 query ? SELECT ARRAY_AGG([]); ---- -[] +[[]] # array_agg_one query ? @@ -2294,6 +2466,15 @@ select max(x_dict) from value_dict group by x_dict % 2 order by max(x_dict); 4 5 +query T +select arrow_typeof(x_dict) from value_dict group by x_dict; +---- +Int32 +Int32 +Int32 +Int32 +Int32 + statement ok drop table value @@ -2523,6 +2704,204 @@ NULL 0 0 b 0 0 c 1 1 +# +# Push limit into distinct group-by aggregation tests +# + +# Make results deterministic +statement ok +set datafusion.optimizer.repartition_aggregations = false; + +# +query TT +EXPLAIN SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 limit 5; +---- +logical_plan +Limit: skip=0, fetch=5 +--Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]] +----Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]] +------TableScan: aggregate_test_100 projection=[c3] +physical_plan +GlobalLimitExec: skip=0, fetch=5 +--AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[], lim=[5] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[], lim=[5] +--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[], lim=[5] +------------CoalescePartitionsExec +--------------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[], lim=[5] +----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3], has_header=true + +query I +SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 limit 5; +---- +1 +-40 +29 +-85 +-82 + +query TT +EXPLAIN SELECT c2, c3 FROM aggregate_test_100 group by c2, c3 limit 5 offset 4; +---- +logical_plan +Limit: skip=4, fetch=5 +--Aggregate: groupBy=[[aggregate_test_100.c2, aggregate_test_100.c3]], aggr=[[]] +----TableScan: aggregate_test_100 projection=[c2, c3] +physical_plan +GlobalLimitExec: skip=4, fetch=5 +--AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[9] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[9] +--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], has_header=true + +query II +SELECT c2, c3 FROM aggregate_test_100 group by c2, c3 limit 5 offset 4; +---- +5 -82 +4 -111 +3 104 +3 13 +1 38 + +# The limit should only apply to the aggregations which group by c3 +query TT +EXPLAIN SELECT DISTINCT c3 FROM aggregate_test_100 WHERE c3 between 10 and 20 group by c2, c3 limit 4; +---- +logical_plan +Limit: skip=0, fetch=4 +--Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]] +----Projection: aggregate_test_100.c3 +------Aggregate: groupBy=[[aggregate_test_100.c2, aggregate_test_100.c3]], aggr=[[]] +--------Filter: aggregate_test_100.c3 >= Int16(10) AND aggregate_test_100.c3 <= Int16(20) +----------TableScan: aggregate_test_100 projection=[c2, c3], partial_filters=[aggregate_test_100.c3 >= Int16(10), aggregate_test_100.c3 <= Int16(20)] +physical_plan +GlobalLimitExec: skip=0, fetch=4 +--AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[], lim=[4] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[], lim=[4] +--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------ProjectionExec: expr=[c3@1 as c3] +------------AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[] +--------------CoalescePartitionsExec +----------------AggregateExec: mode=Partial, gby=[c2@0 as c2, c3@1 as c3], aggr=[] +------------------CoalesceBatchesExec: target_batch_size=8192 +--------------------FilterExec: c3@1 >= 10 AND c3@1 <= 20 +----------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], has_header=true + +query I +SELECT DISTINCT c3 FROM aggregate_test_100 WHERE c3 between 10 and 20 group by c2, c3 limit 4; +---- +13 +17 +12 +14 + +# An aggregate expression causes the limit to not be pushed to the aggregation +query TT +EXPLAIN SELECT max(c1), c2, c3 FROM aggregate_test_100 group by c2, c3 limit 5; +---- +logical_plan +Projection: MAX(aggregate_test_100.c1), aggregate_test_100.c2, aggregate_test_100.c3 +--Limit: skip=0, fetch=5 +----Aggregate: groupBy=[[aggregate_test_100.c2, aggregate_test_100.c3]], aggr=[[MAX(aggregate_test_100.c1)]] +------TableScan: aggregate_test_100 projection=[c1, c2, c3] +physical_plan +ProjectionExec: expr=[MAX(aggregate_test_100.c1)@2 as MAX(aggregate_test_100.c1), c2@0 as c2, c3@1 as c3] +--GlobalLimitExec: skip=0, fetch=5 +----AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[MAX(aggregate_test_100.c1)] +------CoalescePartitionsExec +--------AggregateExec: mode=Partial, gby=[c2@1 as c2, c3@2 as c3], aggr=[MAX(aggregate_test_100.c1)] +----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3], has_header=true + +# TODO(msirek): Extend checking in LimitedDistinctAggregation equal groupings to ignore the order of columns +# in the group-by column lists, so the limit could be pushed to the lowest AggregateExec in this case +query TT +EXPLAIN SELECT DISTINCT c3, c2 FROM aggregate_test_100 group by c2, c3 limit 3 offset 10; +---- +logical_plan +Limit: skip=10, fetch=3 +--Aggregate: groupBy=[[aggregate_test_100.c3, aggregate_test_100.c2]], aggr=[[]] +----Projection: aggregate_test_100.c3, aggregate_test_100.c2 +------Aggregate: groupBy=[[aggregate_test_100.c2, aggregate_test_100.c3]], aggr=[[]] +--------TableScan: aggregate_test_100 projection=[c2, c3] +physical_plan +GlobalLimitExec: skip=10, fetch=3 +--AggregateExec: mode=Final, gby=[c3@0 as c3, c2@1 as c2], aggr=[], lim=[13] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[c3@0 as c3, c2@1 as c2], aggr=[], lim=[13] +--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------ProjectionExec: expr=[c3@1 as c3, c2@0 as c2] +------------AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[] +--------------CoalescePartitionsExec +----------------AggregateExec: mode=Partial, gby=[c2@0 as c2, c3@1 as c3], aggr=[] +------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], has_header=true + +query II +SELECT DISTINCT c3, c2 FROM aggregate_test_100 group by c2, c3 limit 3 offset 10; +---- +57 1 +-54 4 +112 3 + +query TT +EXPLAIN SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3; +---- +logical_plan +Limit: skip=0, fetch=3 +--Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, aggregate_test_100.c3)]], aggr=[[]] +----TableScan: aggregate_test_100 projection=[c2, c3] +physical_plan +GlobalLimitExec: skip=0, fetch=3 +--AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[3] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[(NULL as c2, NULL as c3), (c2@0 as c2, NULL as c3), (c2@0 as c2, c3@1 as c3)], aggr=[] +--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], has_header=true + +query II +SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3; +---- +NULL NULL +2 NULL +5 NULL + + +statement ok +set datafusion.optimizer.enable_distinct_aggregation_soft_limit = false; + +# The limit should not be pushed into the aggregations +query TT +EXPLAIN SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 limit 5; +---- +logical_plan +Limit: skip=0, fetch=5 +--Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]] +----Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]] +------TableScan: aggregate_test_100 projection=[c3] +physical_plan +GlobalLimitExec: skip=0, fetch=5 +--AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[] +--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[] +------------CoalescePartitionsExec +--------------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[] +----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3], has_header=true + +statement ok +set datafusion.optimizer.enable_distinct_aggregation_soft_limit = true; + +statement ok +set datafusion.optimizer.repartition_aggregations = true; + # # regr_*() tests # @@ -2789,3 +3168,92 @@ NULL NULL 1 NULL 3 6 0 0 0 NULL NULL 1 NULL 5 15 0 0 0 3 0 2 1 5.5 16.5 0.5 4.5 1.5 3 0 3 1 6 18 2 18 6 + +statement error +SELECT STRING_AGG() + +statement error +SELECT STRING_AGG(1,2,3) + +statement error +SELECT STRING_AGG(STRING_AGG('a', ',')) + +query T +SELECT STRING_AGG('a', ',') +---- +a + +query TTTT +SELECT STRING_AGG('a',','), STRING_AGG('a', NULL), STRING_AGG(NULL, ','), STRING_AGG(NULL, NULL) +---- +a a NULL NULL + +query TT +select string_agg('', '|'), string_agg('a', ''); +---- +(empty) a + +query T +SELECT STRING_AGG(column1, '|') FROM (values (''), (null), ('')); +---- +| + +statement ok +CREATE TABLE strings(g INTEGER, x VARCHAR, y VARCHAR) + +query ITT +INSERT INTO strings VALUES (1,'a','/'), (1,'b','-'), (2,'i','/'), (2,NULL,'-'), (2,'j','+'), (3,'p','/'), (4,'x','/'), (4,'y','-'), (4,'z','+') +---- +9 + +query IT +SELECT g, STRING_AGG(x,'|') FROM strings GROUP BY g ORDER BY g +---- +1 a|b +2 i|j +3 p +4 x|y|z + +query T +SELECT STRING_AGG(x,',') FROM strings WHERE g > 100 +---- +NULL + +statement ok +drop table strings + +query T +WITH my_data as ( +SELECT 'text1'::varchar(1000) as my_column union all +SELECT 'text1'::varchar(1000) as my_column union all +SELECT 'text1'::varchar(1000) as my_column +) +SELECT string_agg(my_column,', ') as my_string_agg +FROM my_data +---- +text1, text1, text1 + +query T +WITH my_data as ( +SELECT 1 as dummy, 'text1'::varchar(1000) as my_column union all +SELECT 1 as dummy, 'text1'::varchar(1000) as my_column union all +SELECT 1 as dummy, 'text1'::varchar(1000) as my_column +) +SELECT string_agg(my_column,', ') as my_string_agg +FROM my_data +GROUP BY dummy +---- +text1, text1, text1 + + +# Queries with nested count(*) + +query I +select count(*) from (select count(*) from (select 1)); +---- +1 + +query I +select count(*) from (select count(*) a, count(*) b from (select 1)); +---- +1 diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index b5601a22226c0..7cee615a5729b 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -67,6 +67,16 @@ AS VALUES (make_array(make_array(15, 16),make_array(NULL, 18)), make_array(16.6, 17.7, 18.8), NULL) ; +statement ok +CREATE TABLE large_arrays +AS + SELECT + arrow_cast(column1, 'LargeList(List(Int64))') AS column1, + arrow_cast(column2, 'LargeList(Float64)') AS column2, + arrow_cast(column3, 'LargeList(Utf8)') AS column3 + FROM arrays +; + statement ok CREATE TABLE slices AS VALUES @@ -97,6 +107,19 @@ AS VALUES (make_array(make_array(4, 5, 6), make_array(10, 11, 12), make_array(4, 9, 8), make_array(7, 8, 9), make_array(10, 11, 12), make_array(1, 8, 7)), make_array(10, 11, 12), 3, make_array([[11, 12, 13], [14, 15, 16]], [[17, 18, 19], [20, 21, 22]]), make_array(121, 131, 141)) ; +# TODO: add this when #8305 is fixed +# statement ok +# CREATE TABLE large_nested_arrays +# AS +# SELECT +# arrow_cast(column1, 'LargeList(LargeList(Int64))') AS column1, +# arrow_cast(column2, 'LargeList(Int64)') AS column2, +# column3, +# arrow_cast(column4, 'LargeList(LargeList(List(Int64)))') AS column4, +# arrow_cast(column5, 'LargeList(Int64)') AS column5 +# FROM nested_arrays +# ; + statement ok CREATE TABLE arrays_values AS VALUES @@ -110,6 +133,17 @@ AS VALUES (make_array(61, 62, 63, 64, 65, 66, 67, 68, 69, 70), 66, 7, NULL) ; +statement ok +CREATE TABLE large_arrays_values +AS SELECT + arrow_cast(column1, 'LargeList(Int64)') AS column1, + column2, + column3, + column4 +FROM arrays_values +; + + statement ok CREATE TABLE arrays_values_v2 AS VALUES @@ -121,6 +155,17 @@ AS VALUES (NULL, NULL, NULL, NULL) ; +# TODO: add this when #8305 is fixed +# statement ok +# CREATE TABLE large_arrays_values_v2 +# AS SELECT +# arrow_cast(column1, 'LargeList(Int64)') AS column1, +# arrow_cast(column2, 'LargeList(Int64)') AS column2, +# column3, +# arrow_cast(column4, 'LargeList(LargeList(Int64))') AS column4 +# FROM arrays_values_v2 +# ; + statement ok CREATE TABLE flatten_table AS VALUES @@ -182,6 +227,168 @@ AS VALUES (make_array([[1], [2]], [[2], [3]]), make_array([1], [2])) ; +statement ok +CREATE TABLE array_distinct_table_1D +AS VALUES + (make_array(1, 1, 2, 2, 3)), + (make_array(1, 2, 3, 4, 5)), + (make_array(3, 5, 3, 3, 3)) +; + +statement ok +CREATE TABLE array_distinct_table_1D_UTF8 +AS VALUES + (make_array('a', 'a', 'bc', 'bc', 'def')), + (make_array('a', 'bc', 'def', 'defg', 'defg')), + (make_array('defg', 'defg', 'defg', 'defg', 'defg')) +; + +statement ok +CREATE TABLE array_distinct_table_2D +AS VALUES + (make_array([1,2], [1,2], [3,4], [3,4], [5,6])), + (make_array([1,2], [3,4], [5,6], [7,8], [9,10])), + (make_array([5,6], [5,6], NULL)) +; + +statement ok +CREATE TABLE array_distinct_table_1D_large +AS VALUES + (arrow_cast(make_array(1, 1, 2, 2, 3), 'LargeList(Int64)')), + (arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), + (arrow_cast(make_array(3, 5, 3, 3, 3), 'LargeList(Int64)')) +; + +statement ok +CREATE TABLE array_intersect_table_1D +AS VALUES + (make_array(1, 2), make_array(1), make_array(1,2,3), make_array(1,3), make_array(1,3,5), make_array(2,4,6,8,1,3)), + (make_array(11, 22), make_array(11), make_array(11,22,33), make_array(11,33), make_array(11,33,55), make_array(22,44,66,88,11,33)) +; + +statement ok +CREATE TABLE large_array_intersect_table_1D +AS + SELECT + arrow_cast(column1, 'LargeList(Int64)') as column1, + arrow_cast(column2, 'LargeList(Int64)') as column2, + arrow_cast(column3, 'LargeList(Int64)') as column3, + arrow_cast(column4, 'LargeList(Int64)') as column4, + arrow_cast(column5, 'LargeList(Int64)') as column5, + arrow_cast(column6, 'LargeList(Int64)') as column6 +FROM array_intersect_table_1D +; + +statement ok +CREATE TABLE array_intersect_table_1D_Float +AS VALUES + (make_array(1.0, 2.0), make_array(1.0), make_array(1.0,2.0,3.0), make_array(1.0,3.0), make_array(1.11), make_array(2.22, 3.33)), + (make_array(3.0, 4.0, 5.0), make_array(2.0), make_array(1.0,2.0,3.0,4.0), make_array(2.0,5.0), make_array(2.22, 1.11), make_array(1.11, 3.33)) +; + +statement ok +CREATE TABLE large_array_intersect_table_1D_Float +AS + SELECT + arrow_cast(column1, 'LargeList(Float64)') as column1, + arrow_cast(column2, 'LargeList(Float64)') as column2, + arrow_cast(column3, 'LargeList(Float64)') as column3, + arrow_cast(column4, 'LargeList(Float64)') as column4, + arrow_cast(column5, 'LargeList(Float64)') as column5, + arrow_cast(column6, 'LargeList(Float64)') as column6 +FROM array_intersect_table_1D_Float +; + +statement ok +CREATE TABLE array_intersect_table_1D_Boolean +AS VALUES + (make_array(true, true, true), make_array(false), make_array(true, true, false, true, false), make_array(true, false, true), make_array(false), make_array(true, false)), + (make_array(false, false, false), make_array(false), make_array(true, false, true), make_array(true, true), make_array(true, true), make_array(false,false,true)) +; + +statement ok +CREATE TABLE large_array_intersect_table_1D_Boolean +AS + SELECT + arrow_cast(column1, 'LargeList(Boolean)') as column1, + arrow_cast(column2, 'LargeList(Boolean)') as column2, + arrow_cast(column3, 'LargeList(Boolean)') as column3, + arrow_cast(column4, 'LargeList(Boolean)') as column4, + arrow_cast(column5, 'LargeList(Boolean)') as column5, + arrow_cast(column6, 'LargeList(Boolean)') as column6 +FROM array_intersect_table_1D_Boolean +; + +statement ok +CREATE TABLE array_intersect_table_1D_UTF8 +AS VALUES + (make_array('a', 'bc', 'def'), make_array('bc'), make_array('datafusion', 'rust', 'arrow'), make_array('rust', 'arrow'), make_array('rust', 'arrow', 'python'), make_array('data')), + (make_array('a', 'bc', 'def'), make_array('defg'), make_array('datafusion', 'rust', 'arrow'), make_array('datafusion', 'rust', 'arrow', 'python'), make_array('rust', 'arrow'), make_array('datafusion', 'rust', 'arrow')) +; + +statement ok +CREATE TABLE large_array_intersect_table_1D_UTF8 +AS + SELECT + arrow_cast(column1, 'LargeList(Utf8)') as column1, + arrow_cast(column2, 'LargeList(Utf8)') as column2, + arrow_cast(column3, 'LargeList(Utf8)') as column3, + arrow_cast(column4, 'LargeList(Utf8)') as column4, + arrow_cast(column5, 'LargeList(Utf8)') as column5, + arrow_cast(column6, 'LargeList(Utf8)') as column6 +FROM array_intersect_table_1D_UTF8 +; + +statement ok +CREATE TABLE array_intersect_table_2D +AS VALUES + (make_array([1,2]), make_array([1,3]), make_array([1,2,3], [4,5], [6,7]), make_array([4,5], [6,7])), + (make_array([3,4], [5]), make_array([3,4]), make_array([1,2,3,4], [5,6,7], [8,9,10]), make_array([1,2,3], [5,6,7], [8,9,10])) +; + +statement ok +CREATE TABLE large_array_intersect_table_2D +AS + SELECT + arrow_cast(column1, 'LargeList(List(Int64))') as column1, + arrow_cast(column2, 'LargeList(List(Int64))') as column2, + arrow_cast(column3, 'LargeList(List(Int64))') as column3, + arrow_cast(column4, 'LargeList(List(Int64))') as column4 +FROM array_intersect_table_2D +; + +statement ok +CREATE TABLE array_intersect_table_2D_float +AS VALUES + (make_array([1.0, 2.0, 3.0], [1.1, 2.2], [3.3]), make_array([1.1, 2.2], [3.3])), + (make_array([1.0, 2.0, 3.0], [1.1, 2.2], [3.3]), make_array([1.0], [1.1, 2.2], [3.3])) +; + +statement ok +CREATE TABLE large_array_intersect_table_2D_Float +AS + SELECT + arrow_cast(column1, 'LargeList(List(Float64))') as column1, + arrow_cast(column2, 'LargeList(List(Float64))') as column2 +FROM array_intersect_table_2D_Float +; + +statement ok +CREATE TABLE array_intersect_table_3D +AS VALUES + (make_array([[1,2]]), make_array([[1]])), + (make_array([[1,2]]), make_array([[1,2]])) +; + +statement ok +CREATE TABLE large_array_intersect_table_3D +AS + SELECT + arrow_cast(column1, 'LargeList(List(List(Int64)))') as column1, + arrow_cast(column2, 'LargeList(List(List(Int64)))') as column2 +FROM array_intersect_table_3D +; + statement ok CREATE TABLE arrays_values_without_nulls AS VALUES @@ -191,6 +398,24 @@ AS VALUES (make_array(31, 32, 33, 34, 35, 26, 37, 38, 39, 40), 34, 4, 'ok', [8,9]) ; +statement ok +CREATE TABLE large_arrays_values_without_nulls +AS SELECT + arrow_cast(column1, 'LargeList(Int64)') AS column1, + column2, + column3, + column4, + arrow_cast(column5, 'LargeList(Int64)') AS column5 +FROM arrays_values_without_nulls +; + +statement ok +CREATE TABLE arrays_range +AS VALUES + (3, 10, 2), + (4, 13, 3) +; + statement ok CREATE TABLE arrays_with_repeating_elements AS VALUES @@ -200,6 +425,17 @@ AS VALUES (make_array(10, 11, 12, 10, 11, 12, 10, 11, 12, 10), 10, 13, 10) ; +statement ok +CREATE TABLE large_arrays_with_repeating_elements +AS + SELECT + arrow_cast(column1, 'LargeList(Int64)') AS column1, + column2, + column3, + column4 + FROM arrays_with_repeating_elements +; + statement ok CREATE TABLE nested_arrays_with_repeating_elements AS VALUES @@ -209,6 +445,23 @@ AS VALUES (make_array([28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]), [28, 29, 30], [37, 38, 39], 10) ; +statement ok +CREATE TABLE large_nested_arrays_with_repeating_elements +AS + SELECT + arrow_cast(column1, 'LargeList(List(Int64))') AS column1, + column2, + column3, + column4 + FROM nested_arrays_with_repeating_elements +; + +query error +select [1, true, null] + +query error DataFusion error: This feature is not implemented: ScalarFunctions without MakeArray are not supported: now() +SELECT [now()] + query TTT select arrow_typeof(column1), arrow_typeof(column2), arrow_typeof(column3) from arrays; ---- @@ -623,7 +876,7 @@ from arrays_values_without_nulls; ## array_element (aliases: array_extract, list_extract, list_element) # array_element error -query error DataFusion error: Error during planning: The array_element function can only accept list as the first argument +query error DataFusion error: Error during planning: The array_element function can only accept list or largelist as the first argument select array_element(1, 2); @@ -633,58 +886,106 @@ select array_element(make_array(1, 2, 3, 4, 5), 2), array_element(make_array('h' ---- 2 l +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3); +---- +2 l + # array_element scalar function #2 (with positive index; out of bounds) query IT select array_element(make_array(1, 2, 3, 4, 5), 7), array_element(make_array('h', 'e', 'l', 'l', 'o'), 11); ---- NULL NULL +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 7), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 11); +---- +NULL NULL + # array_element scalar function #3 (with zero) query IT select array_element(make_array(1, 2, 3, 4, 5), 0), array_element(make_array('h', 'e', 'l', 'l', 'o'), 0); ---- NULL NULL +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0); +---- +NULL NULL + # array_element scalar function #4 (with NULL) -query error +query error select array_element(make_array(1, 2, 3, 4, 5), NULL), array_element(make_array('h', 'e', 'l', 'l', 'o'), NULL); +query error +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL); + # array_element scalar function #5 (with negative index) query IT select array_element(make_array(1, 2, 3, 4, 5), -2), array_element(make_array('h', 'e', 'l', 'l', 'o'), -3); ---- 4 l +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -2), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3); +---- +4 l + # array_element scalar function #6 (with negative index; out of bounds) query IT select array_element(make_array(1, 2, 3, 4, 5), -11), array_element(make_array('h', 'e', 'l', 'l', 'o'), -7); ---- NULL NULL +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -11), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -7); +---- +NULL NULL + # array_element scalar function #7 (nested array) query ? select array_element(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 1); ---- [1, 2, 3, 4, 5] +query ? +select array_element(arrow_cast(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 'LargeList(List(Int64))'), 1); +---- +[1, 2, 3, 4, 5] + # array_extract scalar function #8 (function alias `array_slice`) query IT select array_extract(make_array(1, 2, 3, 4, 5), 2), array_extract(make_array('h', 'e', 'l', 'l', 'o'), 3); ---- 2 l +query IT +select array_extract(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_extract(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3); +---- +2 l + # list_element scalar function #9 (function alias `array_slice`) query IT select list_element(make_array(1, 2, 3, 4, 5), 2), list_element(make_array('h', 'e', 'l', 'l', 'o'), 3); ---- 2 l +query IT +select list_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_extract(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3); +---- +2 l + # list_extract scalar function #10 (function alias `array_slice`) query IT select list_extract(make_array(1, 2, 3, 4, 5), 2), list_extract(make_array('h', 'e', 'l', 'l', 'o'), 3); ---- 2 l +query IT +select list_extract(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_extract(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3); +---- +2 l + # array_element with columns query I select array_element(column1, column2) from slices; @@ -697,6 +998,17 @@ NULL NULL 55 +query I +select array_element(arrow_cast(column1, 'LargeList(Int64)'), column2) from slices; +---- +NULL +12 +NULL +37 +NULL +NULL +55 + # array_element with columns and scalars query II select array_element(make_array(1, 2, 3, 4, 5), column2), array_element(column1, 3) from slices; @@ -709,6 +1021,17 @@ NULL 23 NULL 43 5 NULL +query II +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), column2), array_element(arrow_cast(column1, 'LargeList(Int64)'), 3) from slices; +---- +1 3 +2 13 +NULL 23 +2 33 +4 NULL +NULL 43 +5 NULL + ## array_pop_back (aliases: `list_pop_back`) # array_pop_back scalar function #1 @@ -717,18 +1040,33 @@ select array_pop_back(make_array(1, 2, 3, 4, 5)), array_pop_back(make_array('h', ---- [1, 2, 3, 4] [h, e, l, l] +query ?? +select array_pop_back(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), array_pop_back(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)')); +---- +[1, 2, 3, 4] [h, e, l, l] + # array_pop_back scalar function #2 (after array_pop_back, array is empty) query ? select array_pop_back(make_array(1)); ---- [] +query ? +select array_pop_back(arrow_cast(make_array(1), 'LargeList(Int64)')); +---- +[] + # array_pop_back scalar function #3 (array_pop_back the empty array) query ? select array_pop_back(array_pop_back(make_array(1))); ---- [] +query ? +select array_pop_back(array_pop_back(arrow_cast(make_array(1), 'LargeList(Int64)'))); +---- +[] + # array_pop_back scalar function #4 (array_pop_back the arrays which have NULL) query ?? select array_pop_back(make_array(1, 2, 3, 4, NULL)), array_pop_back(make_array(NULL, 'e', 'l', NULL, 'o')); @@ -741,24 +1079,44 @@ select array_pop_back(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_ ---- [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4]] +query ? +select array_pop_back(arrow_cast(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4), make_array(4, 5, 6)), 'LargeList(List(Int64))')); +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4]] + # array_pop_back scalar function #6 (array_pop_back the nested arrays with NULL) query ? select array_pop_back(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4), NULL)); ---- [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4]] +query ? +select array_pop_back(arrow_cast(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4), NULL), 'LargeList(List(Int64))')); +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4]] + # array_pop_back scalar function #7 (array_pop_back the nested arrays with NULL) query ? select array_pop_back(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), NULL, make_array(1, 7, 4))); ---- [[1, 2, 3], [2, 9, 1], [7, 8, 9], ] +query ? +select array_pop_back(arrow_cast(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), NULL, make_array(1, 7, 4)), 'LargeList(List(Int64))')); +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], ] + # array_pop_back scalar function #8 (after array_pop_back, nested array is empty) query ? select array_pop_back(make_array(make_array(1, 2, 3))); ---- [] +query ? +select array_pop_back(arrow_cast(make_array(make_array(1, 2, 3)), 'LargeList(List(Int64))')); +---- +[] + # array_pop_back with columns query ? select array_pop_back(column1) from arrayspop; @@ -770,6 +1128,84 @@ select array_pop_back(column1) from arrayspop; [] [, 10, 11] +query ? +select array_pop_back(arrow_cast(column1, 'LargeList(Int64)')) from arrayspop; +---- +[1, 2] +[3, 4, 5] +[6, 7, 8, ] +[, ] +[] +[, 10, 11] + +## array_pop_front (aliases: `list_pop_front`) + +# array_pop_front scalar function #1 +query ?? +select array_pop_front(make_array(1, 2, 3, 4, 5)), array_pop_front(make_array('h', 'e', 'l', 'l', 'o')); +---- +[2, 3, 4, 5] [e, l, l, o] + +query ?? +select array_pop_front(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), array_pop_front(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)')); +---- +[2, 3, 4, 5] [e, l, l, o] + +# array_pop_front scalar function #2 (after array_pop_front, array is empty) +query ? +select array_pop_front(make_array(1)); +---- +[] + +query ? +select array_pop_front(arrow_cast(make_array(1), 'LargeList(Int64)')); +---- +[] + +# array_pop_front scalar function #3 (array_pop_front the empty array) +query ? +select array_pop_front(array_pop_front(make_array(1))); +---- +[] + +query ? +select array_pop_front(array_pop_front(arrow_cast(make_array(1), 'LargeList(Int64)'))); +---- +[] + +# array_pop_front scalar function #5 (array_pop_front the nested arrays) +query ? +select array_pop_front(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4), make_array(4, 5, 6))); +---- +[[2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] + +query ? +select array_pop_front(arrow_cast(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4), make_array(4, 5, 6)), 'LargeList(List(Int64))')); +---- +[[2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] + +# array_pop_front scalar function #6 (array_pop_front the nested arrays with NULL) +query ? +select array_pop_front(make_array(NULL, make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4))); +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4]] + +query ? +select array_pop_front(arrow_cast(make_array(NULL, make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4)), 'LargeList(List(Int64))')); +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4]] + +# array_pop_front scalar function #8 (after array_pop_front, nested array is empty) +query ? +select array_pop_front(make_array(make_array(1, 2, 3))); +---- +[] + +query ? +select array_pop_front(arrow_cast(make_array(make_array(1, 2, 3)), 'LargeList(List(Int64))')); +---- +[] + ## array_slice (aliases: list_slice) # array_slice scalar function #1 (with positive indexes) @@ -778,109 +1214,201 @@ select array_slice(make_array(1, 2, 3, 4, 5), 2, 4), array_slice(make_array('h', ---- [2, 3, 4] [h, e] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 1, 2); +---- +[2, 3, 4] [h, e] + # array_slice scalar function #2 (with positive indexes; full array) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 0, 6), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, 5); ---- [1, 2, 3, 4, 5] [h, e, l, l, o] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0, 6), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0, 5); +---- +[1, 2, 3, 4, 5] [h, e, l, l, o] + # array_slice scalar function #3 (with positive indexes; first index = second index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 4, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, 3); ---- [4] [l] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 4, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3, 3); +---- +[4] [l] + # array_slice scalar function #4 (with positive indexes; first index > second_index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 2, 1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 4, 1); ---- [] [] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, 1), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 4, 1); +---- +[] [] + # array_slice scalar function #5 (with positive indexes; out of bounds) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 2, 6), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, 7); ---- [2, 3, 4, 5] [l, l, o] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, 6), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3, 7); +---- +[2, 3, 4, 5] [l, l, o] + # array_slice scalar function #6 (with positive indexes; nested array) query ? select array_slice(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 1, 1); ---- [[1, 2, 3, 4, 5]] +query ? +select array_slice(arrow_cast(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 'LargeList(List(Int64))'), 1, 1); +---- +[[1, 2, 3, 4, 5]] + # array_slice scalar function #7 (with zero and positive number) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 0, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, 3); ---- [1, 2, 3, 4] [h, e, l] -# array_slice scalar function #8 (with NULL and positive number) -query error -select array_slice(make_array(1, 2, 3, 4, 5), NULL, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, 3); +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0, 3); +---- +[1, 2, 3, 4] [h, e, l] + +# array_slice scalar function #8 (with NULL and positive number) +query error +select array_slice(make_array(1, 2, 3, 4, 5), NULL, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, 3); + +query error +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL, 3); # array_slice scalar function #9 (with positive number and NULL) -query error +query error select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL); +query error +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, NULL), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3, NULL); + # array_slice scalar function #10 (with zero-zero) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 0, 0), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, 0); ---- [] [] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0, 0), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0, 0); +---- +[] [] + # array_slice scalar function #11 (with NULL-NULL) -query error +query error select array_slice(make_array(1, 2, 3, 4, 5), NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL); +query error +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL); + + # array_slice scalar function #12 (with zero and negative number) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 0, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, -3); ---- [1] [h, e] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0, -4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0, -3); +---- +[1] [h, e] + # array_slice scalar function #13 (with negative number and NULL) -query error -select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL); +query error +select array_slice(make_array(1, 2, 3, 4, 5), -2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, NULL); + +query error +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -2, NULL), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, NULL); # array_slice scalar function #14 (with NULL and negative number) -query error +query error select array_slice(make_array(1, 2, 3, 4, 5), NULL, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, -3); +query error +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL, -4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL, -3); + # array_slice scalar function #15 (with negative indexes) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -4, -1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -1); ---- [2, 3, 4] [l, l] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -4, -1), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, -1); +---- +[2, 3, 4] [l, l] + # array_slice scalar function #16 (with negative indexes; almost full array (only with negative indices cannot return full array)) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -5, -1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -5, -1); ---- [1, 2, 3, 4] [h, e, l, l] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -5, -1), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -5, -1); +---- +[1, 2, 3, 4] [h, e, l, l] + # array_slice scalar function #17 (with negative indexes; first index = second index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -4, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -3); ---- [] [] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -4, -4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, -3); +---- +[] [] + # array_slice scalar function #18 (with negative indexes; first index > second_index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -4, -6), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -6); ---- [] [] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -4, -6), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, -6); +---- +[] [] + # array_slice scalar function #19 (with negative indexes; out of bounds) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -7, -2), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -7, -3); ---- [] [] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -7, -2), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -7, -3); +---- +[] [] + # array_slice scalar function #20 (with negative indexes; nested array) -query ? -select array_slice(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), -2, -1); +query ?? +select array_slice(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), -2, -1), array_slice(make_array(make_array(1, 2, 3), make_array(6, 7, 8)), -1, -1); ---- -[[1, 2, 3, 4, 5]] +[[1, 2, 3, 4, 5]] [] + +query ?? +select array_slice(arrow_cast(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 'LargeList(List(Int64))'), -2, -1), array_slice(arrow_cast(make_array(make_array(1, 2, 3), make_array(6, 7, 8)), 'LargeList(List(Int64))'), -1, -1); +---- +[[1, 2, 3, 4, 5]] [] + # array_slice scalar function #21 (with first positive index and last negative index) query ?? @@ -888,18 +1416,33 @@ select array_slice(make_array(1, 2, 3, 4, 5), 2, -3), array_slice(make_array('h' ---- [2] [e, l] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, -3), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 2, -2); +---- +[2] [e, l] + # array_slice scalar function #22 (with first negative index and last positive index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -2, 5), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, 4); ---- [4, 5] [l, l] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -2, 5), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, 4); +---- +[4, 5] [l, l] + # list_slice scalar function #23 (function alias `array_slice`) query ?? select list_slice(make_array(1, 2, 3, 4, 5), 2, 4), list_slice(make_array('h', 'e', 'l', 'l', 'o'), 1, 2); ---- [2, 3, 4] [h, e] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 1, 2); +---- +[2, 3, 4] [h, e] + # array_slice with columns query ? select array_slice(column1, column2, column3) from slices; @@ -912,6 +1455,17 @@ select array_slice(column1, column2, column3) from slices; [41, 42, 43, 44, 45, 46] [55, 56, 57, 58, 59, 60] +query ? +select array_slice(arrow_cast(column1, 'LargeList(Int64)'), column2, column3) from slices; +---- +[] +[12, 13, 14, 15, 16] +[] +[] +[] +[41, 42, 43, 44, 45, 46] +[55, 56, 57, 58, 59, 60] + # TODO: support NULLS in output instead of `[]` # array_slice with columns and scalars query ??? @@ -925,6 +1479,17 @@ select array_slice(make_array(1, 2, 3, 4, 5), column2, column3), array_slice(col [1, 2, 3, 4, 5] [43, 44, 45, 46] [41, 42, 43, 44, 45] [5] [, 54, 55, 56, 57, 58, 59, 60] [55] +query ??? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), column2, column3), array_slice(arrow_cast(column1, 'LargeList(Int64)'), 3, column3), array_slice(arrow_cast(column1, 'LargeList(Int64)'), column2, 5) from slices; +---- +[1] [] [, 2, 3, 4, 5] +[] [13, 14, 15, 16] [12, 13, 14, 15] +[] [] [21, 22, 23, , 25] +[] [33] [] +[4, 5] [] [] +[1, 2, 3, 4, 5] [43, 44, 45, 46] [41, 42, 43, 44, 45] +[5] [, 54, 55, 56, 57, 58, 59, 60] [55] + # make_array with nulls query ??????? select make_array(make_array('a','b'), null), @@ -950,20 +1515,102 @@ select make_array(['a','b'], null); ---- [[a, b], ] +## array_sort (aliases: `list_sort`) +query ??? +select array_sort(make_array(1, 3, null, 5, NULL, -5)), array_sort(make_array(1, 3, null, 2), 'ASC'), array_sort(make_array(1, 3, null, 2), 'desc', 'NULLS FIRST'); +---- +[, , -5, 1, 3, 5] [, 1, 2, 3] [, 3, 2, 1] + +query ? +select array_sort(column1, 'DESC', 'NULLS LAST') from arrays_values; +---- +[10, 9, 8, 7, 6, 5, 4, 3, 2, ] +[20, 18, 17, 16, 15, 14, 13, 12, 11, ] +[30, 29, 28, 27, 26, 25, 23, 22, 21, ] +[40, 39, 38, 37, 35, 34, 33, 32, 31, ] +NULL +[50, 49, 48, 47, 46, 45, 44, 43, 42, 41] +[60, 59, 58, 57, 56, 55, 54, 52, 51, ] +[70, 69, 68, 67, 66, 65, 64, 63, 62, 61] + +query ? +select array_sort(column1, 'ASC', 'NULLS FIRST') from arrays_values; +---- +[, 2, 3, 4, 5, 6, 7, 8, 9, 10] +[, 11, 12, 13, 14, 15, 16, 17, 18, 20] +[, 21, 22, 23, 25, 26, 27, 28, 29, 30] +[, 31, 32, 33, 34, 35, 37, 38, 39, 40] +NULL +[41, 42, 43, 44, 45, 46, 47, 48, 49, 50] +[, 51, 52, 54, 55, 56, 57, 58, 59, 60] +[61, 62, 63, 64, 65, 66, 67, 68, 69, 70] + + +## list_sort (aliases: `array_sort`) +query ??? +select list_sort(make_array(1, 3, null, 5, NULL, -5)), list_sort(make_array(1, 3, null, 2), 'ASC'), list_sort(make_array(1, 3, null, 2), 'desc', 'NULLS FIRST'); +---- +[, , -5, 1, 3, 5] [, 1, 2, 3] [, 3, 2, 1] + + ## array_append (aliases: `list_append`, `array_push_back`, `list_push_back`) -# TODO: array_append with NULLs -# array_append scalar function #1 -# query ? -# select array_append(make_array(), 4); +# array_append with NULLs + +query error +select array_append(null, 1); + +query error +select array_append(null, [2, 3]); + +query error +select array_append(null, [[4]]); + +query ???? +select + array_append(make_array(), 4), + array_append(make_array(), null), + array_append(make_array(1, null, 3), 4), + array_append(make_array(null, null), 1) +; +---- +[4] [] [1, , 3, 4] [, , 1] + +# TODO: add this when #8305 is fixed +# query ???? +# select +# array_append(arrow_cast(make_array(), 'LargeList(Null)'), 4), +# array_append(make_array(), null), +# array_append(make_array(1, null, 3), 4), +# array_append(make_array(null, null), 1) +# ; # ---- -# [4] +# [4] [] [1, , 3, 4] [, , 1] + +# test invalid (non-null) +query error +select array_append(1, 2); + +query error +select array_append(1, [2]); + +query error +select array_append([1], [2]); -# array_append scalar function #2 +query ?? +select + array_append(make_array(make_array(1, null, 3)), make_array(null)), + array_append(make_array(make_array(1, null, 3)), null); +---- +[[1, , 3], []] [[1, , 3], ] + +# TODO: add this when #8305 is fixed # query ?? -# select array_append(make_array(), make_array()), array_append(make_array(), make_array(4)); +# select +# array_append(arrow_cast(make_array(make_array(1, null, 3), 'LargeList(LargeList(Int64))')), arrow_cast(make_array(null), 'LargeList(Int64)')), +# array_append(arrow_cast(make_array(make_array(1, null, 3), 'LargeList(LargeList(Int64))')), null); # ---- -# [[]] [[4]] +# [[1, , 3], []] [[1, , 3], ] # array_append scalar function #3 query ??? @@ -971,30 +1618,56 @@ select array_append(make_array(1, 2, 3), 4), array_append(make_array(1.0, 2.0, 3 ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select array_append(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4), array_append(arrow_cast(make_array(1.0, 2.0, 3.0), 'LargeList(Float64)'), 4.0), array_append(make_array('h', 'e', 'l', 'l'), 'o'); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # array_append scalar function #4 (element is list) query ??? select array_append(make_array([1], [2], [3]), make_array(4)), array_append(make_array([1.0], [2.0], [3.0]), make_array(4.0)), array_append(make_array(['h'], ['e'], ['l'], ['l']), make_array('o')); ---- [[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] +# TODO: add this when #8305 is fixed +# query ??? +# select array_append(arrow_cast(make_array([1], [2], [3]), 'LargeList(LargeList(Int64))'), arrow_cast(make_array(4), 'LargeList(Int64)')), array_append(arrow_cast(make_array([1.0], [2.0], [3.0]), 'LargeList(LargeList(Float64))'), arrow_cast(make_array(4.0), 'LargeList(Float64)')), array_append(arrow_cast(make_array(['h'], ['e'], ['l'], ['l']), 'LargeList(LargeList(Utf8))'), arrow_cast(make_array('o'), 'LargeList(Utf8)')); +# ---- +# [[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] + # list_append scalar function #5 (function alias `array_append`) query ??? select list_append(make_array(1, 2, 3), 4), list_append(make_array(1.0, 2.0, 3.0), 4.0), list_append(make_array('h', 'e', 'l', 'l'), 'o'); ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select list_append(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4), list_append(arrow_cast(make_array(1.0, 2.0, 3.0), 'LargeList(Float64)'), 4.0), list_append(make_array('h', 'e', 'l', 'l'), 'o'); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # array_push_back scalar function #6 (function alias `array_append`) query ??? select array_push_back(make_array(1, 2, 3), 4), array_push_back(make_array(1.0, 2.0, 3.0), 4.0), array_push_back(make_array('h', 'e', 'l', 'l'), 'o'); ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select array_push_back(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4), array_push_back(arrow_cast(make_array(1.0, 2.0, 3.0), 'LargeList(Float64)'), 4.0), array_push_back(make_array('h', 'e', 'l', 'l'), 'o'); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # list_push_back scalar function #7 (function alias `array_append`) query ??? select list_push_back(make_array(1, 2, 3), 4), list_push_back(make_array(1.0, 2.0, 3.0), 4.0), list_push_back(make_array('h', 'e', 'l', 'l'), 'o'); ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select list_push_back(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4), list_push_back(arrow_cast(make_array(1.0, 2.0, 3.0), 'LargeList(Float64)'), 4.0), list_push_back(make_array('h', 'e', 'l', 'l'), 'o'); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # array_append with columns #1 query ? select array_append(column1, column2) from arrays_values; @@ -1008,6 +1681,18 @@ select array_append(column1, column2) from arrays_values; [51, 52, , 54, 55, 56, 57, 58, 59, 60, 55] [61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 66] +query ? +select array_append(column1, column2) from large_arrays_values; +---- +[, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1] +[11, 12, 13, 14, 15, 16, 17, 18, , 20, 12] +[21, 22, 23, , 25, 26, 27, 28, 29, 30, 23] +[31, 32, 33, 34, 35, , 37, 38, 39, 40, 34] +[44] +[41, 42, 43, 44, 45, 46, 47, 48, 49, 50, ] +[51, 52, , 54, 55, 56, 57, 58, 59, 60, 55] +[61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 66] + # array_append with columns #2 (element is list) query ? select array_append(column1, column2) from nested_arrays; @@ -1015,6 +1700,13 @@ select array_append(column1, column2) from nested_arrays; [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [7, 8, 9]] [[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [10, 11, 12]] +# TODO: add this when #8305 is fixed +# query ? +# select array_append(column1, column2) from large_nested_arrays; +# ---- +# [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [7, 8, 9]] +# [[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [10, 11, 12]] + # array_append with columns and scalars #1 query ?? select array_append(column2, 100.1), array_append(column3, '.') from arrays; @@ -1027,6 +1719,17 @@ select array_append(column2, 100.1), array_append(column3, '.') from arrays; [100.1] [,, .] [16.6, 17.7, 18.8, 100.1] [.] +query ?? +select array_append(column2, 100.1), array_append(column3, '.') from large_arrays; +---- +[1.1, 2.2, 3.3, 100.1] [L, o, r, e, m, .] +[, 5.5, 6.6, 100.1] [i, p, , u, m, .] +[7.7, 8.8, 9.9, 100.1] [d, , l, o, r, .] +[10.1, , 12.2, 100.1] [s, i, t, .] +[13.3, 14.4, 15.5, 100.1] [a, m, e, t, .] +[100.1] [,, .] +[16.6, 17.7, 18.8, 100.1] [.] + # array_append with columns and scalars #2 query ?? select array_append(column1, make_array(1, 11, 111)), array_append(make_array(make_array(1, 2, 3), make_array(11, 12, 13)), column2) from nested_arrays; @@ -1034,20 +1737,67 @@ select array_append(column1, make_array(1, 11, 111)), array_append(make_array(ma [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [1, 11, 111]] [[1, 2, 3], [11, 12, 13], [7, 8, 9]] [[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [1, 11, 111]] [[1, 2, 3], [11, 12, 13], [10, 11, 12]] +# TODO: add this when #8305 is fixed +# query ?? +# select array_append(column1, arrow_cast(make_array(1, 11, 111), 'LargeList(Int64)')), array_append(arrow_cast(make_array(make_array(1, 2, 3), make_array(11, 12, 13)), 'LargeList(LargeList(Int64))'), column2) from large_nested_arrays; +# ---- +# [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [1, 11, 111]] [[1, 2, 3], [11, 12, 13], [7, 8, 9]] +# [[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [1, 11, 111]] [[1, 2, 3], [11, 12, 13], [10, 11, 12]] + ## array_prepend (aliases: `list_prepend`, `array_push_front`, `list_push_front`) -# TODO: array_prepend with NULLs -# array_prepend scalar function #1 -# query ? -# select array_prepend(4, make_array()); -# ---- -# [4] +# array_prepend with NULLs + +# DuckDB: [4] +# ClickHouse: Null +# Since they dont have the same result, we just follow Postgres, return error +query error +select array_prepend(4, NULL); + +query ? +select array_prepend(4, []); +---- +[4] + +query ? +select array_prepend(4, [null]); +---- +[4, ] + +# DuckDB: [null] +# ClickHouse: [null] +query ? +select array_prepend(null, []); +---- +[] + +query ? +select array_prepend(null, [1]); +---- +[, 1] + +query ? +select array_prepend(null, [[1,2,3]]); +---- +[, [1, 2, 3]] + +# DuckDB: [[]] +# ClickHouse: [[]] +# TODO: We may also return [[]] +query error +select array_prepend([], []); + +# DuckDB: [null] +# ClickHouse: [null] +# TODO: We may also return [null] +query error +select array_prepend(null, null); + +query ? +select array_append([], null); +---- +[] -# array_prepend scalar function #2 -# query ?? -# select array_prepend(make_array(), make_array()), array_prepend(make_array(4), make_array()); -# ---- -# [[]] [[4]] # array_prepend scalar function #3 query ??? @@ -1055,30 +1805,56 @@ select array_prepend(1, make_array(2, 3, 4)), array_prepend(1.0, make_array(2.0, ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select array_prepend(1, arrow_cast(make_array(2, 3, 4), 'LargeList(Int64)')), array_prepend(1.0, arrow_cast(make_array(2.0, 3.0, 4.0), 'LargeList(Float64)')), array_prepend('h', arrow_cast(make_array('e', 'l', 'l', 'o'), 'LargeList(Utf8)')); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # array_prepend scalar function #4 (element is list) query ??? select array_prepend(make_array(1), make_array(make_array(2), make_array(3), make_array(4))), array_prepend(make_array(1.0), make_array([2.0], [3.0], [4.0])), array_prepend(make_array('h'), make_array(['e'], ['l'], ['l'], ['o'])); ---- [[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] +# TODO: add this when #8305 is fixed +# query ??? +# select array_prepend(arrow_cast(make_array(1), 'LargeList(Int64)'), arrow_cast(make_array(make_array(2), make_array(3), make_array(4)), 'LargeList(LargeList(Int64))')), array_prepend(arrow_cast(make_array(1.0), 'LargeList(Float64)'), arrow_cast(make_array([2.0], [3.0], [4.0]), 'LargeList(LargeList(Float64))')), array_prepend(arrow_cast(make_array('h'), 'LargeList(Utf8)'), arrow_cast(make_array(['e'], ['l'], ['l'], ['o']), 'LargeList(LargeList(Utf8))'')); +# ---- +# [[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] + # list_prepend scalar function #5 (function alias `array_prepend`) query ??? select list_prepend(1, make_array(2, 3, 4)), list_prepend(1.0, make_array(2.0, 3.0, 4.0)), list_prepend('h', make_array('e', 'l', 'l', 'o')); ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select list_prepend(1, arrow_cast(make_array(2, 3, 4), 'LargeList(Int64)')), list_prepend(1.0, arrow_cast(make_array(2.0, 3.0, 4.0), 'LargeList(Float64)')), list_prepend('h', arrow_cast(make_array('e', 'l', 'l', 'o'), 'LargeList(Utf8)')); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # array_push_front scalar function #6 (function alias `array_prepend`) query ??? select array_push_front(1, make_array(2, 3, 4)), array_push_front(1.0, make_array(2.0, 3.0, 4.0)), array_push_front('h', make_array('e', 'l', 'l', 'o')); ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select array_push_front(1, arrow_cast(make_array(2, 3, 4), 'LargeList(Int64)')), array_push_front(1.0, arrow_cast(make_array(2.0, 3.0, 4.0), 'LargeList(Float64)')), array_push_front('h', arrow_cast(make_array('e', 'l', 'l', 'o'), 'LargeList(Utf8)')); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # list_push_front scalar function #7 (function alias `array_prepend`) query ??? select list_push_front(1, make_array(2, 3, 4)), list_push_front(1.0, make_array(2.0, 3.0, 4.0)), list_push_front('h', make_array('e', 'l', 'l', 'o')); ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select list_push_front(1, arrow_cast(make_array(2, 3, 4), 'LargeList(Int64)')), list_push_front(1.0, arrow_cast(make_array(2.0, 3.0, 4.0), 'LargeList(Float64)')), list_push_front('h', arrow_cast(make_array('e', 'l', 'l', 'o'), 'LargeList(Utf8)')); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # array_prepend with columns #1 query ? select array_prepend(column2, column1) from arrays_values; @@ -1092,6 +1868,18 @@ select array_prepend(column2, column1) from arrays_values; [55, 51, 52, , 54, 55, 56, 57, 58, 59, 60] [66, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70] +query ? +select array_prepend(column2, column1) from large_arrays_values; +---- +[1, , 2, 3, 4, 5, 6, 7, 8, 9, 10] +[12, 11, 12, 13, 14, 15, 16, 17, 18, , 20] +[23, 21, 22, 23, , 25, 26, 27, 28, 29, 30] +[34, 31, 32, 33, 34, 35, , 37, 38, 39, 40] +[44] +[, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50] +[55, 51, 52, , 54, 55, 56, 57, 58, 59, 60] +[66, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70] + # array_prepend with columns #2 (element is list) query ? select array_prepend(column2, column1) from nested_arrays; @@ -1099,6 +1887,13 @@ select array_prepend(column2, column1) from nested_arrays; [[7, 8, 9], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] [[10, 11, 12], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] +# TODO: add this when #8305 is fixed +# query ? +# select array_prepend(column2, column1) from large_nested_arrays; +# ---- +# [[7, 8, 9], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] +# [[10, 11, 12], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] + # array_prepend with columns and scalars #1 query ?? select array_prepend(100.1, column2), array_prepend('.', column3) from arrays; @@ -1111,6 +1906,17 @@ select array_prepend(100.1, column2), array_prepend('.', column3) from arrays; [100.1] [., ,] [100.1, 16.6, 17.7, 18.8] [.] +query ?? +select array_prepend(100.1, column2), array_prepend('.', column3) from large_arrays; +---- +[100.1, 1.1, 2.2, 3.3] [., L, o, r, e, m] +[100.1, , 5.5, 6.6] [., i, p, , u, m] +[100.1, 7.7, 8.8, 9.9] [., d, , l, o, r] +[100.1, 10.1, , 12.2] [., s, i, t] +[100.1, 13.3, 14.4, 15.5] [., a, m, e, t] +[100.1] [., ,] +[100.1, 16.6, 17.7, 18.8] [.] + # array_prepend with columns and scalars #2 (element is list) query ?? select array_prepend(make_array(1, 11, 111), column1), array_prepend(column2, make_array(make_array(1, 2, 3), make_array(11, 12, 13))) from nested_arrays; @@ -1118,71 +1924,103 @@ select array_prepend(make_array(1, 11, 111), column1), array_prepend(column2, ma [[1, 11, 111], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] [[7, 8, 9], [1, 2, 3], [11, 12, 13]] [[1, 11, 111], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] [[10, 11, 12], [1, 2, 3], [11, 12, 13]] +# TODO: add this when #8305 is fixed +# query ?? +# select array_prepend(arrow_cast(make_array(1, 11, 111), 'LargeList(Int64)'), column1), array_prepend(column2, arrow_cast(make_array(make_array(1, 2, 3), make_array(11, 12, 13)), 'LargeList(LargeList(Int64))')) from large_nested_arrays; +# ---- +# [[1, 11, 111], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] [[7, 8, 9], [1, 2, 3], [11, 12, 13]] +# [[1, 11, 111], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] [[10, 11, 12], [1, 2, 3], [11, 12, 13]] + ## array_repeat (aliases: `list_repeat`) # array_repeat scalar function #1 -query ??? -select array_repeat(1, 5), array_repeat(3.14, 3), array_repeat('l', 4); ----- -[1, 1, 1, 1, 1] [3.14, 3.14, 3.14] [l, l, l, l] +query ???????? +select + array_repeat(1, 5), + array_repeat(3.14, 3), + array_repeat('l', 4), + array_repeat(null, 2), + list_repeat(-1, 5), + list_repeat(-3.14, 0), + list_repeat('rust', 4), + list_repeat(null, 0); +---- +[1, 1, 1, 1, 1] [3.14, 3.14, 3.14] [l, l, l, l] [, ] [-1, -1, -1, -1, -1] [] [rust, rust, rust, rust] [] # array_repeat scalar function #2 (element as list) -query ??? -select array_repeat([1], 5), array_repeat([1.1, 2.2, 3.3], 3), array_repeat([[1, 2], [3, 4]], 2); +query ???? +select + array_repeat([1], 5), + array_repeat([1.1, 2.2, 3.3], 3), + array_repeat([null, null], 3), + array_repeat([[1, 2], [3, 4]], 2); ---- -[[1], [1], [1], [1], [1]] [[1.1, 2.2, 3.3], [1.1, 2.2, 3.3], [1.1, 2.2, 3.3]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]] +[[1], [1], [1], [1], [1]] [[1.1, 2.2, 3.3], [1.1, 2.2, 3.3], [1.1, 2.2, 3.3]] [[, ], [, ], [, ]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]] -# list_repeat scalar function #3 (function alias: `array_repeat`) -query ??? -select list_repeat(1, 5), list_repeat(3.14, 3), list_repeat('l', 4); +query ???? +select + array_repeat(arrow_cast([1], 'LargeList(Int64)'), 5), + array_repeat(arrow_cast([1.1, 2.2, 3.3], 'LargeList(Float64)'), 3), + array_repeat(arrow_cast([null, null], 'LargeList(Null)'), 3), + array_repeat(arrow_cast([[1, 2], [3, 4]], 'LargeList(List(Int64))'), 2); ---- -[1, 1, 1, 1, 1] [3.14, 3.14, 3.14] [l, l, l, l] +[[1], [1], [1], [1], [1]] [[1.1, 2.2, 3.3], [1.1, 2.2, 3.3], [1.1, 2.2, 3.3]] [[, ], [, ], [, ]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]] # array_repeat with columns #1 -query ? -select array_repeat(column4, column1) from values_without_nulls; ----- -[1.1] -[2.2, 2.2] -[3.3, 3.3, 3.3] -[4.4, 4.4, 4.4, 4.4] -[5.5, 5.5, 5.5, 5.5, 5.5] -[6.6, 6.6, 6.6, 6.6, 6.6, 6.6] -[7.7, 7.7, 7.7, 7.7, 7.7, 7.7, 7.7] -[8.8, 8.8, 8.8, 8.8, 8.8, 8.8, 8.8, 8.8] -[9.9, 9.9, 9.9, 9.9, 9.9, 9.9, 9.9, 9.9, 9.9] -# array_repeat with columns #2 (element as list) -query ? -select array_repeat(column1, column3) from arrays_values_without_nulls; ----- -[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]] -[[11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]] -[[21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [21, 22, 23, 24, 25, 26, 27, 28, 29, 30]] -[[31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40]] +statement ok +CREATE TABLE array_repeat_table +AS VALUES + (1, 1, 1.1, 'a', make_array(4, 5, 6)), + (2, null, null, null, null), + (3, 2, 2.2, 'rust', make_array(7)), + (0, 3, 3.3, 'datafusion', make_array(8, 9)); -# array_repeat with columns and scalars #1 -query ?? -select array_repeat(1, column1), array_repeat(column4, 3) from values_without_nulls; ----- -[1] [1.1, 1.1, 1.1] -[1, 1] [2.2, 2.2, 2.2] -[1, 1, 1] [3.3, 3.3, 3.3] -[1, 1, 1, 1] [4.4, 4.4, 4.4] -[1, 1, 1, 1, 1] [5.5, 5.5, 5.5] -[1, 1, 1, 1, 1, 1] [6.6, 6.6, 6.6] -[1, 1, 1, 1, 1, 1, 1] [7.7, 7.7, 7.7] -[1, 1, 1, 1, 1, 1, 1, 1] [8.8, 8.8, 8.8] -[1, 1, 1, 1, 1, 1, 1, 1, 1] [9.9, 9.9, 9.9] +statement ok +CREATE TABLE large_array_repeat_table +AS SELECT + column1, + column2, + column3, + column4, + arrow_cast(column5, 'LargeList(Int64)') as column5 +FROM array_repeat_table; + +query ?????? +select + array_repeat(column2, column1), + array_repeat(column3, column1), + array_repeat(column4, column1), + array_repeat(column5, column1), + array_repeat(column2, 3), + array_repeat(make_array(1), column1) +from array_repeat_table; +---- +[1] [1.1] [a] [[4, 5, 6]] [1, 1, 1] [[1]] +[, ] [, ] [, ] [, ] [, , ] [[1], [1]] +[2, 2, 2] [2.2, 2.2, 2.2] [rust, rust, rust] [[7], [7], [7]] [2, 2, 2] [[1], [1], [1]] +[] [] [] [] [3, 3, 3] [] + +query ?????? +select + array_repeat(column2, column1), + array_repeat(column3, column1), + array_repeat(column4, column1), + array_repeat(column5, column1), + array_repeat(column2, 3), + array_repeat(make_array(1), column1) +from large_array_repeat_table; +---- +[1] [1.1] [a] [[4, 5, 6]] [1, 1, 1] [[1]] +[, ] [, ] [, ] [, ] [, , ] [[1], [1]] +[2, 2, 2] [2.2, 2.2, 2.2] [rust, rust, rust] [[7], [7], [7]] [2, 2, 2] [[1], [1], [1]] +[] [] [] [] [3, 3, 3] [] -# array_repeat with columns and scalars #2 (element as list) -query ?? -select array_repeat([1], column3), array_repeat(column1, 3) from arrays_values_without_nulls; ----- -[[1]] [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]] -[[1], [1]] [[11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]] -[[1], [1], [1]] [[21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [21, 22, 23, 24, 25, 26, 27, 28, 29, 30]] -[[1], [1], [1], [1]] [[31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40]] +statement ok +drop table array_repeat_table; + +statement ok +drop table large_array_repeat_table; ## array_concat (aliases: `array_cat`, `list_concat`, `list_cat`) @@ -1444,15 +2282,25 @@ select array_position(['h', 'e', 'l', 'l', 'o'], 'l'), array_position([1, 2, 3, ---- 3 5 1 +query III +select array_position(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), array_position(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), array_position(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1); +---- +3 5 1 + # array_position scalar function #2 (with optional argument) query III select array_position(['h', 'e', 'l', 'l', 'o'], 'l', 4), array_position([1, 2, 5, 4, 5], 5, 4), array_position([1, 1, 1], 1, 2); ---- 4 5 2 -# array_position scalar function #3 (element is list) -query II -select array_position(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6]), array_position(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4]); +query III +select array_position(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l', 4), array_position(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5, 4), array_position(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1, 2); +---- +4 5 2 + +# array_position scalar function #3 (element is list) +query II +select array_position(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6]), array_position(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4]); ---- 2 2 @@ -1462,24 +2310,44 @@ select array_position(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, ---- 4 3 +query II +select array_position(arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), 'LargeList(List(Int64))'), [4, 5, 6]), array_position(arrow_cast(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), 'LargeList(List(Int64))'), [2, 3, 4]); +---- +2 2 + # list_position scalar function #5 (function alias `array_position`) query III select list_position(['h', 'e', 'l', 'l', 'o'], 'l'), list_position([1, 2, 3, 4, 5], 5), list_position([1, 1, 1], 1); ---- 3 5 1 +query III +select list_position(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), list_position(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), list_position(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1); +---- +3 5 1 + # array_indexof scalar function #6 (function alias `array_position`) query III select array_indexof(['h', 'e', 'l', 'l', 'o'], 'l'), array_indexof([1, 2, 3, 4, 5], 5), array_indexof([1, 1, 1], 1); ---- 3 5 1 +query III +select array_indexof(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), array_indexof(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), array_indexof(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1); +---- +3 5 1 + # list_indexof scalar function #7 (function alias `array_position`) query III select list_indexof(['h', 'e', 'l', 'l', 'o'], 'l'), list_indexof([1, 2, 3, 4, 5], 5), list_indexof([1, 1, 1], 1); ---- 3 5 1 +query III +select list_indexof(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), list_indexof(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), list_indexof(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1); +---- +3 5 1 + # array_position with columns #1 query II select array_position(column1, column2), array_position(column1, column2, column3) from arrays_values_without_nulls; @@ -1489,6 +2357,14 @@ select array_position(column1, column2), array_position(column1, column2, column 3 3 4 4 +query II +select array_position(column1, column2), array_position(column1, column2, column3) from large_arrays_values_without_nulls; +---- +1 1 +2 2 +3 3 +4 4 + # array_position with columns #2 (element is list) query II select array_position(column1, column2), array_position(column1, column2, column3) from nested_arrays; @@ -1496,6 +2372,13 @@ select array_position(column1, column2), array_position(column1, column2, column 3 3 2 5 +#TODO: add this test when #8305 is fixed +#query II +#select array_position(column1, column2), array_position(column1, column2, column3) from nested_arrays; +#---- +#3 3 +#2 5 + # array_position with columns and scalars #1 query III select array_position(make_array(1, 2, 3, 4, 5), column2), array_position(column1, 3), array_position(column1, 3, 5) from arrays_values_without_nulls; @@ -1505,6 +2388,14 @@ NULL NULL NULL NULL NULL NULL NULL NULL NULL +query III +select array_position(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), column2), array_position(column1, 3), array_position(column1, 3, 5) from large_arrays_values_without_nulls; +---- +1 3 NULL +NULL NULL NULL +NULL NULL NULL +NULL NULL NULL + # array_position with columns and scalars #2 (element is list) query III select array_position(make_array([1, 2, 3], [4, 5, 6], [11, 12, 13]), column2), array_position(column1, make_array(4, 5, 6)), array_position(column1, make_array(1, 2, 3), 2) from nested_arrays; @@ -1512,6 +2403,13 @@ select array_position(make_array([1, 2, 3], [4, 5, 6], [11, 12, 13]), column2), NULL 6 4 NULL 1 NULL +#TODO: add this test when #8305 is fixed +#query III +#select array_position(arrow_cast(make_array([1, 2, 3], [4, 5, 6], [11, 12, 13]), 'LargeList(List(Int64))'), column2), array_position(column1, make_array(4, 5, 6)), array_position(column1, make_array(1, 2, 3), 2) from large_nested_arrays; +#---- +#NULL 6 4 +#NULL 1 NULL + ## array_positions (aliases: `list_positions`) # array_positions scalar function #1 @@ -1520,18 +2418,33 @@ select array_positions(['h', 'e', 'l', 'l', 'o'], 'l'), array_positions([1, 2, 3 ---- [3, 4] [5] [1, 2, 3] +query ??? +select array_positions(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), array_positions(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), array_positions(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1); +---- +[3, 4] [5] [1, 2, 3] + # array_positions scalar function #2 (element is list) query ? select array_positions(make_array([1, 2, 3], [2, 1, 3], [1, 5, 6], [2, 1, 3], [4, 5, 6]), [2, 1, 3]); ---- [2, 4] +query ? +select array_positions(arrow_cast(make_array([1, 2, 3], [2, 1, 3], [1, 5, 6], [2, 1, 3], [4, 5, 6]), 'LargeList(List(Int64))'), [2, 1, 3]); +---- +[2, 4] + # list_positions scalar function #3 (function alias `array_positions`) query ??? select list_positions(['h', 'e', 'l', 'l', 'o'], 'l'), list_positions([1, 2, 3, 4, 5], 5), list_positions([1, 1, 1], 1); ---- [3, 4] [5] [1, 2, 3] +query ??? +select list_positions(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), list_positions(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), list_positions(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1); +---- +[3, 4] [5] [1, 2, 3] + # array_positions with columns #1 query ? select array_positions(column1, column2) from arrays_values_without_nulls; @@ -1541,6 +2454,14 @@ select array_positions(column1, column2) from arrays_values_without_nulls; [3] [4] +query ? +select array_positions(arrow_cast(column1, 'LargeList(Int64)'), column2) from arrays_values_without_nulls; +---- +[1] +[2] +[3] +[4] + # array_positions with columns #2 (element is list) query ? select array_positions(column1, column2) from nested_arrays; @@ -1548,6 +2469,12 @@ select array_positions(column1, column2) from nested_arrays; [3] [2, 5] +query ? +select array_positions(arrow_cast(column1, 'LargeList(List(Int64))'), column2) from nested_arrays; +---- +[3] +[2, 5] + # array_positions with columns and scalars #1 query ?? select array_positions(column1, 4), array_positions(array[1, 2, 23, 13, 33, 45], column2) from arrays_values_without_nulls; @@ -1557,6 +2484,14 @@ select array_positions(column1, 4), array_positions(array[1, 2, 23, 13, 33, 45], [] [3] [] [] +query ?? +select array_positions(arrow_cast(column1, 'LargeList(Int64)'), 4), array_positions(array[1, 2, 23, 13, 33, 45], column2) from arrays_values_without_nulls; +---- +[4] [1] +[] [] +[] [3] +[] [] + # array_positions with columns and scalars #2 (element is list) query ?? select array_positions(column1, make_array(4, 5, 6)), array_positions(make_array([1, 2, 3], [11, 12, 13], [4, 5, 6]), column2) from nested_arrays; @@ -1564,23 +2499,76 @@ select array_positions(column1, make_array(4, 5, 6)), array_positions(make_array [6] [] [1] [] +query ?? +select array_positions(arrow_cast(column1, 'LargeList(List(Int64))'), make_array(4, 5, 6)), array_positions(arrow_cast(make_array([1, 2, 3], [11, 12, 13], [4, 5, 6]), 'LargeList(List(Int64))'), column2) from nested_arrays; +---- +[6] [] +[1] [] + ## array_replace (aliases: `list_replace`) # array_replace scalar function #1 query ??? -select array_replace(make_array(1, 2, 3, 4), 2, 3), array_replace(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), array_replace(make_array(1, 2, 3), 4, 0); +select + array_replace(make_array(1, 2, 3, 4), 2, 3), + array_replace(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), + array_replace(make_array(1, 2, 3), 4, 0); +---- +[1, 3, 3, 4] [1, 0, 4, 5, 4, 6, 7] [1, 2, 3] + +query ??? +select + array_replace(arrow_cast(make_array(1, 2, 3, 4), 'LargeList(Int64)'), 2, 3), + array_replace(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'LargeList(Int64)'), 4, 0), + array_replace(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4, 0); ---- [1, 3, 3, 4] [1, 0, 4, 5, 4, 6, 7] [1, 2, 3] # array_replace scalar function #2 (element is list) query ?? -select array_replace(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6], [1, 1, 1]), array_replace(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4], [3, 1, 4]); +select + array_replace( + make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), + [4, 5, 6], + [1, 1, 1] + ), + array_replace( + make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), + [2, 3, 4], + [3, 1, 4] + ); +---- +[[1, 2, 3], [1, 1, 1], [5, 5, 5], [4, 5, 6], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]] + +query ?? +select + array_replace( + arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), 'LargeList(List(Int64))'), + [4, 5, 6], + [1, 1, 1] + ), + array_replace( + arrow_cast(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), 'LargeList(List(Int64))'), + [2, 3, 4], + [3, 1, 4] + ); ---- [[1, 2, 3], [1, 1, 1], [5, 5, 5], [4, 5, 6], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]] # list_replace scalar function #3 (function alias `list_replace`) query ??? -select list_replace(make_array(1, 2, 3, 4), 2, 3), list_replace(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), list_replace(make_array(1, 2, 3), 4, 0); +select list_replace( + make_array(1, 2, 3, 4), 2, 3), + list_replace(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), + list_replace(make_array(1, 2, 3), 4, 0); +---- +[1, 3, 3, 4] [1, 0, 4, 5, 4, 6, 7] [1, 2, 3] + +query ??? +select list_replace( + arrow_cast(make_array(1, 2, 3, 4), 'LargeList(Int64)'), 2, 3), + list_replace(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'LargeList(Int64)'), 4, 0), + list_replace(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4, 0); ---- [1, 3, 3, 4] [1, 0, 4, 5, 4, 6, 7] [1, 2, 3] @@ -1593,6 +2581,14 @@ select array_replace(column1, column2, column3) from arrays_with_repeating_eleme [10, 7, 7, 8, 7, 9, 7, 8, 7, 7] [13, 11, 12, 10, 11, 12, 10, 11, 12, 10] +query ? +select array_replace(column1, column2, column3) from large_arrays_with_repeating_elements; +---- +[1, 4, 1, 3, 2, 2, 1, 3, 2, 3] +[7, 4, 5, 5, 6, 5, 5, 5, 4, 4] +[10, 7, 7, 8, 7, 9, 7, 8, 7, 7] +[13, 11, 12, 10, 11, 12, 10, 11, 12, 10] + # array_replace scalar function with columns #2 (element is list) query ? select array_replace(column1, column2, column3) from nested_arrays_with_repeating_elements; @@ -1602,9 +2598,33 @@ select array_replace(column1, column2, column3) from nested_arrays_with_repeatin [[28, 29, 30], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[37, 38, 39], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] +query ? +select array_replace(column1, column2, column3) from large_nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [10, 11, 12], [1, 2, 3], [7, 8, 9], [4, 5, 6], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] +[[19, 20, 21], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] +[[28, 29, 30], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] +[[37, 38, 39], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] + # array_replace scalar function with columns and scalars #1 query ??? -select array_replace(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2, column3), array_replace(column1, 1, column3), array_replace(column1, column2, 4) from arrays_with_repeating_elements; +select + array_replace(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2, column3), + array_replace(column1, 1, column3), + array_replace(column1, column2, 4) +from arrays_with_repeating_elements; +---- +[1, 4, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8] [4, 2, 1, 3, 2, 2, 1, 3, 2, 3] [1, 4, 1, 3, 2, 2, 1, 3, 2, 3] +[1, 2, 2, 7, 5, 4, 4, 7, 7, 10, 7, 8] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] +[1, 2, 2, 4, 5, 4, 4, 10, 7, 10, 7, 8] [7, 7, 7, 8, 7, 9, 7, 8, 7, 7] [4, 7, 7, 8, 7, 9, 7, 8, 7, 7] +[1, 2, 2, 4, 5, 4, 4, 7, 7, 13, 7, 8] [10, 11, 12, 10, 11, 12, 10, 11, 12, 10] [4, 11, 12, 10, 11, 12, 10, 11, 12, 10] + +query ??? +select + array_replace(arrow_cast(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), 'LargeList(Int64)'), column2, column3), + array_replace(column1, 1, column3), + array_replace(column1, column2, 4) +from large_arrays_with_repeating_elements; ---- [1, 4, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8] [4, 2, 1, 3, 2, 2, 1, 3, 2, 3] [1, 4, 1, 3, 2, 2, 1, 3, 2, 3] [1, 2, 2, 7, 5, 4, 4, 7, 7, 10, 7, 8] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] @@ -1613,7 +2633,33 @@ select array_replace(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2, c # array_replace scalar function with columns and scalars #2 (element is list) query ??? -select array_replace(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]), column2, column3), array_replace(column1, make_array(1, 2, 3), column3), array_replace(column1, column2, make_array(11, 12, 13)) from nested_arrays_with_repeating_elements; +select + array_replace( + make_array( + [1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]), + column2, + column3 + ), + array_replace(column1, make_array(1, 2, 3), column3), + array_replace(column1, column2, make_array(11, 12, 13)) +from nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [10, 11, 12], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [11, 12, 13], [1, 2, 3], [7, 8, 9], [4, 5, 6], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [19, 20, 21], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[11, 12, 13], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [37, 38, 39], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] + +query ??? +select + array_replace( + arrow_cast(make_array( + [1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]),'LargeList(List(Int64))'), + column2, + column3 + ), + array_replace(column1, make_array(1, 2, 3), column3), + array_replace(column1, column2, make_array(11, 12, 13)) +from large_nested_arrays_with_repeating_elements; ---- [[1, 2, 3], [10, 11, 12], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [11, 12, 13], [1, 2, 3], [7, 8, 9], [4, 5, 6], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [4, 5, 6], [4, 5, 6], [19, 20, 21], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[11, 12, 13], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] @@ -1624,25 +2670,88 @@ select array_replace(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [ # array_replace_n scalar function #1 query ??? -select array_replace_n(make_array(1, 2, 3, 4), 2, 3, 2), array_replace_n(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0, 2), array_replace_n(make_array(1, 2, 3), 4, 0, 3); +select + array_replace_n(make_array(1, 2, 3, 4), 2, 3, 2), + array_replace_n(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0, 2), + array_replace_n(make_array(1, 2, 3), 4, 0, 3); +---- +[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] + +query ??? +select + array_replace_n(arrow_cast(make_array(1, 2, 3, 4), 'LargeList(Int64)'), 2, 3, 2), + array_replace_n(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'LargeList(Int64)'), 4, 0, 2), + array_replace_n(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4, 0, 3); ---- [1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] # array_replace_n scalar function #2 (element is list) query ?? -select array_replace_n(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6], [1, 1, 1], 2), array_replace_n(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4], [3, 1, 4], 2); +select + array_replace_n( + make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), + [4, 5, 6], + [1, 1, 1], + 2 + ), + array_replace_n( + make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), + [2, 3, 4], + [3, 1, 4], + 2 + ); +---- +[[1, 2, 3], [1, 1, 1], [5, 5, 5], [1, 1, 1], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [3, 1, 4], [5, 3, 1], [1, 3, 2]] + +query ?? +select + array_replace_n( + arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), 'LargeList(List(Int64))'), + [4, 5, 6], + [1, 1, 1], + 2 + ), + array_replace_n( + arrow_cast(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), 'LargeList(List(Int64))'), + [2, 3, 4], + [3, 1, 4], + 2 + ); ---- [[1, 2, 3], [1, 1, 1], [5, 5, 5], [1, 1, 1], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [3, 1, 4], [5, 3, 1], [1, 3, 2]] # list_replace_n scalar function #3 (function alias `array_replace_n`) query ??? -select list_replace_n(make_array(1, 2, 3, 4), 2, 3, 2), list_replace_n(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0, 2), list_replace_n(make_array(1, 2, 3), 4, 0, 3); +select + list_replace_n(make_array(1, 2, 3, 4), 2, 3, 2), + list_replace_n(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0, 2), + list_replace_n(make_array(1, 2, 3), 4, 0, 3); +---- +[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] + +query ??? +select + list_replace_n(arrow_cast(make_array(1, 2, 3, 4), 'LargeList(Int64)'), 2, 3, 2), + list_replace_n(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'LargeList(Int64)'), 4, 0, 2), + list_replace_n(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4, 0, 3); ---- [1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] # array_replace_n scalar function with columns #1 query ? -select array_replace_n(column1, column2, column3, column4) from arrays_with_repeating_elements; +select + array_replace_n(column1, column2, column3, column4) +from arrays_with_repeating_elements; +---- +[1, 4, 1, 3, 4, 4, 1, 3, 2, 3] +[7, 7, 5, 5, 6, 5, 5, 5, 4, 4] +[10, 10, 10, 8, 10, 9, 10, 8, 7, 7] +[13, 11, 12, 13, 11, 12, 13, 11, 12, 13] + +query ? +select + array_replace_n(column1, column2, column3, column4) +from large_arrays_with_repeating_elements; ---- [1, 4, 1, 3, 4, 4, 1, 3, 2, 3] [7, 7, 5, 5, 6, 5, 5, 5, 4, 4] @@ -1651,16 +2760,47 @@ select array_replace_n(column1, column2, column3, column4) from arrays_with_repe # array_replace_n scalar function with columns #2 (element is list) query ? -select array_replace_n(column1, column2, column3, column4) from nested_arrays_with_repeating_elements; +select + array_replace_n(column1, column2, column3, column4) +from nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [10, 11, 12], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] +[[19, 20, 21], [19, 20, 21], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] +[[28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23, 24], [28, 29, 30], [25, 26, 27], [28, 29, 30], [22, 23, 24], [19, 20, 21], [19, 20, 21]] +[[37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39]] + +query ? +select + array_replace_n(column1, column2, column3, column4) +from large_nested_arrays_with_repeating_elements; ---- [[1, 2, 3], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [10, 11, 12], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[19, 20, 21], [19, 20, 21], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23, 24], [28, 29, 30], [25, 26, 27], [28, 29, 30], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39]] + # array_replace_n scalar function with columns and scalars #1 query ???? -select array_replace_n(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2, column3, column4), array_replace_n(column1, 1, column3, column4), array_replace_n(column1, column2, 4, column4), array_replace_n(column1, column2, column3, 2) from arrays_with_repeating_elements; +select + array_replace_n(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2, column3, column4), + array_replace_n(column1, 1, column3, column4), + array_replace_n(column1, column2, 4, column4), + array_replace_n(column1, column2, column3, 2) +from arrays_with_repeating_elements; +---- +[1, 4, 4, 4, 5, 4, 4, 7, 7, 10, 7, 8] [4, 2, 4, 3, 2, 2, 4, 3, 2, 3] [1, 4, 1, 3, 4, 4, 1, 3, 2, 3] [1, 4, 1, 3, 4, 2, 1, 3, 2, 3] +[1, 2, 2, 7, 5, 7, 4, 7, 7, 10, 7, 8] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [7, 7, 5, 5, 6, 5, 5, 5, 4, 4] +[1, 2, 2, 4, 5, 4, 4, 10, 10, 10, 10, 8] [7, 7, 7, 8, 7, 9, 7, 8, 7, 7] [4, 4, 4, 8, 4, 9, 4, 8, 7, 7] [10, 10, 7, 8, 7, 9, 7, 8, 7, 7] +[1, 2, 2, 4, 5, 4, 4, 7, 7, 13, 7, 8] [10, 11, 12, 10, 11, 12, 10, 11, 12, 10] [4, 11, 12, 4, 11, 12, 4, 11, 12, 4] [13, 11, 12, 13, 11, 12, 10, 11, 12, 10] + +query ???? +select + array_replace_n(arrow_cast(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), 'LargeList(Int64)'), column2, column3, column4), + array_replace_n(column1, 1, column3, column4), + array_replace_n(column1, column2, 4, column4), + array_replace_n(column1, column2, column3, 2) +from large_arrays_with_repeating_elements; ---- [1, 4, 4, 4, 5, 4, 4, 7, 7, 10, 7, 8] [4, 2, 4, 3, 2, 2, 4, 3, 2, 3] [1, 4, 1, 3, 4, 4, 1, 3, 2, 3] [1, 4, 1, 3, 4, 2, 1, 3, 2, 3] [1, 2, 2, 7, 5, 7, 4, 7, 7, 10, 7, 8] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [7, 7, 5, 5, 6, 5, 5, 5, 4, 4] @@ -1669,7 +2809,37 @@ select array_replace_n(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2, # array_replace_n scalar function with columns and scalars #2 (element is list) query ???? -select array_replace_n(make_array([7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [4, 5, 6]), column2, column3, column4), array_replace_n(column1, make_array(1, 2, 3), column3, column4), array_replace_n(column1, column2, make_array(11, 12, 13), column4), array_replace_n(column1, column2, column3, 2) from nested_arrays_with_repeating_elements; +select + array_replace_n( + make_array( + [7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [4, 5, 6]), + column2, + column3, + column4 + ), + array_replace_n(column1, make_array(1, 2, 3), column3, column4), + array_replace_n(column1, column2, make_array(11, 12, 13), column4), + array_replace_n(column1, column2, column3, 2) +from nested_arrays_with_repeating_elements; +---- +[[7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [10, 11, 12]] [[10, 11, 12], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [11, 12, 13], [1, 2, 3], [7, 8, 9], [11, 12, 13], [11, 12, 13], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] +[[7, 8, 9], [2, 1, 3], [1, 5, 6], [19, 20, 21], [2, 1, 3], [7, 8, 9], [4, 5, 6]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[11, 12, 13], [11, 12, 13], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[19, 20, 21], [19, 20, 21], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] +[[7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [4, 5, 6]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [11, 12, 13], [11, 12, 13], [22, 23, 24], [11, 12, 13], [25, 26, 27], [11, 12, 13], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[28, 29, 30], [28, 29, 30], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] +[[7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [4, 5, 6]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13]] [[37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] + +query ???? +select + array_replace_n( + arrow_cast(make_array( + [7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [4, 5, 6]), 'LargeList(List(Int64))'), + column2, + column3, + column4 + ), + array_replace_n(column1, make_array(1, 2, 3), column3, column4), + array_replace_n(column1, column2, make_array(11, 12, 13), column4), + array_replace_n(column1, column2, column3, 2) +from large_nested_arrays_with_repeating_elements; ---- [[7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [10, 11, 12]] [[10, 11, 12], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [11, 12, 13], [1, 2, 3], [7, 8, 9], [11, 12, 13], [11, 12, 13], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[7, 8, 9], [2, 1, 3], [1, 5, 6], [19, 20, 21], [2, 1, 3], [7, 8, 9], [4, 5, 6]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[11, 12, 13], [11, 12, 13], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[19, 20, 21], [19, 20, 21], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] @@ -1680,25 +2850,84 @@ select array_replace_n(make_array([7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], # array_replace_all scalar function #1 query ??? -select array_replace_all(make_array(1, 2, 3, 4), 2, 3), array_replace_all(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), array_replace_all(make_array(1, 2, 3), 4, 0); +select + array_replace_all(make_array(1, 2, 3, 4), 2, 3), + array_replace_all(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), + array_replace_all(make_array(1, 2, 3), 4, 0); +---- +[1, 3, 3, 4] [1, 0, 0, 5, 0, 6, 7] [1, 2, 3] + +query ??? +select + array_replace_all(arrow_cast(make_array(1, 2, 3, 4), 'LargeList(Int64)'), 2, 3), + array_replace_all(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'LargeList(Int64)'), 4, 0), + array_replace_all(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4, 0); ---- [1, 3, 3, 4] [1, 0, 0, 5, 0, 6, 7] [1, 2, 3] # array_replace_all scalar function #2 (element is list) query ?? -select array_replace_all(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6], [1, 1, 1]), array_replace_all(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4], [3, 1, 4]); +select + array_replace_all( + make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), + [4, 5, 6], + [1, 1, 1] + ), + array_replace_all( + make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), + [2, 3, 4], + [3, 1, 4] + ); +---- +[[1, 2, 3], [1, 1, 1], [5, 5, 5], [1, 1, 1], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [3, 1, 4], [5, 3, 1], [1, 3, 2]] + +query ?? +select + array_replace_all( + arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), 'LargeList(List(Int64))'), + [4, 5, 6], + [1, 1, 1] + ), + array_replace_all( + arrow_cast(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), 'LargeList(List(Int64))'), + [2, 3, 4], + [3, 1, 4] + ); ---- [[1, 2, 3], [1, 1, 1], [5, 5, 5], [1, 1, 1], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [3, 1, 4], [5, 3, 1], [1, 3, 2]] # list_replace_all scalar function #3 (function alias `array_replace_all`) query ??? -select list_replace_all(make_array(1, 2, 3, 4), 2, 3), list_replace_all(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), list_replace_all(make_array(1, 2, 3), 4, 0); +select + list_replace_all(make_array(1, 2, 3, 4), 2, 3), + list_replace_all(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), + list_replace_all(make_array(1, 2, 3), 4, 0); +---- +[1, 3, 3, 4] [1, 0, 0, 5, 0, 6, 7] [1, 2, 3] + +query ??? +select + list_replace_all(arrow_cast(make_array(1, 2, 3, 4), 'LargeList(Int64)'), 2, 3), + list_replace_all(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'LargeList(Int64)'), 4, 0), + list_replace_all(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4, 0); ---- [1, 3, 3, 4] [1, 0, 0, 5, 0, 6, 7] [1, 2, 3] # array_replace_all scalar function with columns #1 query ? -select array_replace_all(column1, column2, column3) from arrays_with_repeating_elements; +select + array_replace_all(column1, column2, column3) +from arrays_with_repeating_elements; +---- +[1, 4, 1, 3, 4, 4, 1, 3, 4, 3] +[7, 7, 5, 5, 6, 5, 5, 5, 7, 7] +[10, 10, 10, 8, 10, 9, 10, 8, 10, 10] +[13, 11, 12, 13, 11, 12, 13, 11, 12, 13] + +query ? +select + array_replace_all(column1, column2, column3) +from large_arrays_with_repeating_elements; ---- [1, 4, 1, 3, 4, 4, 1, 3, 4, 3] [7, 7, 5, 5, 6, 5, 5, 5, 7, 7] @@ -1707,7 +2936,19 @@ select array_replace_all(column1, column2, column3) from arrays_with_repeating_e # array_replace_all scalar function with columns #2 (element is list) query ? -select array_replace_all(column1, column2, column3) from nested_arrays_with_repeating_elements; +select + array_replace_all(column1, column2, column3) +from nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [7, 8, 9]] +[[19, 20, 21], [19, 20, 21], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [19, 20, 21], [19, 20, 21]] +[[28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23, 24], [28, 29, 30], [25, 26, 27], [28, 29, 30], [22, 23, 24], [28, 29, 30], [28, 29, 30]] +[[37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39]] + +query ? +select + array_replace_all(column1, column2, column3) +from large_nested_arrays_with_repeating_elements; ---- [[1, 2, 3], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [7, 8, 9]] [[19, 20, 21], [19, 20, 21], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [19, 20, 21], [19, 20, 21]] @@ -1716,7 +2957,23 @@ select array_replace_all(column1, column2, column3) from nested_arrays_with_repe # array_replace_all scalar function with columns and scalars #1 query ??? -select array_replace_all(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2, column3), array_replace_all(column1, 1, column3), array_replace_all(column1, column2, 4) from arrays_with_repeating_elements; +select + array_replace_all(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2, column3), + array_replace_all(column1, 1, column3), + array_replace_all(column1, column2, 4) +from arrays_with_repeating_elements; +---- +[1, 4, 4, 4, 5, 4, 4, 7, 7, 10, 7, 8] [4, 2, 4, 3, 2, 2, 4, 3, 2, 3] [1, 4, 1, 3, 4, 4, 1, 3, 4, 3] +[1, 2, 2, 7, 5, 7, 7, 7, 7, 10, 7, 8] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] +[1, 2, 2, 4, 5, 4, 4, 10, 10, 10, 10, 8] [7, 7, 7, 8, 7, 9, 7, 8, 7, 7] [4, 4, 4, 8, 4, 9, 4, 8, 4, 4] +[1, 2, 2, 4, 5, 4, 4, 7, 7, 13, 7, 8] [10, 11, 12, 10, 11, 12, 10, 11, 12, 10] [4, 11, 12, 4, 11, 12, 4, 11, 12, 4] + +query ??? +select + array_replace_all(arrow_cast(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), 'LargeList(Int64)'), column2, column3), + array_replace_all(column1, 1, column3), + array_replace_all(column1, column2, 4) +from large_arrays_with_repeating_elements; ---- [1, 4, 4, 4, 5, 4, 4, 7, 7, 10, 7, 8] [4, 2, 4, 3, 2, 2, 4, 3, 2, 3] [1, 4, 1, 3, 4, 4, 1, 3, 4, 3] [1, 2, 2, 7, 5, 7, 7, 7, 7, 10, 7, 8] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] @@ -1725,13 +2982,68 @@ select array_replace_all(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column # array_replace_all scalar function with columns and scalars #2 (element is list) query ??? -select array_replace_all(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]), column2, column3), array_replace_all(column1, make_array(1, 2, 3), column3), array_replace_all(column1, column2, make_array(11, 12, 13)) from nested_arrays_with_repeating_elements; +select + array_replace_all( + make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]), + column2, + column3 + ), + array_replace_all(column1, make_array(1, 2, 3), column3), + array_replace_all(column1, column2, make_array(11, 12, 13)) +from nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [10, 11, 12], [10, 11, 12], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [11, 12, 13], [1, 2, 3], [7, 8, 9], [11, 12, 13], [11, 12, 13], [1, 2, 3], [7, 8, 9], [11, 12, 13], [7, 8, 9]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [19, 20, 21], [13, 14, 15], [19, 20, 21], [19, 20, 21], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[11, 12, 13], [11, 12, 13], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [11, 12, 13], [11, 12, 13]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [11, 12, 13], [11, 12, 13], [22, 23, 24], [11, 12, 13], [25, 26, 27], [11, 12, 13], [22, 23, 24], [11, 12, 13], [11, 12, 13]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [37, 38, 39], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13]] + +query ??? +select + array_replace_all( + arrow_cast(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]), 'LargeList(List(Int64))'), + column2, + column3 + ), + array_replace_all(column1, make_array(1, 2, 3), column3), + array_replace_all(column1, column2, make_array(11, 12, 13)) +from nested_arrays_with_repeating_elements; ---- [[1, 2, 3], [10, 11, 12], [10, 11, 12], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [11, 12, 13], [1, 2, 3], [7, 8, 9], [11, 12, 13], [11, 12, 13], [1, 2, 3], [7, 8, 9], [11, 12, 13], [7, 8, 9]] [[1, 2, 3], [4, 5, 6], [4, 5, 6], [19, 20, 21], [13, 14, 15], [19, 20, 21], [19, 20, 21], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[11, 12, 13], [11, 12, 13], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [11, 12, 13], [11, 12, 13]] [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [11, 12, 13], [11, 12, 13], [22, 23, 24], [11, 12, 13], [25, 26, 27], [11, 12, 13], [22, 23, 24], [11, 12, 13], [11, 12, 13]] [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [37, 38, 39], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13]] +# array_replace with null handling + +statement ok +create table t as values + (make_array(3, 1, NULL, 3), 3, 4, 2), + (make_array(3, 1, NULL, 3), NULL, 5, 2), + (NULL, 3, 2, 1), + (make_array(3, 1, 3), 3, NULL, 1) +; + + +# ([3, 1, NULL, 3], 3, 4, 2) => [4, 1, NULL, 4] NULL not matched +# ([3, 1, NULL, 3], NULL, 5, 2) => [3, 1, NULL, 3] NULL is replaced with 5 +# ([NULL], 3, 2, 1) => NULL +# ([3, 1, 3], 3, NULL, 1) => [NULL, 1 3] + +query ?III? +select column1, column2, column3, column4, array_replace_n(column1, column2, column3, column4) from t; +---- +[3, 1, , 3] 3 4 2 [4, 1, , 4] +[3, 1, , 3] NULL 5 2 [3, 1, 5, 3] +NULL 3 2 1 NULL +[3, 1, 3] 3 NULL 1 [, 1, 3] + + + +statement ok +drop table t; + + + ## array_to_string (aliases: `list_to_string`, `array_join`, `list_join`) # array_to_string scalar function #1 @@ -1752,54 +3064,222 @@ select array_to_string(make_array(), ',') ---- (empty) -# list_to_string scalar function #4 (function alias `array_to_string`) -query TTT -select list_to_string(['h', 'e', 'l', 'l', 'o'], ','), list_to_string([1, 2, 3, 4, 5], '-'), list_to_string([1.0, 2.0, 3.0], '|'); ----- -h,e,l,l,o 1-2-3-4-5 1|2|3 -# array_join scalar function #5 (function alias `array_to_string`) -query TTT -select array_join(['h', 'e', 'l', 'l', 'o'], ','), array_join([1, 2, 3, 4, 5], '-'), array_join([1.0, 2.0, 3.0], '|'); +## array_union (aliases: `list_union`) + +# array_union scalar function #1 +query ? +select array_union([1, 2, 3, 4], [5, 6, 3, 4]); ---- -h,e,l,l,o 1-2-3-4-5 1|2|3 +[1, 2, 3, 4, 5, 6] -# list_join scalar function #6 (function alias `list_join`) -query TTT -select list_join(['h', 'e', 'l', 'l', 'o'], ','), list_join([1, 2, 3, 4, 5], '-'), list_join([1.0, 2.0, 3.0], '|'); +query ? +select array_union(arrow_cast([1, 2, 3, 4], 'LargeList(Int64)'), arrow_cast([5, 6, 3, 4], 'LargeList(Int64)')); ---- -h,e,l,l,o 1-2-3-4-5 1|2|3 +[1, 2, 3, 4, 5, 6] -# array_to_string scalar function with nulls #1 -query TTT -select array_to_string(make_array('h', NULL, 'l', NULL, 'o'), ','), array_to_string(make_array(1, NULL, 3, NULL, 5), '-'), array_to_string(make_array(NULL, 2.0, 3.0), '|'); +# array_union scalar function #2 +query ? +select array_union([1, 2, 3, 4], [5, 6, 7, 8]); ---- -h,l,o 1-3-5 2|3 +[1, 2, 3, 4, 5, 6, 7, 8] -# array_to_string scalar function with nulls #2 -query TTT -select array_to_string(make_array('h', NULL, NULL, NULL, 'o'), ',', '-'), array_to_string(make_array(NULL, 2, NULL, 4, 5), '-', 'nil'), array_to_string(make_array(1.0, NULL, 3.0), '|', '0'); +query ? +select array_union(arrow_cast([1, 2, 3, 4], 'LargeList(Int64)'), arrow_cast([5, 6, 7, 8], 'LargeList(Int64)')); ---- -h,-,-,-,o nil-2-nil-4-5 1|0|3 +[1, 2, 3, 4, 5, 6, 7, 8] -# array_to_string with columns #1 +# array_union scalar function #3 +query ? +select array_union([1,2,3], []); +---- +[1, 2, 3] -# For reference -# select column1, column4 from arrays_values; -# ---- -# [, 2, 3, 4, 5, 6, 7, 8, 9, 10] , -# [11, 12, 13, 14, 15, 16, 17, 18, , 20] . -# [21, 22, 23, , 25, 26, 27, 28, 29, 30] - -# [31, 32, 33, 34, 35, , 37, 38, 39, 40] ok -# NULL @ -# [41, 42, 43, 44, 45, 46, 47, 48, 49, 50] $ -# [51, 52, , 54, 55, 56, 57, 58, 59, 60] ^ -# [61, 62, 63, 64, 65, 66, 67, 68, 69, 70] NULL +query ? +select array_union(arrow_cast([1,2,3], 'LargeList(Int64)'), arrow_cast([], 'LargeList(Null)')); +---- +[1, 2, 3] -query T -select array_to_string(column1, column4) from arrays_values; +# array_union scalar function #4 +query ? +select array_union([1, 2, 3, 4], [5, 4]); ---- -2,3,4,5,6,7,8,9,10 +[1, 2, 3, 4, 5] + +query ? +select array_union(arrow_cast([1, 2, 3, 4], 'LargeList(Int64)'), arrow_cast([5, 4], 'LargeList(Int64)')); +---- +[1, 2, 3, 4, 5] + +# array_union scalar function #5 +statement ok +CREATE TABLE arrays_with_repeating_elements_for_union +AS VALUES + ([1], [2]), + ([2, 3], [3]), + ([3], [3, 4]) +; + +query ? +select array_union(column1, column2) from arrays_with_repeating_elements_for_union; +---- +[1, 2] +[2, 3] +[3, 4] + +query ? +select array_union(arrow_cast(column1, 'LargeList(Int64)'), arrow_cast(column2, 'LargeList(Int64)')) from arrays_with_repeating_elements_for_union; +---- +[1, 2] +[2, 3] +[3, 4] + +statement ok +drop table arrays_with_repeating_elements_for_union; + +# array_union scalar function #6 +query ? +select array_union([], []); +---- +[] + +query ? +select array_union(arrow_cast([], 'LargeList(Null)'), arrow_cast([], 'LargeList(Null)')); +---- +[] + +# array_union scalar function #7 +query ? +select array_union([[null]], []); +---- +[[]] + +query ? +select array_union(arrow_cast([[null]], 'LargeList(List(Null))'), arrow_cast([], 'LargeList(Null)')); +---- +[[]] + +# array_union scalar function #8 +query ? +select array_union([null], [null]); +---- +[] + +query ? +select array_union(arrow_cast([[null]], 'LargeList(List(Null))'), arrow_cast([[null]], 'LargeList(List(Null))')); +---- +[[]] + +# array_union scalar function #9 +query ? +select array_union(null, []); +---- +[] + +query ? +select array_union(null, arrow_cast([], 'LargeList(Null)')); +---- +[] + +# array_union scalar function #10 +query ? +select array_union(null, null); +---- +NULL + +# array_union scalar function #11 +query ? +select array_union([1, 1, 2, 2, 3, 3], null); +---- +[1, 2, 3] + +query ? +select array_union(arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)'), null); +---- +[1, 2, 3] + +# array_union scalar function #12 +query ? +select array_union(null, [1, 1, 2, 2, 3, 3]); +---- +[1, 2, 3] + +query ? +select array_union(null, arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)')); +---- +[1, 2, 3] + +# array_union scalar function #13 +query ? +select array_union([1.2, 3.0], [1.2, 3.0, 5.7]); +---- +[1.2, 3.0, 5.7] + +query ? +select array_union(arrow_cast([1.2, 3.0], 'LargeList(Float64)'), arrow_cast([1.2, 3.0, 5.7], 'LargeList(Float64)')); +---- +[1.2, 3.0, 5.7] + +# array_union scalar function #14 +query ? +select array_union(['hello'], ['hello','datafusion']); +---- +[hello, datafusion] + +query ? +select array_union(arrow_cast(['hello'], 'LargeList(Utf8)'), arrow_cast(['hello','datafusion'], 'LargeList(Utf8)')); +---- +[hello, datafusion] + + +# list_to_string scalar function #4 (function alias `array_to_string`) +query TTT +select list_to_string(['h', 'e', 'l', 'l', 'o'], ','), list_to_string([1, 2, 3, 4, 5], '-'), list_to_string([1.0, 2.0, 3.0], '|'); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + +# array_join scalar function #5 (function alias `array_to_string`) +query TTT +select array_join(['h', 'e', 'l', 'l', 'o'], ','), array_join([1, 2, 3, 4, 5], '-'), array_join([1.0, 2.0, 3.0], '|'); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + +# list_join scalar function #6 (function alias `list_join`) +query TTT +select list_join(['h', 'e', 'l', 'l', 'o'], ','), list_join([1, 2, 3, 4, 5], '-'), list_join([1.0, 2.0, 3.0], '|'); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + +# array_to_string scalar function with nulls #1 +query TTT +select array_to_string(make_array('h', NULL, 'l', NULL, 'o'), ','), array_to_string(make_array(1, NULL, 3, NULL, 5), '-'), array_to_string(make_array(NULL, 2.0, 3.0), '|'); +---- +h,l,o 1-3-5 2|3 + +# array_to_string scalar function with nulls #2 +query TTT +select array_to_string(make_array('h', NULL, NULL, NULL, 'o'), ',', '-'), array_to_string(make_array(NULL, 2, NULL, 4, 5), '-', 'nil'), array_to_string(make_array(1.0, NULL, 3.0), '|', '0'); +---- +h,-,-,-,o nil-2-nil-4-5 1|0|3 + +# array_to_string with columns #1 + +# For reference +# select column1, column4 from arrays_values; +# ---- +# [, 2, 3, 4, 5, 6, 7, 8, 9, 10] , +# [11, 12, 13, 14, 15, 16, 17, 18, , 20] . +# [21, 22, 23, , 25, 26, 27, 28, 29, 30] - +# [31, 32, 33, 34, 35, , 37, 38, 39, 40] ok +# NULL @ +# [41, 42, 43, 44, 45, 46, 47, 48, 49, 50] $ +# [51, 52, , 54, 55, 56, 57, 58, 59, 60] ^ +# [61, 62, 63, 64, 65, 66, 67, 68, 69, 70] NULL + +query T +select array_to_string(column1, column4) from arrays_values; +---- +2,3,4,5,6,7,8,9,10 11.12.13.14.15.16.17.18.20 21-22-23-25-26-27-28-29-30 31ok32ok33ok34ok35ok37ok38ok39ok40 @@ -1872,6 +3352,20 @@ select array_remove(make_array(1, 2, 2, 1, 1), 2), array_remove(make_array(1.0, ---- [1, 2, 1, 1] [2.0, 2.0, 1.0, 1.0] [h, e, l, o] +query ??? +select + array_remove(make_array(1, null, 2, 3), 2), + array_remove(make_array(1.1, null, 2.2, 3.3), 1.1), + array_remove(make_array('a', null, 'bc'), 'a'); +---- +[1, , 3] [, 2.2, 3.3] [, bc] + +# TODO: https://github.com/apache/arrow-datafusion/issues/7142 +# query +# select +# array_remove(make_array(1, null, 2), null), +# array_remove(make_array(1, null, 2, null), null); + # array_remove scalar function #2 (element is list) query ?? select array_remove(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6]), array_remove(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4]); @@ -2042,33 +3536,64 @@ select array_length(make_array(1, 2, 3, 4, 5)), array_length(make_array(1, 2, 3) ---- 5 3 3 +query III +select array_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), array_length(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), array_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))')); +---- +5 3 3 + # array_length scalar function #2 query III select array_length(make_array(1, 2, 3, 4, 5), 1), array_length(make_array(1, 2, 3), 1), array_length(make_array([1, 2], [3, 4], [5, 6]), 1); ---- 5 3 3 +query III +select array_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 1), array_length(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 1), array_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))'), 1); +---- +5 3 3 + # array_length scalar function #3 query III select array_length(make_array(1, 2, 3, 4, 5), 2), array_length(make_array(1, 2, 3), 2), array_length(make_array([1, 2], [3, 4], [5, 6]), 2); ---- NULL NULL 2 +query III +select array_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_length(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 2), array_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))'), 2); +---- +NULL NULL 2 + # array_length scalar function #4 query II select array_length(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 1), array_length(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 2); ---- 3 2 +query II +select array_length(arrow_cast(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 'LargeList(List(List(Int64)))'), 1), array_length(arrow_cast(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 'LargeList(List(List(Int64)))'), 2); +---- +3 2 + # array_length scalar function #5 query III select array_length(make_array()), array_length(make_array(), 1), array_length(make_array(), 2) ---- 0 0 NULL -# list_length scalar function #6 (function alias `array_length`) +# array_length scalar function #6 nested array +query III +select array_length([[1, 2, 3, 4], [5, 6, 7, 8]]), array_length([[1, 2, 3, 4], [5, 6, 7, 8]], 1), array_length([[1, 2, 3, 4], [5, 6, 7, 8]], 2); +---- +2 2 4 + +# list_length scalar function #7 (function alias `array_length`) +query IIII +select list_length(make_array(1, 2, 3, 4, 5)), list_length(make_array(1, 2, 3)), list_length(make_array([1, 2], [3, 4], [5, 6])), array_length([[1, 2, 3, 4], [5, 6, 7, 8]], 3); +---- +5 3 3 NULL + query III -select list_length(make_array(1, 2, 3, 4, 5)), list_length(make_array(1, 2, 3)), list_length(make_array([1, 2], [3, 4], [5, 6])); +select list_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), list_length(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), list_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))')); ---- 5 3 3 @@ -2085,6 +3610,18 @@ NULL NULL NULL +query I +select array_length(arrow_cast(column1, 'LargeList(Int64)'), column3) from arrays_values; +---- +10 +NULL +NULL +NULL +NULL +NULL +NULL +NULL + # array_length with columns and scalars query II select array_length(array[array[1, 2], array[3, 4]], column3), array_length(column1, 1) from arrays_values; @@ -2098,11 +3635,22 @@ NULL 10 NULL 10 NULL 10 +query II +select array_length(arrow_cast(array[array[1, 2], array[3, 4]], 'LargeList(List(Int64))'), column3), array_length(arrow_cast(column1, 'LargeList(Int64)'), 1) from arrays_values; +---- +2 10 +2 10 +NULL 10 +NULL 10 +NULL NULL +NULL 10 +NULL 10 +NULL 10 + ## array_dims (aliases: `list_dims`) # array dims error -# TODO this is a separate bug -query error Internal error: could not cast value to arrow_array::array::list_array::GenericListArray\. +query error Execution error: array_dims does not support type 'Int64' select array_dims(1); # array_dims scalar function @@ -2111,6 +3659,11 @@ select array_dims(make_array(1, 2, 3)), array_dims(make_array([1, 2], [3, 4])), ---- [3] [2, 2] [1, 1, 1, 2, 1] +query ??? +select array_dims(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), array_dims(arrow_cast(make_array([1, 2], [3, 4]), 'LargeList(List(Int64))')), array_dims(arrow_cast(make_array([[[[1], [2]]]]), 'LargeList(List(List(List(List(Int64)))))')); +---- +[3] [2, 2] [1, 1, 1, 2, 1] + # array_dims scalar function #2 query ?? select array_dims(array_repeat(array_repeat(array_repeat(2, 3), 2), 1)), array_dims(array_repeat(array_repeat(array_repeat(3, 4), 5), 2)); @@ -2123,12 +3676,22 @@ select array_dims(make_array()), array_dims(make_array(make_array())) ---- NULL [1, 0] +query ?? +select array_dims(arrow_cast(make_array(), 'LargeList(Null)')), array_dims(arrow_cast(make_array(make_array()), 'LargeList(List(Null))')) +---- +NULL [1, 0] + # list_dims scalar function #4 (function alias `array_dims`) query ??? select list_dims(make_array(1, 2, 3)), list_dims(make_array([1, 2], [3, 4])), list_dims(make_array([[[[1], [2]]]])); ---- [3] [2, 2] [1, 1, 1, 2, 1] +query ??? +select list_dims(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), list_dims(arrow_cast(make_array([1, 2], [3, 4]), 'LargeList(List(Int64))')), list_dims(arrow_cast(make_array([[[[1], [2]]]]), 'LargeList(List(List(List(List(Int64)))))')); +---- +[3] [2, 2] [1, 1, 1, 2, 1] + # array_dims with columns query ??? select array_dims(column1), array_dims(column2), array_dims(column3) from arrays; @@ -2141,171 +3704,808 @@ NULL [3] [4] [2, 2] NULL [1] [2, 2] [3] NULL -## array_ndims (aliases: `list_ndims`) +query ??? +select array_dims(column1), array_dims(column2), array_dims(column3) from large_arrays; +---- +[2, 2] [3] [5] +[2, 2] [3] [5] +[2, 2] [3] [5] +[2, 2] [3] [3] +NULL [3] [4] +[2, 2] NULL [1] +[2, 2] [3] NULL + + +## array_ndims (aliases: `list_ndims`) + +# array_ndims scalar function #1 + +query III +select + array_ndims(1), + array_ndims(null), + array_ndims([2, 3]); +---- +0 0 1 + +statement ok +CREATE TABLE array_ndims_table +AS VALUES + (1, [1, 2, 3], [[7]], [[[[[10]]]]]), + (2, [4, 5], [[8]], [[[[[10]]]]]), + (null, [6], [[9]], [[[[[10]]]]]), + (3, [6], [[9]], [[[[[10]]]]]) +; + +statement ok +CREATE TABLE large_array_ndims_table +AS SELECT + column1, + arrow_cast(column2, 'LargeList(Int64)') as column2, + arrow_cast(column3, 'LargeList(List(Int64))') as column3, + arrow_cast(column4, 'LargeList(List(List(List(List(Int64)))))') as column4 +FROM array_ndims_table; + +query IIII +select + array_ndims(column1), + array_ndims(column2), + array_ndims(column3), + array_ndims(column4) +from array_ndims_table; +---- +0 1 2 5 +0 1 2 5 +0 1 2 5 +0 1 2 5 + +query IIII +select + array_ndims(column1), + array_ndims(column2), + array_ndims(column3), + array_ndims(column4) +from large_array_ndims_table; +---- +0 1 2 5 +0 1 2 5 +0 1 2 5 +0 1 2 5 + +statement ok +drop table array_ndims_table; + +statement ok +drop table large_array_ndims_table + +query I +select array_ndims(arrow_cast([null], 'List(List(List(Int64)))')); +---- +3 + +# array_ndims scalar function #2 +query II +select array_ndims(array_repeat(array_repeat(array_repeat(1, 3), 2), 1)), array_ndims([[[[[[[[[[[[[[[[[[[[[1]]]]]]]]]]]]]]]]]]]]]); +---- +3 21 + +# array_ndims scalar function #3 +query II +select array_ndims(make_array()), array_ndims(make_array(make_array())) +---- +1 2 + +query II +select array_ndims(arrow_cast(make_array(), 'LargeList(Null)')), array_ndims(arrow_cast(make_array(make_array()), 'LargeList(List(Null))')) +---- +1 2 + +# list_ndims scalar function #4 (function alias `array_ndims`) +query III +select list_ndims(make_array(1, 2, 3)), list_ndims(make_array([1, 2], [3, 4])), list_ndims(make_array([[[[1], [2]]]])); +---- +1 2 5 + +query III +select list_ndims(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), list_ndims(arrow_cast(make_array([1, 2], [3, 4]), 'LargeList(List(Int64))')), list_ndims(arrow_cast(make_array([[[[1], [2]]]]), 'LargeList(List(List(List(List(Int64)))))')); +---- +1 2 5 + +query II +select list_ndims(make_array()), list_ndims(make_array(make_array())) +---- +1 2 + +query II +select list_ndims(arrow_cast(make_array(), 'LargeList(Null)')), list_ndims(arrow_cast(make_array(make_array()), 'LargeList(List(Null))')) +---- +1 2 + +# array_ndims with columns +query III +select array_ndims(column1), array_ndims(column2), array_ndims(column3) from arrays; +---- +2 1 1 +2 1 1 +2 1 1 +2 1 1 +NULL 1 1 +2 NULL 1 +2 1 NULL + +query III +select array_ndims(column1), array_ndims(column2), array_ndims(column3) from large_arrays; +---- +2 1 1 +2 1 1 +2 1 1 +2 1 1 +NULL 1 1 +2 NULL 1 +2 1 NULL + +## array_has/array_has_all/array_has_any + +query BBBBBBBBBBBB +select array_has(make_array(1,2), 1), + array_has(make_array(1,2,NULL), 1), + array_has(make_array([2,3], [3,4]), make_array(2,3)), + array_has(make_array([[1], [2,3]], [[4,5], [6]]), make_array([1], [2,3])), + array_has(make_array([[1], [2,3]], [[4,5], [6]]), make_array([4,5], [6])), + array_has(make_array([[1], [2,3]], [[4,5], [6]]), make_array([1])), + array_has(make_array([[[1]]]), make_array([[1]])), + array_has(make_array([[[1]]], [[[1], [2]]]), make_array([[2]])), + array_has(make_array([[[1]]], [[[1], [2]]]), make_array([[1], [2]])), + list_has(make_array(1,2,3), 4), + array_contains(make_array(1,2,3), 3), + list_contains(make_array(1,2,3), 0) +; +---- +true true true true true false true false true false true false + +query BBBBBBBBBBBB +select array_has(arrow_cast(make_array(1,2), 'LargeList(Int64)'), 1), + array_has(arrow_cast(make_array(1,2,NULL), 'LargeList(Int64)'), 1), + array_has(arrow_cast(make_array([2,3], [3,4]), 'LargeList(List(Int64))'), make_array(2,3)), + array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'LargeList(List(List(Int64)))'), make_array([1], [2,3])), + array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'LargeList(List(List(Int64)))'), make_array([4,5], [6])), + array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'LargeList(List(List(Int64)))'), make_array([1])), + array_has(arrow_cast(make_array([[[1]]]), 'LargeList(List(List(List(Int64))))'), make_array([[1]])), + array_has(arrow_cast(make_array([[[1]]], [[[1], [2]]]), 'LargeList(List(List(List(Int64))))'), make_array([[2]])), + array_has(arrow_cast(make_array([[[1]]], [[[1], [2]]]), 'LargeList(List(List(List(Int64))))'), make_array([[1], [2]])), + list_has(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), 4), + array_contains(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), 3), + list_contains(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), 0) +; +---- +true true true true true false true false true false true false + +query BBB +select array_has(column1, column2), + array_has_all(column3, column4), + array_has_any(column5, column6) +from array_has_table_1D; +---- +true true true +false false false + +query BBB +select array_has(arrow_cast(column1, 'LargeList(Int64)'), column2), + array_has_all(arrow_cast(column3, 'LargeList(Int64)'), arrow_cast(column4, 'LargeList(Int64)')), + array_has_any(arrow_cast(column5, 'LargeList(Int64)'), arrow_cast(column6, 'LargeList(Int64)')) +from array_has_table_1D; +---- +true true true +false false false + +query BBB +select array_has(column1, column2), + array_has_all(column3, column4), + array_has_any(column5, column6) +from array_has_table_1D_Float; +---- +true true false +false false true + +query BBB +select array_has(arrow_cast(column1, 'LargeList(Float64)'), column2), + array_has_all(arrow_cast(column3, 'LargeList(Float64)'), arrow_cast(column4, 'LargeList(Float64)')), + array_has_any(arrow_cast(column5, 'LargeList(Float64)'), arrow_cast(column6, 'LargeList(Float64)')) +from array_has_table_1D_Float; +---- +true true false +false false true + +query BBB +select array_has(column1, column2), + array_has_all(column3, column4), + array_has_any(column5, column6) +from array_has_table_1D_Boolean; +---- +false true true +true true true + +query BBB +select array_has(arrow_cast(column1, 'LargeList(Boolean)'), column2), + array_has_all(arrow_cast(column3, 'LargeList(Boolean)'), arrow_cast(column4, 'LargeList(Boolean)')), + array_has_any(arrow_cast(column5, 'LargeList(Boolean)'), arrow_cast(column6, 'LargeList(Boolean)')) +from array_has_table_1D_Boolean; +---- +false true true +true true true + +query BBB +select array_has(column1, column2), + array_has_all(column3, column4), + array_has_any(column5, column6) +from array_has_table_1D_UTF8; +---- +true true false +false false true + +query BBB +select array_has(arrow_cast(column1, 'LargeList(Utf8)'), column2), + array_has_all(arrow_cast(column3, 'LargeList(Utf8)'), arrow_cast(column4, 'LargeList(Utf8)')), + array_has_any(arrow_cast(column5, 'LargeList(Utf8)'), arrow_cast(column6, 'LargeList(Utf8)')) +from array_has_table_1D_UTF8; +---- +true true false +false false true + +query BB +select array_has(column1, column2), + array_has_all(column3, column4) +from array_has_table_2D; +---- +false true +true false + +query BB +select array_has(arrow_cast(column1, 'LargeList(List(Int64))'), column2), + array_has_all(arrow_cast(column3, 'LargeList(List(Int64))'), arrow_cast(column4, 'LargeList(List(Int64))')) +from array_has_table_2D; +---- +false true +true false + +query B +select array_has_all(column1, column2) +from array_has_table_2D_float; +---- +true +false + +query B +select array_has_all(arrow_cast(column1, 'LargeList(List(Float64))'), arrow_cast(column2, 'LargeList(List(Float64))')) +from array_has_table_2D_float; +---- +true +false + +query B +select array_has(column1, column2) from array_has_table_3D; +---- +false +true +false +false +true +false +true + +query B +select array_has(arrow_cast(column1, 'LargeList(List(List(Int64)))'), column2) from array_has_table_3D; +---- +false +true +false +false +true +false +true + +query BBBB +select array_has(column1, make_array(5, 6)), + array_has(column1, make_array(7, NULL)), + array_has(column2, 5.5), + array_has(column3, 'o') +from arrays; +---- +false false false true +true false true false +true false false true +false true false false +false false false false +false false false false + +query BBBB +select array_has(arrow_cast(column1, 'LargeList(List(Int64))'), make_array(5, 6)), + array_has(arrow_cast(column1, 'LargeList(List(Int64))'), make_array(7, NULL)), + array_has(arrow_cast(column2, 'LargeList(Float64)'), 5.5), + array_has(arrow_cast(column3, 'LargeList(Utf8)'), 'o') +from arrays; +---- +false false false true +true false true false +true false false true +false true false false +false false false false +false false false false + +query BBBBBBBBBBBBB +select array_has_all(make_array(1,2,3), make_array(1,3)), + array_has_all(make_array(1,2,3), make_array(1,4)), + array_has_all(make_array([1,2], [3,4]), make_array([1,2])), + array_has_all(make_array([1,2], [3,4]), make_array([1,3])), + array_has_all(make_array([1,2], [3,4]), make_array([1,2], [3,4], [5,6])), + array_has_all(make_array([[1,2,3]]), make_array([[1]])), + array_has_all(make_array([[1,2,3]]), make_array([[1,2,3]])), + array_has_any(make_array(1,2,3), make_array(1,10,100)), + array_has_any(make_array(1,2,3), make_array(10,100)), + array_has_any(make_array([1,2], [3,4]), make_array([1,10], [10,4])), + array_has_any(make_array([1,2], [3,4]), make_array([10,20], [3,4])), + array_has_any(make_array([[1,2,3]]), make_array([[1,2,3], [4,5,6]])), + array_has_any(make_array([[1,2,3]]), make_array([[1,2,3]], [[4,5,6]])) +; +---- +true false true false false false true true false false true false true + +query BBBBBBBBBBBBB +select array_has_all(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_cast(make_array(1,3), 'LargeList(Int64)')), + array_has_all(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,4), 'LargeList(Int64)')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,3]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2], [3,4], [5,6]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1]]), 'LargeList(List(List(Int64)))')), + array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))')), + array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,10,100), 'LargeList(Int64)')), + array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(10,100),'LargeList(Int64)')), + array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,10], [10,4]), 'LargeList(List(Int64))')), + array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([10,20], [3,4]), 'LargeList(List(Int64))')), + array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3], [4,5,6]]), 'LargeList(List(List(Int64)))')), + array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]], [[4,5,6]]), 'LargeList(List(List(Int64)))')) +; +---- +true false true false false false true true false false true false true + +query BBBBBBBBBBBBB +select array_has_all(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_cast(make_array(1,3), 'LargeList(Int64)')), + array_has_all(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,4), 'LargeList(Int64)')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,3]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2], [3,4], [5,6]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1]]), 'LargeList(List(List(Int64)))')), + array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))')), + array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,10,100), 'LargeList(Int64)')), + array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(10,100),'LargeList(Int64)')), + array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,10], [10,4]), 'LargeList(List(Int64))')), + array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([10,20], [3,4]), 'LargeList(List(Int64))')), + array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3], [4,5,6]]), 'LargeList(List(List(Int64)))')), + array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]], [[4,5,6]]), 'LargeList(List(List(Int64)))')) +; +---- +true false true false false false true true false false true false true + +## array_distinct + +query ? +select array_distinct(null); +---- +NULL + +query ? +select array_distinct([]); +---- +[] + +query ? +select array_distinct([[], []]); +---- +[[]] + +query ? +select array_distinct(column1) +from array_distinct_table_1D; +---- +[1, 2, 3] +[1, 2, 3, 4, 5] +[3, 5] + +query ? +select array_distinct(column1) +from array_distinct_table_1D_UTF8; +---- +[a, bc, def] +[a, bc, def, defg] +[defg] + +query ? +select array_distinct(column1) +from array_distinct_table_2D; +---- +[[1, 2], [3, 4], [5, 6]] +[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] +[, [5, 6]] + +query ? +select array_distinct(column1) +from array_distinct_table_1D_large; +---- +[1, 2, 3] +[1, 2, 3, 4, 5] +[3, 5] + +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from array_intersect_table_1D; +---- +[1] [1, 3] [1, 3] +[11] [11, 33] [11, 33] + +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from large_array_intersect_table_1D; +---- +[1] [1, 3] [1, 3] +[11] [11, 33] [11, 33] + +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from array_intersect_table_1D_Float; +---- +[1.0] [1.0, 3.0] [] +[] [2.0] [1.11] + +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from array_intersect_table_1D_Boolean; +---- +[] [false, true] [false] +[false] [true] [true] + +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from large_array_intersect_table_1D_Boolean; +---- +[] [false, true] [false] +[false] [true] [true] + +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from array_intersect_table_1D_UTF8; +---- +[bc] [arrow, rust] [] +[] [arrow, datafusion, rust] [arrow, rust] + +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from large_array_intersect_table_1D_UTF8; +---- +[bc] [arrow, rust] [] +[] [arrow, datafusion, rust] [arrow, rust] + +query ?? +select array_intersect(column1, column2), + array_intersect(column3, column4) +from array_intersect_table_2D; +---- +[] [[4, 5], [6, 7]] +[[3, 4]] [[5, 6, 7], [8, 9, 10]] + +query ?? +select array_intersect(column1, column2), + array_intersect(column3, column4) +from large_array_intersect_table_2D; +---- +[] [[4, 5], [6, 7]] +[[3, 4]] [[5, 6, 7], [8, 9, 10]] + + +query ? +select array_intersect(column1, column2) +from array_intersect_table_2D_float; +---- +[[1.1, 2.2], [3.3]] +[[1.1, 2.2], [3.3]] + +query ? +select array_intersect(column1, column2) +from large_array_intersect_table_2D_float; +---- +[[1.1, 2.2], [3.3]] +[[1.1, 2.2], [3.3]] + +query ? +select array_intersect(column1, column2) +from array_intersect_table_3D; +---- +[] +[[[1, 2]]] + +query ? +select array_intersect(column1, column2) +from large_array_intersect_table_3D; +---- +[] +[[[1, 2]]] + +query ?????? +SELECT array_intersect(make_array(1,2,3), make_array(2,3,4)), + array_intersect(make_array(1,3,5), make_array(2,4,6)), + array_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')), + array_intersect(make_array(true, false), make_array(true)), + array_intersect(make_array(1.1, 2.2, 3.3), make_array(2.2, 3.3, 4.4)), + array_intersect(make_array([1, 1], [2, 2], [3, 3]), make_array([2, 2], [3, 3], [4, 4])) +; +---- +[2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]] + +query ?????? +SELECT array_intersect(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_cast(make_array(2,3,4), 'LargeList(Int64)')), + array_intersect(arrow_cast(make_array(1,3,5), 'LargeList(Int64)'), arrow_cast(make_array(2,4,6), 'LargeList(Int64)')), + array_intersect(arrow_cast(make_array('aa','bb','cc'), 'LargeList(Utf8)'), arrow_cast(make_array('cc','aa','dd'), 'LargeList(Utf8)')), + array_intersect(arrow_cast(make_array(true, false), 'LargeList(Boolean)'), arrow_cast(make_array(true), 'LargeList(Boolean)')), + array_intersect(arrow_cast(make_array(1.1, 2.2, 3.3), 'LargeList(Float64)'), arrow_cast(make_array(2.2, 3.3, 4.4), 'LargeList(Float64)')), + array_intersect(arrow_cast(make_array([1, 1], [2, 2], [3, 3]), 'LargeList(List(Int64))'), arrow_cast(make_array([2, 2], [3, 3], [4, 4]), 'LargeList(List(Int64))')) +; +---- +[2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]] + +query ? +select array_intersect([], []); +---- +[] + +query ? +select array_intersect(arrow_cast([], 'LargeList(Null)'), arrow_cast([], 'LargeList(Null)')); +---- +[] + +query ? +select array_intersect([1, 1, 2, 2, 3, 3], null); +---- +[] + +query ? +select array_intersect(arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)'), null); +---- +[] + +query ? +select array_intersect(null, [1, 1, 2, 2, 3, 3]); +---- +NULL + +query ? +select array_intersect(null, arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)')); +---- +NULL + +query ? +select array_intersect([], null); +---- +[] + +query ? +select array_intersect(arrow_cast([], 'LargeList(Null)'), null); +---- +[] + +query ? +select array_intersect(null, []); +---- +NULL + +query ? +select array_intersect(null, arrow_cast([], 'LargeList(Null)')); +---- +NULL -# array_ndims scalar function #1 -query III -select array_ndims(make_array(1, 2, 3)), array_ndims(make_array([1, 2], [3, 4])), array_ndims(make_array([[[[1], [2]]]])); +query ? +select array_intersect(null, null); ---- -1 2 5 +NULL -# array_ndims scalar function #2 -query II -select array_ndims(array_repeat(array_repeat(array_repeat(1, 3), 2), 1)), array_ndims([[[[[[[[[[[[[[[[[[[[[1]]]]]]]]]]]]]]]]]]]]]); +query ?????? +SELECT list_intersect(make_array(1,2,3), make_array(2,3,4)), + list_intersect(make_array(1,3,5), make_array(2,4,6)), + list_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')), + list_intersect(make_array(true, false), make_array(true)), + list_intersect(make_array(1.1, 2.2, 3.3), make_array(2.2, 3.3, 4.4)), + list_intersect(make_array([1, 1], [2, 2], [3, 3]), make_array([2, 2], [3, 3], [4, 4])) +; ---- -3 21 +[2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]] -# array_ndims scalar function #3 -query II -select array_ndims(make_array()), array_ndims(make_array(make_array())) +query ?????? +SELECT list_intersect(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_cast(make_array(2,3,4), 'LargeList(Int64)')), + list_intersect(arrow_cast(make_array(1,3,5), 'LargeList(Int64)'), arrow_cast(make_array(2,4,6), 'LargeList(Int64)')), + list_intersect(arrow_cast(make_array('aa','bb','cc'), 'LargeList(Utf8)'), arrow_cast(make_array('cc','aa','dd'), 'LargeList(Utf8)')), + list_intersect(arrow_cast(make_array(true, false), 'LargeList(Boolean)'), arrow_cast(make_array(true), 'LargeList(Boolean)')), + list_intersect(arrow_cast(make_array(1.1, 2.2, 3.3), 'LargeList(Float64)'), arrow_cast(make_array(2.2, 3.3, 4.4), 'LargeList(Float64)')), + list_intersect(arrow_cast(make_array([1, 1], [2, 2], [3, 3]), 'LargeList(List(Int64))'), arrow_cast(make_array([2, 2], [3, 3], [4, 4]), 'LargeList(List(Int64))')) +; ---- -NULL 2 +[2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]] -# list_ndims scalar function #4 (function alias `array_ndims`) -query III -select list_ndims(make_array(1, 2, 3)), list_ndims(make_array([1, 2], [3, 4])), list_ndims(make_array([[[[1], [2]]]])); +query BBBB +select list_has_all(make_array(1,2,3), make_array(4,5,6)), + list_has_all(make_array(1,2,3), make_array(1,2)), + list_has_any(make_array(1,2,3), make_array(4,5,6)), + list_has_any(make_array(1,2,3), make_array(1,2,4)) +; ---- -1 2 5 +false true false true -query II -select array_ndims(make_array()), array_ndims(make_array(make_array())) +query ??? +select range(column2), + range(column1, column2), + range(column1, column2, column3) +from arrays_range; +---- +[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] [3, 4, 5, 6, 7, 8, 9] [3, 5, 7, 9] +[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] [4, 5, 6, 7, 8, 9, 10, 11, 12] [4, 7, 10] + +query ?????? +select range(5), + range(2, 5), + range(2, 10, 3), + range(1, 5, -1), + range(1, -5, 1), + range(1, -5, -1) +; ---- -NULL 2 +[0, 1, 2, 3, 4] [2, 3, 4] [2, 5, 8] [] [] [1, 0, -1, -2, -3, -4] -# array_ndims with columns -query III -select array_ndims(column1), array_ndims(column2), array_ndims(column3) from arrays; +query ??? +select generate_series(5), + generate_series(2, 5), + generate_series(2, 10, 3) +; ---- -2 1 1 -2 1 1 -2 1 1 -2 1 1 -NULL 1 1 -2 NULL 1 -2 1 NULL +[0, 1, 2, 3, 4] [2, 3, 4] [2, 5, 8] -## array_has/array_has_all/array_has_any +## array_except -query BBBBBBBBBBBB -select array_has(make_array(1,2), 1), - array_has(make_array(1,2,NULL), 1), - array_has(make_array([2,3], [3,4]), make_array(2,3)), - array_has(make_array([[1], [2,3]], [[4,5], [6]]), make_array([1], [2,3])), - array_has(make_array([[1], [2,3]], [[4,5], [6]]), make_array([4,5], [6])), - array_has(make_array([[1], [2,3]], [[4,5], [6]]), make_array([1])), - array_has(make_array([[[1]]]), make_array([[1]])), - array_has(make_array([[[1]]], [[[1], [2]]]), make_array([[2]])), - array_has(make_array([[[1]]], [[[1], [2]]]), make_array([[1], [2]])), - list_has(make_array(1,2,3), 4), - array_contains(make_array(1,2,3), 3), - list_contains(make_array(1,2,3), 0) +statement ok +CREATE TABLE array_except_table +AS VALUES + ([1, 2, 2, 3], [2, 3, 4]), + ([2, 3, 3], [3]), + ([3], [3, 3, 4]), + (null, [3, 4]), + ([1, 2], null), + (null, null) ; ----- -true true true true true false true false true false true false -query BBB -select array_has(column1, column2), - array_has_all(column3, column4), - array_has_any(column5, column6) -from array_has_table_1D; +query ? +select array_except(column1, column2) from array_except_table; ---- -true true true -false false false +[1] +[2] +[] +NULL +[1, 2] +NULL -query BBB -select array_has(column1, column2), - array_has_all(column3, column4), - array_has_any(column5, column6) -from array_has_table_1D_Float; ----- -true true false -false false true +statement ok +drop table array_except_table; -query BBB -select array_has(column1, column2), - array_has_all(column3, column4), - array_has_any(column5, column6) -from array_has_table_1D_Boolean; ----- -false true true -true true true +statement ok +CREATE TABLE array_except_nested_list_table +AS VALUES + ([[1, 2], [3]], [[2], [3], [4, 5]]), + ([[1, 2], [3]], [[2], [1, 2]]), + ([[1, 2], [3]], null), + (null, [[1], [2, 3], [4, 5, 6]]), + ([[1], [2, 3], [4, 5, 6]], [[2, 3], [4, 5, 6], [1]]) +; -query BBB -select array_has(column1, column2), - array_has_all(column3, column4), - array_has_any(column5, column6) -from array_has_table_1D_UTF8; +query ? +select array_except(column1, column2) from array_except_nested_list_table; ---- -true true false -false false true +[[1, 2]] +[[3]] +[[1, 2], [3]] +NULL +[] -query BB -select array_has(column1, column2), - array_has_all(column3, column4) -from array_has_table_2D; +statement ok +drop table array_except_nested_list_table; + +statement ok +CREATE TABLE array_except_table_float +AS VALUES + ([1.1, 2.2, 3.3], [2.2]), + ([1.1, 2.2, 3.3], [4.4]), + ([1.1, 2.2, 3.3], [3.3, 2.2, 1.1]) +; + +query ? +select array_except(column1, column2) from array_except_table_float; ---- -false true -true false +[1.1, 3.3] +[1.1, 2.2, 3.3] +[] -query B -select array_has_all(column1, column2) -from array_has_table_2D_float; +statement ok +drop table array_except_table_float; + +statement ok +CREATE TABLE array_except_table_ut8 +AS VALUES + (['a', 'b', 'c'], ['a']), + (['a', 'bc', 'def'], ['g', 'def']), + (['a', 'bc', 'def'], null), + (null, ['a']) +; + +query ? +select array_except(column1, column2) from array_except_table_ut8; ---- -true -false +[b, c] +[a, bc] +[a, bc, def] +NULL -query B -select array_has(column1, column2) from array_has_table_3D; +statement ok +drop table array_except_table_ut8; + +statement ok +CREATE TABLE array_except_table_bool +AS VALUES + ([true, false, false], [false]), + ([true, true, true], [false]), + ([false, false, false], [true]), + ([true, false], null), + (null, [true, false]) +; + +query ? +select array_except(column1, column2) from array_except_table_bool; ---- -false -true -false -false -true -false -true +[true] +[true] +[false] +[true, false] +NULL -query BBBB -select array_has(column1, make_array(5, 6)), - array_has(column1, make_array(7, NULL)), - array_has(column2, 5.5), - array_has(column3, 'o') -from arrays; +statement ok +drop table array_except_table_bool; + +query ? +select array_except([], null); ---- -false false false true -true false true false -true false false true -false true false false -false false false false -false false false false +[] -query BBBBBBBBBBBBB -select array_has_all(make_array(1,2,3), make_array(1,3)), - array_has_all(make_array(1,2,3), make_array(1,4)), - array_has_all(make_array([1,2], [3,4]), make_array([1,2])), - array_has_all(make_array([1,2], [3,4]), make_array([1,3])), - array_has_all(make_array([1,2], [3,4]), make_array([1,2], [3,4], [5,6])), - array_has_all(make_array([[1,2,3]]), make_array([[1]])), - array_has_all(make_array([[1,2,3]]), make_array([[1,2,3]])), - array_has_any(make_array(1,2,3), make_array(1,10,100)), - array_has_any(make_array(1,2,3), make_array(10,100)), - array_has_any(make_array([1,2], [3,4]), make_array([1,10], [10,4])), - array_has_any(make_array([1,2], [3,4]), make_array([10,20], [3,4])), - array_has_any(make_array([[1,2,3]]), make_array([[1,2,3], [4,5,6]])), - array_has_any(make_array([[1,2,3]]), make_array([[1,2,3]], [[4,5,6]])) -; +query ? +select array_except([], []); ---- -true false true false false false true true false false true false true +[] -query BBBB -select list_has_all(make_array(1,2,3), make_array(4,5,6)), - list_has_all(make_array(1,2,3), make_array(1,2)), - list_has_any(make_array(1,2,3), make_array(4,5,6)), - list_has_any(make_array(1,2,3), make_array(1,2,4)) -; +query ? +select array_except(null, []); ---- -false true false true +NULL +query ? +select array_except(null, null) +---- +NULL ### Array operators tests @@ -2466,18 +4666,33 @@ select empty(make_array(1)); ---- false +query B +select empty(arrow_cast(make_array(1), 'LargeList(Int64)')); +---- +false + # empty scalar function #2 query B select empty(make_array()); ---- true +query B +select empty(arrow_cast(make_array(), 'LargeList(Null)')); +---- +true + # empty scalar function #3 query B select empty(make_array(NULL)); ---- false +query B +select empty(arrow_cast(make_array(NULL), 'LargeList(Null)')); +---- +false + # empty scalar function #4 query B select empty(NULL); @@ -2496,6 +4711,17 @@ NULL false false +query B +select empty(arrow_cast(column1, 'LargeList(List(Int64))')) from arrays; +---- +false +false +false +false +NULL +false +false + query ? SELECT string_to_array('abcxxxdef', 'xxx') ---- @@ -2556,6 +4782,9 @@ drop table nested_arrays; statement ok drop table arrays; +statement ok +drop table large_arrays; + statement ok drop table slices; @@ -2589,14 +4818,65 @@ drop table array_has_table_2D_float; statement ok drop table array_has_table_3D; +statement ok +drop table array_intersect_table_1D; + +statement ok +drop table large_array_intersect_table_1D; + +statement ok +drop table array_intersect_table_1D_Float; + +statement ok +drop table large_array_intersect_table_1D_Float; + +statement ok +drop table array_intersect_table_1D_Boolean; + +statement ok +drop table large_array_intersect_table_1D_Boolean; + +statement ok +drop table array_intersect_table_1D_UTF8; + +statement ok +drop table large_array_intersect_table_1D_UTF8; + +statement ok +drop table array_intersect_table_2D; + +statement ok +drop table large_array_intersect_table_2D; + +statement ok +drop table array_intersect_table_2D_float; + +statement ok +drop table large_array_intersect_table_2D_float; + +statement ok +drop table array_intersect_table_3D; + +statement ok +drop table large_array_intersect_table_3D; + statement ok drop table arrays_values_without_nulls; +statement ok +drop table arrays_range; + statement ok drop table arrays_with_repeating_elements; +statement ok +drop table large_arrays_with_repeating_elements; + statement ok drop table nested_arrays_with_repeating_elements; +statement ok +drop table large_nested_arrays_with_repeating_elements; + statement ok drop table flatten_table; diff --git a/datafusion/sqllogictest/test_files/arrow_files.slt b/datafusion/sqllogictest/test_files/arrow_files.slt new file mode 100644 index 0000000000000..5c1b6fb726ed7 --- /dev/null +++ b/datafusion/sqllogictest/test_files/arrow_files.slt @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +############# +## Arrow Files Format support +############# + + +statement ok + +CREATE EXTERNAL TABLE arrow_simple +STORED AS ARROW +LOCATION '../core/tests/data/example.arrow'; + + +# physical plan +query TT +EXPLAIN SELECT * FROM arrow_simple +---- +logical_plan TableScan: arrow_simple projection=[f0, f1, f2] +physical_plan ArrowExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.arrow]]}, projection=[f0, f1, f2] + +# correct content +query ITB +SELECT * FROM arrow_simple +---- +1 foo true +2 bar NULL +3 baz false +4 NULL true diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt index e485251b73421..3fad4d0f61b98 100644 --- a/datafusion/sqllogictest/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -338,3 +338,41 @@ select arrow_cast(timestamp '2000-01-01T00:00:00Z', 'Timestamp(Nanosecond, Some( statement error Arrow error: Parser error: Invalid timezone "\+25:00": '\+25:00' is not a valid timezone select arrow_cast(timestamp '2000-01-01T00:00:00', 'Timestamp(Nanosecond, Some( "+25:00" ))'); + + +## List + + +query ? +select arrow_cast('1', 'List(Int64)'); +---- +[1] + +query ? +select arrow_cast(make_array(1, 2, 3), 'List(Int64)'); +---- +[1, 2, 3] + +query T +select arrow_typeof(arrow_cast(make_array(1, 2, 3), 'List(Int64)')); +---- +List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + + +## LargeList + + +query ? +select arrow_cast('1', 'LargeList(Int64)'); +---- +[1] + +query ? +select arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'); +---- +[1, 2, 3] + +query T +select arrow_typeof(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')); +---- +LargeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/copy.slt b/datafusion/sqllogictest/test_files/copy.slt index f2fe216ee864f..89b23917884c6 100644 --- a/datafusion/sqllogictest/test_files/copy.slt +++ b/datafusion/sqllogictest/test_files/copy.slt @@ -32,7 +32,7 @@ logical_plan CopyTo: format=parquet output_url=test_files/scratch/copy/table single_file_output=false options: (compression 'zstd(10)') --TableScan: source_table projection=[col1, col2] physical_plan -InsertExec: sink=ParquetSink(writer_mode=PutMultipart, file_groups=[]) +FileSinkExec: sink=ParquetSink(file_groups=[]) --MemoryExec: partitions=1, partition_sizes=[1] # Error case @@ -66,8 +66,8 @@ select * from validate_parquet; # Copy parquet with all supported statment overrides query IT -COPY source_table -TO 'test_files/scratch/copy/table_with_options' +COPY source_table +TO 'test_files/scratch/copy/table_with_options' (format parquet, single_file_output false, compression snappy, @@ -206,11 +206,11 @@ select * from validate_single_json; # COPY csv files with all options set query IT -COPY source_table -to 'test_files/scratch/copy/table_csv_with_options' -(format csv, -single_file_output false, -header false, +COPY source_table +to 'test_files/scratch/copy/table_csv_with_options' +(format csv, +single_file_output false, +header false, compression 'uncompressed', datetime_format '%FT%H:%M:%S.%9f', delimiter ';', @@ -220,8 +220,8 @@ null_value 'NULLVAL'); # Validate single csv output statement ok -CREATE EXTERNAL TABLE validate_csv_with_options -STORED AS csv +CREATE EXTERNAL TABLE validate_csv_with_options +STORED AS csv LOCATION 'test_files/scratch/copy/table_csv_with_options'; query T @@ -230,6 +230,62 @@ select * from validate_csv_with_options; 1;Foo 2;Bar +# Copy from table to single arrow file +query IT +COPY source_table to 'test_files/scratch/copy/table.arrow'; +---- +2 + +# Validate single csv output +statement ok +CREATE EXTERNAL TABLE validate_arrow_file +STORED AS arrow +LOCATION 'test_files/scratch/copy/table.arrow'; + +query IT +select * from validate_arrow_file; +---- +1 Foo +2 Bar + +# Copy from dict encoded values to single arrow file +query T? +COPY (values +('c', arrow_cast('foo', 'Dictionary(Int32, Utf8)')), ('d', arrow_cast('bar', 'Dictionary(Int32, Utf8)'))) +to 'test_files/scratch/copy/table_dict.arrow'; +---- +2 + +# Validate single csv output +statement ok +CREATE EXTERNAL TABLE validate_arrow_file_dict +STORED AS arrow +LOCATION 'test_files/scratch/copy/table_dict.arrow'; + +query T? +select * from validate_arrow_file_dict; +---- +c foo +d bar + + +# Copy from table to folder of json +query IT +COPY source_table to 'test_files/scratch/copy/table_arrow' (format arrow, single_file_output false); +---- +2 + +# Validate json output +statement ok +CREATE EXTERNAL TABLE validate_arrow STORED AS arrow LOCATION 'test_files/scratch/copy/table_arrow'; + +query IT +select * from validate_arrow; +---- +1 Foo +2 Bar + + # Error cases: # Copy from table with options diff --git a/datafusion/sqllogictest/test_files/csv_files.slt b/datafusion/sqllogictest/test_files/csv_files.slt new file mode 100644 index 0000000000000..9facb064bf32a --- /dev/null +++ b/datafusion/sqllogictest/test_files/csv_files.slt @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# create_external_table_with_quote_escape +statement ok +CREATE EXTERNAL TABLE csv_with_quote ( +c1 VARCHAR, +c2 VARCHAR +) STORED AS CSV +WITH HEADER ROW +DELIMITER ',' +OPTIONS ('quote' '~') +LOCATION '../core/tests/data/quote.csv'; + +statement ok +CREATE EXTERNAL TABLE csv_with_escape ( +c1 VARCHAR, +c2 VARCHAR +) STORED AS CSV +WITH HEADER ROW +DELIMITER ',' +OPTIONS ('escape' '\"') +LOCATION '../core/tests/data/escape.csv'; + +query TT +select * from csv_with_quote; +---- +id0 value0 +id1 value1 +id2 value2 +id3 value3 +id4 value4 +id5 value5 +id6 value6 +id7 value7 +id8 value8 +id9 value9 + +query TT +select * from csv_with_escape; +---- +id0 value"0 +id1 value"1 +id2 value"2 +id3 value"3 +id4 value"4 +id5 value"5 +id6 value"6 +id7 value"7 +id8 value"8 +id9 value"9 diff --git a/datafusion/sqllogictest/test_files/ddl.slt b/datafusion/sqllogictest/test_files/ddl.slt index ed4f4b4a11ac1..682972b5572a9 100644 --- a/datafusion/sqllogictest/test_files/ddl.slt +++ b/datafusion/sqllogictest/test_files/ddl.slt @@ -750,7 +750,7 @@ query TT explain select c1 from t; ---- logical_plan TableScan: t projection=[c1] -physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/empty.csv]]}, projection=[c1], infinite_source=true, has_header=true +physical_plan StreamingTableExec: partition_sizes=1, projection=[c1], infinite_source=true statement ok drop table t; diff --git a/datafusion/sqllogictest/test_files/describe.slt b/datafusion/sqllogictest/test_files/describe.slt index 007aec443cbc9..f94a2e453884f 100644 --- a/datafusion/sqllogictest/test_files/describe.slt +++ b/datafusion/sqllogictest/test_files/describe.slt @@ -62,3 +62,27 @@ DROP TABLE aggregate_simple; statement error Error during planning: table 'datafusion.public.../core/tests/data/aggregate_simple.csv' not found DESCRIBE '../core/tests/data/aggregate_simple.csv'; + +########## +# Describe command +########## + +statement ok +CREATE EXTERNAL TABLE alltypes_tiny_pages STORED AS PARQUET LOCATION '../../parquet-testing/data/alltypes_tiny_pages.parquet'; + +query TTT +describe alltypes_tiny_pages; +---- +id Int32 YES +bool_col Boolean YES +tinyint_col Int8 YES +smallint_col Int16 YES +int_col Int32 YES +bigint_col Int64 YES +float_col Float32 YES +double_col Float64 YES +date_string_col Utf8 YES +string_col Utf8 YES +timestamp_col Timestamp(Nanosecond, None) YES +year Int32 YES +month Int32 YES diff --git a/datafusion/sqllogictest/test_files/distinct_on.slt b/datafusion/sqllogictest/test_files/distinct_on.slt new file mode 100644 index 0000000000000..3f609e2548398 --- /dev/null +++ b/datafusion/sqllogictest/test_files/distinct_on.slt @@ -0,0 +1,145 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +CREATE EXTERNAL TABLE aggregate_test_100 ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 BIGINT UNSIGNED NOT NULL, + c10 VARCHAR NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../../testing/data/csv/aggregate_test_100.csv' + +# Basic example: distinct on the first column project the second one, and +# order by the third +query TI +SELECT DISTINCT ON (c1) c1, c2 FROM aggregate_test_100 ORDER BY c1, c3, c9; +---- +a 4 +b 4 +c 2 +d 1 +e 3 + +# Basic example + reverse order of the selected column +query TI +SELECT DISTINCT ON (c1) c1, c2 FROM aggregate_test_100 ORDER BY c1, c3 DESC, c9; +---- +a 1 +b 5 +c 4 +d 1 +e 1 + +# Basic example + reverse order of the ON column +query TI +SELECT DISTINCT ON (c1) c1, c2 FROM aggregate_test_100 ORDER BY c1 DESC, c3, c9; +---- +e 3 +d 1 +c 2 +b 4 +a 4 + +# Basic example + reverse order of both columns + limit +query TI +SELECT DISTINCT ON (c1) c1, c2 FROM aggregate_test_100 ORDER BY c1 DESC, c3 DESC LIMIT 3; +---- +e 1 +d 1 +c 4 + +# Basic example + omit ON column from selection +query I +SELECT DISTINCT ON (c1) c2 FROM aggregate_test_100 ORDER BY c1, c3; +---- +4 +4 +2 +1 +3 + +# Test explain makes sense +query TT +EXPLAIN SELECT DISTINCT ON (c1) c3, c2 FROM aggregate_test_100 ORDER BY c1, c3; +---- +logical_plan +Projection: FIRST_VALUE(aggregate_test_100.c3) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST] AS c3, FIRST_VALUE(aggregate_test_100.c2) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST] AS c2 +--Sort: aggregate_test_100.c1 ASC NULLS LAST +----Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[FIRST_VALUE(aggregate_test_100.c3) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST], FIRST_VALUE(aggregate_test_100.c2) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST]]] +------TableScan: aggregate_test_100 projection=[c1, c2, c3] +physical_plan +ProjectionExec: expr=[FIRST_VALUE(aggregate_test_100.c3) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST]@1 as c3, FIRST_VALUE(aggregate_test_100.c2) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST]@2 as c2] +--SortPreservingMergeExec: [c1@0 ASC NULLS LAST] +----SortExec: expr=[c1@0 ASC NULLS LAST] +------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[FIRST_VALUE(aggregate_test_100.c3), FIRST_VALUE(aggregate_test_100.c2)] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([c1@0], 4), input_partitions=4 +------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[FIRST_VALUE(aggregate_test_100.c3), FIRST_VALUE(aggregate_test_100.c2)] +--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3], has_header=true + +# ON expressions are not a sub-set of the ORDER BY expressions +query error SELECT DISTINCT ON expressions must match initial ORDER BY expressions +SELECT DISTINCT ON (c2 % 2 = 0) c2, c3 - 100 FROM aggregate_test_100 ORDER BY c2, c3; + +# ON expressions are empty +query error DataFusion error: Error during planning: No `ON` expressions provided +SELECT DISTINCT ON () c1, c2 FROM aggregate_test_100 ORDER BY c1, c2; + +# Use expressions in the ON and ORDER BY clauses, as well as the selection +query II +SELECT DISTINCT ON (c2 % 2 = 0) c2, c3 - 100 FROM aggregate_test_100 ORDER BY c2 % 2 = 0, c3 DESC; +---- +1 25 +4 23 + +# Multiple complex expressions +query TIB +SELECT DISTINCT ON (chr(ascii(c1) + 3), c2 % 2) chr(ascii(upper(c1)) + 3), c2 % 2, c3 > 80 AND c2 % 2 = 1 +FROM aggregate_test_100 +WHERE c1 IN ('a', 'b') +ORDER BY chr(ascii(c1) + 3), c2 % 2, c3 DESC; +---- +D 0 false +D 1 true +E 0 false +E 1 false + +# Joins using CTEs +query II +WITH t1 AS (SELECT * FROM aggregate_test_100), +t2 AS (SELECT * FROM aggregate_test_100) +SELECT DISTINCT ON (t1.c1, t2.c2) t2.c3, t1.c4 +FROM t1 INNER JOIN t2 ON t1.c13 = t2.c13 +ORDER BY t1.c1, t2.c2, t2.c5 +LIMIT 3; +---- +-25 15295 +45 15673 +-72 -11122 diff --git a/datafusion/sqllogictest/test_files/errors.slt b/datafusion/sqllogictest/test_files/errors.slt index 4aded8a576fb0..e3b2610e51be3 100644 --- a/datafusion/sqllogictest/test_files/errors.slt +++ b/datafusion/sqllogictest/test_files/errors.slt @@ -133,4 +133,4 @@ order by c9 statement error Inconsistent data type across values list at row 1 column 0. Was Int64 but found Utf8 -create table foo as values (1), ('foo'); \ No newline at end of file +create table foo as values (1), ('foo'); diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 40a6d43574881..4583ef319b7fc 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -94,7 +94,7 @@ EXPLAIN select count(*) from (values ('a', 1, 100), ('a', 2, 150)) as t (c1,c2,c ---- physical_plan ProjectionExec: expr=[2 as COUNT(*)] ---EmptyExec: produce_one_row=true +--PlaceholderRowExec statement ok set datafusion.explain.physical_plan_only = false @@ -140,7 +140,7 @@ physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/te # create a sink table, path is same with aggregate_test_100 table # we do not overwrite this file, we only assert plan. statement ok -CREATE EXTERNAL TABLE sink_table ( +CREATE UNBOUNDED EXTERNAL TABLE sink_table ( c1 VARCHAR NOT NULL, c2 TINYINT NOT NULL, c3 SMALLINT NOT NULL, @@ -168,10 +168,9 @@ Dml: op=[Insert Into] table=[sink_table] ----Sort: aggregate_test_100.c1 ASC NULLS LAST ------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13] physical_plan -InsertExec: sink=CsvSink(writer_mode=Append, file_groups=[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]) ---ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, c5@4 as c5, c6@5 as c6, c7@6 as c7, c8@7 as c8, c9@8 as c9, c10@9 as c10, c11@10 as c11, c12@11 as c12, c13@12 as c13] -----SortExec: expr=[c1@0 ASC NULLS LAST] -------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], has_header=true +FileSinkExec: sink=StreamWrite { location: "../../testing/data/csv/aggregate_test_100.csv", batch_size: 8192, encoding: Csv, header: true, .. } +--SortExec: expr=[c1@0 ASC NULLS LAST] +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], has_header=true # test EXPLAIN VERBOSE query TT @@ -193,7 +192,6 @@ logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE -logical_plan after merge_projection SAME TEXT AS ABOVE logical_plan after rewrite_disjunctive_predicate SAME TEXT AS ABOVE logical_plan after eliminate_duplicated_expr SAME TEXT AS ABOVE logical_plan after eliminate_filter SAME TEXT AS ABOVE @@ -210,11 +208,7 @@ logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE -logical_plan after push_down_projection -Projection: simple_explain_test.a, simple_explain_test.b, simple_explain_test.c ---TableScan: simple_explain_test projection=[a, b, c] -logical_plan after eliminate_projection TableScan: simple_explain_test projection=[a, b, c] -logical_plan after push_down_limit SAME TEXT AS ABOVE +logical_plan after optimize_projections TableScan: simple_explain_test projection=[a, b, c] logical_plan after eliminate_nested_union SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE @@ -224,7 +218,6 @@ logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE -logical_plan after merge_projection SAME TEXT AS ABOVE logical_plan after rewrite_disjunctive_predicate SAME TEXT AS ABOVE logical_plan after eliminate_duplicated_expr SAME TEXT AS ABOVE logical_plan after eliminate_filter SAME TEXT AS ABOVE @@ -241,16 +234,16 @@ logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE -logical_plan after push_down_projection SAME TEXT AS ABOVE -logical_plan after eliminate_projection SAME TEXT AS ABOVE -logical_plan after push_down_limit SAME TEXT AS ABOVE +logical_plan after optimize_projections SAME TEXT AS ABOVE logical_plan TableScan: simple_explain_test projection=[a, b, c] initial_physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true +initial_physical_plan_with_stats CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] physical_plan after OutputRequirements OutputRequirementExec --CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true physical_plan after aggregate_statistics SAME TEXT AS ABOVE physical_plan after join_selection SAME TEXT AS ABOVE +physical_plan after LimitedDistinctAggregation SAME TEXT AS ABOVE physical_plan after EnforceDistribution SAME TEXT AS ABOVE physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE physical_plan after EnforceSorting SAME TEXT AS ABOVE @@ -258,7 +251,9 @@ physical_plan after coalesce_batches SAME TEXT AS ABOVE physical_plan after OutputRequirements CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true physical_plan after PipelineChecker SAME TEXT AS ABOVE physical_plan after LimitAggregation SAME TEXT AS ABOVE +physical_plan after ProjectionPushdown SAME TEXT AS ABOVE physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true +physical_plan_with_stats CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] ### tests for EXPLAIN with display statistics enabled @@ -273,8 +268,8 @@ query TT EXPLAIN SELECT a, b, c FROM simple_explain_test limit 10; ---- physical_plan -GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Inexact(10), Bytes=Absent] ---CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], limit=10, has_header=true, statistics=[Rows=Absent, Bytes=Absent] +GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Inexact(10), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] +--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], limit=10, has_header=true, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] # Parquet scan with statistics collected statement ok @@ -287,11 +282,100 @@ query TT EXPLAIN SELECT * FROM alltypes_plain limit 10; ---- physical_plan -GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent] ---ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent] +GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] + +# explain verbose with both collect & show statistics on +query TT +EXPLAIN VERBOSE SELECT * FROM alltypes_plain limit 10; +---- +initial_physical_plan +GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +physical_plan after OutputRequirements +OutputRequirementExec, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +--GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +----ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +physical_plan after aggregate_statistics SAME TEXT AS ABOVE +physical_plan after join_selection SAME TEXT AS ABOVE +physical_plan after LimitedDistinctAggregation SAME TEXT AS ABOVE +physical_plan after EnforceDistribution SAME TEXT AS ABOVE +physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE +physical_plan after EnforceSorting SAME TEXT AS ABOVE +physical_plan after coalesce_batches SAME TEXT AS ABOVE +physical_plan after OutputRequirements +GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +physical_plan after PipelineChecker SAME TEXT AS ABOVE +physical_plan after LimitAggregation SAME TEXT AS ABOVE +physical_plan after ProjectionPushdown SAME TEXT AS ABOVE +physical_plan +GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] + + +statement ok +set datafusion.explain.show_statistics = false; + +# explain verbose with collect on and & show statistics off: still has stats +query TT +EXPLAIN VERBOSE SELECT * FROM alltypes_plain limit 10; +---- +initial_physical_plan +GlobalLimitExec: skip=0, fetch=10 +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 +initial_physical_plan_with_stats +GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +physical_plan after OutputRequirements +OutputRequirementExec +--GlobalLimitExec: skip=0, fetch=10 +----ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 +physical_plan after aggregate_statistics SAME TEXT AS ABOVE +physical_plan after join_selection SAME TEXT AS ABOVE +physical_plan after LimitedDistinctAggregation SAME TEXT AS ABOVE +physical_plan after EnforceDistribution SAME TEXT AS ABOVE +physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE +physical_plan after EnforceSorting SAME TEXT AS ABOVE +physical_plan after coalesce_batches SAME TEXT AS ABOVE +physical_plan after OutputRequirements +GlobalLimitExec: skip=0, fetch=10 +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 +physical_plan after PipelineChecker SAME TEXT AS ABOVE +physical_plan after LimitAggregation SAME TEXT AS ABOVE +physical_plan after ProjectionPushdown SAME TEXT AS ABOVE +physical_plan +GlobalLimitExec: skip=0, fetch=10 +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 +physical_plan_with_stats +GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] + statement ok set datafusion.execution.collect_statistics = false; +# Explain ArrayFuncions + statement ok -set datafusion.explain.show_statistics = false; +set datafusion.explain.physical_plan_only = false + +query TT +explain select make_array(make_array(1, 2, 3), make_array(4, 5, 6)); +---- +logical_plan +Projection: List([[1, 2, 3], [4, 5, 6]]) AS make_array(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(4),Int64(5),Int64(6))) +--EmptyRelation +physical_plan +ProjectionExec: expr=[[[1, 2, 3], [4, 5, 6]] as make_array(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(4),Int64(5),Int64(6)))] +--PlaceholderRowExec + +query TT +explain select [[1, 2, 3], [4, 5, 6]]; +---- +logical_plan +Projection: List([[1, 2, 3], [4, 5, 6]]) AS make_array(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(4),Int64(5),Int64(6))) +--EmptyRelation +physical_plan +ProjectionExec: expr=[[[1, 2, 3], [4, 5, 6]] as make_array(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(4),Int64(5),Int64(6)))] +--PlaceholderRowExec diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index e3e39ef6cc4c8..1903088b0748d 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -494,6 +494,10 @@ SELECT counter(*) from test; statement error Did you mean 'STDDEV'? SELECT STDEV(v1) from test; +# Aggregate function +statement error Did you mean 'COVAR'? +SELECT COVARIA(1,1); + # Window function statement error Did you mean 'SUM'? SELECT v1, v2, SUMM(v2) OVER(ORDER BY v1) from test; @@ -784,7 +788,7 @@ INSERT INTO products (product_id, product_name, price) VALUES (1, 'OldBrand Product 1', 19.99), (2, 'OldBrand Product 2', 29.99), (3, 'OldBrand Product 3', 39.99), -(4, 'OldBrand Product 4', 49.99) +(4, 'OldBrand Product 4', 49.99) query ITR SELECT * REPLACE (price*2 AS price) FROM products @@ -811,3 +815,189 @@ SELECT products.* REPLACE (price*2 AS price, product_id+1000 AS product_id) FROM 1002 OldBrand Product 2 59.98 1003 OldBrand Product 3 79.98 1004 OldBrand Product 4 99.98 + +#overlay tests +statement ok +CREATE TABLE over_test( + str TEXT, + characters TEXT, + pos INT, + len INT +) as VALUES + ('123', 'abc', 4, 5), + ('abcdefg', 'qwertyasdfg', 1, 7), + ('xyz', 'ijk', 1, 2), + ('Txxxxas', 'hom', 2, 4), + (NULL, 'hom', 2, 4), + ('Txxxxas', 'hom', NULL, 4), + ('Txxxxas', 'hom', 2, NULL), + ('Txxxxas', NULL, 2, 4) +; + +query T +SELECT overlay(str placing characters from pos for len) from over_test +---- +abc +qwertyasdfg +ijkz +Thomas +NULL +NULL +NULL +NULL + +query T +SELECT overlay(str placing characters from pos) from over_test +---- +abc +qwertyasdfg +ijk +Thomxas +NULL +NULL +Thomxas +NULL + +query I +SELECT levenshtein('kitten', 'sitting') +---- +3 + +query I +SELECT levenshtein('kitten', NULL) +---- +NULL + +query ? +SELECT levenshtein(NULL, 'sitting') +---- +NULL + +query ? +SELECT levenshtein(NULL, NULL) +---- +NULL + +query T +SELECT substr_index('www.apache.org', '.', 1) +---- +www + +query T +SELECT substr_index('www.apache.org', '.', 2) +---- +www.apache + +query T +SELECT substr_index('www.apache.org', '.', -1) +---- +org + +query T +SELECT substr_index('www.apache.org', '.', -2) +---- +apache.org + +query T +SELECT substr_index('www.apache.org', 'ac', 1) +---- +www.ap + +query T +SELECT substr_index('www.apache.org', 'ac', -1) +---- +he.org + +query T +SELECT substr_index('www.apache.org', 'ac', 2) +---- +www.apache.org + +query T +SELECT substr_index('www.apache.org', 'ac', -2) +---- +www.apache.org + +query ? +SELECT substr_index(NULL, 'ac', 1) +---- +NULL + +query T +SELECT substr_index('www.apache.org', NULL, 1) +---- +NULL + +query T +SELECT substr_index('www.apache.org', 'ac', NULL) +---- +NULL + +query T +SELECT substr_index('', 'ac', 1) +---- +(empty) + +query T +SELECT substr_index('www.apache.org', '', 1) +---- +(empty) + +query T +SELECT substr_index('www.apache.org', 'ac', 0) +---- +(empty) + +query ? +SELECT substr_index(NULL, NULL, NULL) +---- +NULL + +query I +SELECT find_in_set('b', 'a,b,c,d') +---- +2 + + +query I +SELECT find_in_set('a', 'a,b,c,d,a') +---- +1 + +query I +SELECT find_in_set('', 'a,b,c,d,a') +---- +0 + +query I +SELECT find_in_set('a', '') +---- +0 + + +query I +SELECT find_in_set('', '') +---- +1 + +query ? +SELECT find_in_set(NULL, 'a,b,c,d') +---- +NULL + +query I +SELECT find_in_set('a', NULL) +---- +NULL + + +query ? +SELECT find_in_set(NULL, NULL) +---- +NULL + +# Verify that multiple calls to volatile functions like `random()` are not combined / optimized away +query B +SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random() r1, random() r2) WHERE r1 > 0 AND r2 > 0) +---- +false diff --git a/datafusion/sqllogictest/test_files/groupby.slt b/datafusion/sqllogictest/test_files/groupby.slt index cb0b0b7c76a58..b09ff79e88d50 100644 --- a/datafusion/sqllogictest/test_files/groupby.slt +++ b/datafusion/sqllogictest/test_files/groupby.slt @@ -2019,8 +2019,8 @@ SortPreservingMergeExec: [col0@0 ASC NULLS LAST] ------AggregateExec: mode=FinalPartitioned, gby=[col0@0 as col0, col1@1 as col1, col2@2 as col2], aggr=[LAST_VALUE(r.col1)] --------CoalesceBatchesExec: target_batch_size=8192 ----------RepartitionExec: partitioning=Hash([col0@0, col1@1, col2@2], 4), input_partitions=4 -------------AggregateExec: mode=Partial, gby=[col0@0 as col0, col1@1 as col1, col2@2 as col2], aggr=[LAST_VALUE(r.col1)], ordering_mode=PartiallySorted([0]) ---------------SortExec: expr=[col0@3 ASC NULLS LAST] +------------AggregateExec: mode=Partial, gby=[col0@0 as col0, col1@1 as col1, col2@2 as col2], aggr=[LAST_VALUE(r.col1)] +--------------ProjectionExec: expr=[col0@2 as col0, col1@3 as col1, col2@4 as col2, col0@0 as col0, col1@1 as col1] ----------------CoalesceBatchesExec: target_batch_size=8192 ------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col0@0, col0@0)] --------------------CoalesceBatchesExec: target_batch_size=8192 @@ -2084,9 +2084,7 @@ logical_plan Projection: multiple_ordered_table.a --Sort: multiple_ordered_table.c ASC NULLS LAST ----TableScan: multiple_ordered_table projection=[a, c] -physical_plan -ProjectionExec: expr=[a@0 as a] ---CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_ordering=[a@0 ASC NULLS LAST], has_header=true +physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a], output_ordering=[a@0 ASC NULLS LAST], has_header=true # Final plan shouldn't have SortExec a ASC, b ASC, # because table already satisfies this ordering. @@ -2097,9 +2095,7 @@ logical_plan Projection: multiple_ordered_table.a --Sort: multiple_ordered_table.a ASC NULLS LAST, multiple_ordered_table.b ASC NULLS LAST ----TableScan: multiple_ordered_table projection=[a, b] -physical_plan -ProjectionExec: expr=[a@0 as a] ---CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], has_header=true +physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a], output_ordering=[a@0 ASC NULLS LAST], has_header=true # test_window_agg_sort statement ok @@ -2119,7 +2115,7 @@ Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, SUM(annotate physical_plan ProjectionExec: expr=[a@1 as a, b@0 as b, SUM(annotated_data_infinite2.c)@2 as summation1] --AggregateExec: mode=Single, gby=[b@1 as b, a@0 as a], aggr=[SUM(annotated_data_infinite2.c)], ordering_mode=Sorted -----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true +----StreamingTableExec: partition_sizes=1, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] query III @@ -2150,7 +2146,7 @@ Projection: annotated_data_infinite2.a, annotated_data_infinite2.d, SUM(annotate physical_plan ProjectionExec: expr=[a@1 as a, d@0 as d, SUM(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]@2 as summation1] --AggregateExec: mode=Single, gby=[d@2 as d, a@0 as a], aggr=[SUM(annotated_data_infinite2.c)], ordering_mode=PartiallySorted([1]) -----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true +----StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] query III SELECT a, d, @@ -2183,7 +2179,7 @@ Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, FIRST_VALUE( physical_plan ProjectionExec: expr=[a@0 as a, b@1 as b, FIRST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]@2 as first_c] --AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[FIRST_VALUE(annotated_data_infinite2.c)], ordering_mode=Sorted -----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true +----StreamingTableExec: partition_sizes=1, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] query III SELECT a, b, FIRST_VALUE(c ORDER BY a DESC) as first_c @@ -2209,10 +2205,10 @@ Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, LAST_VALUE(a physical_plan ProjectionExec: expr=[a@0 as a, b@1 as b, LAST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]@2 as last_c] --AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[LAST_VALUE(annotated_data_infinite2.c)], ordering_mode=Sorted -----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true +----StreamingTableExec: partition_sizes=1, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] query III -SELECT a, b, LAST_VALUE(c ORDER BY a DESC) as last_c +SELECT a, b, LAST_VALUE(c ORDER BY a DESC, c ASC) as last_c FROM annotated_data_infinite2 GROUP BY a, b ---- @@ -2236,7 +2232,7 @@ Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, LAST_VALUE(a physical_plan ProjectionExec: expr=[a@0 as a, b@1 as b, LAST_VALUE(annotated_data_infinite2.c)@2 as last_c] --AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[LAST_VALUE(annotated_data_infinite2.c)], ordering_mode=Sorted -----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true +----StreamingTableExec: partition_sizes=1, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] query III SELECT a, b, LAST_VALUE(c) as last_c @@ -2333,15 +2329,15 @@ ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.amou ----SortExec: expr=[amount@1 DESC] ------MemoryExec: partitions=1, partition_sizes=[1] -query T?R +query T?R rowsort SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, SUM(s.amount) AS sum1 FROM sales_global AS s GROUP BY s.country ---- FRA [200.0, 50.0] 250 -TUR [100.0, 75.0] 175 GRC [80.0, 30.0] 110 +TUR [100.0, 75.0] 175 # test_ordering_sensitive_aggregation3 # When different aggregators have conflicting requirements, we cannot satisfy all of them in current implementation. @@ -2377,7 +2373,7 @@ ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.amou ----SortExec: expr=[country@0 ASC NULLS LAST,amount@1 DESC] ------MemoryExec: partitions=1, partition_sizes=[1] -query T?R +query T?R rowsort SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, SUM(s.amount) AS sum1 FROM (SELECT * @@ -2413,7 +2409,7 @@ ProjectionExec: expr=[country@0 as country, zip_code@1 as zip_code, ARRAY_AGG(s. ----SortExec: expr=[country@1 ASC NULLS LAST,amount@2 DESC] ------MemoryExec: partitions=1, partition_sizes=[1] -query TI?R +query TI?R rowsort SELECT s.country, s.zip_code, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, SUM(s.amount) AS sum1 FROM (SELECT * @@ -2449,7 +2445,7 @@ ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.coun ----SortExec: expr=[country@0 ASC NULLS LAST] ------MemoryExec: partitions=1, partition_sizes=[1] -query T?R +query T?R rowsort SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, SUM(s.amount) AS sum1 FROM (SELECT * @@ -2484,7 +2480,7 @@ ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.coun ----SortExec: expr=[country@0 ASC NULLS LAST,amount@1 DESC] ------MemoryExec: partitions=1, partition_sizes=[1] -query T?R +query T?R rowsort SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.country DESC, s.amount DESC) AS amounts, SUM(s.amount) AS sum1 FROM (SELECT * @@ -2516,7 +2512,7 @@ ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER ----SortExec: expr=[amount@1 DESC] ------MemoryExec: partitions=1, partition_sizes=[1] -query T?RR +query T?RR rowsort SELECT country, ARRAY_AGG(amount ORDER BY amount DESC) AS amounts, FIRST_VALUE(amount ORDER BY amount ASC) AS fv1, LAST_VALUE(amount ORDER BY amount DESC) AS fv2 @@ -2524,8 +2520,8 @@ SELECT country, ARRAY_AGG(amount ORDER BY amount DESC) AS amounts, GROUP BY country ---- FRA [200.0, 50.0] 50 50 -TUR [100.0, 75.0] 75 75 GRC [80.0, 30.0] 30 30 +TUR [100.0, 75.0] 75 75 # test_reverse_aggregate_expr2 # Some of the Aggregators can be reversed, by this way we can still run aggregators without re-ordering @@ -2641,10 +2637,9 @@ Projection: sales_global.country, FIRST_VALUE(sales_global.amount) ORDER BY [sal physical_plan ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@2 as lv1, SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@3 as sum1] --AggregateExec: mode=Single, gby=[country@0 as country], aggr=[LAST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount), SUM(sales_global.amount)] -----SortExec: expr=[ts@1 ASC NULLS LAST] -------MemoryExec: partitions=1, partition_sizes=[1] +----MemoryExec: partitions=1, partition_sizes=[1] -query TRRR +query TRRR rowsort SELECT country, FIRST_VALUE(amount ORDER BY ts DESC) as fv1, LAST_VALUE(amount ORDER BY ts DESC) as lv1, SUM(amount ORDER BY ts DESC) as sum1 @@ -2653,8 +2648,8 @@ SELECT country, FIRST_VALUE(amount ORDER BY ts DESC) as fv1, ORDER BY ts ASC) GROUP BY country ---- -GRC 80 30 110 FRA 200 50 250 +GRC 80 30 110 TUR 100 75 175 # If existing ordering doesn't satisfy requirement, we should do calculations @@ -2675,19 +2670,18 @@ Projection: sales_global.country, FIRST_VALUE(sales_global.amount) ORDER BY [sal physical_plan ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@2 as lv1, SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@3 as sum1] --AggregateExec: mode=Single, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount), SUM(sales_global.amount)] -----SortExec: expr=[ts@1 DESC] -------MemoryExec: partitions=1, partition_sizes=[1] +----MemoryExec: partitions=1, partition_sizes=[1] -query TRRR +query TRRR rowsort SELECT country, FIRST_VALUE(amount ORDER BY ts DESC) as fv1, LAST_VALUE(amount ORDER BY ts DESC) as lv1, SUM(amount ORDER BY ts DESC) as sum1 FROM sales_global GROUP BY country ---- -TUR 100 75 175 -GRC 80 30 110 FRA 200 50 250 +GRC 80 30 110 +TUR 100 75 175 query TT EXPLAIN SELECT s.zip_code, s.country, s.sn, s.ts, s.currency, LAST_VALUE(e.amount ORDER BY e.sn) AS last_rate @@ -2712,14 +2706,13 @@ physical_plan SortExec: expr=[sn@2 ASC NULLS LAST] --ProjectionExec: expr=[zip_code@1 as zip_code, country@2 as country, sn@0 as sn, ts@3 as ts, currency@4 as currency, LAST_VALUE(e.amount) ORDER BY [e.sn ASC NULLS LAST]@5 as last_rate] ----AggregateExec: mode=Single, gby=[sn@2 as sn, zip_code@0 as zip_code, country@1 as country, ts@3 as ts, currency@4 as currency], aggr=[LAST_VALUE(e.amount)] -------SortExec: expr=[sn@5 ASC NULLS LAST] ---------ProjectionExec: expr=[zip_code@0 as zip_code, country@1 as country, sn@2 as sn, ts@3 as ts, currency@4 as currency, sn@5 as sn, amount@8 as amount] -----------CoalesceBatchesExec: target_batch_size=8192 -------------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(currency@4, currency@2)], filter=ts@0 >= ts@1 ---------------MemoryExec: partitions=1, partition_sizes=[1] ---------------MemoryExec: partitions=1, partition_sizes=[1] +------ProjectionExec: expr=[zip_code@4 as zip_code, country@5 as country, sn@6 as sn, ts@7 as ts, currency@8 as currency, sn@0 as sn, amount@3 as amount] +--------CoalesceBatchesExec: target_batch_size=8192 +----------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(currency@2, currency@4)], filter=ts@0 >= ts@1 +------------MemoryExec: partitions=1, partition_sizes=[1] +------------MemoryExec: partitions=1, partition_sizes=[1] -query ITIPTR +query ITIPTR rowsort SELECT s.zip_code, s.country, s.sn, s.ts, s.currency, LAST_VALUE(e.amount ORDER BY e.sn) AS last_rate FROM sales_global AS s JOIN sales_global AS e @@ -2729,10 +2722,10 @@ GROUP BY s.sn, s.zip_code, s.country, s.ts, s.currency ORDER BY s.sn ---- 0 GRC 0 2022-01-01T06:00:00 EUR 30 +0 GRC 4 2022-01-03T10:00:00 EUR 80 1 FRA 1 2022-01-01T08:00:00 EUR 50 -1 TUR 2 2022-01-01T11:30:00 TRY 75 1 FRA 3 2022-01-02T12:00:00 EUR 200 -0 GRC 4 2022-01-03T10:00:00 EUR 80 +1 TUR 2 2022-01-01T11:30:00 TRY 75 1 TUR 4 2022-01-03T10:00:00 TRY 100 # Run order-sensitive aggregators in multiple partitions @@ -2762,8 +2755,7 @@ SortPreservingMergeExec: [country@0 ASC NULLS LAST] ----------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 ------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 --------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] -----------------SortExec: expr=[ts@1 ASC NULLS LAST] -------------------MemoryExec: partitions=1, partition_sizes=[1] +----------------MemoryExec: partitions=1, partition_sizes=[1] query TRR SELECT country, FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, @@ -2794,13 +2786,12 @@ physical_plan SortPreservingMergeExec: [country@0 ASC NULLS LAST] --SortExec: expr=[country@0 ASC NULLS LAST] ----ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@2 as fv2] -------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] +------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] --------CoalesceBatchesExec: target_batch_size=8192 ----------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 ------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 ---------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] -----------------SortExec: expr=[ts@1 ASC NULLS LAST] -------------------MemoryExec: partitions=1, partition_sizes=[1] +--------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +----------------MemoryExec: partitions=1, partition_sizes=[1] query TRR SELECT country, FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, @@ -2834,16 +2825,15 @@ ProjectionExec: expr=[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts --AggregateExec: mode=Final, gby=[], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] ----CoalescePartitionsExec ------AggregateExec: mode=Partial, gby=[], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] ---------SortExec: expr=[ts@0 ASC NULLS LAST] -----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 -------------MemoryExec: partitions=1, partition_sizes=[1] +--------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] query RR SELECT FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, LAST_VALUE(amount ORDER BY ts ASC) AS fv2 FROM sales_global ---- -30 80 +30 100 # Conversion in between FIRST_VALUE and LAST_VALUE to resolve # contradictory requirements should work in multi partitions. @@ -2858,12 +2848,11 @@ Projection: FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS ----TableScan: sales_global projection=[ts, amount] physical_plan ProjectionExec: expr=[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@0 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as fv2] ---AggregateExec: mode=Final, gby=[], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] +--AggregateExec: mode=Final, gby=[], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] ----CoalescePartitionsExec -------AggregateExec: mode=Partial, gby=[], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] ---------SortExec: expr=[ts@0 ASC NULLS LAST] -----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 -------------MemoryExec: partitions=1, partition_sizes=[1] +------AggregateExec: mode=Partial, gby=[], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +--------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] query RR SELECT FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, @@ -2996,7 +2985,7 @@ physical_plan SortPreservingMergeExec: [country@0 ASC NULLS LAST] --SortExec: expr=[country@0 ASC NULLS LAST] ----ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@1 as amounts, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] -------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), LAST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] --------CoalesceBatchesExec: target_batch_size=4 ----------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 ------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), LAST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] @@ -3215,6 +3204,21 @@ SELECT s.sn, s.amount, 2*s.sn 3 200 6 4 100 8 +# we should be able to re-write group by expression +# using functional dependencies for complex expressions also. +# In this case, we use 2*s.amount instead of s.amount. +query IRI +SELECT s.sn, 2*s.amount, 2*s.sn + FROM sales_global_with_pk AS s + GROUP BY sn + ORDER BY sn +---- +0 60 0 +1 100 2 +2 150 4 +3 400 6 +4 200 8 + query IRI SELECT s.sn, s.amount, 2*s.sn FROM sales_global_with_pk_alternate AS s @@ -3368,7 +3372,7 @@ SELECT column1, COUNT(*) as column2 FROM (VALUES (['a', 'b'], 1), (['c', 'd', 'e # primary key should be aware from which columns it is associated -statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression r.sn could not be resolved from available columns: l.sn, SUM\(l.amount\) +statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression r.sn could not be resolved from available columns: l.sn, l.zip_code, l.country, l.ts, l.currency, l.amount, SUM\(l.amount\) SELECT l.sn, r.sn, SUM(l.amount), r.amount FROM sales_global_with_pk AS l JOIN sales_global_with_pk AS r @@ -3460,7 +3464,7 @@ ORDER BY r.sn 4 100 2022-01-03T10:00:00 # after join, new window expressions shouldn't be associated with primary keys -statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression rn1 could not be resolved from available columns: r.sn, SUM\(r.amount\) +statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression rn1 could not be resolved from available columns: r.sn, r.ts, r.amount, SUM\(r.amount\) SELECT r.sn, SUM(r.amount), rn1 FROM (SELECT r.ts, r.sn, r.amount, @@ -3632,7 +3636,7 @@ ProjectionExec: expr=[FIRST_VALUE(multiple_ordered_table.a) ORDER BY [multiple_o ------RepartitionExec: partitioning=Hash([d@0], 8), input_partitions=8 --------AggregateExec: mode=Partial, gby=[d@2 as d], aggr=[FIRST_VALUE(multiple_ordered_table.a), FIRST_VALUE(multiple_ordered_table.c)] ----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 -------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true +------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true query II rowsort SELECT FIRST_VALUE(a ORDER BY a ASC) as first_a, @@ -3696,16 +3700,15 @@ Projection: amount_usd ----------------SubqueryAlias: r ------------------TableScan: multiple_ordered_table projection=[a, d] physical_plan -ProjectionExec: expr=[amount_usd@0 as amount_usd] ---ProjectionExec: expr=[LAST_VALUE(l.d) ORDER BY [l.a ASC NULLS LAST]@1 as amount_usd, row_n@0 as row_n] -----AggregateExec: mode=Single, gby=[row_n@2 as row_n], aggr=[LAST_VALUE(l.d)], ordering_mode=Sorted -------ProjectionExec: expr=[a@0 as a, d@1 as d, row_n@4 as row_n] ---------CoalesceBatchesExec: target_batch_size=2 -----------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(d@1, d@1)], filter=CAST(a@0 AS Int64) >= CAST(a@1 AS Int64) - 10 -------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true -------------ProjectionExec: expr=[a@0 as a, d@1 as d, ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as row_n] ---------------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] -----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true +ProjectionExec: expr=[LAST_VALUE(l.d) ORDER BY [l.a ASC NULLS LAST]@1 as amount_usd] +--AggregateExec: mode=Single, gby=[row_n@2 as row_n], aggr=[LAST_VALUE(l.d)], ordering_mode=Sorted +----ProjectionExec: expr=[a@0 as a, d@1 as d, row_n@4 as row_n] +------CoalesceBatchesExec: target_batch_size=2 +--------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(d@1, d@1)], filter=CAST(a@0 AS Int64) >= CAST(a@1 AS Int64) - 10 +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true +----------ProjectionExec: expr=[a@0 as a, d@1 as d, ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as row_n] +------------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] +--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true # reset partition number to 8. statement ok @@ -3789,6 +3792,192 @@ AggregateExec: mode=FinalPartitioned, gby=[c@0 as c, b@1 as b], aggr=[SUM(multip ----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 ------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true +statement ok +set datafusion.execution.target_partitions = 1; + +query TT +EXPLAIN SELECT c, sum1 + FROM + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) +GROUP BY c; +---- +logical_plan +Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, sum1]], aggr=[[]] +--Projection: multiple_ordered_table_with_pk.c, SUM(multiple_ordered_table_with_pk.d) AS sum1 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[c, d] +physical_plan +AggregateExec: mode=Single, gby=[c@0 as c, sum1@1 as sum1], aggr=[], ordering_mode=PartiallySorted([0]) +--ProjectionExec: expr=[c@0 as c, SUM(multiple_ordered_table_with_pk.d)@1 as sum1] +----AggregateExec: mode=Single, gby=[c@0 as c], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], output_ordering=[c@0 ASC NULLS LAST], has_header=true + +query TT +EXPLAIN SELECT c, sum1, SUM(b) OVER() as sumb + FROM + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c); +---- +logical_plan +Projection: multiple_ordered_table_with_pk.c, sum1, SUM(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS sumb +--WindowAggr: windowExpr=[[SUM(CAST(multiple_ordered_table_with_pk.b AS Int64)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +----Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b, SUM(multiple_ordered_table_with_pk.d) AS sum1 +------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +--------TableScan: multiple_ordered_table_with_pk projection=[b, c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, sum1@2 as sum1, SUM(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@3 as sumb] +--WindowAggExec: wdw=[SUM(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }] +----ProjectionExec: expr=[c@0 as c, b@1 as b, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +------AggregateExec: mode=Single, gby=[c@1 as c, b@0 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true + +query TT +EXPLAIN SELECT lhs.c, rhs.c, lhs.sum1, rhs.sum1 + FROM + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) as lhs + JOIN + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) as rhs + ON lhs.b=rhs.b; +---- +logical_plan +Projection: lhs.c, rhs.c, lhs.sum1, rhs.sum1 +--Inner Join: lhs.b = rhs.b +----SubqueryAlias: lhs +------Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----------TableScan: multiple_ordered_table_with_pk projection=[b, c, d] +----SubqueryAlias: rhs +------Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----------TableScan: multiple_ordered_table_with_pk projection=[b, c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, c@3 as c, sum1@2 as sum1, sum1@5 as sum1] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(b@1, b@1)] +------ProjectionExec: expr=[c@0 as c, b@1 as b, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +--------AggregateExec: mode=Single, gby=[c@1 as c, b@0 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true +------ProjectionExec: expr=[c@0 as c, b@1 as b, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +--------AggregateExec: mode=Single, gby=[c@1 as c, b@0 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true + +query TT +EXPLAIN SELECT lhs.c, rhs.c, lhs.sum1, rhs.sum1 + FROM + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) as lhs + CROSS JOIN + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) as rhs; +---- +logical_plan +Projection: lhs.c, rhs.c, lhs.sum1, rhs.sum1 +--CrossJoin: +----SubqueryAlias: lhs +------Projection: multiple_ordered_table_with_pk.c, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----------TableScan: multiple_ordered_table_with_pk projection=[c, d] +----SubqueryAlias: rhs +------Projection: multiple_ordered_table_with_pk.c, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----------TableScan: multiple_ordered_table_with_pk projection=[c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, c@2 as c, sum1@1 as sum1, sum1@3 as sum1] +--CrossJoinExec +----ProjectionExec: expr=[c@0 as c, SUM(multiple_ordered_table_with_pk.d)@1 as sum1] +------AggregateExec: mode=Single, gby=[c@0 as c], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], output_ordering=[c@0 ASC NULLS LAST], has_header=true +----ProjectionExec: expr=[c@0 as c, SUM(multiple_ordered_table_with_pk.d)@1 as sum1] +------AggregateExec: mode=Single, gby=[c@0 as c], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], output_ordering=[c@0 ASC NULLS LAST], has_header=true + +# we do not generate physical plan for Repartition yet (e.g Distribute By queries). +query TT +EXPLAIN SELECT a, b, sum1 +FROM (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) +DISTRIBUTE BY a +---- +logical_plan +Repartition: DistributeBy(a) +--Projection: multiple_ordered_table_with_pk.a, multiple_ordered_table_with_pk.b, SUM(multiple_ordered_table_with_pk.d) AS sum1 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[a, b, c, d] + +# union with aggregate +query TT +EXPLAIN SELECT c, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c +UNION ALL + SELECT c, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c +---- +logical_plan +Union +--Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, SUM(multiple_ordered_table_with_pk.d) AS sum1 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[a, c, d] +--Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, SUM(multiple_ordered_table_with_pk.d) AS sum1 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[a, c, d] +physical_plan +UnionExec +--ProjectionExec: expr=[c@0 as c, a@1 as a, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +----AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true +--ProjectionExec: expr=[c@0 as c, a@1 as a, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +----AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true + +# table scan should be simplified. +query TT +EXPLAIN SELECT c, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c +---- +logical_plan +Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----TableScan: multiple_ordered_table_with_pk projection=[a, c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, a@1 as a, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +--AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true + +# limit should be simplified +query TT +EXPLAIN SELECT * + FROM (SELECT c, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c + LIMIT 5) +---- +logical_plan +Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--Limit: skip=0, fetch=5 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[a, c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, a@1 as a, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +--GlobalLimitExec: skip=0, fetch=5 +----AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true + +statement ok +set datafusion.execution.target_partitions = 8; + # Tests for single distinct to group by optimization rule statement ok CREATE TABLE t(x int) AS VALUES (1), (2), (1); @@ -3846,3 +4035,252 @@ ProjectionExec: expr=[SUM(alias1)@1 as SUM(DISTINCT t1.x), MAX(alias1)@2 as MAX( ------------------AggregateExec: mode=Partial, gby=[y@1 as y, CAST(t1.x AS Float64)t1.x@0 as alias1], aggr=[] --------------------ProjectionExec: expr=[CAST(x@0 AS Float64) as CAST(t1.x AS Float64)t1.x, y@1 as y] ----------------------MemoryExec: partitions=1, partition_sizes=[1] + +# create an unbounded table that contains ordered timestamp. +statement ok +CREATE UNBOUNDED EXTERNAL TABLE unbounded_csv_with_timestamps ( + name VARCHAR, + ts TIMESTAMP +) +STORED AS CSV +WITH ORDER (ts DESC) +LOCATION '../core/tests/data/timestamps.csv' + +# below query should work in streaming mode. +query TT +EXPLAIN SELECT date_bin('15 minutes', ts) as time_chunks + FROM unbounded_csv_with_timestamps + GROUP BY date_bin('15 minutes', ts) + ORDER BY time_chunks DESC + LIMIT 5; +---- +logical_plan +Limit: skip=0, fetch=5 +--Sort: time_chunks DESC NULLS FIRST, fetch=5 +----Projection: date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts) AS time_chunks +------Aggregate: groupBy=[[date_bin(IntervalMonthDayNano("900000000000"), unbounded_csv_with_timestamps.ts) AS date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)]], aggr=[[]] +--------TableScan: unbounded_csv_with_timestamps projection=[ts] +physical_plan +GlobalLimitExec: skip=0, fetch=5 +--SortPreservingMergeExec: [time_chunks@0 DESC], fetch=5 +----ProjectionExec: expr=[date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0 as time_chunks] +------AggregateExec: mode=FinalPartitioned, gby=[date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0 as date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)], aggr=[], ordering_mode=Sorted +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0], 8), input_partitions=8, preserve_order=true, sort_exprs=date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0 DESC +------------AggregateExec: mode=Partial, gby=[date_bin(900000000000, ts@0) as date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)], aggr=[], ordering_mode=Sorted +--------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +----------------StreamingTableExec: partition_sizes=1, projection=[ts], infinite_source=true, output_ordering=[ts@0 DESC] + +query P +SELECT date_bin('15 minutes', ts) as time_chunks + FROM unbounded_csv_with_timestamps + GROUP BY date_bin('15 minutes', ts) + ORDER BY time_chunks DESC + LIMIT 5; +---- +2018-12-13T12:00:00 +2018-11-13T17:00:00 + +# Since extract is not a monotonic function, below query should not run. +# when source is unbounded. +query error +SELECT extract(month from ts) as months + FROM unbounded_csv_with_timestamps + GROUP BY extract(month from ts) + ORDER BY months DESC + LIMIT 5; + +# Create a table where timestamp is ordered +statement ok +CREATE EXTERNAL TABLE csv_with_timestamps ( + name VARCHAR, + ts TIMESTAMP +) +STORED AS CSV +WITH ORDER (ts DESC) +LOCATION '../core/tests/data/timestamps.csv'; + +# below query should run since it operates on a bounded source and have a sort +# at the top of its plan. +query TT +EXPLAIN SELECT extract(month from ts) as months + FROM csv_with_timestamps + GROUP BY extract(month from ts) + ORDER BY months DESC + LIMIT 5; +---- +logical_plan +Limit: skip=0, fetch=5 +--Sort: months DESC NULLS FIRST, fetch=5 +----Projection: date_part(Utf8("MONTH"),csv_with_timestamps.ts) AS months +------Aggregate: groupBy=[[date_part(Utf8("MONTH"), csv_with_timestamps.ts)]], aggr=[[]] +--------TableScan: csv_with_timestamps projection=[ts] +physical_plan +GlobalLimitExec: skip=0, fetch=5 +--SortPreservingMergeExec: [months@0 DESC], fetch=5 +----SortExec: TopK(fetch=5), expr=[months@0 DESC] +------ProjectionExec: expr=[date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0 as months] +--------AggregateExec: mode=FinalPartitioned, gby=[date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0 as date_part(Utf8("MONTH"),csv_with_timestamps.ts)], aggr=[] +----------CoalesceBatchesExec: target_batch_size=2 +------------RepartitionExec: partitioning=Hash([date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0], 8), input_partitions=8 +--------------AggregateExec: mode=Partial, gby=[date_part(MONTH, ts@0) as date_part(Utf8("MONTH"),csv_with_timestamps.ts)], aggr=[] +----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/timestamps.csv]]}, projection=[ts], output_ordering=[ts@0 DESC], has_header=false + +query R +SELECT extract(month from ts) as months + FROM csv_with_timestamps + GROUP BY extract(month from ts) + ORDER BY months DESC + LIMIT 5; +---- +12 +11 + +statement ok +drop table t1 + +# Reproducer for https://github.com/apache/arrow-datafusion/issues/8175 + +statement ok +create table t1(state string, city string, min_temp float, area int, time timestamp) as values + ('MA', 'Boston', 70.4, 1, 50), + ('MA', 'Bedford', 71.59, 2, 150); + +query RI +select date_part('year', time) as bla, count(distinct state) as count from t1 group by bla; +---- +1970 1 + +query PI +select date_bin(interval '1 year', time) as bla, count(distinct state) as count from t1 group by bla; +---- +1970-01-01T00:00:00 1 + +statement ok +drop table t1 + +statement ok +CREATE EXTERNAL TABLE aggregate_test_100 ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 INT UNSIGNED NOT NULL, + c10 BIGINT UNSIGNED NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../../testing/data/csv/aggregate_test_100.csv' + +query TIIII +SELECT c1, count(distinct c2), min(distinct c2), min(c3), max(c4) FROM aggregate_test_100 GROUP BY c1 ORDER BY c1; +---- +a 5 1 -101 32064 +b 5 1 -117 25286 +c 5 1 -117 29106 +d 5 1 -99 31106 +e 5 1 -95 32514 + +query TT +EXPLAIN SELECT c1, count(distinct c2), min(distinct c2), sum(c3), max(c4) FROM aggregate_test_100 GROUP BY c1 ORDER BY c1; +---- +logical_plan +Sort: aggregate_test_100.c1 ASC NULLS LAST +--Projection: aggregate_test_100.c1, COUNT(alias1) AS COUNT(DISTINCT aggregate_test_100.c2), MIN(alias1) AS MIN(DISTINCT aggregate_test_100.c2), SUM(alias2) AS SUM(aggregate_test_100.c3), MAX(alias3) AS MAX(aggregate_test_100.c4) +----Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[COUNT(alias1), MIN(alias1), SUM(alias2), MAX(alias3)]] +------Aggregate: groupBy=[[aggregate_test_100.c1, aggregate_test_100.c2 AS alias1]], aggr=[[SUM(CAST(aggregate_test_100.c3 AS Int64)) AS alias2, MAX(aggregate_test_100.c4) AS alias3]] +--------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4] +physical_plan +SortPreservingMergeExec: [c1@0 ASC NULLS LAST] +--SortExec: expr=[c1@0 ASC NULLS LAST] +----ProjectionExec: expr=[c1@0 as c1, COUNT(alias1)@1 as COUNT(DISTINCT aggregate_test_100.c2), MIN(alias1)@2 as MIN(DISTINCT aggregate_test_100.c2), SUM(alias2)@3 as SUM(aggregate_test_100.c3), MAX(alias3)@4 as MAX(aggregate_test_100.c4)] +------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[COUNT(alias1), MIN(alias1), SUM(alias2), MAX(alias3)] +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 +------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[COUNT(alias1), MIN(alias1), SUM(alias2), MAX(alias3)] +--------------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1, alias1@1 as alias1], aggr=[alias2, alias3] +----------------CoalesceBatchesExec: target_batch_size=2 +------------------RepartitionExec: partitioning=Hash([c1@0, alias1@1], 8), input_partitions=8 +--------------------AggregateExec: mode=Partial, gby=[c1@0 as c1, c2@1 as alias1], aggr=[alias2, alias3] +----------------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4], has_header=true + +# Use PostgreSQL dialect +statement ok +set datafusion.sql_parser.dialect = 'Postgres'; + +query II +SELECT c2, count(distinct c3) FILTER (WHERE c1 != 'a') FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 17 +2 17 +3 13 +4 19 +5 11 + +query III +SELECT c2, count(distinct c3) FILTER (WHERE c1 != 'a'), count(c5) FILTER (WHERE c1 != 'b') FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 17 19 +2 17 18 +3 13 17 +4 19 18 +5 11 9 + +# Restore the default dialect +statement ok +set datafusion.sql_parser.dialect = 'Generic'; + +statement ok +drop table aggregate_test_100; + + +# Create an unbounded external table with primary key +# column c +statement ok +CREATE EXTERNAL TABLE unbounded_multiple_ordered_table_with_pk ( + a0 INTEGER, + a INTEGER, + b INTEGER, + c INTEGER primary key, + d INTEGER +) +STORED AS CSV +WITH HEADER ROW +WITH ORDER (a ASC, b ASC) +WITH ORDER (c ASC) +LOCATION '../core/tests/data/window_2.csv'; + +# Query below can be executed, since c is primary key. +query III rowsort +SELECT c, a, SUM(d) +FROM unbounded_multiple_ordered_table_with_pk +GROUP BY c +ORDER BY c +LIMIT 5 +---- +0 0 0 +1 0 2 +2 0 0 +3 0 0 +4 0 1 + + +query ITIPTR rowsort +SELECT r.* +FROM sales_global_with_pk as l, sales_global_with_pk as r +LIMIT 5 +---- +0 GRC 0 2022-01-01T06:00:00 EUR 30 +1 FRA 1 2022-01-01T08:00:00 EUR 50 +1 FRA 3 2022-01-02T12:00:00 EUR 200 +1 TUR 2 2022-01-01T11:30:00 TRY 75 +1 TUR 4 2022-01-03T10:00:00 TRY 100 diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index ed85f54a39aa2..1b5ad86546a33 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -150,6 +150,7 @@ datafusion.execution.aggregate.scalar_update_factor 10 datafusion.execution.batch_size 8192 datafusion.execution.coalesce_batches true datafusion.execution.collect_statistics false +datafusion.execution.listing_table_ignore_subdirectory true datafusion.execution.max_buffered_batches_per_output_file 2 datafusion.execution.meta_fetch_concurrency 32 datafusion.execution.minimum_parallel_output_files 4 @@ -188,6 +189,8 @@ datafusion.explain.logical_plan_only false datafusion.explain.physical_plan_only false datafusion.explain.show_statistics false datafusion.optimizer.allow_symmetric_joins_without_pruning true +datafusion.optimizer.default_filter_selectivity 20 +datafusion.optimizer.enable_distinct_aggregation_soft_limit true datafusion.optimizer.enable_round_robin_repartition true datafusion.optimizer.enable_topk_aggregation true datafusion.optimizer.filter_null_join_keys false @@ -222,6 +225,7 @@ datafusion.execution.aggregate.scalar_update_factor 10 Specifies the threshold f datafusion.execution.batch_size 8192 Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption datafusion.execution.coalesce_batches true When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting datafusion.execution.collect_statistics false Should DataFusion collect statistics after listing files +datafusion.execution.listing_table_ignore_subdirectory true Should sub directories be ignored when scanning directories for data files. Defaults to true (ignores subdirectories), consistent with Hive. Note that this setting does not affect reading partitioned tables (e.g. `/table/year=2021/month=01/data.parquet`). datafusion.execution.max_buffered_batches_per_output_file 2 This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption datafusion.execution.meta_fetch_concurrency 32 Number of files to read in parallel when inferring schema and statistics datafusion.execution.minimum_parallel_output_files 4 Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. @@ -260,6 +264,8 @@ datafusion.explain.logical_plan_only false When set to true, the explain stateme datafusion.explain.physical_plan_only false When set to true, the explain statement will only print physical plans datafusion.explain.show_statistics false When set to true, the explain statement will print operator statistics for physical plans datafusion.optimizer.allow_symmetric_joins_without_pruning true Should DataFusion allow symmetric hash joins for unbounded data sources even when its inputs do not have any ordering or filtering If the flag is not enabled, the SymmetricHashJoin operator will be unable to prune its internal buffers, resulting in certain join types - such as Full, Left, LeftAnti, LeftSemi, Right, RightAnti, and RightSemi - being produced only at the end of the execution. This is not typical in stream processing. Additionally, without proper design for long runner execution, all types of joins may encounter out-of-memory errors. +datafusion.optimizer.default_filter_selectivity 20 The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). +datafusion.optimizer.enable_distinct_aggregation_soft_limit true When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. datafusion.optimizer.enable_round_robin_repartition true When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores datafusion.optimizer.enable_topk_aggregation true When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible datafusion.optimizer.filter_null_join_keys false When set to true, the optimizer will insert filters before a join between a nullable and non-nullable column to filter out nulls on the nullable side. This filter can add additional overhead when the file format does not fully support predicate push down. diff --git a/datafusion/sqllogictest/test_files/insert.slt b/datafusion/sqllogictest/test_files/insert.slt index cc04c62277212..e20b3779459bf 100644 --- a/datafusion/sqllogictest/test_files/insert.slt +++ b/datafusion/sqllogictest/test_files/insert.slt @@ -64,7 +64,7 @@ Dml: op=[Insert Into] table=[table_without_values] --------WindowAggr: windowExpr=[[SUM(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] ----------TableScan: aggregate_test_100 projection=[c1, c4, c9] physical_plan -InsertExec: sink=MemoryTable (partitions=1) +FileSinkExec: sink=MemoryTable (partitions=1) --ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as field2] ----SortPreservingMergeExec: [c1@2 ASC NULLS LAST] ------ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, c1@0 as c1] @@ -125,7 +125,7 @@ Dml: op=[Insert Into] table=[table_without_values] ----WindowAggr: windowExpr=[[SUM(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] ------TableScan: aggregate_test_100 projection=[c1, c4, c9] physical_plan -InsertExec: sink=MemoryTable (partitions=1) +FileSinkExec: sink=MemoryTable (partitions=1) --CoalescePartitionsExec ----ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as field2] ------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted] @@ -175,7 +175,7 @@ Dml: op=[Insert Into] table=[table_without_values] --------WindowAggr: windowExpr=[[SUM(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] ----------TableScan: aggregate_test_100 projection=[c1, c4, c9] physical_plan -InsertExec: sink=MemoryTable (partitions=8) +FileSinkExec: sink=MemoryTable (partitions=8) --ProjectionExec: expr=[a1@0 as a1, a2@1 as a2] ----SortPreservingMergeExec: [c1@2 ASC NULLS LAST] ------ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as a1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as a2, c1@0 as c1] @@ -217,10 +217,9 @@ Dml: op=[Insert Into] table=[table_without_values] ----Sort: aggregate_test_100.c1 ASC NULLS LAST ------TableScan: aggregate_test_100 projection=[c1] physical_plan -InsertExec: sink=MemoryTable (partitions=1) ---ProjectionExec: expr=[c1@0 as c1] -----SortExec: expr=[c1@0 ASC NULLS LAST] -------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true +FileSinkExec: sink=MemoryTable (partitions=1) +--SortExec: expr=[c1@0 ASC NULLS LAST] +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true query T insert into table_without_values select c1 from aggregate_test_100 order by c1; @@ -259,14 +258,18 @@ insert into table_without_values(name, id) values(4, 'zoo'); statement error Error during planning: Column count doesn't match insert query! insert into table_without_values(id) values(4, 'zoo'); -statement error Error during planning: Inserting query must have the same schema with the table. +# insert NULL values for the missing column (name) +query IT insert into table_without_values(id) values(4); +---- +1 query IT rowsort select * from table_without_values; ---- 1 foo 2 bar +4 NULL statement ok drop table table_without_values; @@ -286,6 +289,16 @@ insert into table_without_values values(2, NULL); ---- 1 +# insert NULL values for the missing column (field2) +query II +insert into table_without_values(field1) values(3); +---- +1 + +# insert NULL values for the missing column (field1), but column is non-nullable +statement error Execution error: Invalid batch column at '0' has null but schema specifies non-nullable +insert into table_without_values(field2) values(300); + statement error Execution error: Invalid batch column at '0' has null but schema specifies non-nullable insert into table_without_values values(NULL, 300); @@ -297,6 +310,126 @@ select * from table_without_values; ---- 1 100 2 NULL +3 NULL statement ok drop table table_without_values; + + +### Test for creating tables into directories that do not already exist +# note use of `scratch` directory (which is cleared between runs) + +statement ok +create external table new_empty_table(x int) stored as parquet location 'test_files/scratch/insert/new_empty_table/'; -- needs trailing slash + +# should start empty +query I +select * from new_empty_table; +---- + +# should succeed and the table should create the direectory +statement ok +insert into new_empty_table values (1); + +# Now has values +query I +select * from new_empty_table; +---- +1 + +statement ok +drop table new_empty_table; + +## test we get an error if the path doesn't end in slash +statement ok +create external table bad_new_empty_table(x int) stored as parquet location 'test_files/scratch/insert/bad_new_empty_table'; -- no trailing slash + +# should fail +query error DataFusion error: Error during planning: Inserting into a ListingTable backed by a single file is not supported, URL is possibly missing a trailing `/`\. To append to an existing file use StreamTable, e\.g\. by using CREATE UNBOUNDED EXTERNAL TABLE +insert into bad_new_empty_table values (1); + +statement ok +drop table bad_new_empty_table; + + +### Test for specifying column's default value + +statement ok +create table test_column_defaults( + a int, + b int not null default null, + c int default 100*2+300, + d text default lower('DEFAULT_TEXT'), + e timestamp default now() +) + +query IIITP +insert into test_column_defaults values(1, 10, 100, 'ABC', now()) +---- +1 + +statement error DataFusion error: Execution error: Invalid batch column at '1' has null but schema specifies non-nullable +insert into test_column_defaults(a) values(2) + +query IIITP +insert into test_column_defaults(b) values(20) +---- +1 + +query IIIT rowsort +select a,b,c,d from test_column_defaults +---- +1 10 100 ABC +NULL 20 500 default_text + +# fill the timestamp column with default value `now()` again, it should be different from the previous one +query IIITP +insert into test_column_defaults(a, b, c, d) values(2, 20, 200, 'DEF') +---- +1 + +# Ensure that the default expression `now()` is evaluated during insertion, not optimized away. +# Rows are inserted during different time, so their timestamp values should be different. +query I rowsort +select count(distinct e) from test_column_defaults +---- +3 + +# Expect all rows to be true as now() was inserted into the table +query B rowsort +select e < now() from test_column_defaults +---- +true +true +true + +statement ok +drop table test_column_defaults + + +# test create table as +statement ok +create table test_column_defaults( + a int, + b int not null default null, + c int default 100*2+300, + d text default lower('DEFAULT_TEXT'), + e timestamp default now() +) as values(1, 10, 100, 'ABC', now()) + +query IIITP +insert into test_column_defaults(b) values(20) +---- +1 + +query IIIT rowsort +select a,b,c,d from test_column_defaults +---- +1 10 100 ABC +NULL 20 500 default_text + +statement ok +drop table test_column_defaults + +statement error DataFusion error: Error during planning: Column reference is not allowed in the DEFAULT expression : Schema error: No field named a. +create table test_column_defaults(a int, b int default a+1) diff --git a/datafusion/sqllogictest/test_files/insert_to_external.slt b/datafusion/sqllogictest/test_files/insert_to_external.slt index 8b01a14568e7c..e73778ad44e52 100644 --- a/datafusion/sqllogictest/test_files/insert_to_external.slt +++ b/datafusion/sqllogictest/test_files/insert_to_external.slt @@ -57,7 +57,7 @@ CREATE EXTERNAL TABLE dictionary_encoded_parquet_partitioned( b varchar, ) STORED AS parquet -LOCATION 'test_files/scratch/insert_to_external/parquet_types_partitioned' +LOCATION 'test_files/scratch/insert_to_external/parquet_types_partitioned/' PARTITIONED BY (b) OPTIONS( create_local_path 'true', @@ -76,6 +76,45 @@ select * from dictionary_encoded_parquet_partitioned order by (a); a foo b bar +statement ok +CREATE EXTERNAL TABLE dictionary_encoded_arrow_partitioned( + a varchar, + b varchar, +) +STORED AS arrow +LOCATION 'test_files/scratch/insert_to_external/arrow_dict_partitioned/' +PARTITIONED BY (b) +OPTIONS( +create_local_path 'true', +insert_mode 'append_new_files', +); + +query TT +insert into dictionary_encoded_arrow_partitioned +select * from dictionary_encoded_values +---- +2 + +statement ok +CREATE EXTERNAL TABLE dictionary_encoded_arrow_test_readback( + a varchar, +) +STORED AS arrow +LOCATION 'test_files/scratch/insert_to_external/arrow_dict_partitioned/b=bar/' +OPTIONS( +create_local_path 'true', +insert_mode 'append_new_files', +); + +query T +select * from dictionary_encoded_arrow_test_readback; +---- +b + +# https://github.com/apache/arrow-datafusion/issues/7816 +query error DataFusion error: Arrow error: Schema error: project index 1 out of bounds, max field 1 +select * from dictionary_encoded_arrow_partitioned order by (a); + # test_insert_into statement ok @@ -100,7 +139,7 @@ Dml: op=[Insert Into] table=[ordered_insert_test] --Projection: column1 AS a, column2 AS b ----Values: (Int64(5), Int64(1)), (Int64(4), Int64(2)), (Int64(7), Int64(7)), (Int64(7), Int64(8)), (Int64(7), Int64(9))... physical_plan -InsertExec: sink=CsvSink(writer_mode=PutMultipart, file_groups=[]) +FileSinkExec: sink=CsvSink(file_groups=[]) --SortExec: expr=[a@0 ASC NULLS LAST,b@1 DESC] ----ProjectionExec: expr=[column1@0 as a, column2@1 as b] ------ValuesExec @@ -254,6 +293,22 @@ create_local_path 'true', single_file 'true', ); +query error DataFusion error: Error during planning: Inserting into a ListingTable backed by a single file is not supported, URL is possibly missing a trailing `/`\. To append to an existing file use StreamTable, e\.g\. by using CREATE UNBOUNDED EXTERNAL TABLE +INSERT INTO single_file_test values (1, 2), (3, 4); + +statement ok +drop table single_file_test; + +statement ok +CREATE UNBOUNDED EXTERNAL TABLE +single_file_test(a bigint, b bigint) +STORED AS csv +LOCATION 'test_files/scratch/insert_to_external/single_csv_table.csv' +OPTIONS( +create_local_path 'true', +single_file 'true', +); + query II INSERT INTO single_file_test values (1, 2), (3, 4); ---- @@ -276,7 +331,7 @@ statement ok CREATE EXTERNAL TABLE directory_test(a bigint, b bigint) STORED AS parquet -LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q0' +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q0/' OPTIONS( create_local_path 'true', ); @@ -296,7 +351,7 @@ statement ok CREATE EXTERNAL TABLE table_without_values(field1 BIGINT NULL, field2 BIGINT NULL) STORED AS parquet -LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q1' +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q1/' OPTIONS (create_local_path 'true'); query TT @@ -315,7 +370,7 @@ Dml: op=[Insert Into] table=[table_without_values] --------WindowAggr: windowExpr=[[SUM(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] ----------TableScan: aggregate_test_100 projection=[c1, c4, c9] physical_plan -InsertExec: sink=ParquetSink(writer_mode=PutMultipart, file_groups=[]) +FileSinkExec: sink=ParquetSink(file_groups=[]) --ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as field2] ----SortPreservingMergeExec: [c1@2 ASC NULLS LAST] ------ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, c1@0 as c1] @@ -362,7 +417,7 @@ statement ok CREATE EXTERNAL TABLE table_without_values(field1 BIGINT NULL, field2 BIGINT NULL) STORED AS parquet -LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q2' +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q2/' OPTIONS (create_local_path 'true'); query TT @@ -378,7 +433,7 @@ Dml: op=[Insert Into] table=[table_without_values] ----WindowAggr: windowExpr=[[SUM(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] ------TableScan: aggregate_test_100 projection=[c1, c4, c9] physical_plan -InsertExec: sink=ParquetSink(writer_mode=PutMultipart, file_groups=[]) +FileSinkExec: sink=ParquetSink(file_groups=[]) --CoalescePartitionsExec ----ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as field2] ------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted] @@ -407,7 +462,7 @@ statement ok CREATE EXTERNAL TABLE table_without_values(c1 varchar NULL) STORED AS parquet -LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q3' +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q3/' OPTIONS (create_local_path 'true'); # verify that the sort order of the insert query is maintained into the @@ -422,10 +477,9 @@ Dml: op=[Insert Into] table=[table_without_values] ----Sort: aggregate_test_100.c1 ASC NULLS LAST ------TableScan: aggregate_test_100 projection=[c1] physical_plan -InsertExec: sink=ParquetSink(writer_mode=PutMultipart, file_groups=[]) ---ProjectionExec: expr=[c1@0 as c1] -----SortExec: expr=[c1@0 ASC NULLS LAST] -------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true +FileSinkExec: sink=ParquetSink(file_groups=[]) +--SortExec: expr=[c1@0 ASC NULLS LAST] +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true query T insert into table_without_values select c1 from aggregate_test_100 order by c1; @@ -447,7 +501,7 @@ statement ok CREATE EXTERNAL TABLE table_without_values(id BIGINT, name varchar) STORED AS parquet -LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q4' +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q4/' OPTIONS (create_local_path 'true'); query IT @@ -469,14 +523,18 @@ insert into table_without_values(name, id) values(4, 'zoo'); statement error Error during planning: Column count doesn't match insert query! insert into table_without_values(id) values(4, 'zoo'); -statement error Error during planning: Inserting query must have the same schema with the table. +# insert NULL values for the missing column (name) +query IT insert into table_without_values(id) values(4); +---- +1 query IT rowsort select * from table_without_values; ---- 1 foo 2 bar +4 NULL statement ok drop table table_without_values; @@ -486,7 +544,7 @@ statement ok CREATE EXTERNAL TABLE table_without_values(field1 BIGINT NOT NULL, field2 BIGINT NULL) STORED AS parquet -LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q5' +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q5/' OPTIONS (create_local_path 'true'); query II @@ -499,6 +557,16 @@ insert into table_without_values values(2, NULL); ---- 1 +# insert NULL values for the missing column (field2) +query II +insert into table_without_values(field1) values(3); +---- +1 + +# insert NULL values for the missing column (field1), but column is non-nullable +statement error Execution error: Invalid batch column at '0' has null but schema specifies non-nullable +insert into table_without_values(field2) values(300); + statement error Execution error: Invalid batch column at '0' has null but schema specifies non-nullable insert into table_without_values values(NULL, 300); @@ -510,6 +578,74 @@ select * from table_without_values; ---- 1 100 2 NULL +3 NULL statement ok drop table table_without_values; + + +### Test for specifying column's default value + +statement ok +CREATE EXTERNAL TABLE test_column_defaults( + a int, + b int not null default null, + c int default 100*2+300, + d text default lower('DEFAULT_TEXT'), + e timestamp default now() +) STORED AS parquet +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q6/' +OPTIONS (create_local_path 'true'); + +# fill in all column values +query IIITP +insert into test_column_defaults values(1, 10, 100, 'ABC', now()) +---- +1 + +statement error DataFusion error: Execution error: Invalid batch column at '1' has null but schema specifies non-nullable +insert into test_column_defaults(a) values(2) + +query IIITP +insert into test_column_defaults(b) values(20) +---- +1 + +query IIIT rowsort +select a,b,c,d from test_column_defaults +---- +1 10 100 ABC +NULL 20 500 default_text + +# fill the timestamp column with default value `now()` again, it should be different from the previous one +query IIITP +insert into test_column_defaults(a, b, c, d) values(2, 20, 200, 'DEF') +---- +1 + +# Ensure that the default expression `now()` is evaluated during insertion, not optimized away. +# Rows are inserted during different time, so their timestamp values should be different. +query I rowsort +select count(distinct e) from test_column_defaults +---- +3 + +# Expect all rows to be true as now() was inserted into the table +query B rowsort +select e < now() from test_column_defaults +---- +true +true +true + +statement ok +drop table test_column_defaults + +# test invalid default value +statement error DataFusion error: Error during planning: Column reference is not allowed in the DEFAULT expression : Schema error: No field named a. +CREATE EXTERNAL TABLE test_column_defaults( + a int, + b int default a+1 +) STORED AS parquet +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q7/' +OPTIONS (create_local_path 'true'); diff --git a/datafusion/sqllogictest/test_files/interval.slt b/datafusion/sqllogictest/test_files/interval.slt index 500876f76221c..f2ae2984f07b7 100644 --- a/datafusion/sqllogictest/test_files/interval.slt +++ b/datafusion/sqllogictest/test_files/interval.slt @@ -126,6 +126,86 @@ select interval '5' nanoseconds ---- 0 years 0 mons 0 days 0 hours 0 mins 0.000000005 secs +query ? +select interval '5 YEAR' +---- +0 years 60 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +select interval '5 MONTH' +---- +0 years 5 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +select interval '5 WEEK' +---- +0 years 0 mons 35 days 0 hours 0 mins 0.000000000 secs + +query ? +select interval '5 DAY' +---- +0 years 0 mons 5 days 0 hours 0 mins 0.000000000 secs + +query ? +select interval '5 HOUR' +---- +0 years 0 mons 0 days 5 hours 0 mins 0.000000000 secs + +query ? +select interval '5 HOURS' +---- +0 years 0 mons 0 days 5 hours 0 mins 0.000000000 secs + +query ? +select interval '5 MINUTE' +---- +0 years 0 mons 0 days 0 hours 5 mins 0.000000000 secs + +query ? +select interval '5 SECOND' +---- +0 years 0 mons 0 days 0 hours 0 mins 5.000000000 secs + +query ? +select interval '5 SECONDS' +---- +0 years 0 mons 0 days 0 hours 0 mins 5.000000000 secs + +query ? +select interval '5 MILLISECOND' +---- +0 years 0 mons 0 days 0 hours 0 mins 0.005000000 secs + +query ? +select interval '5 MILLISECONDS' +---- +0 years 0 mons 0 days 0 hours 0 mins 0.005000000 secs + +query ? +select interval '5 MICROSECOND' +---- +0 years 0 mons 0 days 0 hours 0 mins 0.000005000 secs + +query ? +select interval '5 MICROSECONDS' +---- +0 years 0 mons 0 days 0 hours 0 mins 0.000005000 secs + +query ? +select interval '5 NANOSECOND' +---- +0 years 0 mons 0 days 0 hours 0 mins 0.000000005 secs + +query ? +select interval '5 NANOSECONDS' +---- +0 years 0 mons 0 days 0 hours 0 mins 0.000000005 secs + +query ? +select interval '5 YEAR 5 MONTH 5 DAY 5 HOUR 5 MINUTE 5 SECOND 5 MILLISECOND 5 MICROSECOND 5 NANOSECOND' +---- +0 years 65 mons 5 days 5 hours 5 mins 5.005005005 secs + # Interval with string literal addition query ? select interval '1 month' + '1 month' diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index 874d849e9a29b..c9dd7ca604ad9 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -556,7 +556,7 @@ query TT explain select * from t1 join t2 on false; ---- logical_plan EmptyRelation -physical_plan EmptyExec: produce_one_row=false +physical_plan EmptyExec # Make batch size smaller than table row number. to introduce parallelism to the plan. statement ok @@ -594,3 +594,40 @@ drop table IF EXISTS full_join_test; # batch size statement ok set datafusion.execution.batch_size = 8192; + +# related to: https://github.com/apache/arrow-datafusion/issues/8374 +statement ok +CREATE TABLE t1(a text, b int) AS VALUES ('Alice', 50), ('Alice', 100); + +statement ok +CREATE TABLE t2(a text, b int) AS VALUES ('Alice', 2), ('Alice', 1); + +# the current query results are incorrect, becuase the query was incorrectly rewritten as: +# SELECT t1.a, t1.b FROM t1 JOIN t2 ON t1.a = t2.a ORDER BY t1.a, t1.b; +# the difference is ORDER BY clause rewrite from t2.b to t1.b, it is incorrect. +# after https://github.com/apache/arrow-datafusion/issues/8374 fixed, the correct result should be: +# Alice 50 +# Alice 100 +# Alice 50 +# Alice 100 +query TI +SELECT t1.a, t1.b FROM t1 JOIN t2 ON t1.a = t2.a ORDER BY t1.a, t2.b; +---- +Alice 50 +Alice 50 +Alice 100 +Alice 100 + +query TITI +SELECT t1.a, t1.b, t2.a, t2.b FROM t1 JOIN t2 ON t1.a = t2.a ORDER BY t1.a, t2.b; +---- +Alice 50 Alice 1 +Alice 100 Alice 1 +Alice 50 Alice 2 +Alice 100 Alice 2 + +statement ok +DROP TABLE t1; + +statement ok +DROP TABLE t2; diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index c794c4da43108..a7146a5a91c4a 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -140,6 +140,17 @@ SELECT FROM test_timestamps_table_source; +# create a table of timestamps with time zone +statement ok +CREATE TABLE test_timestamps_tz_table as +SELECT + arrow_cast(ts::timestamp::bigint, 'Timestamp(Nanosecond, Some("UTC"))') as nanos, + arrow_cast(ts::timestamp::bigint / 1000, 'Timestamp(Microsecond, Some("UTC"))') as micros, + arrow_cast(ts::timestamp::bigint / 1000000, 'Timestamp(Millisecond, Some("UTC"))') as millis, + arrow_cast(ts::timestamp::bigint / 1000000000, 'Timestamp(Second, Some("UTC"))') as secs, + names +FROM + test_timestamps_table_source; statement ok @@ -1443,17 +1454,16 @@ Projection: join_t1.t1_id, join_t1.t1_name, join_t1.t1_int, join_t2.t2_id, join_ ----TableScan: join_t1 projection=[t1_id, t1_name, t1_int] ----TableScan: join_t2 projection=[t2_id, t2_name, t2_int] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5 as t2_int, CAST(t1_id@0 AS Int64) + 11 as join_t1.t1_id + Int64(11)] ---ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@4 as t2_id, t2_name@5 as t2_name, t2_int@6 as t2_int] -----CoalesceBatchesExec: target_batch_size=2 -------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(join_t1.t1_id + Int64(11)@3, CAST(join_t2.t2_id AS Int64)@3)] ---------CoalescePartitionsExec -----------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, CAST(t1_id@0 AS Int64) + 11 as join_t1.t1_id + Int64(11)] -------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ---------------MemoryExec: partitions=1, partition_sizes=[1] ---------ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name, t2_int@2 as t2_int, CAST(t2_id@0 AS Int64) as CAST(join_t2.t2_id AS Int64)] +ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@4 as t2_id, t2_name@5 as t2_name, t2_int@6 as t2_int, CAST(t1_id@0 AS Int64) + 11 as join_t1.t1_id + Int64(11)] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(join_t1.t1_id + Int64(11)@3, CAST(join_t2.t2_id AS Int64)@3)] +------CoalescePartitionsExec +--------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, CAST(t1_id@0 AS Int64) + 11 as join_t1.t1_id + Int64(11)] ----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ------------MemoryExec: partitions=1, partition_sizes=[1] +------ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name, t2_int@2 as t2_int, CAST(t2_id@0 AS Int64) as CAST(join_t2.t2_id AS Int64)] +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] statement ok set datafusion.optimizer.repartition_joins = true; @@ -1470,20 +1480,19 @@ Projection: join_t1.t1_id, join_t1.t1_name, join_t1.t1_int, join_t2.t2_id, join_ ----TableScan: join_t1 projection=[t1_id, t1_name, t1_int] ----TableScan: join_t2 projection=[t2_id, t2_name, t2_int] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5 as t2_int, CAST(t1_id@0 AS Int64) + 11 as join_t1.t1_id + Int64(11)] ---ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@4 as t2_id, t2_name@5 as t2_name, t2_int@6 as t2_int] -----CoalesceBatchesExec: target_batch_size=2 -------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(join_t1.t1_id + Int64(11)@3, CAST(join_t2.t2_id AS Int64)@3)] ---------CoalesceBatchesExec: target_batch_size=2 -----------RepartitionExec: partitioning=Hash([join_t1.t1_id + Int64(11)@3], 2), input_partitions=2 -------------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, CAST(t1_id@0 AS Int64) + 11 as join_t1.t1_id + Int64(11)] ---------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -----------------MemoryExec: partitions=1, partition_sizes=[1] ---------CoalesceBatchesExec: target_batch_size=2 -----------RepartitionExec: partitioning=Hash([CAST(join_t2.t2_id AS Int64)@3], 2), input_partitions=2 -------------ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name, t2_int@2 as t2_int, CAST(t2_id@0 AS Int64) as CAST(join_t2.t2_id AS Int64)] ---------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -----------------MemoryExec: partitions=1, partition_sizes=[1] +ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@4 as t2_id, t2_name@5 as t2_name, t2_int@6 as t2_int, CAST(t1_id@0 AS Int64) + 11 as join_t1.t1_id + Int64(11)] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Inner, on=[(join_t1.t1_id + Int64(11)@3, CAST(join_t2.t2_id AS Int64)@3)] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([join_t1.t1_id + Int64(11)@3], 2), input_partitions=2 +----------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, CAST(t1_id@0 AS Int64) + 11 as join_t1.t1_id + Int64(11)] +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([CAST(join_t2.t2_id AS Int64)@3], 2), input_partitions=2 +----------ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name, t2_int@2 as t2_int, CAST(t2_id@0 AS Int64) as CAST(join_t2.t2_id AS Int64)] +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] # Both side expr key inner join @@ -1502,18 +1511,16 @@ Projection: join_t1.t1_id, join_t2.t2_id, join_t1.t1_name ----TableScan: join_t1 projection=[t1_id, t1_name] ----TableScan: join_t2 projection=[t2_id] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name] ---ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@3 as t2_id] -----ProjectionExec: expr=[t1_id@2 as t1_id, t1_name@3 as t1_name, join_t1.t1_id + UInt32(12)@4 as join_t1.t1_id + UInt32(12), t2_id@0 as t2_id, join_t2.t2_id + UInt32(1)@1 as join_t2.t2_id + UInt32(1)] -------CoalesceBatchesExec: target_batch_size=2 ---------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(join_t2.t2_id + UInt32(1)@1, join_t1.t1_id + UInt32(12)@2)] -----------CoalescePartitionsExec -------------ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 + 1 as join_t2.t2_id + UInt32(1)] ---------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -----------------MemoryExec: partitions=1, partition_sizes=[1] -----------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 12 as join_t1.t1_id + UInt32(12)] -------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ---------------MemoryExec: partitions=1, partition_sizes=[1] +ProjectionExec: expr=[t1_id@2 as t1_id, t2_id@0 as t2_id, t1_name@3 as t1_name] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(join_t2.t2_id + UInt32(1)@1, join_t1.t1_id + UInt32(12)@2)] +------CoalescePartitionsExec +--------ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 + 1 as join_t2.t2_id + UInt32(1)] +----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] +------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 12 as join_t1.t1_id + UInt32(12)] +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] statement ok set datafusion.optimizer.repartition_joins = true; @@ -1530,21 +1537,19 @@ Projection: join_t1.t1_id, join_t2.t2_id, join_t1.t1_name ----TableScan: join_t1 projection=[t1_id, t1_name] ----TableScan: join_t2 projection=[t2_id] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name] ---ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@3 as t2_id] -----ProjectionExec: expr=[t1_id@2 as t1_id, t1_name@3 as t1_name, join_t1.t1_id + UInt32(12)@4 as join_t1.t1_id + UInt32(12), t2_id@0 as t2_id, join_t2.t2_id + UInt32(1)@1 as join_t2.t2_id + UInt32(1)] +ProjectionExec: expr=[t1_id@2 as t1_id, t2_id@0 as t2_id, t1_name@3 as t1_name] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Inner, on=[(join_t2.t2_id + UInt32(1)@1, join_t1.t1_id + UInt32(12)@2)] ------CoalesceBatchesExec: target_batch_size=2 ---------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(join_t2.t2_id + UInt32(1)@1, join_t1.t1_id + UInt32(12)@2)] -----------CoalesceBatchesExec: target_batch_size=2 -------------RepartitionExec: partitioning=Hash([join_t2.t2_id + UInt32(1)@1], 2), input_partitions=2 ---------------ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 + 1 as join_t2.t2_id + UInt32(1)] -----------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -------------------MemoryExec: partitions=1, partition_sizes=[1] -----------CoalesceBatchesExec: target_batch_size=2 -------------RepartitionExec: partitioning=Hash([join_t1.t1_id + UInt32(12)@2], 2), input_partitions=2 ---------------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 12 as join_t1.t1_id + UInt32(12)] -----------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -------------------MemoryExec: partitions=1, partition_sizes=[1] +--------RepartitionExec: partitioning=Hash([join_t2.t2_id + UInt32(1)@1], 2), input_partitions=2 +----------ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 + 1 as join_t2.t2_id + UInt32(1)] +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([join_t1.t1_id + UInt32(12)@2], 2), input_partitions=2 +----------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 12 as join_t1.t1_id + UInt32(12)] +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] # Left side expr key inner join @@ -1564,14 +1569,11 @@ Projection: join_t1.t1_id, join_t2.t2_id, join_t1.t1_name ----TableScan: join_t1 projection=[t1_id, t1_name] ----TableScan: join_t2 projection=[t2_id] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name] ---ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@3 as t2_id] -----CoalesceBatchesExec: target_batch_size=2 -------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(join_t1.t1_id + UInt32(11)@2, t2_id@0)] ---------CoalescePartitionsExec -----------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 11 as join_t1.t1_id + UInt32(11)] -------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ---------------MemoryExec: partitions=1, partition_sizes=[1] +ProjectionExec: expr=[t1_id@1 as t1_id, t2_id@0 as t2_id, t1_name@2 as t1_name] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(t2_id@0, join_t1.t1_id + UInt32(11)@2)] +------MemoryExec: partitions=1, partition_sizes=[1] +------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 11 as join_t1.t1_id + UInt32(11)] --------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ----------MemoryExec: partitions=1, partition_sizes=[1] @@ -1591,17 +1593,16 @@ Projection: join_t1.t1_id, join_t2.t2_id, join_t1.t1_name ----TableScan: join_t1 projection=[t1_id, t1_name] ----TableScan: join_t2 projection=[t2_id] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name] ---ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@3 as t2_id] -----CoalesceBatchesExec: target_batch_size=2 -------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(join_t1.t1_id + UInt32(11)@2, t2_id@0)] ---------CoalesceBatchesExec: target_batch_size=2 -----------RepartitionExec: partitioning=Hash([join_t1.t1_id + UInt32(11)@2], 2), input_partitions=2 -------------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 11 as join_t1.t1_id + UInt32(11)] ---------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -----------------MemoryExec: partitions=1, partition_sizes=[1] ---------CoalesceBatchesExec: target_batch_size=2 -----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +ProjectionExec: expr=[t1_id@1 as t1_id, t2_id@0 as t2_id, t1_name@2 as t1_name] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Inner, on=[(t2_id@0, join_t1.t1_id + UInt32(11)@2)] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([join_t1.t1_id + UInt32(11)@2], 2), input_partitions=2 +----------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 11 as join_t1.t1_id + UInt32(11)] ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------MemoryExec: partitions=1, partition_sizes=[1] @@ -1623,17 +1624,15 @@ Projection: join_t1.t1_id, join_t2.t2_id, join_t1.t1_name ----TableScan: join_t1 projection=[t1_id, t1_name] ----TableScan: join_t2 projection=[t2_id] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name] ---ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@2 as t2_id] -----ProjectionExec: expr=[t1_id@2 as t1_id, t1_name@3 as t1_name, t2_id@0 as t2_id, join_t2.t2_id - UInt32(11)@1 as join_t2.t2_id - UInt32(11)] -------CoalesceBatchesExec: target_batch_size=2 ---------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(join_t2.t2_id - UInt32(11)@1, t1_id@0)] -----------CoalescePartitionsExec -------------ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 - 11 as join_t2.t2_id - UInt32(11)] ---------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -----------------MemoryExec: partitions=1, partition_sizes=[1] +ProjectionExec: expr=[t1_id@2 as t1_id, t2_id@0 as t2_id, t1_name@3 as t1_name] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(join_t2.t2_id - UInt32(11)@1, t1_id@0)] +------CoalescePartitionsExec +--------ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 - 11 as join_t2.t2_id - UInt32(11)] ----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ------------MemoryExec: partitions=1, partition_sizes=[1] +------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------MemoryExec: partitions=1, partition_sizes=[1] statement ok set datafusion.optimizer.repartition_joins = true; @@ -1651,20 +1650,18 @@ Projection: join_t1.t1_id, join_t2.t2_id, join_t1.t1_name ----TableScan: join_t1 projection=[t1_id, t1_name] ----TableScan: join_t2 projection=[t2_id] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name] ---ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@2 as t2_id] -----ProjectionExec: expr=[t1_id@2 as t1_id, t1_name@3 as t1_name, t2_id@0 as t2_id, join_t2.t2_id - UInt32(11)@1 as join_t2.t2_id - UInt32(11)] +ProjectionExec: expr=[t1_id@2 as t1_id, t2_id@0 as t2_id, t1_name@3 as t1_name] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Inner, on=[(join_t2.t2_id - UInt32(11)@1, t1_id@0)] ------CoalesceBatchesExec: target_batch_size=2 ---------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(join_t2.t2_id - UInt32(11)@1, t1_id@0)] -----------CoalesceBatchesExec: target_batch_size=2 -------------RepartitionExec: partitioning=Hash([join_t2.t2_id - UInt32(11)@1], 2), input_partitions=2 ---------------ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 - 11 as join_t2.t2_id - UInt32(11)] -----------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -------------------MemoryExec: partitions=1, partition_sizes=[1] -----------CoalesceBatchesExec: target_batch_size=2 -------------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 ---------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -----------------MemoryExec: partitions=1, partition_sizes=[1] +--------RepartitionExec: partitioning=Hash([join_t2.t2_id - UInt32(11)@1], 2), input_partitions=2 +----------ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 - 11 as join_t2.t2_id - UInt32(11)] +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 +----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] # Select wildcard with expr key inner join @@ -2474,6 +2471,16 @@ test_timestamps_table NULL NULL NULL NULL Row 2 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 +# show the contents of the timestamp with timezone table +query PPPPT +select * from +test_timestamps_tz_table +---- +2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 +2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 +NULL NULL NULL NULL Row 2 +2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 + # test timestamp join on nanos datatype query PPPPTPPPPT rowsort SELECT * FROM test_timestamps_table as t1 JOIN (SELECT * FROM test_timestamps_table ) as t2 ON t1.nanos = t2.nanos; @@ -2482,6 +2489,14 @@ SELECT * FROM test_timestamps_table as t1 JOIN (SELECT * FROM test_timestamps_ta 2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 +# test timestamp with timezone join on nanos datatype +query PPPPTPPPPT rowsort +SELECT * FROM test_timestamps_tz_table as t1 JOIN (SELECT * FROM test_timestamps_tz_table ) as t2 ON t1.nanos = t2.nanos; +---- +2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 +2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 +2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 + # test timestamp join on micros datatype query PPPPTPPPPT rowsort SELECT * FROM test_timestamps_table as t1 JOIN (SELECT * FROM test_timestamps_table ) as t2 ON t1.micros = t2.micros @@ -2490,6 +2505,14 @@ SELECT * FROM test_timestamps_table as t1 JOIN (SELECT * FROM test_timestamps_ta 2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 +# test timestamp with timezone join on micros datatype +query PPPPTPPPPT rowsort +SELECT * FROM test_timestamps_tz_table as t1 JOIN (SELECT * FROM test_timestamps_tz_table ) as t2 ON t1.micros = t2.micros +---- +2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 +2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 +2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 + # test timestamp join on millis datatype query PPPPTPPPPT rowsort SELECT * FROM test_timestamps_table as t1 JOIN (SELECT * FROM test_timestamps_table ) as t2 ON t1.millis = t2.millis @@ -2498,6 +2521,46 @@ SELECT * FROM test_timestamps_table as t1 JOIN (SELECT * FROM test_timestamps_ta 2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 +# test timestamp with timezone join on millis datatype +query PPPPTPPPPT rowsort +SELECT * FROM test_timestamps_tz_table as t1 JOIN (SELECT * FROM test_timestamps_tz_table ) as t2 ON t1.millis = t2.millis +---- +2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 +2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 +2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 + +#### +# Config setup +#### + +statement ok +set datafusion.explain.logical_plan_only = false; + +statement ok +set datafusion.optimizer.prefer_hash_join = true; + +# explain hash join on timestamp with timezone type +query TT +EXPLAIN SELECT * FROM test_timestamps_tz_table as t1 JOIN test_timestamps_tz_table as t2 ON t1.millis = t2.millis +---- +logical_plan +Inner Join: t1.millis = t2.millis +--SubqueryAlias: t1 +----TableScan: test_timestamps_tz_table projection=[nanos, micros, millis, secs, names] +--SubqueryAlias: t2 +----TableScan: test_timestamps_tz_table projection=[nanos, micros, millis, secs, names] +physical_plan +CoalesceBatchesExec: target_batch_size=2 +--HashJoinExec: mode=Partitioned, join_type=Inner, on=[(millis@2, millis@2)] +----CoalesceBatchesExec: target_batch_size=2 +------RepartitionExec: partitioning=Hash([millis@2], 2), input_partitions=2 +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] +----CoalesceBatchesExec: target_batch_size=2 +------RepartitionExec: partitioning=Hash([millis@2], 2), input_partitions=2 +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] + # left_join_using_2 query II SELECT t1.c1, t2.c2 FROM test_partition_table t1 JOIN test_partition_table t2 USING (c2) ORDER BY t2.c2; @@ -2756,13 +2819,13 @@ physical_plan SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] --SortExec: expr=[t1_id@0 ASC NULLS LAST] ----CoalesceBatchesExec: target_batch_size=2 -------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(t1_id@0, t2_id@0)] +------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] --------CoalesceBatchesExec: target_batch_size=2 -----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 +----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------MemoryExec: partitions=1, partition_sizes=[1] --------CoalesceBatchesExec: target_batch_size=2 -----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------MemoryExec: partitions=1, partition_sizes=[1] @@ -2797,13 +2860,13 @@ physical_plan SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] --SortExec: expr=[t1_id@0 ASC NULLS LAST] ----CoalesceBatchesExec: target_batch_size=2 -------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(t1_id@0, t2_id@0)] +------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] --------CoalesceBatchesExec: target_batch_size=2 -----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 +----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------MemoryExec: partitions=1, partition_sizes=[1] --------CoalesceBatchesExec: target_batch_size=2 -----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------MemoryExec: partitions=1, partition_sizes=[1] @@ -2859,7 +2922,7 @@ physical_plan SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] --SortExec: expr=[t1_id@0 ASC NULLS LAST] ----CoalesceBatchesExec: target_batch_size=2 -------HashJoinExec: mode=CollectLeft, join_type=LeftSemi, on=[(t1_id@0, t2_id@0)] +------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] --------MemoryExec: partitions=1, partition_sizes=[1] --------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ----------MemoryExec: partitions=1, partition_sizes=[1] @@ -2895,7 +2958,7 @@ physical_plan SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] --SortExec: expr=[t1_id@0 ASC NULLS LAST] ----CoalesceBatchesExec: target_batch_size=2 -------HashJoinExec: mode=CollectLeft, join_type=LeftSemi, on=[(t1_id@0, t2_id@0)] +------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] --------MemoryExec: partitions=1, partition_sizes=[1] --------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ----------MemoryExec: partitions=1, partition_sizes=[1] @@ -3349,16 +3412,15 @@ Projection: amount_usd ----------------SubqueryAlias: r ------------------TableScan: multiple_ordered_table projection=[a, d] physical_plan -ProjectionExec: expr=[amount_usd@0 as amount_usd] ---ProjectionExec: expr=[LAST_VALUE(l.d) ORDER BY [l.a ASC NULLS LAST]@1 as amount_usd, row_n@0 as row_n] -----AggregateExec: mode=Single, gby=[row_n@2 as row_n], aggr=[LAST_VALUE(l.d)], ordering_mode=Sorted -------ProjectionExec: expr=[a@0 as a, d@1 as d, row_n@4 as row_n] ---------CoalesceBatchesExec: target_batch_size=2 -----------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(d@1, d@1)], filter=CAST(a@0 AS Int64) >= CAST(a@1 AS Int64) - 10 -------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true -------------ProjectionExec: expr=[a@0 as a, d@1 as d, ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as row_n] ---------------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] -----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true +ProjectionExec: expr=[LAST_VALUE(l.d) ORDER BY [l.a ASC NULLS LAST]@1 as amount_usd] +--AggregateExec: mode=Single, gby=[row_n@2 as row_n], aggr=[LAST_VALUE(l.d)], ordering_mode=Sorted +----ProjectionExec: expr=[a@0 as a, d@1 as d, row_n@4 as row_n] +------CoalesceBatchesExec: target_batch_size=2 +--------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(d@1, d@1)], filter=CAST(a@0 AS Int64) >= CAST(a@1 AS Int64) - 10 +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true +----------ProjectionExec: expr=[a@0 as a, d@1 as d, ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as row_n] +------------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] +--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true # run query above in multiple partitions statement ok @@ -3392,7 +3454,7 @@ SortPreservingMergeExec: [a@0 ASC] ------AggregateExec: mode=FinalPartitioned, gby=[a@0 as a, b@1 as b, c@2 as c], aggr=[LAST_VALUE(r.b)] --------CoalesceBatchesExec: target_batch_size=2 ----------RepartitionExec: partitioning=Hash([a@0, b@1, c@2], 2), input_partitions=2 -------------AggregateExec: mode=Partial, gby=[a@0 as a, b@1 as b, c@2 as c], aggr=[LAST_VALUE(r.b)], ordering_mode=PartiallySorted([0]) +------------AggregateExec: mode=Partial, gby=[a@0 as a, b@1 as b, c@2 as c], aggr=[LAST_VALUE(r.b)] --------------CoalesceBatchesExec: target_batch_size=2 ----------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0)] ------------------CoalesceBatchesExec: target_batch_size=2 @@ -3400,10 +3462,49 @@ SortPreservingMergeExec: [a@0 ASC] ----------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], output_ordering=[a@0 ASC, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true ------------------CoalesceBatchesExec: target_batch_size=2 ---------------------SortPreservingRepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2, sort_exprs=a@0 ASC,b@1 ASC NULLS LAST +--------------------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 ----------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], output_ordering=[a@0 ASC, b@1 ASC NULLS LAST], has_header=true +query TT +EXPLAIN SELECT * +FROM annotated_data as l, annotated_data as r +WHERE l.a > r.a +---- +logical_plan +Inner Join: Filter: l.a > r.a +--SubqueryAlias: l +----TableScan: annotated_data projection=[a0, a, b, c, d] +--SubqueryAlias: r +----TableScan: annotated_data projection=[a0, a, b, c, d] +physical_plan +NestedLoopJoinExec: join_type=Inner, filter=a@0 > a@1 +--RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true + +# Currently datafusion cannot pushdown filter conditions with scalar UDF into +# cross join. +query TT +EXPLAIN SELECT * +FROM annotated_data as t1, annotated_data as t2 +WHERE EXAMPLE(t1.a, t2.a) > 3 +---- +logical_plan +Filter: example(CAST(t1.a AS Float64), CAST(t2.a AS Float64)) > Float64(3) +--CrossJoin: +----SubqueryAlias: t1 +------TableScan: annotated_data projection=[a0, a, b, c, d] +----SubqueryAlias: t2 +------TableScan: annotated_data projection=[a0, a, b, c, d] +physical_plan +CoalesceBatchesExec: target_batch_size=2 +--FilterExec: example(CAST(a@1 AS Float64), CAST(a@6 AS Float64)) > 3 +----CrossJoinExec +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true + #### # Config teardown #### @@ -3422,4 +3523,3 @@ set datafusion.optimizer.prefer_existing_sort = false; statement ok drop table annotated_data; - diff --git a/datafusion/sqllogictest/test_files/limit.slt b/datafusion/sqllogictest/test_files/limit.slt index 21248ddbd8d7d..e063d6e8960af 100644 --- a/datafusion/sqllogictest/test_files/limit.slt +++ b/datafusion/sqllogictest/test_files/limit.slt @@ -312,7 +312,7 @@ Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] ----TableScan: t1 projection=[], fetch=14 physical_plan ProjectionExec: expr=[0 as COUNT(*)] ---EmptyExec: produce_one_row=true +--PlaceholderRowExec query I SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 11); @@ -330,7 +330,7 @@ Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] ----TableScan: t1 projection=[], fetch=11 physical_plan ProjectionExec: expr=[2 as COUNT(*)] ---EmptyExec: produce_one_row=true +--PlaceholderRowExec query I SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8); @@ -348,7 +348,7 @@ Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] ----TableScan: t1 projection=[] physical_plan ProjectionExec: expr=[2 as COUNT(*)] ---EmptyExec: produce_one_row=true +--PlaceholderRowExec query I SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8); @@ -361,24 +361,145 @@ EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 WHERE a > 3 LIMIT 3 OFFSET 6); ---- logical_plan Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] ---Limit: skip=6, fetch=3 -----Filter: t1.a > Int32(3) -------TableScan: t1 projection=[a] +--Projection: +----Limit: skip=6, fetch=3 +------Filter: t1.a > Int32(3) +--------TableScan: t1 projection=[a] physical_plan AggregateExec: mode=Final, gby=[], aggr=[COUNT(*)] --CoalescePartitionsExec ----AggregateExec: mode=Partial, gby=[], aggr=[COUNT(*)] ------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------GlobalLimitExec: skip=6, fetch=3 -----------CoalesceBatchesExec: target_batch_size=8192 -------------FilterExec: a@0 > 3 ---------------MemoryExec: partitions=1, partition_sizes=[1] +--------ProjectionExec: expr=[] +----------GlobalLimitExec: skip=6, fetch=3 +------------CoalesceBatchesExec: target_batch_size=8192 +--------------FilterExec: a@0 > 3 +----------------MemoryExec: partitions=1, partition_sizes=[1] query I SELECT COUNT(*) FROM (SELECT a FROM t1 WHERE a > 3 LIMIT 3 OFFSET 6); ---- 1 +# generate BIGINT data from 1 to 1000 in multiple partitions +statement ok +CREATE TABLE t1000 (i BIGINT) AS +WITH t AS (VALUES (0), (0), (0), (0), (0), (0), (0), (0), (0), (0)) +SELECT ROW_NUMBER() OVER (PARTITION BY t1.column1) FROM t t1, t t2, t t3; + +# verify that there are multiple partitions in the input (i.e. MemoryExec says +# there are 4 partitions) so that this tests multi-partition limit. +query TT +EXPLAIN SELECT DISTINCT i FROM t1000; +---- +logical_plan +Aggregate: groupBy=[[t1000.i]], aggr=[[]] +--TableScan: t1000 projection=[i] +physical_plan +AggregateExec: mode=FinalPartitioned, gby=[i@0 as i], aggr=[] +--CoalesceBatchesExec: target_batch_size=8192 +----RepartitionExec: partitioning=Hash([i@0], 4), input_partitions=4 +------AggregateExec: mode=Partial, gby=[i@0 as i], aggr=[] +--------MemoryExec: partitions=4, partition_sizes=[1, 1, 2, 1] + +query I +SELECT i FROM t1000 ORDER BY i DESC LIMIT 3; +---- +1000 +999 +998 + +query I +SELECT i FROM t1000 ORDER BY i LIMIT 3; +---- +1 +2 +3 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t1000 LIMIT 3); +---- +3 + +# limit_multi_partitions +statement ok +CREATE TABLE t15 (i BIGINT); + +query I +INSERT INTO t15 VALUES (1); +---- +1 + +query I +INSERT INTO t15 VALUES (1), (2); +---- +2 + +query I +INSERT INTO t15 VALUES (1), (2), (3); +---- +3 + +query I +INSERT INTO t15 VALUES (1), (2), (3), (4); +---- +4 + +query I +INSERT INTO t15 VALUES (1), (2), (3), (4), (5); +---- +5 + +query I +SELECT COUNT(*) FROM t15; +---- +15 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 1); +---- +1 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 2); +---- +2 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 3); +---- +3 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 4); +---- +4 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 5); +---- +5 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 6); +---- +6 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 7); +---- +7 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 8); +---- +8 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 9); +---- +9 + ######## # Clean up after the test ######## diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index c3d16fca904e0..7863bf4454997 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -44,3 +44,22 @@ DELETE 24 query T SELECT strings['not_found'] FROM data LIMIT 1; ---- + +statement ok +drop table data; + + +# Testing explain on a table with a map filter, registered in test_context.rs. +query TT +explain select * from table_with_map where int_field > 0; +---- +logical_plan +Filter: table_with_map.int_field > Int64(0) +--TableScan: table_with_map projection=[int_field, map_field] +physical_plan +CoalesceBatchesExec: target_batch_size=8192 +--FilterExec: int_field@0 > 0 +----MemoryExec: partitions=1, partition_sizes=[0] + +statement ok +drop table table_with_map; diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index ee1e345f946a8..0fa7ff9c20511 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -293,53 +293,52 @@ select c1*0, c2*0, c3*0, c4*0, c5*0, c6*0, c7*0, c8*0 from test_non_nullable_int ---- 0 0 0 0 0 0 0 0 -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c1/0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c2/0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c3/0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c4/0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c5/0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c6/0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c7/0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c8/0 FROM test_non_nullable_integer - -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c1%0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c2%0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c3%0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c4%0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c5%0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c6%0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c7%0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c8%0 FROM test_non_nullable_integer statement ok @@ -557,10 +556,10 @@ SELECT c1*0 FROM test_non_nullable_decimal ---- 0 -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c1/0 FROM test_non_nullable_decimal -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c1%0 FROM test_non_nullable_decimal statement ok diff --git a/datafusion/sqllogictest/test_files/metadata.slt b/datafusion/sqllogictest/test_files/metadata.slt new file mode 100644 index 0000000000000..3b2b219244f55 --- /dev/null +++ b/datafusion/sqllogictest/test_files/metadata.slt @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## Tests for tables that has both metadata on each field as well as metadata on +## the schema itself. +########## + +## Note that table_with_metadata is defined using Rust code +## in the test harness as there is no way to define schema +## with metadata in SQL. + +query IT +select * from table_with_metadata; +---- +1 NULL +NULL bar +3 baz + +query I rowsort +SELECT ( + SELECT id FROM table_with_metadata + ) UNION ( + SELECT id FROM table_with_metadata + ); +---- +1 +3 +NULL + +query I rowsort +SELECT "data"."id" +FROM + ( + (SELECT "id" FROM "table_with_metadata") + UNION + (SELECT "id" FROM "table_with_metadata") + ) as "data", + ( + SELECT "id" FROM "table_with_metadata" + ) as "samples" +WHERE "data"."id" = "samples"."id"; +---- +1 +3 + +statement ok +drop table table_with_metadata; diff --git a/datafusion/sqllogictest/test_files/options.slt b/datafusion/sqllogictest/test_files/options.slt index 83fe85745ef87..9366a9b3b3c8f 100644 --- a/datafusion/sqllogictest/test_files/options.slt +++ b/datafusion/sqllogictest/test_files/options.slt @@ -84,7 +84,7 @@ statement ok drop table a # test datafusion.sql_parser.parse_float_as_decimal -# +# # default option value is false query RR select 10000000000000000000.01, -10000000000000000000.01 @@ -209,5 +209,3 @@ select -123456789.0123456789012345678901234567890 # Restore option to default value statement ok set datafusion.sql_parser.parse_float_as_decimal = false; - - diff --git a/datafusion/sqllogictest/test_files/order.slt b/datafusion/sqllogictest/test_files/order.slt index 8148f1c4c7c9d..77df9e0bb4937 100644 --- a/datafusion/sqllogictest/test_files/order.slt +++ b/datafusion/sqllogictest/test_files/order.slt @@ -441,13 +441,13 @@ physical_plan SortPreservingMergeExec: [result@0 ASC NULLS LAST] --ProjectionExec: expr=[b@1 + a@0 + c@2 as result] ----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], output_ordering=[a@0 ASC NULLS LAST], has_header=true +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], output_orderings=[[a@0 ASC NULLS LAST], [b@1 ASC NULLS LAST], [c@2 ASC NULLS LAST]], has_header=true statement ok drop table multiple_ordered_table; # Create tables having some ordered columns. In the next step, we will expect to observe that scalar -# functions, such as mathematical functions like atan(), ceil(), sqrt(), or date_time functions +# functions, such as mathematical functions like atan(), ceil(), sqrt(), or date_time functions # like date_bin() and date_trunc(), will maintain the order of its argument columns. statement ok CREATE EXTERNAL TABLE csv_with_timestamps ( @@ -559,7 +559,7 @@ physical_plan SortPreservingMergeExec: [log_c11_base_c12@0 ASC NULLS LAST] --ProjectionExec: expr=[log(CAST(c11@0 AS Float64), c12@1) as log_c11_base_c12] ----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11, c12], output_ordering=[c11@0 ASC NULLS LAST], has_header=true +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11, c12], output_orderings=[[c11@0 ASC NULLS LAST], [c12@1 DESC]], has_header=true query TT EXPLAIN SELECT LOG(c12, c11) as log_c12_base_c11 @@ -574,7 +574,7 @@ physical_plan SortPreservingMergeExec: [log_c12_base_c11@0 DESC] --ProjectionExec: expr=[log(c12@1, CAST(c11@0 AS Float64)) as log_c12_base_c11] ----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11, c12], output_ordering=[c11@0 ASC NULLS LAST], has_header=true +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11, c12], output_orderings=[[c11@0 ASC NULLS LAST], [c12@1 DESC]], has_header=true statement ok drop table aggregate_test_100; diff --git a/datafusion/sqllogictest/test_files/parquet.slt b/datafusion/sqllogictest/test_files/parquet.slt new file mode 100644 index 0000000000000..0f26c14f00179 --- /dev/null +++ b/datafusion/sqllogictest/test_files/parquet.slt @@ -0,0 +1,357 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# TESTS FOR PARQUET FILES + +# Set 2 partitions for deterministic output plans +statement ok +set datafusion.execution.target_partitions = 2; + +# Create a table as a data source +statement ok +CREATE TABLE src_table ( + int_col INT, + string_col TEXT, + bigint_col BIGINT, + date_col DATE +) AS VALUES +(1, 'aaa', 100, 1), +(2, 'bbb', 200, 2), +(3, 'ccc', 300, 3), +(4, 'ddd', 400, 4), +(5, 'eee', 500, 5), +(6, 'fff', 600, 6), +(7, 'ggg', 700, 7), +(8, 'hhh', 800, 8), +(9, 'iii', 900, 9); + +# Setup 2 files, i.e., as many as there are partitions: + +# File 1: +query ITID +COPY (SELECT * FROM src_table LIMIT 3) +TO 'test_files/scratch/parquet/test_table/0.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); +---- +3 + +# File 2: +query ITID +COPY (SELECT * FROM src_table WHERE int_col > 3 LIMIT 3) +TO 'test_files/scratch/parquet/test_table/1.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); +---- +3 + +# Create a table from generated parquet files, without ordering: +statement ok +CREATE EXTERNAL TABLE test_table ( + int_col INT, + string_col TEXT, + bigint_col BIGINT, + date_col DATE +) +STORED AS PARQUET +WITH HEADER ROW +LOCATION 'test_files/scratch/parquet/test_table'; + +# Basic query: +query ITID +SELECT * FROM test_table ORDER BY int_col; +---- +1 aaa 100 1970-01-02 +2 bbb 200 1970-01-03 +3 ccc 300 1970-01-04 +4 ddd 400 1970-01-05 +5 eee 500 1970-01-06 +6 fff 600 1970-01-07 + +# Check output plan, expect no "output_ordering" clause in the physical_plan -> ParquetExec: +query TT +EXPLAIN SELECT int_col, string_col +FROM test_table +ORDER BY string_col, int_col; +---- +logical_plan +Sort: test_table.string_col ASC NULLS LAST, test_table.int_col ASC NULLS LAST +--TableScan: test_table projection=[int_col, string_col] +physical_plan +SortPreservingMergeExec: [string_col@1 ASC NULLS LAST,int_col@0 ASC NULLS LAST] +--SortExec: expr=[string_col@1 ASC NULLS LAST,int_col@0 ASC NULLS LAST] +----ParquetExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/0.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/1.parquet]]}, projection=[int_col, string_col] + +# Tear down test_table: +statement ok +DROP TABLE test_table; + +# Create test_table again, but with ordering: +statement ok +CREATE EXTERNAL TABLE test_table ( + int_col INT, + string_col TEXT, + bigint_col BIGINT, + date_col DATE +) +STORED AS PARQUET +WITH HEADER ROW +WITH ORDER (string_col ASC NULLS LAST, int_col ASC NULLS LAST) +LOCATION 'test_files/scratch/parquet/test_table'; + +# Check output plan, expect an "output_ordering" clause in the physical_plan -> ParquetExec: +query TT +EXPLAIN SELECT int_col, string_col +FROM test_table +ORDER BY string_col, int_col; +---- +logical_plan +Sort: test_table.string_col ASC NULLS LAST, test_table.int_col ASC NULLS LAST +--TableScan: test_table projection=[int_col, string_col] +physical_plan +SortPreservingMergeExec: [string_col@1 ASC NULLS LAST,int_col@0 ASC NULLS LAST] +--ParquetExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/0.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/1.parquet]]}, projection=[int_col, string_col], output_ordering=[string_col@1 ASC NULLS LAST, int_col@0 ASC NULLS LAST] + +# Add another file to the directory underlying test_table +query ITID +COPY (SELECT * FROM src_table WHERE int_col > 6 LIMIT 3) +TO 'test_files/scratch/parquet/test_table/2.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); +---- +3 + +# Check output plan again, expect no "output_ordering" clause in the physical_plan -> ParquetExec, +# due to there being more files than partitions: +query TT +EXPLAIN SELECT int_col, string_col +FROM test_table +ORDER BY string_col, int_col; +---- +logical_plan +Sort: test_table.string_col ASC NULLS LAST, test_table.int_col ASC NULLS LAST +--TableScan: test_table projection=[int_col, string_col] +physical_plan +SortPreservingMergeExec: [string_col@1 ASC NULLS LAST,int_col@0 ASC NULLS LAST] +--SortExec: expr=[string_col@1 ASC NULLS LAST,int_col@0 ASC NULLS LAST] +----ParquetExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/0.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/2.parquet]]}, projection=[int_col, string_col] + + +# Perform queries using MIN and MAX +query I +SELECT max(int_col) FROM test_table; +---- +9 + +query T +SELECT min(string_col) FROM test_table; +---- +aaa + +query I +SELECT max(bigint_col) FROM test_table; +---- +900 + +query D +SELECT min(date_col) FROM test_table; +---- +1970-01-02 + +# Clean up +statement ok +DROP TABLE test_table; + +# Setup alltypes_plain table: +statement ok +CREATE EXTERNAL TABLE alltypes_plain ( + id INT NOT NULL, + bool_col BOOLEAN NOT NULL, + tinyint_col TINYINT NOT NULL, + smallint_col SMALLINT NOT NULL, + int_col INT NOT NULL, + bigint_col BIGINT NOT NULL, + float_col FLOAT NOT NULL, + double_col DOUBLE NOT NULL, + date_string_col BYTEA NOT NULL, + string_col VARCHAR NOT NULL, + timestamp_col TIMESTAMP NOT NULL, +) +STORED AS PARQUET +WITH HEADER ROW +LOCATION '../../parquet-testing/data/alltypes_plain.parquet' + +# Test a basic query with a CAST: +query IT +SELECT id, CAST(string_col AS varchar) FROM alltypes_plain +---- +4 0 +5 1 +6 0 +7 1 +2 0 +3 1 +0 0 +1 1 + +# Clean up +statement ok +DROP TABLE alltypes_plain; + +# Perform SELECT on table with fixed sized binary columns + +statement ok +CREATE EXTERNAL TABLE test_binary +STORED AS PARQUET +WITH HEADER ROW +LOCATION '../core/tests/data/test_binary.parquet'; + +# Check size of table: +query I +SELECT count(ids) FROM test_binary; +---- +466 + +# Do the SELECT query: +query ? +SELECT ids FROM test_binary ORDER BY ids LIMIT 10; +---- +008c7196f68089ab692e4739c5fd16b5 +00a51a7bc5ff8eb1627f8f3dc959dce8 +0166ce1d46129ad104fa4990c6057c91 +03a4893f3285b422820b4cd74c9b9786 +04999ac861e14682cd339eae2cc74359 +04b86bf8f228739fde391f850636a77d +050fb9cf722a709eb94b70b3ee7dc342 +052578a65e8e91b8526b182d40e846e8 +05408e6a403e4296526006e20cc4a45a +0592e6fb7d7169b888a4029b53abb701 + +# Clean up +statement ok +DROP TABLE test_binary; + +# Perform a query with a window function and timestamp data: + +statement ok +CREATE EXTERNAL TABLE timestamp_with_tz +STORED AS PARQUET +WITH HEADER ROW +LOCATION '../core/tests/data/timestamp_with_tz.parquet'; + +# Check size of table: +query I +SELECT COUNT(*) FROM timestamp_with_tz; +---- +131072 + +# Perform the query: +query IPT +SELECT + count, + LAG(timestamp, 1) OVER (ORDER BY timestamp), + arrow_typeof(LAG(timestamp, 1) OVER (ORDER BY timestamp)) +FROM timestamp_with_tz +LIMIT 10; +---- +0 NULL Timestamp(Millisecond, Some("UTC")) +0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +4 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +14 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) + +# Test config listing_table_ignore_subdirectory: + +query ITID +COPY (SELECT * FROM src_table WHERE int_col > 6 LIMIT 3) +TO 'test_files/scratch/parquet/test_table/subdir/3.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); +---- +3 + +statement ok +CREATE EXTERNAL TABLE listing_table +STORED AS PARQUET +WITH HEADER ROW +LOCATION 'test_files/scratch/parquet/test_table/*.parquet'; + +statement ok +set datafusion.execution.listing_table_ignore_subdirectory = true; + +# scan file: 0.parquet 1.parquet 2.parquet +query I +select count(*) from listing_table; +---- +9 + +statement ok +set datafusion.execution.listing_table_ignore_subdirectory = false; + +# scan file: 0.parquet 1.parquet 2.parquet 3.parquet +query I +select count(*) from listing_table; +---- +12 + +# Clean up +statement ok +DROP TABLE timestamp_with_tz; + +# Test a query from the single_nan data set: +statement ok +CREATE EXTERNAL TABLE single_nan +STORED AS PARQUET +WITH HEADER ROW +LOCATION '../../parquet-testing/data/single_nan.parquet'; + +# Check table size: +query I +SELECT COUNT(*) FROM single_nan; +---- +1 + +# Query for the single NULL: +query R +SELECT mycol FROM single_nan; +---- +NULL + +# Clean up +statement ok +DROP TABLE single_nan; + +statement ok +CREATE EXTERNAL TABLE list_columns +STORED AS PARQUET +WITH HEADER ROW +LOCATION '../../parquet-testing/data/list_columns.parquet'; + +query ?? +SELECT int64_list, utf8_list FROM list_columns +---- +[1, 2, 3] [abc, efg, hij] +[, 1] NULL +[4] [efg, , hij, xyz] + +statement ok +DROP TABLE list_columns; + +# Clean up +statement ok +DROP TABLE listing_table; diff --git a/datafusion/sqllogictest/test_files/predicates.slt b/datafusion/sqllogictest/test_files/predicates.slt index d22b2ff953b72..e992a440d0a25 100644 --- a/datafusion/sqllogictest/test_files/predicates.slt +++ b/datafusion/sqllogictest/test_files/predicates.slt @@ -495,6 +495,7 @@ set datafusion.execution.parquet.bloom_filter_enabled=true; query T SELECT * FROM data_index_bloom_encoding_stats WHERE "String" = 'foo'; +---- query T SELECT * FROM data_index_bloom_encoding_stats WHERE "String" = 'test'; diff --git a/datafusion/sqllogictest/test_files/projection.slt b/datafusion/sqllogictest/test_files/projection.slt new file mode 100644 index 0000000000000..b752f5644b7fb --- /dev/null +++ b/datafusion/sqllogictest/test_files/projection.slt @@ -0,0 +1,235 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## Projection Statement Tests +########## + +# prepare data +statement ok +CREATE EXTERNAL TABLE aggregate_test_100 ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 BIGINT UNSIGNED NOT NULL, + c10 VARCHAR NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../../testing/data/csv/aggregate_test_100.csv' + +statement ok +CREATE EXTERNAL TABLE aggregate_simple ( + c1 FLOAT NOT NULL, + c2 DOUBLE NOT NULL, + c3 BOOLEAN NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../core/tests/data/aggregate_simple.csv' + +statement ok +CREATE TABLE memory_table(a INT NOT NULL, b INT NOT NULL, c INT NOT NULL) AS VALUES +(1, 2, 3), +(10, 12, 12), +(10, 12, 12), +(100, 120, 120); + +statement ok +CREATE TABLE cpu_load_short(host STRING NOT NULL) AS VALUES +('host1'), +('host2'); + +statement ok +CREATE EXTERNAL TABLE test (c1 int, c2 bigint, c3 boolean) +STORED AS CSV LOCATION '../core/tests/data/partitioned_csv'; + +statement ok +CREATE EXTERNAL TABLE test_simple (c1 int, c2 bigint, c3 boolean) +STORED AS CSV LOCATION '../core/tests/data/partitioned_csv/partition-0.csv'; + +# projection same fields +query I rowsort +select (1+1) as a from (select 1 as a) as b; +---- +2 + +# projection type alias +query R rowsort +SELECT c1 as c3 FROM aggregate_simple ORDER BY c3 LIMIT 2; +---- +0.00001 +0.00002 + +# csv query group by avg with projection +query RT rowsort +SELECT avg(c12), c1 FROM aggregate_test_100 GROUP BY c1; +---- +0.410407092638 b +0.486006692713 e +0.487545174661 a +0.488553793875 d +0.660045653644 c + +# parallel projection +query II +SELECT c1, c2 FROM test ORDER BY c1 DESC, c2 ASC +---- +3 0 +3 1 +3 2 +3 3 +3 4 +3 5 +3 6 +3 7 +3 8 +3 9 +3 10 +2 0 +2 1 +2 2 +2 3 +2 4 +2 5 +2 6 +2 7 +2 8 +2 9 +2 10 +1 0 +1 1 +1 2 +1 3 +1 4 +1 5 +1 6 +1 7 +1 8 +1 9 +1 10 +0 0 +0 1 +0 2 +0 3 +0 4 +0 5 +0 6 +0 7 +0 8 +0 9 +0 10 + +# subquery alias case insensitive +query II +SELECT V1.c1, v1.C2 FROM (SELECT test_simple.C1, TEST_SIMPLE.c2 FROM test_simple) V1 ORDER BY v1.c1, V1.C2 LIMIT 1; +---- +0 0 + +# projection on table scan +statement ok +set datafusion.explain.logical_plan_only = true + +query TT +EXPLAIN SELECT c2 FROM test; +---- +logical_plan TableScan: test projection=[c2] + +statement count 44 +select c2 from test; + +statement ok +set datafusion.explain.logical_plan_only = false + +# project cast dictionary +query T +SELECT + CASE + WHEN cpu_load_short.host IS NULL THEN '' + ELSE cpu_load_short.host + END AS host +FROM + cpu_load_short; +---- +host1 +host2 + +# projection on memory scan +query TT +explain select b from memory_table; +---- +logical_plan TableScan: memory_table projection=[b] +physical_plan MemoryExec: partitions=1, partition_sizes=[1] + +query I +select b from memory_table; +---- +2 +12 +12 +120 + +# project column with same name as relation +query I +select a.a from (select 1 as a) as a; +---- +1 + +# project column with filters that cant pushed down always false +query I +select * from (select 1 as a) f where f.a=2; +---- + + +# project column with filters that cant pushed down always true +query I +select * from (select 1 as a) f where f.a=1; +---- +1 + +# project columns in memory without propagation +query I +SELECT column1 as a from (values (1), (2)) f where f.column1 = 2; +---- +2 + +# clean data +statement ok +DROP TABLE aggregate_simple; + +statement ok +DROP TABLE aggregate_test_100; + +statement ok +DROP TABLE memory_table; + +statement ok +DROP TABLE cpu_load_short; + +statement ok +DROP TABLE test; + +statement ok +DROP TABLE test_simple; diff --git a/datafusion/sqllogictest/test_files/repartition_scan.slt b/datafusion/sqllogictest/test_files/repartition_scan.slt new file mode 100644 index 0000000000000..02eccd7c5d06f --- /dev/null +++ b/datafusion/sqllogictest/test_files/repartition_scan.slt @@ -0,0 +1,264 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +# Tests for automatically reading files in parallel during scan +########## + +# Set 4 partitions for deterministic output plans +statement ok +set datafusion.execution.target_partitions = 4; + +# automatically partition all files over 1 byte +statement ok +set datafusion.optimizer.repartition_file_min_size = 1; + +################### +### Parquet tests +################### + +# create a single parquet file +# Note filename 2.parquet to test sorting (on local file systems it is often listed before 1.parquet) +statement ok +COPY (VALUES (1), (2), (3), (4), (5)) TO 'test_files/scratch/repartition_scan/parquet_table/2.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); + +statement ok +CREATE EXTERNAL TABLE parquet_table(column1 int) +STORED AS PARQUET +LOCATION 'test_files/scratch/repartition_scan/parquet_table/'; + +query I +select * from parquet_table; +---- +1 +2 +3 +4 +5 + +## Expect to see the scan read the file as "4" groups with even sizes (offsets) +query TT +EXPLAIN SELECT column1 FROM parquet_table WHERE column1 <> 42; +---- +logical_plan +Filter: parquet_table.column1 != Int32(42) +--TableScan: parquet_table projection=[column1], partial_filters=[parquet_table.column1 != Int32(42)] +physical_plan +CoalesceBatchesExec: target_batch_size=8192 +--FilterExec: column1@0 != 42 +----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..101], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:101..202], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:202..303], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:303..403]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1 + +# create a second parquet file +statement ok +COPY (VALUES (100), (200)) TO 'test_files/scratch/repartition_scan/parquet_table/1.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); + +## Still expect to see the scan read the file as "4" groups with even sizes. One group should read +## parts of both files. +query TT +EXPLAIN SELECT column1 FROM parquet_table WHERE column1 <> 42 ORDER BY column1; +---- +logical_plan +Sort: parquet_table.column1 ASC NULLS LAST +--Filter: parquet_table.column1 != Int32(42) +----TableScan: parquet_table projection=[column1], partial_filters=[parquet_table.column1 != Int32(42)] +physical_plan +SortPreservingMergeExec: [column1@0 ASC NULLS LAST] +--SortExec: expr=[column1@0 ASC NULLS LAST] +----CoalesceBatchesExec: target_batch_size=8192 +------FilterExec: column1@0 != 42 +--------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..200], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:200..394, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..6], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:6..206], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:206..403]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1 + + +## Read the files as though they are ordered + +statement ok +CREATE EXTERNAL TABLE parquet_table_with_order(column1 int) +STORED AS PARQUET +LOCATION 'test_files/scratch/repartition_scan/parquet_table' +WITH ORDER (column1 ASC); + +# output should be ordered +query I +SELECT column1 FROM parquet_table_with_order WHERE column1 <> 42 ORDER BY column1; +---- +1 +2 +3 +4 +5 +100 +200 + +# explain should not have any groups with more than one file +# https://github.com/apache/arrow-datafusion/issues/8451 +query TT +EXPLAIN SELECT column1 FROM parquet_table_with_order WHERE column1 <> 42 ORDER BY column1; +---- +logical_plan +Sort: parquet_table_with_order.column1 ASC NULLS LAST +--Filter: parquet_table_with_order.column1 != Int32(42) +----TableScan: parquet_table_with_order projection=[column1], partial_filters=[parquet_table_with_order.column1 != Int32(42)] +physical_plan +SortPreservingMergeExec: [column1@0 ASC NULLS LAST] +--CoalesceBatchesExec: target_batch_size=8192 +----FilterExec: column1@0 != 42 +------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..197], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..201], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:201..403], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:197..394]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1 + +# Cleanup +statement ok +DROP TABLE parquet_table; + +statement ok +DROP TABLE parquet_table_with_order; + + +################### +### CSV tests +################### + +# Since parquet and CSV share most of the same implementation, this test checks +# that the basics are connected properly + +# create a single csv file +statement ok +COPY (VALUES (1), (2), (3), (4), (5)) TO 'test_files/scratch/repartition_scan/csv_table/1.csv' +(FORMAT csv, SINGLE_FILE_OUTPUT true, HEADER true); + +statement ok +CREATE EXTERNAL TABLE csv_table(column1 int) +STORED AS csv +WITH HEADER ROW +LOCATION 'test_files/scratch/repartition_scan/csv_table/'; + +query I +select * from csv_table; +---- +1 +2 +3 +4 +5 + +## Expect to see the scan read the file as "4" groups with even sizes (offsets) +query TT +EXPLAIN SELECT column1 FROM csv_table WHERE column1 <> 42; +---- +logical_plan +Filter: csv_table.column1 != Int32(42) +--TableScan: csv_table projection=[column1], partial_filters=[csv_table.column1 != Int32(42)] +physical_plan +CoalesceBatchesExec: target_batch_size=8192 +--FilterExec: column1@0 != 42 +----CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/csv_table/1.csv:0..5], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/csv_table/1.csv:5..10], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/csv_table/1.csv:10..15], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/csv_table/1.csv:15..18]]}, projection=[column1], has_header=true + +# Cleanup +statement ok +DROP TABLE csv_table; + + +################### +### JSON tests +################### + +# Since parquet and json share most of the same implementation, this test checks +# that the basics are connected properly + +# create a single json file +statement ok +COPY (VALUES (1), (2), (3), (4), (5)) TO 'test_files/scratch/repartition_scan/json_table/1.json' +(FORMAT json, SINGLE_FILE_OUTPUT true); + +statement ok +CREATE EXTERNAL TABLE json_table (column1 int) +STORED AS json +LOCATION 'test_files/scratch/repartition_scan/json_table/'; + +query I +select * from "json_table"; +---- +1 +2 +3 +4 +5 + +## Expect to see the scan read the file as "4" groups with even sizes (offsets) +query TT +EXPLAIN SELECT column1 FROM "json_table" WHERE column1 <> 42; +---- +logical_plan +Filter: json_table.column1 != Int32(42) +--TableScan: json_table projection=[column1], partial_filters=[json_table.column1 != Int32(42)] +physical_plan +CoalesceBatchesExec: target_batch_size=8192 +--FilterExec: column1@0 != 42 +----JsonExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/json_table/1.json:0..18], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/json_table/1.json:18..36], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/json_table/1.json:36..54], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/json_table/1.json:54..70]]}, projection=[column1] + +# Cleanup +statement ok +DROP TABLE json_table; + + +################### +### Arrow File tests +################### + +## Use pre-existing files we don't have a way to create arrow files yet +## (https://github.com/apache/arrow-datafusion/issues/8504) +statement ok +CREATE EXTERNAL TABLE arrow_table +STORED AS ARROW +LOCATION '../core/tests/data/example.arrow'; + + +# It would be great to see the file read as "4" groups with even sizes (offsets) eventually +# https://github.com/apache/arrow-datafusion/issues/8503 +query TT +EXPLAIN SELECT * FROM arrow_table +---- +logical_plan TableScan: arrow_table projection=[f0, f1, f2] +physical_plan ArrowExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.arrow]]}, projection=[f0, f1, f2] + +# Cleanup +statement ok +DROP TABLE arrow_table; + +################### +### Avro File tests +################### + +## Use pre-existing files we don't have a way to create avro files yet + +statement ok +CREATE EXTERNAL TABLE avro_table +STORED AS AVRO +WITH HEADER ROW +LOCATION '../../testing/data/avro/simple_enum.avro' + + +# It would be great to see the file read as "4" groups with even sizes (offsets) eventually +query TT +EXPLAIN SELECT * FROM avro_table +---- +logical_plan TableScan: avro_table projection=[f1, f2, f3] +physical_plan AvroExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/avro/simple_enum.avro]]}, projection=[f1, f2, f3] + +# Cleanup +statement ok +DROP TABLE avro_table; diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index ecb7fe13fcf4c..9b30699e3fa3e 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -1926,3 +1926,30 @@ A true B false C false D false + +# test string_temporal_coercion +query BBBBBBBBBB +select + arrow_cast(to_timestamp('2020-01-01 01:01:11.1234567890Z'), 'Timestamp(Second, None)') == '2020-01-01T01:01:11', + arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Second, None)') == arrow_cast('2020-01-02T01:01:11', 'LargeUtf8'), + arrow_cast(to_timestamp('2020-01-03 01:01:11.1234567890Z'), 'Time32(Second)') == '01:01:11', + arrow_cast(to_timestamp('2020-01-04 01:01:11.1234567890Z'), 'Time32(Second)') == arrow_cast('01:01:11', 'LargeUtf8'), + arrow_cast(to_timestamp('2020-01-05 01:01:11.1234567890Z'), 'Time64(Microsecond)') == '01:01:11.123456', + arrow_cast(to_timestamp('2020-01-06 01:01:11.1234567890Z'), 'Time64(Microsecond)') == arrow_cast('01:01:11.123456', 'LargeUtf8'), + arrow_cast('2020-01-07', 'Date32') == '2020-01-07', + arrow_cast('2020-01-08', 'Date64') == '2020-01-08', + arrow_cast('2020-01-09', 'Date32') == arrow_cast('2020-01-09', 'LargeUtf8'), + arrow_cast('2020-01-10', 'Date64') == arrow_cast('2020-01-10', 'LargeUtf8') +; +---- +true true true true true true true true true true + +query I +SELECT ALL - CASE WHEN NOT - AVG ( - 41 ) IS NULL THEN 47 WHEN NULL IS NULL THEN COUNT ( * ) END + 93 + - - 44 * 91 + CASE + 44 WHEN - - 21 * 69 - 12 THEN 58 ELSE - 3 END * + + 23 * + 84 * - - 59 +---- +-337914 + +query T +SELECT CASE 3 WHEN 1+2 THEN 'first' WHEN 1+1+1 THEN 'second' END +---- +first diff --git a/datafusion/sqllogictest/test_files/schema_evolution.slt b/datafusion/sqllogictest/test_files/schema_evolution.slt new file mode 100644 index 0000000000000..36d54159e24d2 --- /dev/null +++ b/datafusion/sqllogictest/test_files/schema_evolution.slt @@ -0,0 +1,140 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +# Tests for schema evolution -- reading +# data from different files with different schemas +########## + + +statement ok +CREATE EXTERNAL TABLE parquet_table(a varchar, b int, c float) STORED AS PARQUET +LOCATION 'test_files/scratch/schema_evolution/parquet_table/'; + +# File1 has only columns a and b +statement ok +COPY ( + SELECT column1 as a, column2 as b + FROM ( VALUES ('foo', 1), ('foo', 2), ('foo', 3) ) + ) TO 'test_files/scratch/schema_evolution/parquet_table/1.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); + + +# File2 has only b +statement ok +COPY ( + SELECT column1 as b + FROM ( VALUES (10) ) + ) TO 'test_files/scratch/schema_evolution/parquet_table/2.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); + +# File3 has a column from 'z' which does not appear in the table +# but also values from a which do appear in the table +statement ok +COPY ( + SELECT column1 as z, column2 as a + FROM ( VALUES ('bar', 'foo'), ('blarg', 'foo') ) + ) TO 'test_files/scratch/schema_evolution/parquet_table/3.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); + +# File4 has data for b and a (reversed) and d +statement ok +COPY ( + SELECT column1 as b, column2 as a, column3 as c + FROM ( VALUES (100, 'foo', 10.5), (200, 'foo', 12.6), (300, 'bzz', 13.7) ) + ) TO 'test_files/scratch/schema_evolution/parquet_table/4.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); + +# The logical distribution of `a`, `b` and `c` in the files is like this: +# +## File1: +# foo 1 NULL +# foo 2 NULL +# foo 3 NULL +# +## File2: +# NULL 10 NULL +# +## File3: +# foo NULL NULL +# foo NULL NULL +# +## File4: +# foo 100 10.5 +# foo 200 12.6 +# bzz 300 13.7 + +# Show all the data +query TIR rowsort +select * from parquet_table; +---- +NULL 10 NULL +bzz 300 13.7 +foo 1 NULL +foo 100 10.5 +foo 2 NULL +foo 200 12.6 +foo 3 NULL +foo NULL NULL +foo NULL NULL + +# Should see all 7 rows that have 'a=foo' +query TIR rowsort +select * from parquet_table where a = 'foo'; +---- +foo 1 NULL +foo 100 10.5 +foo 2 NULL +foo 200 12.6 +foo 3 NULL +foo NULL NULL +foo NULL NULL + +query TIR rowsort +select * from parquet_table where a != 'foo'; +---- +bzz 300 13.7 + +# this should produce at least one row +query TIR rowsort +select * from parquet_table where a is NULL; +---- +NULL 10 NULL + +query TIR rowsort +select * from parquet_table where b > 5; +---- +NULL 10 NULL +bzz 300 13.7 +foo 100 10.5 +foo 200 12.6 + + +query TIR rowsort +select * from parquet_table where b < 150; +---- +NULL 10 NULL +foo 1 NULL +foo 100 10.5 +foo 2 NULL +foo 3 NULL + +query TIR rowsort +select * from parquet_table where c > 11.0; +---- +bzz 300 13.7 +foo 200 12.6 diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index 98ea061c731bf..ea570b99d4dd1 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -868,6 +868,21 @@ statement error DataFusion error: Error during planning: EXCLUDE or EXCEPT conta SELECT * EXCLUDE(d, b, c, a, a, b, c, d) FROM table1 +# avoiding adding an alias if the column name is the same +query TT +EXPLAIN select a as a FROM table1 order by a +---- +logical_plan +Sort: table1.a ASC NULLS LAST +--TableScan: table1 projection=[a] +physical_plan +SortExec: expr=[a@0 ASC NULLS LAST] +--MemoryExec: partitions=1, partition_sizes=[1] + +# ambiguous column references in on join +query error DataFusion error: Schema error: Ambiguous reference to unqualified field a +EXPLAIN select a as a FROM table1 t1 CROSS JOIN table1 t2 order by a + # run below query in multi partitions statement ok set datafusion.execution.target_partitions = 2; @@ -1013,8 +1028,79 @@ SortPreservingMergeExec: [c@3 ASC NULLS LAST] --------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +# When ordering lost during projection, we shouldn't keep the SortExec. +# in the final physical plan. +query TT +EXPLAIN SELECT c2, COUNT(*) +FROM (SELECT c2 +FROM aggregate_test_100 +ORDER BY c1, c2) +GROUP BY c2; +---- +logical_plan +Aggregate: groupBy=[[aggregate_test_100.c2]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] +--Projection: aggregate_test_100.c2 +----Sort: aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST +------Projection: aggregate_test_100.c2, aggregate_test_100.c1 +--------TableScan: aggregate_test_100 projection=[c1, c2] +physical_plan +AggregateExec: mode=FinalPartitioned, gby=[c2@0 as c2], aggr=[COUNT(*)] +--CoalesceBatchesExec: target_batch_size=8192 +----RepartitionExec: partitioning=Hash([c2@0], 2), input_partitions=2 +------AggregateExec: mode=Partial, gby=[c2@0 as c2], aggr=[COUNT(*)] +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2], has_header=true + statement ok drop table annotated_data_finite2; statement ok drop table t; + +statement ok +create table t(x bigint, y bigint) as values (1,2), (1,3); + +query II +select z+1, y from (select x+1 as z, y from t) where y > 1; +---- +3 2 +3 3 + +query TT +EXPLAIN SELECT x/2, x/2+1 FROM t; +---- +logical_plan +Projection: t.x / Int64(2)Int64(2)t.x AS t.x / Int64(2), t.x / Int64(2)Int64(2)t.x AS t.x / Int64(2) + Int64(1) +--Projection: t.x / Int64(2) AS t.x / Int64(2)Int64(2)t.x +----TableScan: t projection=[x] +physical_plan +ProjectionExec: expr=[t.x / Int64(2)Int64(2)t.x@0 as t.x / Int64(2), t.x / Int64(2)Int64(2)t.x@0 + 1 as t.x / Int64(2) + Int64(1)] +--ProjectionExec: expr=[x@0 / 2 as t.x / Int64(2)Int64(2)t.x] +----MemoryExec: partitions=1, partition_sizes=[1] + +query II +SELECT x/2, x/2+1 FROM t; +---- +0 1 +0 1 + +query TT +EXPLAIN SELECT abs(x), abs(x) + abs(y) FROM t; +---- +logical_plan +Projection: abs(t.x)t.x AS abs(t.x), abs(t.x)t.x AS abs(t.x) + abs(t.y) +--Projection: abs(t.x) AS abs(t.x)t.x, t.y +----TableScan: t projection=[x, y] +physical_plan +ProjectionExec: expr=[abs(t.x)t.x@0 as abs(t.x), abs(t.x)t.x@0 + abs(y@1) as abs(t.x) + abs(t.y)] +--ProjectionExec: expr=[abs(x@0) as abs(t.x)t.x, y@1 as y] +----MemoryExec: partitions=1, partition_sizes=[1] + +query II +SELECT abs(x), abs(x) + abs(y) FROM t; +---- +1 3 +1 4 + +statement ok +DROP TABLE t; diff --git a/datafusion/sqllogictest/test_files/set_variable.slt b/datafusion/sqllogictest/test_files/set_variable.slt index 714e1e995e262..440fb2c6ef2b0 100644 --- a/datafusion/sqllogictest/test_files/set_variable.slt +++ b/datafusion/sqllogictest/test_files/set_variable.slt @@ -243,4 +243,4 @@ statement ok SET TIME ZONE = 'Asia/Taipei2' statement error Arrow error: Parser error: Invalid timezone "Asia/Taipei2": 'Asia/Taipei2' is not a valid timezone -SELECT '2000-01-01T00:00:00'::TIMESTAMP::TIMESTAMPTZ \ No newline at end of file +SELECT '2000-01-01T00:00:00'::TIMESTAMP::TIMESTAMPTZ diff --git a/datafusion/sqllogictest/test_files/struct.slt b/datafusion/sqllogictest/test_files/struct.slt index fc14798a3bfed..936dedcc896ec 100644 --- a/datafusion/sqllogictest/test_files/struct.slt +++ b/datafusion/sqllogictest/test_files/struct.slt @@ -58,5 +58,16 @@ select struct(a, b, c) from values; {c0: 2, c1: 2.2, c2: b} {c0: 3, c1: 3.3, c2: c} +# explain struct scalar function with columns #1 +query TT +explain select struct(a, b, c) from values; +---- +logical_plan +Projection: struct(values.a, values.b, values.c) +--TableScan: values projection=[a, b, c] +physical_plan +ProjectionExec: expr=[struct(a@0, b@1, c@2) as struct(values.a,values.b,values.c)] +--MemoryExec: partitions=1, partition_sizes=[1] + statement ok drop table values; diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 822a70bb5badf..3e0fcb7aa96eb 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -49,6 +49,13 @@ CREATE TABLE t2(t2_id INT, t2_name TEXT, t2_int INT) AS VALUES (44, 'x', 3), (55, 'w', 3); +statement ok +CREATE TABLE t3(t3_id INT PRIMARY KEY, t3_name TEXT, t3_int INT) AS VALUES +(11, 'e', 3), +(22, 'f', 1), +(44, 'g', 3), +(55, 'h', 3); + statement ok CREATE EXTERNAL TABLE IF NOT EXISTS customer ( c_custkey BIGINT, @@ -180,19 +187,18 @@ Projection: t1.t1_id, __scalar_sq_1.SUM(t2.t2_int) AS t2_sum --------Aggregate: groupBy=[[t2.t2_id]], aggr=[[SUM(CAST(t2.t2_int AS Int64))]] ----------TableScan: t2 projection=[t2_id, t2_int] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, SUM(t2.t2_int)@1 as t2_sum] ---ProjectionExec: expr=[t1_id@2 as t1_id, SUM(t2.t2_int)@0 as SUM(t2.t2_int), t2_id@1 as t2_id] -----CoalesceBatchesExec: target_batch_size=2 -------HashJoinExec: mode=Partitioned, join_type=Right, on=[(t2_id@1, t1_id@0)] ---------ProjectionExec: expr=[SUM(t2.t2_int)@1 as SUM(t2.t2_int), t2_id@0 as t2_id] -----------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] -------------CoalesceBatchesExec: target_batch_size=2 ---------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 -----------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] -------------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] ---------CoalesceBatchesExec: target_batch_size=2 -----------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 -------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +ProjectionExec: expr=[t1_id@2 as t1_id, SUM(t2.t2_int)@0 as t2_sum] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Right, on=[(t2_id@1, t1_id@0)] +------ProjectionExec: expr=[SUM(t2.t2_int)@1 as SUM(t2.t2_int), t2_id@0 as t2_id] +--------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] +----------CoalesceBatchesExec: target_batch_size=2 +------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 +--------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] +----------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 +----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] query II rowsort SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id) as t2_sum from t1 @@ -215,19 +221,18 @@ Projection: t1.t1_id, __scalar_sq_1.SUM(t2.t2_int * Float64(1)) + Int64(1) AS t2 --------Aggregate: groupBy=[[t2.t2_id]], aggr=[[SUM(CAST(t2.t2_int AS Float64)) AS SUM(t2.t2_int * Float64(1))]] ----------TableScan: t2 projection=[t2_id, t2_int] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, SUM(t2.t2_int * Float64(1)) + Int64(1)@1 as t2_sum] ---ProjectionExec: expr=[t1_id@2 as t1_id, SUM(t2.t2_int * Float64(1)) + Int64(1)@0 as SUM(t2.t2_int * Float64(1)) + Int64(1), t2_id@1 as t2_id] -----CoalesceBatchesExec: target_batch_size=2 -------HashJoinExec: mode=Partitioned, join_type=Right, on=[(t2_id@1, t1_id@0)] ---------ProjectionExec: expr=[SUM(t2.t2_int * Float64(1))@1 + 1 as SUM(t2.t2_int * Float64(1)) + Int64(1), t2_id@0 as t2_id] -----------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int * Float64(1))] -------------CoalesceBatchesExec: target_batch_size=2 ---------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 -----------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int * Float64(1))] -------------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] ---------CoalesceBatchesExec: target_batch_size=2 -----------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 -------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +ProjectionExec: expr=[t1_id@2 as t1_id, SUM(t2.t2_int * Float64(1)) + Int64(1)@0 as t2_sum] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Right, on=[(t2_id@1, t1_id@0)] +------ProjectionExec: expr=[SUM(t2.t2_int * Float64(1))@1 + 1 as SUM(t2.t2_int * Float64(1)) + Int64(1), t2_id@0 as t2_id] +--------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int * Float64(1))] +----------CoalesceBatchesExec: target_batch_size=2 +------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 +--------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int * Float64(1))] +----------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 +----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] query IR rowsort SELECT t1_id, (SELECT sum(t2_int * 1.0) + 1 FROM t2 WHERE t2.t2_id = t1.t1_id) as t2_sum from t1 @@ -287,21 +292,20 @@ Projection: t1.t1_id, __scalar_sq_1.SUM(t2.t2_int) AS t2_sum ----------Aggregate: groupBy=[[t2.t2_id]], aggr=[[SUM(CAST(t2.t2_int AS Int64))]] ------------TableScan: t2 projection=[t2_id, t2_int] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, SUM(t2.t2_int)@1 as t2_sum] ---ProjectionExec: expr=[t1_id@2 as t1_id, SUM(t2.t2_int)@0 as SUM(t2.t2_int), t2_id@1 as t2_id] -----CoalesceBatchesExec: target_batch_size=2 -------HashJoinExec: mode=Partitioned, join_type=Right, on=[(t2_id@1, t1_id@0)] ---------ProjectionExec: expr=[SUM(t2.t2_int)@1 as SUM(t2.t2_int), t2_id@0 as t2_id] -----------CoalesceBatchesExec: target_batch_size=2 -------------FilterExec: SUM(t2.t2_int)@1 < 3 ---------------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] -----------------CoalesceBatchesExec: target_batch_size=2 -------------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 ---------------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] -----------------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +ProjectionExec: expr=[t1_id@2 as t1_id, SUM(t2.t2_int)@0 as t2_sum] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Right, on=[(t2_id@1, t1_id@0)] +------ProjectionExec: expr=[SUM(t2.t2_int)@1 as SUM(t2.t2_int), t2_id@0 as t2_id] --------CoalesceBatchesExec: target_batch_size=2 -----------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 -------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +----------FilterExec: SUM(t2.t2_int)@1 < 3 +------------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] +--------------CoalesceBatchesExec: target_batch_size=2 +----------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 +------------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] +--------------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 +----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] query II rowsort SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id having sum(t2_int) < 3) as t2_sum from t1 @@ -422,6 +426,17 @@ SELECT t1_id, t1_name, t1_int FROM t1 order by t1_int in (SELECT t2_int FROM t2 statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: Correlated scalar subquery must be aggregated to return at most one row SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1.t1_int) as t2_int from t1 +#non_aggregated_correlated_scalar_subquery_unique +query II rowsort +SELECT t1_id, (SELECT t3_int FROM t3 WHERE t3.t3_id = t1.t1_id) as t3_int from t1 +---- +11 3 +22 1 +33 NULL +44 3 + + +#non_aggregated_correlated_scalar_subquery statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: Correlated scalar subquery must be aggregated to return at most one row SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1_int group by t2_int) as t2_int from t1 @@ -440,7 +455,7 @@ Projection: t1.t1_id, () AS t2_int ------Projection: t2.t2_int --------Filter: t2.t2_int = outer_ref(t1.t1_int) ----------TableScan: t2 ---TableScan: t1 projection=[t1_id] +--TableScan: t1 projection=[t1_id, t1_int] query TT explain SELECT t1_id from t1 where t1_int = (SELECT t2_int FROM t2 WHERE t2.t2_int = t1.t1_int limit 1) @@ -487,27 +502,29 @@ query TT explain SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT sum(t1.t1_int + t2.t2_id) FROM t2 WHERE t1.t1_name = t2.t2_name) ---- logical_plan -Filter: EXISTS () ---Subquery: -----Projection: SUM(outer_ref(t1.t1_int) + t2.t2_id) -------Aggregate: groupBy=[[]], aggr=[[SUM(CAST(outer_ref(t1.t1_int) + t2.t2_id AS Int64))]] ---------Filter: outer_ref(t1.t1_name) = t2.t2_name -----------TableScan: t2 ---TableScan: t1 projection=[t1_id, t1_name] +Projection: t1.t1_id, t1.t1_name +--Filter: EXISTS () +----Subquery: +------Projection: SUM(outer_ref(t1.t1_int) + t2.t2_id) +--------Aggregate: groupBy=[[]], aggr=[[SUM(CAST(outer_ref(t1.t1_int) + t2.t2_id AS Int64))]] +----------Filter: outer_ref(t1.t1_name) = t2.t2_name +------------TableScan: t2 +----TableScan: t1 projection=[t1_id, t1_name, t1_int] #support_agg_correlated_columns2 query TT explain SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT count(*) FROM t2 WHERE t1.t1_name = t2.t2_name having sum(t1_int + t2_id) >0) ---- logical_plan -Filter: EXISTS () ---Subquery: -----Projection: COUNT(*) -------Filter: SUM(outer_ref(t1.t1_int) + t2.t2_id) > Int64(0) ---------Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*), SUM(CAST(outer_ref(t1.t1_int) + t2.t2_id AS Int64))]] -----------Filter: outer_ref(t1.t1_name) = t2.t2_name -------------TableScan: t2 ---TableScan: t1 projection=[t1_id, t1_name] +Projection: t1.t1_id, t1.t1_name +--Filter: EXISTS () +----Subquery: +------Projection: COUNT(*) +--------Filter: SUM(outer_ref(t1.t1_int) + t2.t2_id) > Int64(0) +----------Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*), SUM(CAST(outer_ref(t1.t1_int) + t2.t2_id AS Int64))]] +------------Filter: outer_ref(t1.t1_name) = t2.t2_name +--------------TableScan: t2 +----TableScan: t1 projection=[t1_id, t1_name, t1_int] #support_join_correlated_columns query TT @@ -991,3 +1008,55 @@ SELECT * FROM ON (severity.cron_job_name = jobs.cron_job_name); ---- catan-prod1-daily success catan-prod1-daily high + +##correlated_scalar_subquery_sum_agg_bug +#query TT +#explain +#select t1.t1_int from t1 where +# (select sum(t2_int) is null from t2 where t1.t1_id = t2.t2_id) +#---- +#logical_plan +#Projection: t1.t1_int +#--Inner Join: t1.t1_id = __scalar_sq_1.t2_id +#----TableScan: t1 projection=[t1_id, t1_int] +#----SubqueryAlias: __scalar_sq_1 +#------Projection: t2.t2_id +#--------Filter: SUM(t2.t2_int) IS NULL +#----------Aggregate: groupBy=[[t2.t2_id]], aggr=[[SUM(t2.t2_int)]] +#------------TableScan: t2 projection=[t2_id, t2_int] + +#query I rowsort +#select t1.t1_int from t1 where +# (select sum(t2_int) is null from t2 where t1.t1_id = t2.t2_id) +#---- +#2 +#3 +#4 + +statement ok +create table t(a bigint); + +# Result of query below shouldn't depend on +# number of optimization passes +# See issue: https://github.com/apache/arrow-datafusion/issues/8296 +statement ok +set datafusion.optimizer.max_passes = 1; + +query TT +explain select a/2, a/2 + 1 from t +---- +logical_plan +Projection: t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2), t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2) + Int64(1) +--Projection: t.a / Int64(2) AS t.a / Int64(2)Int64(2)t.a +----TableScan: t projection=[a] + +statement ok +set datafusion.optimizer.max_passes = 3; + +query TT +explain select a/2, a/2 + 1 from t +---- +logical_plan +Projection: t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2), t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2) + Int64(1) +--Projection: t.a / Int64(2) AS t.a / Int64(2)Int64(2)t.a +----TableScan: t projection=[a] diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index e186aa12f7a95..c84e46c965fac 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -46,6 +46,30 @@ statement ok create table ts_data_secs as select arrow_cast(ts / 1000000000, 'Timestamp(Second, None)') as ts, value from ts_data; +########## +## Current date Tests +########## + +query B +select cast(now() as date) = current_date(); +---- +true + +query B +select now() = current_date(); +---- +false + +query B +select current_date() = today(); +---- +true + +query B +select cast(now() as date) = today(); +---- +true + ########## ## Timestamp Handling Tests @@ -291,6 +315,35 @@ SELECT COUNT(*) FROM ts_data_secs where ts > to_timestamp_seconds('2020-09-08T12 ---- 2 + +# to_timestamp float inputs + +query PPP +SELECT to_timestamp(1.1) as c1, cast(1.1 as timestamp) as c2, 1.1::timestamp as c3; +---- +1970-01-01T00:00:01.100 1970-01-01T00:00:01.100 1970-01-01T00:00:01.100 + +query PPP +SELECT to_timestamp(-1.1) as c1, cast(-1.1 as timestamp) as c2, (-1.1)::timestamp as c3; +---- +1969-12-31T23:59:58.900 1969-12-31T23:59:58.900 1969-12-31T23:59:58.900 + +query PPP +SELECT to_timestamp(0.0) as c1, cast(0.0 as timestamp) as c2, 0.0::timestamp as c3; +---- +1970-01-01T00:00:00 1970-01-01T00:00:00 1970-01-01T00:00:00 + +query PPP +SELECT to_timestamp(1.23456789) as c1, cast(1.23456789 as timestamp) as c2, 1.23456789::timestamp as c3; +---- +1970-01-01T00:00:01.234567890 1970-01-01T00:00:01.234567890 1970-01-01T00:00:01.234567890 + +query PPP +SELECT to_timestamp(123456789.123456789) as c1, cast(123456789.123456789 as timestamp) as c2, 123456789.123456789::timestamp as c3; +---- +1973-11-29T21:33:09.123456784 1973-11-29T21:33:09.123456784 1973-11-29T21:33:09.123456784 + + # from_unixtime # 1599566400 is '2020-09-08T12:00:00+00:00' @@ -1677,14 +1730,11 @@ SELECT TIMESTAMPTZ '2022-01-01 01:10:00 AEST' query P rowsort SELECT TIMESTAMPTZ '2022-01-01 01:10:00 Australia/Sydney' as ts_geo UNION ALL -SELECT TIMESTAMPTZ '2022-01-01 01:10:00 Antarctica/Vostok' as ts_geo - UNION ALL SELECT TIMESTAMPTZ '2022-01-01 01:10:00 Africa/Johannesburg' as ts_geo UNION ALL SELECT TIMESTAMPTZ '2022-01-01 01:10:00 America/Los_Angeles' as ts_geo ---- 2021-12-31T14:10:00Z -2021-12-31T19:10:00Z 2021-12-31T23:10:00Z 2022-01-01T09:10:00Z @@ -1788,8 +1838,59 @@ SELECT TIMESTAMPTZ '2020-01-01 00:00:00Z' = TIMESTAMP '2020-01-01' ---- true -# verify to_timestamp edge cases to be in sync with postgresql -query PPPPP -SELECT to_timestamp(null), to_timestamp(-62125747200), to_timestamp(0), to_timestamp(1926632005177), to_timestamp(1926632005) +# verify timestamp cast with integer input +query PPPPPP +SELECT to_timestamp(null), to_timestamp(0), to_timestamp(1926632005), to_timestamp(1), to_timestamp(-1), to_timestamp(0-1) +---- +NULL 1970-01-01T00:00:00 2031-01-19T23:33:25 1970-01-01T00:00:01 1969-12-31T23:59:59 1969-12-31T23:59:59 + +# verify timestamp syntax stlyes are consistent +query BBBBBBBBBBBBB +SELECT to_timestamp(null) is null as c1, + null::timestamp is null as c2, + cast(null as timestamp) is null as c3, + to_timestamp(0) = 0::timestamp as c4, + to_timestamp(1926632005) = 1926632005::timestamp as c5, + to_timestamp(1) = 1::timestamp as c6, + to_timestamp(-1) = -1::timestamp as c7, + to_timestamp(0-1) = (0-1)::timestamp as c8, + to_timestamp(0) = cast(0 as timestamp) as c9, + to_timestamp(1926632005) = cast(1926632005 as timestamp) as c10, + to_timestamp(1) = cast(1 as timestamp) as c11, + to_timestamp(-1) = cast(-1 as timestamp) as c12, + to_timestamp(0-1) = cast(0-1 as timestamp) as c13 +---- +true true true true true true true true true true true true true + +# verify timestamp output types +query TTT +SELECT arrow_typeof(to_timestamp(1)), arrow_typeof(to_timestamp(null)), arrow_typeof(to_timestamp('2023-01-10 12:34:56.000')) +---- +Timestamp(Nanosecond, None) Timestamp(Nanosecond, None) Timestamp(Nanosecond, None) + +# verify timestamp output types using timestamp literal syntax +query BBBBBB +SELECT arrow_typeof(to_timestamp(1)) = arrow_typeof(1::timestamp) as c1, + arrow_typeof(to_timestamp(null)) = arrow_typeof(null::timestamp) as c2, + arrow_typeof(to_timestamp('2023-01-10 12:34:56.000')) = arrow_typeof('2023-01-10 12:34:56.000'::timestamp) as c3, + arrow_typeof(to_timestamp(1)) = arrow_typeof(cast(1 as timestamp)) as c4, + arrow_typeof(to_timestamp(null)) = arrow_typeof(cast(null as timestamp)) as c5, + arrow_typeof(to_timestamp('2023-01-10 12:34:56.000')) = arrow_typeof(cast('2023-01-10 12:34:56.000' as timestamp)) as c6 +---- +true true true true true true + +# known issues. currently overflows (expects default precision to be microsecond instead of nanoseconds. Work pending) +#verify extreme values +#query PPPPPPPP +#SELECT to_timestamp(-62125747200), to_timestamp(1926632005177), -62125747200::timestamp, 1926632005177::timestamp, cast(-62125747200 as timestamp), cast(1926632005177 as timestamp) +#---- +#0001-04-25T00:00:00 +63022-07-16T12:59:37 0001-04-25T00:00:00 +63022-07-16T12:59:37 0001-04-25T00:00:00 +63022-07-16T12:59:37 + +########## +## Test binary temporal coercion for Date and Timestamp +########## + +query B +select arrow_cast(now(), 'Date64') < arrow_cast('2022-02-02 02:02:02', 'Timestamp(Nanosecond, None)'); ---- -NULL 0001-04-25T00:00:00 1970-01-01T00:00:00 +63022-07-16T12:59:37 2031-01-19T23:33:25 +false diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index 688774c906fe0..b4e338875e247 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -82,6 +82,11 @@ SELECT 2 as x 1 2 +query I +select count(*) from (select id from t1 union all select id from t2) +---- +6 + # csv_union_all statement ok CREATE EXTERNAL TABLE aggregate_test_100 ( @@ -272,37 +277,36 @@ Union ------TableScan: t1 projection=[id, name] physical_plan UnionExec ---ProjectionExec: expr=[id@0 as id, name@1 as name] -----CoalesceBatchesExec: target_batch_size=2 -------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(id@0, CAST(t2.id AS Int32)@2), (name@1, name@1)] ---------AggregateExec: mode=FinalPartitioned, gby=[id@0 as id, name@1 as name], aggr=[] -----------CoalesceBatchesExec: target_batch_size=2 -------------RepartitionExec: partitioning=Hash([id@0, name@1], 4), input_partitions=4 ---------------AggregateExec: mode=Partial, gby=[id@0 as id, name@1 as name], aggr=[] -----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------MemoryExec: partitions=1, partition_sizes=[1] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(id@0, CAST(t2.id AS Int32)@2), (name@1, name@1)] +------AggregateExec: mode=FinalPartitioned, gby=[id@0 as id, name@1 as name], aggr=[] --------CoalesceBatchesExec: target_batch_size=2 -----------RepartitionExec: partitioning=Hash([CAST(t2.id AS Int32)@2, name@1], 4), input_partitions=4 -------------ProjectionExec: expr=[id@0 as id, name@1 as name, CAST(id@0 AS Int32) as CAST(t2.id AS Int32)] +----------RepartitionExec: partitioning=Hash([id@0, name@1], 4), input_partitions=4 +------------AggregateExec: mode=Partial, gby=[id@0 as id, name@1 as name], aggr=[] --------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ----------------MemoryExec: partitions=1, partition_sizes=[1] ---ProjectionExec: expr=[CAST(id@0 AS Int32) as id, name@1 as name] -----ProjectionExec: expr=[id@0 as id, name@1 as name] ------CoalesceBatchesExec: target_batch_size=2 ---------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(CAST(t2.id AS Int32)@2, id@0), (name@1, name@1)] -----------CoalesceBatchesExec: target_batch_size=2 -------------RepartitionExec: partitioning=Hash([CAST(t2.id AS Int32)@2, name@1], 4), input_partitions=4 ---------------ProjectionExec: expr=[id@0 as id, name@1 as name, CAST(id@0 AS Int32) as CAST(t2.id AS Int32)] -----------------AggregateExec: mode=FinalPartitioned, gby=[id@0 as id, name@1 as name], aggr=[] -------------------CoalesceBatchesExec: target_batch_size=2 ---------------------RepartitionExec: partitioning=Hash([id@0, name@1], 4), input_partitions=4 -----------------------AggregateExec: mode=Partial, gby=[id@0 as id, name@1 as name], aggr=[] -------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------MemoryExec: partitions=1, partition_sizes=[1] -----------CoalesceBatchesExec: target_batch_size=2 -------------RepartitionExec: partitioning=Hash([id@0, name@1], 4), input_partitions=4 ---------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------MemoryExec: partitions=1, partition_sizes=[1] +--------RepartitionExec: partitioning=Hash([CAST(t2.id AS Int32)@2, name@1], 4), input_partitions=4 +----------ProjectionExec: expr=[id@0 as id, name@1 as name, CAST(id@0 AS Int32) as CAST(t2.id AS Int32)] +------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] +--ProjectionExec: expr=[CAST(id@0 AS Int32) as id, name@1 as name] +----CoalesceBatchesExec: target_batch_size=2 +------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(CAST(t2.id AS Int32)@2, id@0), (name@1, name@1)] +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([CAST(t2.id AS Int32)@2, name@1], 4), input_partitions=4 +------------ProjectionExec: expr=[id@0 as id, name@1 as name, CAST(id@0 AS Int32) as CAST(t2.id AS Int32)] +--------------AggregateExec: mode=FinalPartitioned, gby=[id@0 as id, name@1 as name], aggr=[] +----------------CoalesceBatchesExec: target_batch_size=2 +------------------RepartitionExec: partitioning=Hash([id@0, name@1], 4), input_partitions=4 +--------------------AggregateExec: mode=Partial, gby=[id@0 as id, name@1 as name], aggr=[] +----------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------------------MemoryExec: partitions=1, partition_sizes=[1] +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([id@0, name@1], 4), input_partitions=4 +------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] + query IT rowsort ( @@ -547,11 +551,11 @@ UnionExec ------CoalesceBatchesExec: target_batch_size=2 --------RepartitionExec: partitioning=Hash([Int64(1)@0], 4), input_partitions=1 ----------AggregateExec: mode=Partial, gby=[1 as Int64(1)], aggr=[] -------------EmptyExec: produce_one_row=true +------------PlaceholderRowExec --ProjectionExec: expr=[2 as a] -----EmptyExec: produce_one_row=true +----PlaceholderRowExec --ProjectionExec: expr=[3 as a] -----EmptyExec: produce_one_row=true +----PlaceholderRowExec # test UNION ALL aliases correctly with aliased subquery query TT @@ -579,8 +583,7 @@ UnionExec --------RepartitionExec: partitioning=Hash([n@0], 4), input_partitions=1 ----------AggregateExec: mode=Partial, gby=[n@0 as n], aggr=[COUNT(*)] ------------ProjectionExec: expr=[5 as n] ---------------EmptyExec: produce_one_row=true ---ProjectionExec: expr=[x@0 as count, y@1 as n] -----ProjectionExec: expr=[1 as x, MAX(Int64(10))@0 as y] -------AggregateExec: mode=Single, gby=[], aggr=[MAX(Int64(10))] ---------EmptyExec: produce_one_row=true +--------------PlaceholderRowExec +--ProjectionExec: expr=[1 as count, MAX(Int64(10))@0 as n] +----AggregateExec: mode=Single, gby=[], aggr=[MAX(Int64(10))] +------PlaceholderRowExec diff --git a/datafusion/sqllogictest/test_files/update.slt b/datafusion/sqllogictest/test_files/update.slt index cb8c6a4fac28a..6412c3ca859e4 100644 --- a/datafusion/sqllogictest/test_files/update.slt +++ b/datafusion/sqllogictest/test_files/update.slt @@ -76,4 +76,17 @@ create table t3(a int, b varchar, c double, d int); # set from mutiple tables, sqlparser only supports from one table query error DataFusion error: SQL error: ParserError\("Expected end of statement, found: ,"\) -explain update t1 set b = t2.b, c = t3.a, d = 1 from t2, t3 where t1.a = t2.a and t1.a = t3.a; \ No newline at end of file +explain update t1 set b = t2.b, c = t3.a, d = 1 from t2, t3 where t1.a = t2.a and t1.a = t3.a; + +# test table alias +query TT +explain update t1 as T set b = t2.b, c = t.a, d = 1 from t2 where t.a = t2.a and t.b > 'foo' and t2.c > 1.0; +---- +logical_plan +Dml: op=[Update] table=[t1] +--Projection: t.a AS a, t2.b AS b, CAST(t.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d +----Filter: t.a = t2.a AND t.b > Utf8("foo") AND t2.c > Float64(1) +------CrossJoin: +--------SubqueryAlias: t +----------TableScan: t1 +--------TableScan: t2 diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 2eb0576d559bc..7d6d59201396d 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -279,13 +279,13 @@ SortPreservingMergeExec: [b@0 ASC NULLS LAST] ------------AggregateExec: mode=Partial, gby=[b@1 as b], aggr=[MAX(d.a)] --------------UnionExec ----------------ProjectionExec: expr=[1 as a, aa as b] -------------------EmptyExec: produce_one_row=true +------------------PlaceholderRowExec ----------------ProjectionExec: expr=[3 as a, aa as b] -------------------EmptyExec: produce_one_row=true +------------------PlaceholderRowExec ----------------ProjectionExec: expr=[5 as a, bb as b] -------------------EmptyExec: produce_one_row=true +------------------PlaceholderRowExec ----------------ProjectionExec: expr=[7 as a, bb as b] -------------------EmptyExec: produce_one_row=true +------------------PlaceholderRowExec # Check actual result: query TI @@ -365,13 +365,13 @@ SortPreservingMergeExec: [b@0 ASC NULLS LAST] --------------RepartitionExec: partitioning=Hash([b@1], 4), input_partitions=4 ----------------UnionExec ------------------ProjectionExec: expr=[1 as a, aa as b] ---------------------EmptyExec: produce_one_row=true +--------------------PlaceholderRowExec ------------------ProjectionExec: expr=[3 as a, aa as b] ---------------------EmptyExec: produce_one_row=true +--------------------PlaceholderRowExec ------------------ProjectionExec: expr=[5 as a, bb as b] ---------------------EmptyExec: produce_one_row=true +--------------------PlaceholderRowExec ------------------ProjectionExec: expr=[7 as a, bb as b] ---------------------EmptyExec: produce_one_row=true +--------------------PlaceholderRowExec # check actual result @@ -895,14 +895,14 @@ SELECT statement ok create table temp as values -(1664264591000000000), -(1664264592000000000), -(1664264592000000000), -(1664264593000000000), -(1664264594000000000), -(1664364594000000000), -(1664464594000000000), -(1664564594000000000); +(1664264591), +(1664264592), +(1664264592), +(1664264593), +(1664264594), +(1664364594), +(1664464594), +(1664564594); statement ok create table t as select cast(column1 as timestamp) as ts from temp; @@ -1731,26 +1731,28 @@ logical_plan Projection: COUNT(*) AS global_count --Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] ----SubqueryAlias: a -------Sort: aggregate_test_100.c1 ASC NULLS LAST ---------Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[]] -----------Projection: aggregate_test_100.c1 -------------Filter: aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434") ---------------TableScan: aggregate_test_100 projection=[c1, c13], partial_filters=[aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434")] +------Projection: +--------Sort: aggregate_test_100.c1 ASC NULLS LAST +----------Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[]] +------------Projection: aggregate_test_100.c1 +--------------Filter: aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434") +----------------TableScan: aggregate_test_100 projection=[c1, c13], partial_filters=[aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434")] physical_plan ProjectionExec: expr=[COUNT(*)@0 as global_count] --AggregateExec: mode=Final, gby=[], aggr=[COUNT(*)] ----CoalescePartitionsExec ------AggregateExec: mode=Partial, gby=[], aggr=[COUNT(*)] --------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=2 -----------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[] -------------CoalesceBatchesExec: target_batch_size=4096 ---------------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 -----------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[] -------------------ProjectionExec: expr=[c1@0 as c1] ---------------------CoalesceBatchesExec: target_batch_size=4096 -----------------------FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434 -------------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ---------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c13], has_header=true +----------ProjectionExec: expr=[] +------------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[] +--------------CoalesceBatchesExec: target_batch_size=4096 +----------------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 +------------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[] +--------------------ProjectionExec: expr=[c1@0 as c1] +----------------------CoalesceBatchesExec: target_batch_size=4096 +------------------------FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434 +--------------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c13], has_header=true query I SELECT count(*) as global_count FROM @@ -2812,7 +2814,7 @@ ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, count1@2 as count1, count2 ----ProjectionExec: expr=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@4 as sum1, SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@2 as sum2, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@5 as count1, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@3 as count2, ts@0 as ts] ------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }], mode=[Sorted] --------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)) }, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)) }], mode=[Sorted] -----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], infinite_source=true, output_ordering=[ts@0 ASC NULLS LAST], has_header=true +----------StreamingTableExec: partition_sizes=1, projection=[ts, inc_col], infinite_source=true, output_ordering=[ts@0 ASC NULLS LAST] query IIII @@ -2858,7 +2860,7 @@ ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, count1@2 as count1, count2 ----ProjectionExec: expr=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@4 as sum1, SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@2 as sum2, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@5 as count1, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@3 as count2, ts@0 as ts] ------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }], mode=[Sorted] --------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)) }, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)) }], mode=[Sorted] -----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], infinite_source=true, output_ordering=[ts@0 ASC NULLS LAST], has_header=true +----------StreamingTableExec: partition_sizes=1, projection=[ts, inc_col], infinite_source=true, output_ordering=[ts@0 ASC NULLS LAST] query IIII @@ -2962,7 +2964,7 @@ ProjectionExec: expr=[a@1 as a, b@2 as b, c@3 as c, SUM(annotated_data_infinite2 ------------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: CurrentRow }], mode=[PartiallySorted([0, 1])] --------------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(5)) }], mode=[Sorted] ----------------ProjectionExec: expr=[CAST(c@2 AS Int64) as CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c, a@0 as a, b@1 as b, c@2 as c, d@3 as d] -------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true +------------------StreamingTableExec: partition_sizes=1, projection=[a, b, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] query IIIIIIIIIIIIIII @@ -3104,7 +3106,7 @@ CoalesceBatchesExec: target_batch_size=4096 ----GlobalLimitExec: skip=0, fetch=5 ------ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, ROW_NUMBER() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as rn1] --------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] -----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], infinite_source=true, output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +----------StreamingTableExec: partition_sizes=1, projection=[a0, a, b, c, d], infinite_source=true, output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST] # this is a negative test for asserting that window functions (other than ROW_NUMBER) # are not added to ordering equivalence @@ -3217,7 +3219,7 @@ ProjectionExec: expr=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_da ------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] --------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[PartiallySorted([0])] ----------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] -------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true +------------StreamingTableExec: partition_sizes=1, projection=[a, b, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] statement ok set datafusion.execution.target_partitions = 2; @@ -3243,19 +3245,19 @@ physical_plan ProjectionExec: expr=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as sum1, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum2, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum3, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as sum4] --BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Linear] ----CoalesceBatchesExec: target_batch_size=4096 -------SortPreservingRepartitionExec: partitioning=Hash([d@1], 2), input_partitions=2, sort_exprs=a@0 ASC NULLS LAST +------RepartitionExec: partitioning=Hash([d@1], 2), input_partitions=2, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST --------ProjectionExec: expr=[a@0 as a, d@3 as d, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] ----------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] ------------CoalesceBatchesExec: target_batch_size=4096 ---------------SortPreservingRepartitionExec: partitioning=Hash([b@1, a@0], 2), input_partitions=2, sort_exprs=a@0 ASC NULLS LAST,b@1 ASC NULLS LAST,c@2 ASC NULLS LAST +--------------RepartitionExec: partitioning=Hash([b@1, a@0], 2), input_partitions=2, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST,b@1 ASC NULLS LAST,c@2 ASC NULLS LAST ----------------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[PartiallySorted([0])] ------------------CoalesceBatchesExec: target_batch_size=4096 ---------------------SortPreservingRepartitionExec: partitioning=Hash([a@0, d@3], 2), input_partitions=2, sort_exprs=a@0 ASC NULLS LAST,b@1 ASC NULLS LAST,c@2 ASC NULLS LAST +--------------------RepartitionExec: partitioning=Hash([a@0, d@3], 2), input_partitions=2, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST,b@1 ASC NULLS LAST,c@2 ASC NULLS LAST ----------------------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] ------------------------CoalesceBatchesExec: target_batch_size=4096 ---------------------------SortPreservingRepartitionExec: partitioning=Hash([a@0, b@1], 2), input_partitions=2, sort_exprs=a@0 ASC NULLS LAST,b@1 ASC NULLS LAST,c@2 ASC NULLS LAST +--------------------------RepartitionExec: partitioning=Hash([a@0, b@1], 2), input_partitions=2, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST,b@1 ASC NULLS LAST,c@2 ASC NULLS LAST ----------------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true +------------------------------StreamingTableExec: partition_sizes=1, projection=[a, b, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] # reset the partition number 1 again statement ok @@ -3396,6 +3398,21 @@ WITH ORDER (a ASC, b ASC) WITH ORDER (c ASC) LOCATION '../core/tests/data/window_2.csv'; +# Create an unbounded source where there is multiple orderings. +statement ok +CREATE UNBOUNDED EXTERNAL TABLE multiple_ordered_table_inf ( + a0 INTEGER, + a INTEGER, + b INTEGER, + c INTEGER, + d INTEGER +) +STORED AS CSV +WITH HEADER ROW +WITH ORDER (a ASC, b ASC) +WITH ORDER (c ASC) +LOCATION '../core/tests/data/window_2.csv'; + # All of the window execs in the physical plan should work in the # sorted mode. query TT @@ -3414,7 +3431,7 @@ ProjectionExec: expr=[MIN(multiple_ordered_table.d) ORDER BY [multiple_ordered_t --BoundedWindowAggExec: wdw=[MIN(multiple_ordered_table.d) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MIN(multiple_ordered_table.d) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] ----ProjectionExec: expr=[c@2 as c, d@3 as d, MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] ------BoundedWindowAggExec: wdw=[MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] ---------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], has_header=true +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_orderings=[[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], [c@2 ASC NULLS LAST]], has_header=true query TT EXPLAIN SELECT MAX(c) OVER(PARTITION BY d ORDER BY c ASC) as max_c @@ -3446,7 +3463,7 @@ Projection: SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c physical_plan ProjectionExec: expr=[SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] --BoundedWindowAggExec: wdw=[SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] -----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true query TT explain SELECT SUM(d) OVER(PARTITION BY c, a ORDER BY b ASC) @@ -3459,7 +3476,7 @@ Projection: SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c physical_plan ProjectionExec: expr=[SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] --BoundedWindowAggExec: wdw=[SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] -----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], has_header=true +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_orderings=[[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], [c@2 ASC NULLS LAST]], has_header=true query I SELECT SUM(d) OVER(PARTITION BY c, a ORDER BY b ASC) @@ -3477,3 +3494,380 @@ query II select sum(1) over() x, sum(1) over () y ---- 1 1 + +# NTH_VALUE requirement is c DESC, However existing ordering is c ASC +# if we reverse window expression: "NTH_VALUE(c, 2) OVER(order by c DESC ) as nv1" +# as "NTH_VALUE(c, -2) OVER(order by c ASC RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) as nv1" +# Please note that: "NTH_VALUE(c, 2) OVER(order by c DESC ) as nv1" is same with +# "NTH_VALUE(c, 2) OVER(order by c DESC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as nv1" " +# we can produce same result without re-sorting the table. +# Unfortunately since window expression names are string, this change is not seen the plan (we do not do string manipulation). +# TODO: Reflect window expression reversal in the plans. +query TT +EXPLAIN SELECT c, NTH_VALUE(c, 2) OVER(order by c DESC) as nv1 + FROM multiple_ordered_table + ORDER BY c ASC + LIMIT 5 +---- +logical_plan +Limit: skip=0, fetch=5 +--Sort: multiple_ordered_table.c ASC NULLS LAST, fetch=5 +----Projection: multiple_ordered_table.c, NTH_VALUE(multiple_ordered_table.c,Int64(2)) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS nv1 +------WindowAggr: windowExpr=[[NTH_VALUE(multiple_ordered_table.c, Int64(2)) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +--------TableScan: multiple_ordered_table projection=[c] +physical_plan +GlobalLimitExec: skip=0, fetch=5 +--ProjectionExec: expr=[c@0 as c, NTH_VALUE(multiple_ordered_table.c,Int64(2)) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as nv1] +----WindowAggExec: wdw=[NTH_VALUE(multiple_ordered_table.c,Int64(2)) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "NTH_VALUE(multiple_ordered_table.c,Int64(2)) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int32(NULL)) }] +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c], output_ordering=[c@0 ASC NULLS LAST], has_header=true + +query II +SELECT c, NTH_VALUE(c, 2) OVER(order by c DESC) as nv1 + FROM multiple_ordered_table + ORDER BY c ASC + LIMIT 5 +---- +0 98 +1 98 +2 98 +3 98 +4 98 + +query II +SELECT c, NTH_VALUE(c, 2) OVER(order by c DESC) as nv1 + FROM multiple_ordered_table + ORDER BY c DESC + LIMIT 5 +---- +99 NULL +98 98 +97 98 +96 98 +95 98 + +statement ok +set datafusion.execution.target_partitions = 2; + +# source is ordered by [a ASC, b ASC], [c ASC] +# after sort preserving repartition and sort preserving merge +# we should still have the orderings [a ASC, b ASC], [c ASC]. +query TT +EXPLAIN SELECT *, + AVG(d) OVER sliding_window AS avg_d +FROM multiple_ordered_table_inf +WINDOW sliding_window AS ( + PARTITION BY d + ORDER BY a RANGE 10 PRECEDING +) +ORDER BY c +---- +logical_plan +Sort: multiple_ordered_table_inf.c ASC NULLS LAST +--Projection: multiple_ordered_table_inf.a0, multiple_ordered_table_inf.a, multiple_ordered_table_inf.b, multiple_ordered_table_inf.c, multiple_ordered_table_inf.d, AVG(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW AS avg_d +----WindowAggr: windowExpr=[[AVG(CAST(multiple_ordered_table_inf.d AS Float64)) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW]] +------TableScan: multiple_ordered_table_inf projection=[a0, a, b, c, d] +physical_plan +SortPreservingMergeExec: [c@3 ASC NULLS LAST] +--ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, AVG(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW@5 as avg_d] +----BoundedWindowAggExec: wdw=[AVG(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW: Ok(Field { name: "AVG(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: CurrentRow }], mode=[Linear] +------CoalesceBatchesExec: target_batch_size=4096 +--------RepartitionExec: partitioning=Hash([d@4], 2), input_partitions=2, preserve_order=true, sort_exprs=a@1 ASC NULLS LAST,b@2 ASC NULLS LAST +----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------StreamingTableExec: partition_sizes=1, projection=[a0, a, b, c, d], infinite_source=true, output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST] + +# CTAS with NTILE function +statement ok +CREATE TABLE new_table AS SELECT NTILE(2) OVER(ORDER BY c1) AS ntile_2 FROM aggregate_test_100; + +statement ok +DROP TABLE new_table; + +statement ok +CREATE TABLE t1 (a int) AS VALUES (1), (2), (3); + +query I +SELECT NTILE(9223377) OVER(ORDER BY a) FROM t1; +---- +1 +2 +3 + +query I +SELECT NTILE(9223372036854775809) OVER(ORDER BY a) FROM t1; +---- +1 +2 +3 + +query error DataFusion error: Execution error: NTILE requires a positive integer +SELECT NTILE(-922337203685477580) OVER(ORDER BY a) FROM t1; + +query error DataFusion error: Execution error: Table 't' doesn't exist\. +DROP TABLE t; + +# NTILE with PARTITION BY, those tests from duckdb: https://github.com/duckdb/duckdb/blob/main/test/sql/window/test_ntile.test +statement ok +CREATE TABLE score_board (team_name VARCHAR, player VARCHAR, score INTEGER) as VALUES + ('Mongrels', 'Apu', 350), + ('Mongrels', 'Ned', 666), + ('Mongrels', 'Meg', 1030), + ('Mongrels', 'Burns', 1270), + ('Simpsons', 'Homer', 1), + ('Simpsons', 'Lisa', 710), + ('Simpsons', 'Marge', 990), + ('Simpsons', 'Bart', 2010) + +query TTII +SELECT + team_name, + player, + score, + NTILE(2) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s +ORDER BY team_name, score; +---- +Mongrels Apu 350 1 +Mongrels Ned 666 1 +Mongrels Meg 1030 2 +Mongrels Burns 1270 2 +Simpsons Homer 1 1 +Simpsons Lisa 710 1 +Simpsons Marge 990 2 +Simpsons Bart 2010 2 + +query TTII +SELECT + team_name, + player, + score, + NTILE(2) OVER (ORDER BY score ASC) AS NTILE +FROM score_board s +ORDER BY score; +---- +Simpsons Homer 1 1 +Mongrels Apu 350 1 +Mongrels Ned 666 1 +Simpsons Lisa 710 1 +Simpsons Marge 990 2 +Mongrels Meg 1030 2 +Mongrels Burns 1270 2 +Simpsons Bart 2010 2 + +query TTII +SELECT + team_name, + player, + score, + NTILE(1000) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s +ORDER BY team_name, score; +---- +Mongrels Apu 350 1 +Mongrels Ned 666 2 +Mongrels Meg 1030 3 +Mongrels Burns 1270 4 +Simpsons Homer 1 1 +Simpsons Lisa 710 2 +Simpsons Marge 990 3 +Simpsons Bart 2010 4 + +query TTII +SELECT + team_name, + player, + score, + NTILE(1) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s +ORDER BY team_name, score; +---- +Mongrels Apu 350 1 +Mongrels Ned 666 1 +Mongrels Meg 1030 1 +Mongrels Burns 1270 1 +Simpsons Homer 1 1 +Simpsons Lisa 710 1 +Simpsons Marge 990 1 +Simpsons Bart 2010 1 + +# incorrect number of parameters for ntile +query error DataFusion error: Execution error: NTILE requires a positive integer, but finds NULL +SELECT + NTILE(NULL) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +query error DataFusion error: Execution error: NTILE requires a positive integer +SELECT + NTILE(-1) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +query error DataFusion error: Execution error: NTILE requires a positive integer +SELECT + NTILE(0) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +statement error +SELECT + NTILE() OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +statement error +SELECT + NTILE(1,2) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +statement error +SELECT + NTILE(1,2,3) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +statement error +SELECT + NTILE(1,2,3,4) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +statement ok +DROP TABLE score_board; + +# Regularize RANGE frame +query error DataFusion error: Error during planning: RANGE requires exactly one ORDER BY column +select a, + rank() over (order by a, a + 1 RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a + +query II +select a, + rank() over (order by a RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 2 + +query error DataFusion error: Error during planning: RANGE requires exactly one ORDER BY column +select a, + rank() over (RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a + +query II +select a, + rank() over (order by a, a + 1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 2 + +query II +select a, + rank() over (order by a RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 2 + +query II +select a, + rank() over (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 1 + +query I +select rank() over (RANGE between UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q; +---- +1 +1 + +query II +select a, + rank() over (order by 1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 1 + +query II +select a, + rank() over (order by null RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 1 + +# support scalar value in ORDER BY +query I +select rank() over (order by 1) rnk from (select 1 a union all select 2 a) x +---- +1 +1 + +# support scalar value in ORDER BY +query I +select dense_rank() over () rnk from (select 1 a union all select 2 a) x +---- +1 +1 + +# support scalar value in both ORDER BY and PARTITION BY, RANK function +query IIIIII +select rank() over (partition by 1 order by 1) rnk, + rank() over (partition by a, 1 order by 1) rnk1, + rank() over (partition by a, 1 order by a, 1) rnk2, + rank() over (partition by 1) rnk3, + rank() over (partition by null) rnk4, + rank() over (partition by 1, null, a) rnk5 +from (select 1 a union all select 2 a) x +---- +1 1 1 1 1 1 +1 1 1 1 1 1 + +# support scalar value in both ORDER BY and PARTITION BY, ROW_NUMBER function +query IIIIII +select row_number() over (partition by 1 order by 1) rn, + row_number() over (partition by a, 1 order by 1) rn1, + row_number() over (partition by a, 1 order by a, 1) rn2, + row_number() over (partition by 1) rn3, + row_number() over (partition by null) rn4, + row_number() over (partition by 1, null, a) rn5 +from (select 1 a union all select 2 a) x; +---- +1 1 1 1 1 1 +2 1 1 2 2 1 + +# when partition by expression is empty row number result will be unique. +query TII +SELECT * +FROM (SELECT c1, c2, ROW_NUMBER() OVER() as rn + FROM aggregate_test_100 + LIMIT 5) +GROUP BY rn +ORDER BY rn; +---- +c 2 1 +d 5 2 +b 1 3 +a 1 4 +b 5 5 + +# when partition by expression is constant row number result will be unique. +query TII +SELECT * +FROM (SELECT c1, c2, ROW_NUMBER() OVER(PARTITION BY 3) as rn + FROM aggregate_test_100 + LIMIT 5) +GROUP BY rn +ORDER BY rn; +---- +c 2 1 +d 5 2 +b 1 3 +a 1 4 +b 5 5 + +statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression aggregate_test_100.c1 could not be resolved from available columns: rn +SELECT * +FROM (SELECT c1, c2, ROW_NUMBER() OVER(PARTITION BY c1) as rn + FROM aggregate_test_100 + LIMIT 5) +GROUP BY rn +ORDER BY rn; diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index 102b0a7c58f18..0a9a6e8dd12b4 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -35,7 +35,7 @@ itertools = { workspace = true } object_store = { workspace = true } prost = "0.12" prost-types = "0.12" -substrait = "0.19.0" +substrait = "0.21.0" tokio = "1.17" [features] diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index a15121652452b..a4ec3e7722a23 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -17,25 +17,30 @@ use async_recursion::async_recursion; use datafusion::arrow::datatypes::{DataType, Field, TimeUnit}; -use datafusion::common::{not_impl_err, DFField, DFSchema, DFSchemaRef}; +use datafusion::common::{ + not_impl_err, substrait_datafusion_err, substrait_err, DFField, DFSchema, DFSchemaRef, +}; +use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{ - aggregate_function, window_function::find_df_window_func, BinaryExpr, - BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator, + aggregate_function, expr::find_df_window_func, BinaryExpr, BuiltinScalarFunction, + Case, Expr, LogicalPlan, Operator, }; use datafusion::logical_expr::{ - expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, WindowFrameBound, - WindowFrameUnits, + expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, + Repartition, Subquery, WindowFrameBound, WindowFrameUnits, }; use datafusion::prelude::JoinType; use datafusion::sql::TableReference; use datafusion::{ error::{DataFusionError, Result}, - optimizer::utils::split_conjunction, + logical_expr::utils::split_conjunction, prelude::{Column, SessionContext}, scalar::ScalarValue, }; -use substrait::proto::expression::{Literal, ScalarFunction}; +use substrait::proto::exchange_rel::ExchangeKind; +use substrait::proto::expression::subquery::SubqueryType; +use substrait::proto::expression::{FieldReference, Literal, ScalarFunction}; use substrait::proto::{ aggregate_function::AggregationInvocation, expression::{ @@ -57,7 +62,7 @@ use substrait::proto::{ use substrait::proto::{FunctionArgument, SortField}; use datafusion::common::plan_err; -use datafusion::logical_expr::expr::{InList, Sort}; +use datafusion::logical_expr::expr::{InList, InSubquery, Sort}; use std::collections::HashMap; use std::str::FromStr; use std::sync::Arc; @@ -72,12 +77,7 @@ use crate::variation_const::{ enum ScalarFunctionType { Builtin(BuiltinScalarFunction), Op(Operator), - /// [Expr::Not] - Not, - /// [Expr::Like] Used for filtering rows based on the given wildcard pattern. Case sensitive - Like, - /// [Expr::Like] Case insensitive operator counterpart of `Like` - ILike, + Expr(BuiltinExprBuilder), } pub fn name_to_op(name: &str) -> Result { @@ -122,12 +122,11 @@ fn scalar_function_type_from_str(name: &str) -> Result { return Ok(ScalarFunctionType::Builtin(fun)); } - match name { - "not" => Ok(ScalarFunctionType::Not), - "like" => Ok(ScalarFunctionType::Like), - "ilike" => Ok(ScalarFunctionType::ILike), - others => not_impl_err!("Unsupported function name: {others:?}"), + if let Some(builder) = BuiltinExprBuilder::try_from_name(name) { + return Ok(ScalarFunctionType::Expr(builder)); } + + not_impl_err!("Unsupported function name: {name:?}") } fn split_eq_and_noneq_join_predicate_with_nulls_equality( @@ -232,7 +231,8 @@ pub async fn from_substrait_rel( let mut exprs: Vec = vec![]; for e in &p.expressions { let x = - from_substrait_rex(e, input.clone().schema(), extensions).await?; + from_substrait_rex(ctx, e, input.clone().schema(), extensions) + .await?; // if the expression is WindowFunction, wrap in a Window relation // before returning and do not add to list of this Projection's expression list // otherwise, add expression to the Projection's expression list @@ -258,7 +258,8 @@ pub async fn from_substrait_rel( ); if let Some(condition) = filter.condition.as_ref() { let expr = - from_substrait_rex(condition, input.schema(), extensions).await?; + from_substrait_rex(ctx, condition, input.schema(), extensions) + .await?; input.filter(expr.as_ref().clone())?.build() } else { not_impl_err!("Filter without an condition is not valid") @@ -290,7 +291,8 @@ pub async fn from_substrait_rel( from_substrait_rel(ctx, input, extensions).await?, ); let sorts = - from_substrait_sorts(&sort.sorts, input.schema(), extensions).await?; + from_substrait_sorts(ctx, &sort.sorts, input.schema(), extensions) + .await?; input.sort(sorts)?.build() } else { not_impl_err!("Sort without an input is not valid") @@ -308,7 +310,8 @@ pub async fn from_substrait_rel( 1 => { for e in &agg.groupings[0].grouping_expressions { let x = - from_substrait_rex(e, input.schema(), extensions).await?; + from_substrait_rex(ctx, e, input.schema(), extensions) + .await?; group_expr.push(x.as_ref().clone()); } } @@ -317,8 +320,13 @@ pub async fn from_substrait_rel( for grouping in &agg.groupings { let mut grouping_set = vec![]; for e in &grouping.grouping_expressions { - let x = from_substrait_rex(e, input.schema(), extensions) - .await?; + let x = from_substrait_rex( + ctx, + e, + input.schema(), + extensions, + ) + .await?; grouping_set.push(x.as_ref().clone()); } grouping_sets.push(grouping_set); @@ -336,7 +344,7 @@ pub async fn from_substrait_rel( for m in &agg.measures { let filter = match &m.filter { Some(fil) => Some(Box::new( - from_substrait_rex(fil, input.schema(), extensions) + from_substrait_rex(ctx, fil, input.schema(), extensions) .await? .as_ref() .clone(), @@ -359,6 +367,7 @@ pub async fn from_substrait_rel( _ => false, }; from_substrait_agg_func( + ctx, f, input.schema(), extensions, @@ -403,8 +412,8 @@ pub async fn from_substrait_rel( // Otherwise, build join with only the filter, without join keys match &join.expression.as_ref() { Some(expr) => { - let on = - from_substrait_rex(expr, &in_join_schema, extensions).await?; + let on = from_substrait_rex(ctx, expr, &in_join_schema, extensions) + .await?; // The join expression can contain both equal and non-equal ops. // As of datafusion 31.0.0, the equal and non equal join conditions are in separate fields. // So we extract each part as follows: @@ -426,6 +435,15 @@ pub async fn from_substrait_rel( None => plan_err!("JoinRel without join condition is not allowed"), } } + Some(RelType::Cross(cross)) => { + let left: LogicalPlanBuilder = LogicalPlanBuilder::from( + from_substrait_rel(ctx, cross.left.as_ref().unwrap(), extensions).await?, + ); + let right = + from_substrait_rel(ctx, cross.right.as_ref().unwrap(), extensions) + .await?; + left.cross_join(right)?.build() + } Some(RelType::Read(read)) => match &read.as_ref().read_type { Some(ReadType::NamedTable(nt)) => { let table_reference = match nt.names.len() { @@ -502,9 +520,7 @@ pub async fn from_substrait_rel( }, Some(RelType::ExtensionLeaf(extension)) => { let Some(ext_detail) = &extension.detail else { - return Err(DataFusionError::Substrait( - "Unexpected empty detail in ExtensionLeafRel".to_string(), - )); + return substrait_err!("Unexpected empty detail in ExtensionLeafRel"); }; let plan = ctx .state() @@ -514,18 +530,16 @@ pub async fn from_substrait_rel( } Some(RelType::ExtensionSingle(extension)) => { let Some(ext_detail) = &extension.detail else { - return Err(DataFusionError::Substrait( - "Unexpected empty detail in ExtensionSingleRel".to_string(), - )); + return substrait_err!("Unexpected empty detail in ExtensionSingleRel"); }; let plan = ctx .state() .serializer_registry() .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; let Some(input_rel) = &extension.input else { - return Err(DataFusionError::Substrait( - "ExtensionSingleRel doesn't contains input rel. Try use ExtensionLeafRel instead".to_string() - )); + return substrait_err!( + "ExtensionSingleRel doesn't contains input rel. Try use ExtensionLeafRel instead" + ); }; let input_plan = from_substrait_rel(ctx, input_rel, extensions).await?; let plan = plan.from_template(&plan.expressions(), &[input_plan]); @@ -533,9 +547,7 @@ pub async fn from_substrait_rel( } Some(RelType::ExtensionMulti(extension)) => { let Some(ext_detail) = &extension.detail else { - return Err(DataFusionError::Substrait( - "Unexpected empty detail in ExtensionSingleRel".to_string(), - )); + return substrait_err!("Unexpected empty detail in ExtensionSingleRel"); }; let plan = ctx .state() @@ -549,6 +561,45 @@ pub async fn from_substrait_rel( let plan = plan.from_template(&plan.expressions(), &inputs); Ok(LogicalPlan::Extension(Extension { node: plan })) } + Some(RelType::Exchange(exchange)) => { + let Some(input) = exchange.input.as_ref() else { + return substrait_err!("Unexpected empty input in ExchangeRel"); + }; + let input = Arc::new(from_substrait_rel(ctx, input, extensions).await?); + + let Some(exchange_kind) = &exchange.exchange_kind else { + return substrait_err!("Unexpected empty input in ExchangeRel"); + }; + + // ref: https://substrait.io/relations/physical_relations/#exchange-types + let partitioning_scheme = match exchange_kind { + ExchangeKind::ScatterByFields(scatter_fields) => { + let mut partition_columns = vec![]; + let input_schema = input.schema(); + for field_ref in &scatter_fields.fields { + let column = + from_substrait_field_reference(field_ref, input_schema)?; + partition_columns.push(column); + } + Partitioning::Hash( + partition_columns, + exchange.partition_count as usize, + ) + } + ExchangeKind::RoundRobin(_) => { + Partitioning::RoundRobinBatch(exchange.partition_count as usize) + } + ExchangeKind::SingleTarget(_) + | ExchangeKind::MultiTarget(_) + | ExchangeKind::Broadcast(_) => { + return not_impl_err!("Unsupported exchange kind: {exchange_kind:?}"); + } + }; + Ok(LogicalPlan::Repartition(Repartition { + input, + partitioning_scheme, + })) + } _ => not_impl_err!("Unsupported RelType: {:?}", rel.rel_type), } } @@ -571,14 +622,16 @@ fn from_substrait_jointype(join_type: i32) -> Result { /// Convert Substrait Sorts to DataFusion Exprs pub async fn from_substrait_sorts( + ctx: &SessionContext, substrait_sorts: &Vec, input_schema: &DFSchema, extensions: &HashMap, ) -> Result> { let mut sorts: Vec = vec![]; for s in substrait_sorts { - let expr = from_substrait_rex(s.expr.as_ref().unwrap(), input_schema, extensions) - .await?; + let expr = + from_substrait_rex(ctx, s.expr.as_ref().unwrap(), input_schema, extensions) + .await?; let asc_nullfirst = match &s.sort_kind { Some(k) => match k { Direction(d) => { @@ -619,13 +672,14 @@ pub async fn from_substrait_sorts( /// Convert Substrait Expressions to DataFusion Exprs pub async fn from_substrait_rex_vec( + ctx: &SessionContext, exprs: &Vec, input_schema: &DFSchema, extensions: &HashMap, ) -> Result> { let mut expressions: Vec = vec![]; for expr in exprs { - let expression = from_substrait_rex(expr, input_schema, extensions).await?; + let expression = from_substrait_rex(ctx, expr, input_schema, extensions).await?; expressions.push(expression.as_ref().clone()); } Ok(expressions) @@ -633,6 +687,7 @@ pub async fn from_substrait_rex_vec( /// Convert Substrait FunctionArguments to DataFusion Exprs pub async fn from_substriat_func_args( + ctx: &SessionContext, arguments: &Vec, input_schema: &DFSchema, extensions: &HashMap, @@ -641,7 +696,7 @@ pub async fn from_substriat_func_args( for arg in arguments { let arg_expr = match &arg.arg_type { Some(ArgType::Value(e)) => { - from_substrait_rex(e, input_schema, extensions).await + from_substrait_rex(ctx, e, input_schema, extensions).await } _ => { not_impl_err!("Aggregated function argument non-Value type not supported") @@ -654,6 +709,7 @@ pub async fn from_substriat_func_args( /// Convert Substrait AggregateFunction to DataFusion Expr pub async fn from_substrait_agg_func( + ctx: &SessionContext, f: &AggregateFunction, input_schema: &DFSchema, extensions: &HashMap, @@ -665,7 +721,7 @@ pub async fn from_substrait_agg_func( for arg in &f.arguments { let arg_expr = match &arg.arg_type { Some(ArgType::Value(e)) => { - from_substrait_rex(e, input_schema, extensions).await + from_substrait_rex(ctx, e, input_schema, extensions).await } _ => { not_impl_err!("Aggregated function argument non-Value type not supported") @@ -674,28 +730,36 @@ pub async fn from_substrait_agg_func( args.push(arg_expr?.as_ref().clone()); } - let fun = match extensions.get(&f.function_reference) { - Some(function_name) => { - aggregate_function::AggregateFunction::from_str(function_name) - } - None => not_impl_err!( - "Aggregated function not found: function anchor = {:?}", + let Some(function_name) = extensions.get(&f.function_reference) else { + return plan_err!( + "Aggregate function not registered: function anchor = {:?}", f.function_reference - ), + ); }; - Ok(Arc::new(Expr::AggregateFunction(expr::AggregateFunction { - fun: fun.unwrap(), - args, - distinct, - filter, - order_by, - }))) + // try udaf first, then built-in aggr fn. + if let Ok(fun) = ctx.udaf(function_name) { + Ok(Arc::new(Expr::AggregateFunction( + expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by), + ))) + } else if let Ok(fun) = aggregate_function::AggregateFunction::from_str(function_name) + { + Ok(Arc::new(Expr::AggregateFunction( + expr::AggregateFunction::new(fun, args, distinct, filter, order_by), + ))) + } else { + not_impl_err!( + "Aggregated function {} is not supported: function anchor = {:?}", + function_name, + f.function_reference + ) + } } /// Convert Substrait Rex to DataFusion Expr #[async_recursion] pub async fn from_substrait_rex( + ctx: &SessionContext, e: &Expression, input_schema: &DFSchema, extensions: &HashMap, @@ -706,37 +770,24 @@ pub async fn from_substrait_rex( let substrait_list = s.options.as_ref(); Ok(Arc::new(Expr::InList(InList { expr: Box::new( - from_substrait_rex(substrait_expr, input_schema, extensions) + from_substrait_rex(ctx, substrait_expr, input_schema, extensions) .await? .as_ref() .clone(), ), - list: from_substrait_rex_vec(substrait_list, input_schema, extensions) - .await?, + list: from_substrait_rex_vec( + ctx, + substrait_list, + input_schema, + extensions, + ) + .await?, negated: false, }))) } - Some(RexType::Selection(field_ref)) => match &field_ref.reference_type { - Some(DirectReference(direct)) => match &direct.reference_type.as_ref() { - Some(StructField(x)) => match &x.child.as_ref() { - Some(_) => not_impl_err!( - "Direct reference StructField with child is not supported" - ), - None => { - let column = - input_schema.field(x.field as usize).qualified_column(); - Ok(Arc::new(Expr::Column(Column { - relation: column.relation, - name: column.name, - }))) - } - }, - _ => not_impl_err!( - "Direct reference with types other than StructField is not supported" - ), - }, - _ => not_impl_err!("unsupported field ref type"), - }, + Some(RexType::Selection(field_ref)) => Ok(Arc::new( + from_substrait_field_reference(field_ref, input_schema)?, + )), Some(RexType::IfThen(if_then)) => { // Parse `ifs` // If the first element does not have a `then` part, then we can assume it's a base expression @@ -748,6 +799,7 @@ pub async fn from_substrait_rex( if if_expr.then.is_none() { expr = Some(Box::new( from_substrait_rex( + ctx, if_expr.r#if.as_ref().unwrap(), input_schema, extensions, @@ -762,6 +814,7 @@ pub async fn from_substrait_rex( when_then_expr.push(( Box::new( from_substrait_rex( + ctx, if_expr.r#if.as_ref().unwrap(), input_schema, extensions, @@ -772,6 +825,7 @@ pub async fn from_substrait_rex( ), Box::new( from_substrait_rex( + ctx, if_expr.then.as_ref().unwrap(), input_schema, extensions, @@ -785,7 +839,7 @@ pub async fn from_substrait_rex( // Parse `else` let else_expr = match &if_then.r#else { Some(e) => Some(Box::new( - from_substrait_rex(e, input_schema, extensions) + from_substrait_rex(ctx, e, input_schema, extensions) .await? .as_ref() .clone(), @@ -812,7 +866,7 @@ pub async fn from_substrait_rex( for arg in &f.arguments { let arg_expr = match &arg.arg_type { Some(ArgType::Value(e)) => { - from_substrait_rex(e, input_schema, extensions).await + from_substrait_rex(ctx, e, input_schema, extensions).await } _ => not_impl_err!( "Aggregated function argument non-Value type not supported" @@ -820,10 +874,9 @@ pub async fn from_substrait_rex( }; args.push(arg_expr?.as_ref().clone()); } - Ok(Arc::new(Expr::ScalarFunction(expr::ScalarFunction { - fun, - args, - }))) + Ok(Arc::new(Expr::ScalarFunction(expr::ScalarFunction::new( + fun, args, + )))) } ScalarFunctionType::Op(op) => { if f.arguments.len() != 2 { @@ -838,14 +891,14 @@ pub async fn from_substrait_rex( (Some(ArgType::Value(l)), Some(ArgType::Value(r))) => { Ok(Arc::new(Expr::BinaryExpr(BinaryExpr { left: Box::new( - from_substrait_rex(l, input_schema, extensions) + from_substrait_rex(ctx, l, input_schema, extensions) .await? .as_ref() .clone(), ), op, right: Box::new( - from_substrait_rex(r, input_schema, extensions) + from_substrait_rex(ctx, r, input_schema, extensions) .await? .as_ref() .clone(), @@ -857,28 +910,8 @@ pub async fn from_substrait_rex( ), } } - ScalarFunctionType::Not => { - let arg = f.arguments.first().ok_or_else(|| { - DataFusionError::Substrait( - "expect one argument for `NOT` expr".to_string(), - ) - })?; - match &arg.arg_type { - Some(ArgType::Value(e)) => { - let expr = from_substrait_rex(e, input_schema, extensions) - .await? - .as_ref() - .clone(); - Ok(Arc::new(Expr::Not(Box::new(expr)))) - } - _ => not_impl_err!("Invalid arguments for Not expression"), - } - } - ScalarFunctionType::Like => { - make_datafusion_like(false, f, input_schema, extensions).await - } - ScalarFunctionType::ILike => { - make_datafusion_like(true, f, input_schema, extensions).await + ScalarFunctionType::Expr(builder) => { + builder.build(ctx, f, input_schema, extensions).await } } } @@ -890,6 +923,7 @@ pub async fn from_substrait_rex( Some(output_type) => Ok(Arc::new(Expr::Cast(Cast::new( Box::new( from_substrait_rex( + ctx, cast.as_ref().input.as_ref().unwrap().as_ref(), input_schema, extensions, @@ -900,9 +934,7 @@ pub async fn from_substrait_rex( ), from_substrait_type(output_type)?, )))), - None => Err(DataFusionError::Substrait( - "Cast experssion without output type is not allowed".to_string(), - )), + None => substrait_err!("Cast experssion without output type is not allowed"), }, Some(RexType::WindowFunction(window)) => { let fun = match extensions.get(&window.function_reference) { @@ -913,7 +945,8 @@ pub async fn from_substrait_rex( ), }; let order_by = - from_substrait_sorts(&window.sorts, input_schema, extensions).await?; + from_substrait_sorts(ctx, &window.sorts, input_schema, extensions) + .await?; // Substrait does not encode WindowFrameUnits so we're using a simple logic to determine the units // If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary // If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row @@ -926,12 +959,14 @@ pub async fn from_substrait_rex( Ok(Arc::new(Expr::WindowFunction(expr::WindowFunction { fun: fun?.unwrap(), args: from_substriat_func_args( + ctx, &window.arguments, input_schema, extensions, ) .await?, partition_by: from_substrait_rex_vec( + ctx, &window.partitions, input_schema, extensions, @@ -945,6 +980,51 @@ pub async fn from_substrait_rex( }, }))) } + Some(RexType::Subquery(subquery)) => match &subquery.as_ref().subquery_type { + Some(subquery_type) => match subquery_type { + SubqueryType::InPredicate(in_predicate) => { + if in_predicate.needles.len() != 1 { + Err(DataFusionError::Substrait( + "InPredicate Subquery type must have exactly one Needle expression" + .to_string(), + )) + } else { + let needle_expr = &in_predicate.needles[0]; + let haystack_expr = &in_predicate.haystack; + if let Some(haystack_expr) = haystack_expr { + let haystack_expr = + from_substrait_rel(ctx, haystack_expr, extensions) + .await?; + let outer_refs = haystack_expr.all_out_ref_exprs(); + Ok(Arc::new(Expr::InSubquery(InSubquery { + expr: Box::new( + from_substrait_rex( + ctx, + needle_expr, + input_schema, + extensions, + ) + .await? + .as_ref() + .clone(), + ), + subquery: Subquery { + subquery: Arc::new(haystack_expr), + outer_ref_columns: outer_refs, + }, + negated: false, + }))) + } else { + substrait_err!("InPredicate Subquery type must have a Haystack expression") + } + } + } + _ => substrait_err!("Subquery type not implemented"), + }, + None => { + substrait_err!("Subquery experssion without SubqueryType is not allowed") + } + }, _ => not_impl_err!("unsupported rex_type"), } } @@ -1027,9 +1107,7 @@ fn from_substrait_type(dt: &substrait::proto::Type) -> Result { r#type::Kind::List(list) => { let inner_type = from_substrait_type(list.r#type.as_ref().ok_or_else(|| { - DataFusionError::Substrait( - "List type must have inner type".to_string(), - ) + substrait_datafusion_err!("List type must have inner type") })?)?; let field = Arc::new(Field::new("list_item", inner_type, true)); match list.type_variation_reference { @@ -1081,9 +1159,7 @@ fn from_substrait_bound( } } }, - None => Err(DataFusionError::Substrait( - "WindowFunction missing Substrait Bound kind".to_string(), - )), + None => substrait_err!("WindowFunction missing Substrait Bound kind"), }, None => { if is_lower { @@ -1102,36 +1178,28 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result { DEFAULT_TYPE_REF => ScalarValue::Int8(Some(*n as i8)), UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt8(Some(*n as u8)), others => { - return Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {others}", - ))); + return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::I16(n)) => match lit.type_variation_reference { DEFAULT_TYPE_REF => ScalarValue::Int16(Some(*n as i16)), UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt16(Some(*n as u16)), others => { - return Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {others}", - ))); + return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::I32(n)) => match lit.type_variation_reference { DEFAULT_TYPE_REF => ScalarValue::Int32(Some(*n)), UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt32(Some(*n as u32)), others => { - return Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {others}", - ))); + return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::I64(n)) => match lit.type_variation_reference { DEFAULT_TYPE_REF => ScalarValue::Int64(Some(*n)), UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt64(Some(*n as u64)), others => { - return Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {others}", - ))); + return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::Fp32(f)) => ScalarValue::Float32(Some(*f)), @@ -1142,9 +1210,7 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result { TIMESTAMP_MICRO_TYPE_REF => ScalarValue::TimestampMicrosecond(Some(*t), None), TIMESTAMP_NANO_TYPE_REF => ScalarValue::TimestampNanosecond(Some(*t), None), others => { - return Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {others}", - ))); + return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::Date(d)) => ScalarValue::Date32(Some(*d)), @@ -1152,38 +1218,30 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result { DEFAULT_CONTAINER_TYPE_REF => ScalarValue::Utf8(Some(s.clone())), LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeUtf8(Some(s.clone())), others => { - return Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {others}", - ))); + return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::Binary(b)) => match lit.type_variation_reference { DEFAULT_CONTAINER_TYPE_REF => ScalarValue::Binary(Some(b.clone())), LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeBinary(Some(b.clone())), others => { - return Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {others}", - ))); + return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::FixedBinary(b)) => { ScalarValue::FixedSizeBinary(b.len() as _, Some(b.clone())) } Some(LiteralType::Decimal(d)) => { - let value: [u8; 16] = - d.value - .clone() - .try_into() - .or(Err(DataFusionError::Substrait( - "Failed to parse decimal value".to_string(), - )))?; + let value: [u8; 16] = d + .value + .clone() + .try_into() + .or(substrait_err!("Failed to parse decimal value"))?; let p = d.precision.try_into().map_err(|e| { - DataFusionError::Substrait(format!( - "Failed to parse decimal precision: {e}" - )) + substrait_datafusion_err!("Failed to parse decimal precision: {e}") })?; let s = d.scale.try_into().map_err(|e| { - DataFusionError::Substrait(format!("Failed to parse decimal scale: {e}")) + substrait_datafusion_err!("Failed to parse decimal scale: {e}") })?; ScalarValue::Decimal128( Some(std::primitive::i128::from_le_bytes(value)), @@ -1281,50 +1339,157 @@ fn from_substrait_null(null_type: &Type) -> Result { } } -async fn make_datafusion_like( - case_insensitive: bool, - f: &ScalarFunction, +fn from_substrait_field_reference( + field_ref: &FieldReference, input_schema: &DFSchema, - extensions: &HashMap, -) -> Result> { - let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" }; - if f.arguments.len() != 3 { - return not_impl_err!("Expect three arguments for `{fn_name}` expr"); +) -> Result { + match &field_ref.reference_type { + Some(DirectReference(direct)) => match &direct.reference_type.as_ref() { + Some(StructField(x)) => match &x.child.as_ref() { + Some(_) => not_impl_err!( + "Direct reference StructField with child is not supported" + ), + None => { + let column = input_schema.field(x.field as usize).qualified_column(); + Ok(Expr::Column(Column { + relation: column.relation, + name: column.name, + })) + } + }, + _ => not_impl_err!( + "Direct reference with types other than StructField is not supported" + ), + }, + _ => not_impl_err!("unsupported field ref type"), } +} - let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { - return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); - }; - let expr = from_substrait_rex(expr_substrait, input_schema, extensions) - .await? - .as_ref() - .clone(); - let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else { - return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); - }; - let pattern = from_substrait_rex(pattern_substrait, input_schema, extensions) - .await? - .as_ref() - .clone(); - let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type else { - return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); - }; - let escape_char_expr = - from_substrait_rex(escape_char_substrait, input_schema, extensions) +/// Build [`Expr`] from its name and required inputs. +struct BuiltinExprBuilder { + expr_name: String, +} + +impl BuiltinExprBuilder { + pub fn try_from_name(name: &str) -> Option { + match name { + "not" | "like" | "ilike" | "is_null" | "is_not_null" | "is_true" + | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" + | "is_not_unknown" | "negative" => Some(Self { + expr_name: name.to_string(), + }), + _ => None, + } + } + + pub async fn build( + self, + ctx: &SessionContext, + f: &ScalarFunction, + input_schema: &DFSchema, + extensions: &HashMap, + ) -> Result> { + match self.expr_name.as_str() { + "like" => { + Self::build_like_expr(ctx, false, f, input_schema, extensions).await + } + "ilike" => { + Self::build_like_expr(ctx, true, f, input_schema, extensions).await + } + "not" | "negative" | "is_null" | "is_not_null" | "is_true" | "is_false" + | "is_not_true" | "is_not_false" | "is_unknown" | "is_not_unknown" => { + Self::build_unary_expr(ctx, &self.expr_name, f, input_schema, extensions) + .await + } + _ => { + not_impl_err!("Unsupported builtin expression: {}", self.expr_name) + } + } + } + + async fn build_unary_expr( + ctx: &SessionContext, + fn_name: &str, + f: &ScalarFunction, + input_schema: &DFSchema, + extensions: &HashMap, + ) -> Result> { + if f.arguments.len() != 1 { + return substrait_err!("Expect one argument for {fn_name} expr"); + } + let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { + return substrait_err!("Invalid arguments type for {fn_name} expr"); + }; + let arg = from_substrait_rex(ctx, expr_substrait, input_schema, extensions) .await? .as_ref() .clone(); - let Expr::Literal(ScalarValue::Utf8(escape_char)) = escape_char_expr else { - return Err(DataFusionError::Substrait(format!( - "Expect Utf8 literal for escape char, but found {escape_char_expr:?}", - ))); - }; + let arg = Box::new(arg); - Ok(Arc::new(Expr::Like(Like { - negated: false, - expr: Box::new(expr), - pattern: Box::new(pattern), - escape_char: escape_char.map(|c| c.chars().next().unwrap()), - case_insensitive, - }))) + let expr = match fn_name { + "not" => Expr::Not(arg), + "negative" => Expr::Negative(arg), + "is_null" => Expr::IsNull(arg), + "is_not_null" => Expr::IsNotNull(arg), + "is_true" => Expr::IsTrue(arg), + "is_false" => Expr::IsFalse(arg), + "is_not_true" => Expr::IsNotTrue(arg), + "is_not_false" => Expr::IsNotFalse(arg), + "is_unknown" => Expr::IsUnknown(arg), + "is_not_unknown" => Expr::IsNotUnknown(arg), + _ => return not_impl_err!("Unsupported builtin expression: {}", fn_name), + }; + + Ok(Arc::new(expr)) + } + + async fn build_like_expr( + ctx: &SessionContext, + case_insensitive: bool, + f: &ScalarFunction, + input_schema: &DFSchema, + extensions: &HashMap, + ) -> Result> { + let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" }; + if f.arguments.len() != 3 { + return substrait_err!("Expect three arguments for `{fn_name}` expr"); + } + + let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { + return substrait_err!("Invalid arguments type for `{fn_name}` expr"); + }; + let expr = from_substrait_rex(ctx, expr_substrait, input_schema, extensions) + .await? + .as_ref() + .clone(); + let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else { + return substrait_err!("Invalid arguments type for `{fn_name}` expr"); + }; + let pattern = + from_substrait_rex(ctx, pattern_substrait, input_schema, extensions) + .await? + .as_ref() + .clone(); + let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type else { + return substrait_err!("Invalid arguments type for `{fn_name}` expr"); + }; + let escape_char_expr = + from_substrait_rex(ctx, escape_char_substrait, input_schema, extensions) + .await? + .as_ref() + .clone(); + let Expr::Literal(ScalarValue::Utf8(escape_char)) = escape_char_expr else { + return substrait_err!( + "Expect Utf8 literal for escape char, but found {escape_char_expr:?}" + ); + }; + + Ok(Arc::new(Expr::Like(Like { + negated: false, + expr: Box::new(expr), + pattern: Box::new(pattern), + escape_char: escape_char.map(|c| c.chars().next().unwrap()), + case_insensitive, + }))) + } } diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index e3c6f94d43d58..ab0e8c860858e 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -19,7 +19,9 @@ use std::collections::HashMap; use std::ops::Deref; use std::sync::Arc; -use datafusion::logical_expr::{Like, WindowFrameUnits}; +use datafusion::logical_expr::{ + CrossJoin, Distinct, Like, Partitioning, WindowFrameUnits, +}; use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, error::{DataFusionError, Result}, @@ -28,18 +30,21 @@ use datafusion::{ scalar::ScalarValue, }; -use datafusion::common::DFSchemaRef; use datafusion::common::{exec_err, internal_err, not_impl_err}; +use datafusion::common::{substrait_err, DFSchemaRef}; #[allow(unused_imports)] use datafusion::logical_expr::aggregate_function; use datafusion::logical_expr::expr::{ - Alias, BinaryExpr, Case, Cast, GroupingSet, InList, - ScalarFunction as DFScalarFunction, Sort, WindowFunction, + AggregateFunctionDefinition, Alias, BinaryExpr, Case, Cast, GroupingSet, InList, + InSubquery, ScalarFunctionDefinition, Sort, WindowFunction, }; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::Expr; use prost_types::Any as ProtoAny; +use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields}; +use substrait::proto::expression::subquery::InPredicate; use substrait::proto::expression::window_function::BoundsType; +use substrait::proto::{CrossRel, ExchangeRel}; use substrait::{ proto::{ aggregate_function::AggregationInvocation, @@ -54,7 +59,8 @@ use substrait::{ window_function::bound::Kind as BoundKind, window_function::Bound, FieldReference, IfThen, Literal, MaskExpression, ReferenceSegment, RexType, - ScalarFunction, SingularOrList, WindowFunction as SubstraitWindowFunction, + ScalarFunction, SingularOrList, Subquery, + WindowFunction as SubstraitWindowFunction, }, extensions::{ self, @@ -163,7 +169,7 @@ pub fn to_substrait_rel( let expressions = p .expr .iter() - .map(|e| to_substrait_rex(e, p.input.schema(), 0, extension_info)) + .map(|e| to_substrait_rex(ctx, e, p.input.schema(), 0, extension_info)) .collect::>>()?; Ok(Box::new(Rel { rel_type: Some(RelType::Project(Box::new(ProjectRel { @@ -177,6 +183,7 @@ pub fn to_substrait_rel( LogicalPlan::Filter(filter) => { let input = to_substrait_rel(filter.input.as_ref(), ctx, extension_info)?; let filter_expr = to_substrait_rex( + ctx, &filter.predicate, filter.input.schema(), 0, @@ -210,7 +217,9 @@ pub fn to_substrait_rel( let sort_fields = sort .expr .iter() - .map(|e| substrait_sort_field(e, sort.input.schema(), extension_info)) + .map(|e| { + substrait_sort_field(ctx, e, sort.input.schema(), extension_info) + }) .collect::>>()?; Ok(Box::new(Rel { rel_type: Some(RelType::Sort(Box::new(SortRel { @@ -224,6 +233,7 @@ pub fn to_substrait_rel( LogicalPlan::Aggregate(agg) => { let input = to_substrait_rel(agg.input.as_ref(), ctx, extension_info)?; let groupings = to_substrait_groupings( + ctx, &agg.group_expr, agg.input.schema(), extension_info, @@ -231,7 +241,9 @@ pub fn to_substrait_rel( let measures = agg .aggr_expr .iter() - .map(|e| to_substrait_agg_measure(e, agg.input.schema(), extension_info)) + .map(|e| { + to_substrait_agg_measure(ctx, e, agg.input.schema(), extension_info) + }) .collect::>>()?; Ok(Box::new(Rel { @@ -244,11 +256,11 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::Distinct(distinct) => { + LogicalPlan::Distinct(Distinct::All(plan)) => { // Use Substrait's AggregateRel with empty measures to represent `select distinct` - let input = to_substrait_rel(distinct.input.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(plan.as_ref(), ctx, extension_info)?; // Get grouping keys from the input relation's number of output fields - let grouping = (0..distinct.input.schema().fields().len()) + let grouping = (0..plan.schema().fields().len()) .map(substrait_field_ref) .collect::>>()?; @@ -279,6 +291,7 @@ pub fn to_substrait_rel( let in_join_schema = join.left.schema().join(join.right.schema())?; let join_filter = match &join.filter { Some(filter) => Some(to_substrait_rex( + ctx, filter, &Arc::new(in_join_schema), 0, @@ -295,6 +308,7 @@ pub fn to_substrait_rel( Operator::Eq }; let join_on = to_substrait_join_expr( + ctx, &join.on, eq_op, join.left.schema(), @@ -332,6 +346,23 @@ pub fn to_substrait_rel( }))), })) } + LogicalPlan::CrossJoin(cross_join) => { + let CrossJoin { + left, + right, + schema: _, + } = cross_join; + let left = to_substrait_rel(left.as_ref(), ctx, extension_info)?; + let right = to_substrait_rel(right.as_ref(), ctx, extension_info)?; + Ok(Box::new(Rel { + rel_type: Some(RelType::Cross(Box::new(CrossRel { + common: None, + left: Some(left), + right: Some(right), + advanced_extension: None, + }))), + })) + } LogicalPlan::SubqueryAlias(alias) => { // Do nothing if encounters SubqueryAlias // since there is no corresponding relation type in Substrait @@ -380,6 +411,7 @@ pub fn to_substrait_rel( let mut window_exprs = vec![]; for expr in &window.window_expr { window_exprs.push(to_substrait_rex( + ctx, expr, window.input.schema(), 0, @@ -392,6 +424,53 @@ pub fn to_substrait_rel( rel_type: Some(RelType::Project(project_rel)), })) } + LogicalPlan::Repartition(repartition) => { + let input = + to_substrait_rel(repartition.input.as_ref(), ctx, extension_info)?; + let partition_count = match repartition.partitioning_scheme { + Partitioning::RoundRobinBatch(num) => num, + Partitioning::Hash(_, num) => num, + Partitioning::DistributeBy(_) => { + return not_impl_err!( + "Physical plan does not support DistributeBy partitioning" + ) + } + }; + // ref: https://substrait.io/relations/physical_relations/#exchange-types + let exchange_kind = match &repartition.partitioning_scheme { + Partitioning::RoundRobinBatch(_) => { + ExchangeKind::RoundRobin(RoundRobin::default()) + } + Partitioning::Hash(exprs, _) => { + let fields = exprs + .iter() + .map(|e| { + try_to_substrait_field_reference( + e, + repartition.input.schema(), + ) + }) + .collect::>>()?; + ExchangeKind::ScatterByFields(ScatterFields { fields }) + } + Partitioning::DistributeBy(_) => { + return not_impl_err!( + "Physical plan does not support DistributeBy partitioning" + ) + } + }; + let exchange_rel = ExchangeRel { + common: None, + input: Some(input), + exchange_kind: Some(exchange_kind), + advanced_extension: None, + partition_count: partition_count as i32, + targets: vec![], + }; + Ok(Box::new(Rel { + rel_type: Some(RelType::Exchange(Box::new(exchange_rel))), + })) + } LogicalPlan::Extension(extension_plan) => { let extension_bytes = ctx .state() @@ -432,6 +511,7 @@ pub fn to_substrait_rel( } fn to_substrait_join_expr( + ctx: &SessionContext, join_conditions: &Vec<(Expr, Expr)>, eq_op: Operator, left_schema: &DFSchemaRef, @@ -445,9 +525,10 @@ fn to_substrait_join_expr( let mut exprs: Vec = vec![]; for (left, right) in join_conditions { // Parse left - let l = to_substrait_rex(left, left_schema, 0, extension_info)?; + let l = to_substrait_rex(ctx, left, left_schema, 0, extension_info)?; // Parse right let r = to_substrait_rex( + ctx, right, right_schema, left_schema.fields().len(), // offset to return the correct index @@ -508,6 +589,7 @@ pub fn operator_to_name(op: Operator) -> &'static str { } pub fn parse_flat_grouping_exprs( + ctx: &SessionContext, exprs: &[Expr], schema: &DFSchemaRef, extension_info: &mut ( @@ -517,7 +599,7 @@ pub fn parse_flat_grouping_exprs( ) -> Result { let grouping_expressions = exprs .iter() - .map(|e| to_substrait_rex(e, schema, 0, extension_info)) + .map(|e| to_substrait_rex(ctx, e, schema, 0, extension_info)) .collect::>>()?; Ok(Grouping { grouping_expressions, @@ -525,7 +607,8 @@ pub fn parse_flat_grouping_exprs( } pub fn to_substrait_groupings( - exprs: &Vec, + ctx: &SessionContext, + exprs: &[Expr], schema: &DFSchemaRef, extension_info: &mut ( Vec, @@ -540,7 +623,9 @@ pub fn to_substrait_groupings( )), GroupingSet::GroupingSets(sets) => Ok(sets .iter() - .map(|set| parse_flat_grouping_exprs(set, schema, extension_info)) + .map(|set| { + parse_flat_grouping_exprs(ctx, set, schema, extension_info) + }) .collect::>>()?), GroupingSet::Rollup(set) => { let mut sets: Vec> = vec![vec![]]; @@ -550,17 +635,21 @@ pub fn to_substrait_groupings( Ok(sets .iter() .rev() - .map(|set| parse_flat_grouping_exprs(set, schema, extension_info)) + .map(|set| { + parse_flat_grouping_exprs(ctx, set, schema, extension_info) + }) .collect::>>()?) } }, _ => Ok(vec![parse_flat_grouping_exprs( + ctx, exprs, schema, extension_info, )?]), }, _ => Ok(vec![parse_flat_grouping_exprs( + ctx, exprs, schema, extension_info, @@ -570,6 +659,7 @@ pub fn to_substrait_groupings( #[allow(deprecated)] pub fn to_substrait_agg_measure( + ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, extension_info: &mut ( @@ -578,40 +668,75 @@ pub fn to_substrait_agg_measure( ), ) -> Result { match expr { - Expr::AggregateFunction(expr::AggregateFunction { fun, args, distinct, filter, order_by }) => { - let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::>>()? - } else { - vec![] - }; - let mut arguments: Vec = vec![]; - for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); - } - let function_name = fun.to_string().to_lowercase(); - let function_anchor = _register_function(function_name, extension_info); - Ok(Measure { - measure: Some(AggregateFunction { - function_reference: function_anchor, - arguments, - sorts, - output_type: None, - invocation: match distinct { - true => AggregationInvocation::Distinct as i32, - false => AggregationInvocation::All as i32, - }, - phase: AggregationPhase::Unspecified as i32, - args: vec![], - options: vec![], - }), - filter: match filter { - Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?), - None => None + Expr::AggregateFunction(expr::AggregateFunction { func_def, args, distinct, filter, order_by }) => { + match func_def { + AggregateFunctionDefinition::BuiltIn (fun) => { + let sorts = if let Some(order_by) = order_by { + order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extension_info)).collect::>>()? + } else { + vec![] + }; + let mut arguments: Vec = vec![]; + for arg in args { + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) }); + } + let function_anchor = _register_function(fun.to_string(), extension_info); + Ok(Measure { + measure: Some(AggregateFunction { + function_reference: function_anchor, + arguments, + sorts, + output_type: None, + invocation: match distinct { + true => AggregationInvocation::Distinct as i32, + false => AggregationInvocation::All as i32, + }, + phase: AggregationPhase::Unspecified as i32, + args: vec![], + options: vec![], + }), + filter: match filter { + Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extension_info)?), + None => None + } + }) } - }) + AggregateFunctionDefinition::UDF(fun) => { + let sorts = if let Some(order_by) = order_by { + order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extension_info)).collect::>>()? + } else { + vec![] + }; + let mut arguments: Vec = vec![]; + for arg in args { + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) }); + } + let function_anchor = _register_function(fun.name().to_string(), extension_info); + Ok(Measure { + measure: Some(AggregateFunction { + function_reference: function_anchor, + arguments, + sorts, + output_type: None, + invocation: AggregationInvocation::All as i32, + phase: AggregationPhase::Unspecified as i32, + args: vec![], + options: vec![], + }), + filter: match filter { + Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extension_info)?), + None => None + } + }) + } + AggregateFunctionDefinition::Name(name) => { + internal_err!("AggregateFunctionDefinition::Name({:?}) should be resolved during `AnalyzerRule`", name) + } + } + } Expr::Alias(Alias{expr,..})=> { - to_substrait_agg_measure(expr, schema, extension_info) + to_substrait_agg_measure(ctx, expr, schema, extension_info) } _ => internal_err!( "Expression must be compatible with aggregation. Unsupported expression: {:?}. ExpressionType: {:?}", @@ -623,6 +748,7 @@ pub fn to_substrait_agg_measure( /// Converts sort expression to corresponding substrait `SortField` fn to_substrait_sort_field( + ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, extension_info: &mut ( @@ -640,6 +766,7 @@ fn to_substrait_sort_field( }; Ok(SortField { expr: Some(to_substrait_rex( + ctx, sort.expr.deref(), schema, 0, @@ -703,8 +830,8 @@ pub fn make_binary_op_scalar_func( HashMap, ), ) -> Expression { - let function_name = operator_to_name(op).to_string().to_lowercase(); - let function_anchor = _register_function(function_name, extension_info); + let function_anchor = + _register_function(operator_to_name(op).to_string(), extension_info); Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, @@ -748,6 +875,7 @@ pub fn make_binary_op_scalar_func( /// * `extension_info` - Substrait extension info. Contains registered function information #[allow(deprecated)] pub fn to_substrait_rex( + ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, col_ref_offset: usize, @@ -764,10 +892,10 @@ pub fn to_substrait_rex( }) => { let substrait_list = list .iter() - .map(|x| to_substrait_rex(x, schema, col_ref_offset, extension_info)) + .map(|x| to_substrait_rex(ctx, x, schema, col_ref_offset, extension_info)) .collect::>>()?; let substrait_expr = - to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; let substrait_or_list = Expression { rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { @@ -795,11 +923,12 @@ pub fn to_substrait_rex( Ok(substrait_or_list) } } - Expr::ScalarFunction(DFScalarFunction { fun, args }) => { + Expr::ScalarFunction(fun) => { let mut arguments: Vec = vec![]; - for arg in args { + for arg in &fun.args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex( + ctx, arg, schema, col_ref_offset, @@ -807,8 +936,14 @@ pub fn to_substrait_rex( )?)), }); } - let function_name = fun.to_string().to_lowercase(); - let function_anchor = _register_function(function_name, extension_info); + + // function should be resolved during `AnalyzerRule` + if let ScalarFunctionDefinition::Name(_) = fun.func_def { + return internal_err!("Function `Expr` with name should be resolved."); + } + + let function_anchor = + _register_function(fun.name().to_string(), extension_info); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, @@ -828,11 +963,11 @@ pub fn to_substrait_rex( if *negated { // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) let substrait_expr = - to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; let substrait_low = - to_substrait_rex(low, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, low, schema, col_ref_offset, extension_info)?; let substrait_high = - to_substrait_rex(high, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, high, schema, col_ref_offset, extension_info)?; let l_expr = make_binary_op_scalar_func( &substrait_expr, @@ -856,11 +991,11 @@ pub fn to_substrait_rex( } else { // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) let substrait_expr = - to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; let substrait_low = - to_substrait_rex(low, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, low, schema, col_ref_offset, extension_info)?; let substrait_high = - to_substrait_rex(high, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, high, schema, col_ref_offset, extension_info)?; let l_expr = make_binary_op_scalar_func( &substrait_low, @@ -888,8 +1023,8 @@ pub fn to_substrait_rex( substrait_field_ref(index + col_ref_offset) } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let l = to_substrait_rex(left, schema, col_ref_offset, extension_info)?; - let r = to_substrait_rex(right, schema, col_ref_offset, extension_info)?; + let l = to_substrait_rex(ctx, left, schema, col_ref_offset, extension_info)?; + let r = to_substrait_rex(ctx, right, schema, col_ref_offset, extension_info)?; Ok(make_binary_op_scalar_func(&l, &r, *op, extension_info)) } @@ -904,6 +1039,7 @@ pub fn to_substrait_rex( // Base expression exists ifs.push(IfClause { r#if: Some(to_substrait_rex( + ctx, e, schema, col_ref_offset, @@ -916,12 +1052,14 @@ pub fn to_substrait_rex( for (r#if, then) in when_then_expr { ifs.push(IfClause { r#if: Some(to_substrait_rex( + ctx, r#if, schema, col_ref_offset, extension_info, )?), then: Some(to_substrait_rex( + ctx, then, schema, col_ref_offset, @@ -933,6 +1071,7 @@ pub fn to_substrait_rex( // Parse outer `else` let r#else: Option> = match else_expr { Some(e) => Some(Box::new(to_substrait_rex( + ctx, e, schema, col_ref_offset, @@ -951,6 +1090,7 @@ pub fn to_substrait_rex( substrait::proto::expression::Cast { r#type: Some(to_substrait_type(data_type)?), input: Some(Box::new(to_substrait_rex( + ctx, expr, schema, col_ref_offset, @@ -963,7 +1103,7 @@ pub fn to_substrait_rex( } Expr::Literal(value) => to_substrait_literal(value), Expr::Alias(Alias { expr, .. }) => { - to_substrait_rex(expr, schema, col_ref_offset, extension_info) + to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info) } Expr::WindowFunction(WindowFunction { fun, @@ -973,13 +1113,13 @@ pub fn to_substrait_rex( window_frame, }) => { // function reference - let function_name = fun.to_string().to_lowercase(); - let function_anchor = _register_function(function_name, extension_info); + let function_anchor = _register_function(fun.to_string(), extension_info); // arguments let mut arguments: Vec = vec![]; for arg in args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex( + ctx, arg, schema, col_ref_offset, @@ -990,12 +1130,12 @@ pub fn to_substrait_rex( // partition by expressions let partition_by = partition_by .iter() - .map(|e| to_substrait_rex(e, schema, col_ref_offset, extension_info)) + .map(|e| to_substrait_rex(ctx, e, schema, col_ref_offset, extension_info)) .collect::>>()?; // order by expressions let order_by = order_by .iter() - .map(|e| substrait_sort_field(e, schema, extension_info)) + .map(|e| substrait_sort_field(ctx, e, schema, extension_info)) .collect::>>()?; // window frame let bounds = to_substrait_bounds(window_frame)?; @@ -1016,6 +1156,7 @@ pub fn to_substrait_rex( escape_char, case_insensitive, }) => make_substrait_like_expr( + ctx, *case_insensitive, *negated, expr, @@ -1025,7 +1166,131 @@ pub fn to_substrait_rex( col_ref_offset, extension_info, ), - _ => not_impl_err!("Unsupported expression: {expr:?}"), + Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + }) => { + let substrait_expr = + to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; + + let subquery_plan = + to_substrait_rel(subquery.subquery.as_ref(), ctx, extension_info)?; + + let substrait_subquery = Expression { + rex_type: Some(RexType::Subquery(Box::new(Subquery { + subquery_type: Some( + substrait::proto::expression::subquery::SubqueryType::InPredicate( + Box::new(InPredicate { + needles: (vec![substrait_expr]), + haystack: Some(subquery_plan), + }), + ), + ), + }))), + }; + if *negated { + let function_anchor = + _register_function("not".to_string(), extension_info); + + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_subquery)), + }], + output_type: None, + args: vec![], + options: vec![], + })), + }) + } else { + Ok(substrait_subquery) + } + } + Expr::Not(arg) => to_substrait_unary_scalar_fn( + ctx, + "not", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsNull(arg) => to_substrait_unary_scalar_fn( + ctx, + "is_null", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsNotNull(arg) => to_substrait_unary_scalar_fn( + ctx, + "is_not_null", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsTrue(arg) => to_substrait_unary_scalar_fn( + ctx, + "is_true", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsFalse(arg) => to_substrait_unary_scalar_fn( + ctx, + "is_false", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsUnknown(arg) => to_substrait_unary_scalar_fn( + ctx, + "is_unknown", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsNotTrue(arg) => to_substrait_unary_scalar_fn( + ctx, + "is_not_true", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsNotFalse(arg) => to_substrait_unary_scalar_fn( + ctx, + "is_not_false", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsNotUnknown(arg) => to_substrait_unary_scalar_fn( + ctx, + "is_not_unknown", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::Negative(arg) => to_substrait_unary_scalar_fn( + ctx, + "negative", + arg, + schema, + col_ref_offset, + extension_info, + ), + _ => { + not_impl_err!("Unsupported expression: {expr:?}") + } } } @@ -1241,6 +1506,7 @@ fn make_substrait_window_function( #[allow(deprecated)] #[allow(clippy::too_many_arguments)] fn make_substrait_like_expr( + ctx: &SessionContext, ignore_case: bool, negated: bool, expr: &Expr, @@ -1258,8 +1524,8 @@ fn make_substrait_like_expr( } else { _register_function("like".to_string(), extension_info) }; - let expr = to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; - let pattern = to_substrait_rex(pattern, schema, col_ref_offset, extension_info)?; + let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; + let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, extension_info)?; let escape_char = to_substrait_literal(&ScalarValue::Utf8(escape_char.map(|c| c.to_string())))?; let arguments = vec![ @@ -1487,6 +1753,35 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { }) } +/// Util to generate substrait [RexType::ScalarFunction] with one argument +fn to_substrait_unary_scalar_fn( + ctx: &SessionContext, + fn_name: &str, + arg: &Expr, + schema: &DFSchemaRef, + col_ref_offset: usize, + extension_info: &mut ( + Vec, + HashMap, + ), +) -> Result { + let function_anchor = _register_function(fn_name.to_string(), extension_info); + let substrait_expr = + to_substrait_rex(ctx, arg, schema, col_ref_offset, extension_info)?; + + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_expr)), + }], + output_type: None, + options: vec![], + ..Default::default() + })), + }) +} + fn try_to_substrait_null(v: &ScalarValue) -> Result { let default_nullability = r#type::Nullability::Nullable as i32; match v { @@ -1647,7 +1942,33 @@ fn try_to_substrait_null(v: &ScalarValue) -> Result { } } +/// Try to convert an [Expr] to a [FieldReference]. +/// Returns `Err` if the [Expr] is not a [Expr::Column]. +fn try_to_substrait_field_reference( + expr: &Expr, + schema: &DFSchemaRef, +) -> Result { + match expr { + Expr::Column(col) => { + let index = schema.index_of_column(col)?; + Ok(FieldReference { + reference_type: Some(ReferenceType::DirectReference(ReferenceSegment { + reference_type: Some(reference_segment::ReferenceType::StructField( + Box::new(reference_segment::StructField { + field: index as i32, + child: None, + }), + )), + })), + root_type: None, + }) + } + _ => substrait_err!("Expect a `Column` expr, but found {expr:?}"), + } +} + fn substrait_sort_field( + ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, extension_info: &mut ( @@ -1661,7 +1982,7 @@ fn substrait_sort_field( asc, nulls_first, }) => { - let e = to_substrait_rex(expr, schema, 0, extension_info)?; + let e = to_substrait_rex(ctx, expr, schema, 0, extension_info)?; let d = match (asc, nulls_first) { (true, true) => SortDirection::AscNullsFirst, (true, false) => SortDirection::AscNullsLast, diff --git a/datafusion/substrait/src/physical_plan/consumer.rs b/datafusion/substrait/src/physical_plan/consumer.rs index 1dab1f9d5e398..3098dc386e6a3 100644 --- a/datafusion/substrait/src/physical_plan/consumer.rs +++ b/datafusion/substrait/src/physical_plan/consumer.rs @@ -89,6 +89,7 @@ pub async fn from_substrait_rel( location: path.into(), size, e_tag: None, + version: None, }, partition_values: vec![], range: None, @@ -111,7 +112,6 @@ pub async fn from_substrait_rel( limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }; if let Some(MaskExpression { select, .. }) = &read.projection { diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index ca2b4d48c4602..d7327caee43d3 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +use datafusion::arrow::array::ArrayRef; +use datafusion::physical_plan::Accumulator; +use datafusion::scalar::ScalarValue; use datafusion_substrait::logical_plan::{ consumer::from_substrait_plan, producer::to_substrait_plan, }; @@ -28,7 +31,9 @@ use datafusion::error::{DataFusionError, Result}; use datafusion::execution::context::SessionState; use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::runtime_env::RuntimeEnv; -use datafusion::logical_expr::{Extension, LogicalPlan, UserDefinedLogicalNode}; +use datafusion::logical_expr::{ + Extension, LogicalPlan, Repartition, UserDefinedLogicalNode, Volatility, +}; use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST; use datafusion::prelude::*; @@ -314,6 +319,16 @@ async fn simple_scalar_function_substr() -> Result<()> { roundtrip("SELECT * FROM data WHERE a = SUBSTR('datafusion', 0, 3)").await } +#[tokio::test] +async fn simple_scalar_function_is_null() -> Result<()> { + roundtrip("SELECT * FROM data WHERE a IS NULL").await +} + +#[tokio::test] +async fn simple_scalar_function_is_not_null() -> Result<()> { + roundtrip("SELECT * FROM data WHERE a IS NOT NULL").await +} + #[tokio::test] async fn case_without_base_expression() -> Result<()> { roundtrip("SELECT (CASE WHEN a >= 0 THEN 'positive' ELSE 'negative' END) FROM data") @@ -379,6 +394,29 @@ async fn roundtrip_inlist_4() -> Result<()> { roundtrip("SELECT * FROM data WHERE f NOT IN ('a', 'b', 'c', 'd')").await } +#[tokio::test] +async fn roundtrip_inlist_5() -> Result<()> { + // on roundtrip there is an additional projection during TableScan which includes all column of the table, + // using assert_expected_plan here as a workaround + assert_expected_plan( + "SELECT a, f FROM data WHERE (f IN ('a', 'b', 'c') OR a in (SELECT data2.a FROM data2 WHERE f IN ('b', 'c', 'd')))", + "Filter: data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR data.a IN ()\ + \n Subquery:\ + \n Projection: data2.a\ + \n Filter: data2.f IN ([Utf8(\"b\"), Utf8(\"c\"), Utf8(\"d\")])\ + \n TableScan: data2 projection=[a, b, c, d, e, f]\ + \n TableScan: data projection=[a, f], partial_filters=[data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR data.a IN ()]\ + \n Subquery:\ + \n Projection: data2.a\ + \n Filter: data2.f IN ([Utf8(\"b\"), Utf8(\"c\"), Utf8(\"d\")])\ + \n TableScan: data2 projection=[a, b, c, d, e, f]").await +} + +#[tokio::test] +async fn roundtrip_cross_join() -> Result<()> { + roundtrip("SELECT * FROM data CROSS JOIN data2").await +} + #[tokio::test] async fn roundtrip_inner_join() -> Result<()> { roundtrip("SELECT data.a FROM data JOIN data2 ON data.a = data2.a").await @@ -463,6 +501,46 @@ async fn roundtrip_ilike() -> Result<()> { roundtrip("SELECT f FROM data WHERE f ILIKE 'a%b'").await } +#[tokio::test] +async fn roundtrip_not() -> Result<()> { + roundtrip("SELECT * FROM data WHERE NOT d").await +} + +#[tokio::test] +async fn roundtrip_negative() -> Result<()> { + roundtrip("SELECT * FROM data WHERE -a = 1").await +} + +#[tokio::test] +async fn roundtrip_is_true() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS TRUE").await +} + +#[tokio::test] +async fn roundtrip_is_false() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS FALSE").await +} + +#[tokio::test] +async fn roundtrip_is_not_true() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS NOT TRUE").await +} + +#[tokio::test] +async fn roundtrip_is_not_false() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS NOT FALSE").await +} + +#[tokio::test] +async fn roundtrip_is_unknown() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS UNKNOWN").await +} + +#[tokio::test] +async fn roundtrip_is_not_unknown() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS NOT UNKNOWN").await +} + #[tokio::test] async fn roundtrip_union() -> Result<()> { roundtrip("SELECT a, e FROM data UNION SELECT a, e FROM data").await @@ -486,10 +564,11 @@ async fn simple_intersect() -> Result<()> { assert_expected_plan( "SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data2.a FROM data2);", "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ - \n LeftSemi Join: data.a = data2.a\ - \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ - \n TableScan: data projection=[a]\ - \n TableScan: data2 projection=[a]", + \n Projection: \ + \n LeftSemi Join: data.a = data2.a\ + \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ + \n TableScan: data projection=[a]\ + \n TableScan: data2 projection=[a]", ) .await } @@ -499,10 +578,11 @@ async fn simple_intersect_table_reuse() -> Result<()> { assert_expected_plan( "SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data.a FROM data);", "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ - \n LeftSemi Join: data.a = data.a\ - \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ - \n TableScan: data projection=[a]\ - \n TableScan: data projection=[a]", + \n Projection: \ + \n LeftSemi Join: data.a = data.a\ + \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ + \n TableScan: data projection=[a]\ + \n TableScan: data projection=[a]", ) .await } @@ -626,6 +706,90 @@ async fn extension_logical_plan() -> Result<()> { Ok(()) } +#[tokio::test] +async fn roundtrip_aggregate_udf() -> Result<()> { + #[derive(Debug)] + struct Dummy {} + + impl Accumulator for Dummy { + fn state(&self) -> datafusion::error::Result> { + Ok(vec![]) + } + + fn update_batch( + &mut self, + _values: &[ArrayRef], + ) -> datafusion::error::Result<()> { + Ok(()) + } + + fn merge_batch(&mut self, _states: &[ArrayRef]) -> datafusion::error::Result<()> { + Ok(()) + } + + fn evaluate(&self) -> datafusion::error::Result { + Ok(ScalarValue::Float64(None)) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + } + + let dummy_agg = create_udaf( + // the name; used to represent it in plan descriptions and in the registry, to use in SQL. + "dummy_agg", + // the input type; DataFusion guarantees that the first entry of `values` in `update` has this type. + vec![DataType::Int64], + // the return type; DataFusion expects this to match the type returned by `evaluate`. + Arc::new(DataType::Int64), + Volatility::Immutable, + // This is the accumulator factory; DataFusion uses it to create new accumulators. + Arc::new(|_| Ok(Box::new(Dummy {}))), + // This is the description of the state. `state()` must match the types here. + Arc::new(vec![DataType::Float64, DataType::UInt32]), + ); + + let ctx = create_context().await?; + ctx.register_udaf(dummy_agg); + + roundtrip_with_ctx("select dummy_agg(a) from data", ctx).await +} + +#[tokio::test] +async fn roundtrip_repartition_roundrobin() -> Result<()> { + let ctx = create_context().await?; + let scan_plan = ctx.sql("SELECT * FROM data").await?.into_optimized_plan()?; + let plan = LogicalPlan::Repartition(Repartition { + input: Arc::new(scan_plan), + partitioning_scheme: Partitioning::RoundRobinBatch(8), + }); + + let proto = to_substrait_plan(&plan, &ctx)?; + let plan2 = from_substrait_plan(&ctx, &proto).await?; + let plan2 = ctx.state().optimize(&plan2)?; + + assert_eq!(format!("{plan:?}"), format!("{plan2:?}")); + Ok(()) +} + +#[tokio::test] +async fn roundtrip_repartition_hash() -> Result<()> { + let ctx = create_context().await?; + let scan_plan = ctx.sql("SELECT * FROM data").await?.into_optimized_plan()?; + let plan = LogicalPlan::Repartition(Repartition { + input: Arc::new(scan_plan), + partitioning_scheme: Partitioning::Hash(vec![col("data.a")], 8), + }); + + let proto = to_substrait_plan(&plan, &ctx)?; + let plan2 = from_substrait_plan(&ctx, &proto).await?; + let plan2 = ctx.state().optimize(&plan2)?; + + assert_eq!(format!("{plan:?}"), format!("{plan2:?}")); + Ok(()) +} + fn check_post_join_filters(rel: &Rel) -> Result<()> { // search for target_rel and field value in proto match &rel.rel_type { @@ -762,8 +926,7 @@ async fn test_alias(sql_with_alias: &str, sql_no_alias: &str) -> Result<()> { Ok(()) } -async fn roundtrip(sql: &str) -> Result<()> { - let ctx = create_context().await?; +async fn roundtrip_with_ctx(sql: &str, ctx: SessionContext) -> Result<()> { let df = ctx.sql(sql).await?; let plan = df.into_optimized_plan()?; let proto = to_substrait_plan(&plan, &ctx)?; @@ -779,6 +942,10 @@ async fn roundtrip(sql: &str) -> Result<()> { Ok(()) } +async fn roundtrip(sql: &str) -> Result<()> { + roundtrip_with_ctx(sql, create_context().await?).await +} + async fn roundtrip_verify_post_join_filter(sql: &str) -> Result<()> { let ctx = create_context().await?; let df = ctx.sql(sql).await?; diff --git a/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs index b64dd2c138fc9..e5af3f94cc05d 100644 --- a/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs @@ -49,7 +49,6 @@ async fn parquet_exec() -> Result<()> { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }; let parquet_exec: Arc = Arc::new(ParquetExec::new(scan_config, None, None)); diff --git a/datafusion/wasmtest/Cargo.toml b/datafusion/wasmtest/Cargo.toml index 882b02bcc84b6..c5f795d0653ae 100644 --- a/datafusion/wasmtest/Cargo.toml +++ b/datafusion/wasmtest/Cargo.toml @@ -46,5 +46,5 @@ datafusion-sql = { workspace = true } # getrandom must be compiled with js feature getrandom = { version = "0.2.8", features = ["js"] } -parquet = { version = "48.0.0", default-features = false } +parquet = { workspace = true } wasm-bindgen = "0.2.87" diff --git a/dev/changelog/33.0.0.md b/dev/changelog/33.0.0.md index 9acf40705264b..17862a64a9512 100644 --- a/dev/changelog/33.0.0.md +++ b/dev/changelog/33.0.0.md @@ -17,9 +17,9 @@ under the License. --> -## [33.0.0](https://github.com/apache/arrow-datafusion/tree/33.0.0) (2023-11-05) +## [33.0.0](https://github.com/apache/arrow-datafusion/tree/33.0.0) (2023-11-12) -[Full Changelog](https://github.com/apache/arrow-datafusion/compare/31.0.0...32.0.0) +[Full Changelog](https://github.com/apache/arrow-datafusion/compare/32.0.0...33.0.0) **Breaking changes:** @@ -28,6 +28,14 @@ - Add `parquet` feature flag, enabled by default, and make parquet conditional [#7745](https://github.com/apache/arrow-datafusion/pull/7745) (ongchi) - Change input for `to_timestamp` function to be seconds rather than nanoseconds, add `to_timestamp_nanos` [#7844](https://github.com/apache/arrow-datafusion/pull/7844) (comphead) - Percent Decode URL Paths (#8009) [#8012](https://github.com/apache/arrow-datafusion/pull/8012) (tustvold) +- chore: remove panics in datafusion-common::scalar by making more operations return `Result` [#7901](https://github.com/apache/arrow-datafusion/pull/7901) (junjunjd) +- Combine `Expr::Wildcard` and `Wxpr::QualifiedWildcard`, add `wildcard()` expr fn [#8105](https://github.com/apache/arrow-datafusion/pull/8105) (alamb) + +**Performance related:** + +- Add distinct union optimization [#7788](https://github.com/apache/arrow-datafusion/pull/7788) (maruschin) +- Fix join order for TPCH Q17 & Q18 by improving FilterExec statistics [#8126](https://github.com/apache/arrow-datafusion/pull/8126) (andygrove) +- feat: add column statistics into explain [#8112](https://github.com/apache/arrow-datafusion/pull/8112) (NGA-TRAN) **Implemented enhancements:** @@ -36,7 +44,6 @@ - add interval arithmetic for timestamp types [#7758](https://github.com/apache/arrow-datafusion/pull/7758) (mhilton) - Interval Arithmetic NegativeExpr Support [#7804](https://github.com/apache/arrow-datafusion/pull/7804) (berkaysynnada) - Exactness Indicator of Parameters: Precision [#7809](https://github.com/apache/arrow-datafusion/pull/7809) (berkaysynnada) -- Add distinct union optimization [#7788](https://github.com/apache/arrow-datafusion/pull/7788) (maruschin) - Implement GetIndexedField for map-typed columns [#7825](https://github.com/apache/arrow-datafusion/pull/7825) (swgillespie) - Fix precision loss when coercing date_part utf8 argument [#7846](https://github.com/apache/arrow-datafusion/pull/7846) (Dandandan) - Support `Binary`/`LargeBinary` --> `Utf8`/`LargeUtf8` in ilike and string functions [#7840](https://github.com/apache/arrow-datafusion/pull/7840) (alamb) @@ -49,6 +56,10 @@ - feat: Use bloom filter when reading parquet to skip row groups [#7821](https://github.com/apache/arrow-datafusion/pull/7821) (hengfeiyang) - Support Partitioning Data by Dictionary Encoded String Array Types [#7896](https://github.com/apache/arrow-datafusion/pull/7896) (devinjdangelo) - Read only enough bytes to infer Arrow IPC file schema via stream [#7962](https://github.com/apache/arrow-datafusion/pull/7962) (Jefffrey) +- feat: Support determining extensions from names like `foo.parquet.snappy` as well as `foo.parquet` [#7972](https://github.com/apache/arrow-datafusion/pull/7972) (Weijun-H) +- feat: Protobuf serde for Json file sink [#8062](https://github.com/apache/arrow-datafusion/pull/8062) (Jefffrey) +- feat: support target table alias in update statement [#8080](https://github.com/apache/arrow-datafusion/pull/8080) (jonahgao) +- feat: support UDAF in substrait producer/consumer [#8119](https://github.com/apache/arrow-datafusion/pull/8119) (waynexia) **Fixed bugs:** @@ -57,6 +68,8 @@ - fix: generate logical plan for `UPDATE SET FROM` statement [#7984](https://github.com/apache/arrow-datafusion/pull/7984) (jonahgao) - fix: single_distinct_aggretation_to_group_by fail [#7997](https://github.com/apache/arrow-datafusion/pull/7997) (haohuaijin) - fix: clippy warnings from nightly rust 1.75 [#8025](https://github.com/apache/arrow-datafusion/pull/8025) (waynexia) +- fix: DataFusion suggests invalid functions [#8083](https://github.com/apache/arrow-datafusion/pull/8083) (jonahgao) +- fix: add encode/decode to protobuf encoding [#8089](https://github.com/apache/arrow-datafusion/pull/8089) (Syleechan) **Documentation updates:** @@ -69,6 +82,10 @@ - Minor: Improve documentation for Filter Pushdown [#8023](https://github.com/apache/arrow-datafusion/pull/8023) (alamb) - Minor: Improve `ExecutionPlan` documentation [#8019](https://github.com/apache/arrow-datafusion/pull/8019) (alamb) - Improve comments for `PartitionSearchMode` struct [#8047](https://github.com/apache/arrow-datafusion/pull/8047) (ozankabak) +- Prepare 33.0.0 Release [#8057](https://github.com/apache/arrow-datafusion/pull/8057) (andygrove) +- Improve documentation for calculate_prune_length method in `SymmetricHashJoin` [#8125](https://github.com/apache/arrow-datafusion/pull/8125) (Asura7969) +- docs: show creation of DFSchema [#8132](https://github.com/apache/arrow-datafusion/pull/8132) (wjones127) +- Improve documentation site to make it easier to find communication on Slack/Discord [#8138](https://github.com/apache/arrow-datafusion/pull/8138) (alamb) **Merged pull requests:** @@ -226,3 +243,50 @@ - General approach for Array replace [#8050](https://github.com/apache/arrow-datafusion/pull/8050) (jayzhan211) - Minor: Remove the irrelevant note from the Expression API doc [#8053](https://github.com/apache/arrow-datafusion/pull/8053) (ongchi) - Minor: Add more documentation about Partitioning [#8022](https://github.com/apache/arrow-datafusion/pull/8022) (alamb) +- Minor: improve documentation for IsNotNull, DISTINCT, etc [#8052](https://github.com/apache/arrow-datafusion/pull/8052) (alamb) +- Prepare 33.0.0 Release [#8057](https://github.com/apache/arrow-datafusion/pull/8057) (andygrove) +- Minor: improve error message by adding types to message [#8065](https://github.com/apache/arrow-datafusion/pull/8065) (alamb) +- Minor: Remove redundant BuiltinScalarFunction::supports_zero_argument() [#8059](https://github.com/apache/arrow-datafusion/pull/8059) (2010YOUY01) +- Add example to ci [#8060](https://github.com/apache/arrow-datafusion/pull/8060) (smallzhongfeng) +- Update substrait requirement from 0.18.0 to 0.19.0 [#8076](https://github.com/apache/arrow-datafusion/pull/8076) (dependabot[bot]) +- Fix incorrect results in COUNT(\*) queries with LIMIT [#8049](https://github.com/apache/arrow-datafusion/pull/8049) (msirek) +- feat: Support determining extensions from names like `foo.parquet.snappy` as well as `foo.parquet` [#7972](https://github.com/apache/arrow-datafusion/pull/7972) (Weijun-H) +- Use FairSpillPool for TaskContext with spillable config [#8072](https://github.com/apache/arrow-datafusion/pull/8072) (viirya) +- Minor: Improve HashJoinStream docstrings [#8070](https://github.com/apache/arrow-datafusion/pull/8070) (alamb) +- Fixing broken link [#8085](https://github.com/apache/arrow-datafusion/pull/8085) (edmondop) +- fix: DataFusion suggests invalid functions [#8083](https://github.com/apache/arrow-datafusion/pull/8083) (jonahgao) +- Replace macro with function for `array_repeat` [#8071](https://github.com/apache/arrow-datafusion/pull/8071) (jayzhan211) +- Minor: remove unnecessary projection in `single_distinct_to_group_by` rule [#8061](https://github.com/apache/arrow-datafusion/pull/8061) (haohuaijin) +- minor: Remove duplicate version numbers for arrow, object_store, and parquet dependencies [#8095](https://github.com/apache/arrow-datafusion/pull/8095) (andygrove) +- fix: add encode/decode to protobuf encoding [#8089](https://github.com/apache/arrow-datafusion/pull/8089) (Syleechan) +- feat: Protobuf serde for Json file sink [#8062](https://github.com/apache/arrow-datafusion/pull/8062) (Jefffrey) +- Minor: use `Expr::alias` in a few places to make the code more concise [#8097](https://github.com/apache/arrow-datafusion/pull/8097) (alamb) +- Minor: Cleanup BuiltinScalarFunction::return_type() [#8088](https://github.com/apache/arrow-datafusion/pull/8088) (2010YOUY01) +- Update sqllogictest requirement from 0.17.0 to 0.18.0 [#8102](https://github.com/apache/arrow-datafusion/pull/8102) (dependabot[bot]) +- Projection Pushdown in PhysicalPlan [#8073](https://github.com/apache/arrow-datafusion/pull/8073) (berkaysynnada) +- Push limit into aggregation for DISTINCT ... LIMIT queries [#8038](https://github.com/apache/arrow-datafusion/pull/8038) (msirek) +- Bug-fix in Filter and Limit statistics [#8094](https://github.com/apache/arrow-datafusion/pull/8094) (berkaysynnada) +- feat: support target table alias in update statement [#8080](https://github.com/apache/arrow-datafusion/pull/8080) (jonahgao) +- Minor: Simlify downcast functions in cast.rs. [#8103](https://github.com/apache/arrow-datafusion/pull/8103) (Weijun-H) +- Fix ArrayAgg schema mismatch issue [#8055](https://github.com/apache/arrow-datafusion/pull/8055) (jayzhan211) +- Minor: Support `nulls` in `array_replace`, avoid a copy [#8054](https://github.com/apache/arrow-datafusion/pull/8054) (alamb) +- Minor: Improve the document format of JoinHashMap [#8090](https://github.com/apache/arrow-datafusion/pull/8090) (Asura7969) +- Simplify ProjectionPushdown and make it more general [#8109](https://github.com/apache/arrow-datafusion/pull/8109) (alamb) +- Minor: clean up the code regarding clippy [#8122](https://github.com/apache/arrow-datafusion/pull/8122) (Weijun-H) +- Support remaining functions in protobuf serialization, add `expr_fn` for `StructFunction` [#8100](https://github.com/apache/arrow-datafusion/pull/8100) (JacobOgle) +- Minor: Cleanup BuiltinScalarFunction's phys-expr creation [#8114](https://github.com/apache/arrow-datafusion/pull/8114) (2010YOUY01) +- rewrite `array_append/array_prepend` to remove deplicate codes [#8108](https://github.com/apache/arrow-datafusion/pull/8108) (Veeupup) +- Implementation of `array_intersect` [#8081](https://github.com/apache/arrow-datafusion/pull/8081) (Veeupup) +- Minor: fix ci break [#8136](https://github.com/apache/arrow-datafusion/pull/8136) (haohuaijin) +- Improve documentation for calculate_prune_length method in `SymmetricHashJoin` [#8125](https://github.com/apache/arrow-datafusion/pull/8125) (Asura7969) +- Minor: remove duplicated `array_replace` tests [#8066](https://github.com/apache/arrow-datafusion/pull/8066) (alamb) +- Minor: Fix temporary files created but not deleted during testing [#8115](https://github.com/apache/arrow-datafusion/pull/8115) (2010YOUY01) +- chore: remove panics in datafusion-common::scalar by making more operations return `Result` [#7901](https://github.com/apache/arrow-datafusion/pull/7901) (junjunjd) +- Fix join order for TPCH Q17 & Q18 by improving FilterExec statistics [#8126](https://github.com/apache/arrow-datafusion/pull/8126) (andygrove) +- Fix: Do not try and preserve order when there is no order to preserve in RepartitionExec [#8127](https://github.com/apache/arrow-datafusion/pull/8127) (alamb) +- feat: add column statistics into explain [#8112](https://github.com/apache/arrow-datafusion/pull/8112) (NGA-TRAN) +- Add subtrait support for `IS NULL` and `IS NOT NULL` [#8093](https://github.com/apache/arrow-datafusion/pull/8093) (tgujar) +- Combine `Expr::Wildcard` and `Wxpr::QualifiedWildcard`, add `wildcard()` expr fn [#8105](https://github.com/apache/arrow-datafusion/pull/8105) (alamb) +- docs: show creation of DFSchema [#8132](https://github.com/apache/arrow-datafusion/pull/8132) (wjones127) +- feat: support UDAF in substrait producer/consumer [#8119](https://github.com/apache/arrow-datafusion/pull/8119) (waynexia) +- Improve documentation site to make it easier to find communication on Slack/Discord [#8138](https://github.com/apache/arrow-datafusion/pull/8138) (alamb) diff --git a/dev/changelog/34.0.0.md b/dev/changelog/34.0.0.md new file mode 100644 index 0000000000000..c5526f60531c7 --- /dev/null +++ b/dev/changelog/34.0.0.md @@ -0,0 +1,273 @@ + + +## [34.0.0](https://github.com/apache/arrow-datafusion/tree/34.0.0) (2023-12-11) + +[Full Changelog](https://github.com/apache/arrow-datafusion/compare/33.0.0...34.0.0) + +**Breaking changes:** + +- Implement `DISTINCT ON` from Postgres [#7981](https://github.com/apache/arrow-datafusion/pull/7981) (gruuya) +- Encapsulate `EquivalenceClass` into a struct [#8034](https://github.com/apache/arrow-datafusion/pull/8034) (alamb) +- Make fields of `ScalarUDF` , `AggregateUDF` and `WindowUDF` non `pub` [#8079](https://github.com/apache/arrow-datafusion/pull/8079) (alamb) +- Implement StreamTable and StreamTableProvider (#7994) [#8021](https://github.com/apache/arrow-datafusion/pull/8021) (tustvold) +- feat: make FixedSizeList scalar also an ArrayRef [#8221](https://github.com/apache/arrow-datafusion/pull/8221) (wjones127) +- Remove FileWriterMode and ListingTableInsertMode (#7994) [#8017](https://github.com/apache/arrow-datafusion/pull/8017) (tustvold) +- Refactor: Unify `Expr::ScalarFunction` and `Expr::ScalarUDF`, introduce unresolved functions by name [#8258](https://github.com/apache/arrow-datafusion/pull/8258) (2010YOUY01) +- Refactor aggregate function handling [#8358](https://github.com/apache/arrow-datafusion/pull/8358) (Weijun-H) +- Move `PartitionSearchMode` into datafusion_physical_plan, rename to `InputOrderMode` [#8364](https://github.com/apache/arrow-datafusion/pull/8364) (alamb) +- Split `EmptyExec` into `PlaceholderRowExec` [#8446](https://github.com/apache/arrow-datafusion/pull/8446) (razeghi71) + +**Implemented enhancements:** + +- feat: show statistics in explain verbose [#8113](https://github.com/apache/arrow-datafusion/pull/8113) (NGA-TRAN) +- feat:implement postgres style 'overlay' string function [#8117](https://github.com/apache/arrow-datafusion/pull/8117) (Syleechan) +- feat: fill missing values with NULLs while inserting [#8146](https://github.com/apache/arrow-datafusion/pull/8146) (jonahgao) +- feat: to_array_of_size for ScalarValue::FixedSizeList [#8225](https://github.com/apache/arrow-datafusion/pull/8225) (wjones127) +- feat:implement calcite style 'levenshtein' string function [#8168](https://github.com/apache/arrow-datafusion/pull/8168) (Syleechan) +- feat: roundtrip FixedSizeList Scalar to protobuf [#8239](https://github.com/apache/arrow-datafusion/pull/8239) (wjones127) +- feat: impl the basic `string_agg` function [#8148](https://github.com/apache/arrow-datafusion/pull/8148) (haohuaijin) +- feat: support simplifying BinaryExpr with arbitrary guarantees in GuaranteeRewriter [#8256](https://github.com/apache/arrow-datafusion/pull/8256) (wjones127) +- feat: support customizing column default values for inserting [#8283](https://github.com/apache/arrow-datafusion/pull/8283) (jonahgao) +- feat:implement sql style 'substr_index' string function [#8272](https://github.com/apache/arrow-datafusion/pull/8272) (Syleechan) +- feat:implement sql style 'find_in_set' string function [#8328](https://github.com/apache/arrow-datafusion/pull/8328) (Syleechan) +- feat: support `LargeList` in `array_empty` [#8321](https://github.com/apache/arrow-datafusion/pull/8321) (Weijun-H) +- feat: support `LargeList` in `make_array` and `array_length` [#8121](https://github.com/apache/arrow-datafusion/pull/8121) (Weijun-H) +- feat: ScalarValue from String [#8411](https://github.com/apache/arrow-datafusion/pull/8411) (QuenKar) +- feat: support `LargeList` for `array_has`, `array_has_all` and `array_has_any` [#8322](https://github.com/apache/arrow-datafusion/pull/8322) (Weijun-H) +- feat: customize column default values for external tables [#8415](https://github.com/apache/arrow-datafusion/pull/8415) (jonahgao) +- feat: Support `array_sort`(`list_sort`) [#8279](https://github.com/apache/arrow-datafusion/pull/8279) (Asura7969) +- feat: support `InterleaveExecNode` in the proto [#8460](https://github.com/apache/arrow-datafusion/pull/8460) (liukun4515) +- feat: improve string statistics display in datafusion-cli `parquet_metadata` function [#8535](https://github.com/apache/arrow-datafusion/pull/8535) (asimsedhain) + +**Fixed bugs:** + +- fix: Timestamp with timezone not considered `join on` [#8150](https://github.com/apache/arrow-datafusion/pull/8150) (ACking-you) +- fix: wrong result of range function [#8313](https://github.com/apache/arrow-datafusion/pull/8313) (smallzhongfeng) +- fix: make `ntile` work in some corner cases [#8371](https://github.com/apache/arrow-datafusion/pull/8371) (haohuaijin) +- fix: Changed labeler.yml to latest format [#8431](https://github.com/apache/arrow-datafusion/pull/8431) (viirya) +- fix: Literal in `ORDER BY` window definition should not be an ordinal referring to relation column [#8419](https://github.com/apache/arrow-datafusion/pull/8419) (viirya) +- fix: ORDER BY window definition should work on null literal [#8444](https://github.com/apache/arrow-datafusion/pull/8444) (viirya) +- fix: RANGE frame for corner cases with empty ORDER BY clause should be treated as constant sort [#8445](https://github.com/apache/arrow-datafusion/pull/8445) (viirya) +- fix: don't unifies projection if expr is non-trival [#8454](https://github.com/apache/arrow-datafusion/pull/8454) (haohuaijin) +- fix: support uppercase when parsing `Interval` [#8478](https://github.com/apache/arrow-datafusion/pull/8478) (QuenKar) +- fix: incorrect set preserve_partitioning in SortExec [#8485](https://github.com/apache/arrow-datafusion/pull/8485) (haohuaijin) +- fix: Pull stats in `IdentVisitor`/`GraphvizVisitor` only when requested [#8514](https://github.com/apache/arrow-datafusion/pull/8514) (vrongmeal) +- fix: volatile expressions should not be target of common subexpt elimination [#8520](https://github.com/apache/arrow-datafusion/pull/8520) (viirya) + +**Documentation updates:** + +- Library Guide: Add Using the DataFrame API [#8319](https://github.com/apache/arrow-datafusion/pull/8319) (Veeupup) +- Minor: Add installation link to README.md [#8389](https://github.com/apache/arrow-datafusion/pull/8389) (Weijun-H) +- Prepare version 34.0.0 [#8508](https://github.com/apache/arrow-datafusion/pull/8508) (andygrove) + +**Merged pull requests:** + +- Fix typo in partitioning.rs [#8134](https://github.com/apache/arrow-datafusion/pull/8134) (lewiszlw) +- Implement `DISTINCT ON` from Postgres [#7981](https://github.com/apache/arrow-datafusion/pull/7981) (gruuya) +- Prepare 33.0.0-rc2 [#8144](https://github.com/apache/arrow-datafusion/pull/8144) (andygrove) +- Avoid concat in `array_append` [#8137](https://github.com/apache/arrow-datafusion/pull/8137) (jayzhan211) +- Replace macro with function for array_remove [#8106](https://github.com/apache/arrow-datafusion/pull/8106) (jayzhan211) +- Implement `array_union` [#7897](https://github.com/apache/arrow-datafusion/pull/7897) (edmondop) +- Minor: Document `ExecutionPlan::equivalence_properties` more thoroughly [#8128](https://github.com/apache/arrow-datafusion/pull/8128) (alamb) +- feat: show statistics in explain verbose [#8113](https://github.com/apache/arrow-datafusion/pull/8113) (NGA-TRAN) +- feat:implement postgres style 'overlay' string function [#8117](https://github.com/apache/arrow-datafusion/pull/8117) (Syleechan) +- Minor: Encapsulate `LeftJoinData` into a struct (rather than anonymous enum) and add comments [#8153](https://github.com/apache/arrow-datafusion/pull/8153) (alamb) +- Update sqllogictest requirement from 0.18.0 to 0.19.0 [#8163](https://github.com/apache/arrow-datafusion/pull/8163) (dependabot[bot]) +- feat: fill missing values with NULLs while inserting [#8146](https://github.com/apache/arrow-datafusion/pull/8146) (jonahgao) +- Introduce return type for aggregate sum [#8141](https://github.com/apache/arrow-datafusion/pull/8141) (jayzhan211) +- implement range/generate_series func [#8140](https://github.com/apache/arrow-datafusion/pull/8140) (Veeupup) +- Encapsulate `EquivalenceClass` into a struct [#8034](https://github.com/apache/arrow-datafusion/pull/8034) (alamb) +- Revert "Minor: remove unnecessary projection in `single_distinct_to_g… [#8176](https://github.com/apache/arrow-datafusion/pull/8176) (NGA-TRAN) +- Preserve all of the valid orderings during merging. [#8169](https://github.com/apache/arrow-datafusion/pull/8169) (mustafasrepo) +- Make fields of `ScalarUDF` , `AggregateUDF` and `WindowUDF` non `pub` [#8079](https://github.com/apache/arrow-datafusion/pull/8079) (alamb) +- Fix logical conflicts [#8187](https://github.com/apache/arrow-datafusion/pull/8187) (tustvold) +- Minor: Update JoinHashMap comment example to make it clearer [#8154](https://github.com/apache/arrow-datafusion/pull/8154) (alamb) +- Implement StreamTable and StreamTableProvider (#7994) [#8021](https://github.com/apache/arrow-datafusion/pull/8021) (tustvold) +- [MINOR]: Remove unused Results [#8189](https://github.com/apache/arrow-datafusion/pull/8189) (mustafasrepo) +- Minor: clean up the code based on clippy [#8179](https://github.com/apache/arrow-datafusion/pull/8179) (Weijun-H) +- Minor: simplify filter statistics code [#8174](https://github.com/apache/arrow-datafusion/pull/8174) (alamb) +- Replace macro with function for `array_position` and `array_positions` [#8170](https://github.com/apache/arrow-datafusion/pull/8170) (jayzhan211) +- Add Library Guide for User Defined Functions: Window/Aggregate [#8171](https://github.com/apache/arrow-datafusion/pull/8171) (Veeupup) +- Add more stream docs [#8192](https://github.com/apache/arrow-datafusion/pull/8192) (tustvold) +- Implement func `array_pop_front` [#8142](https://github.com/apache/arrow-datafusion/pull/8142) (Veeupup) +- Moving arrow_files SQL tests to sqllogictest [#8217](https://github.com/apache/arrow-datafusion/pull/8217) (edmondop) +- fix regression in the use of name in ProjectionPushdown [#8219](https://github.com/apache/arrow-datafusion/pull/8219) (alamb) +- [MINOR]: Fix column indices in the planning tests [#8191](https://github.com/apache/arrow-datafusion/pull/8191) (mustafasrepo) +- Remove unnecessary reassignment [#8232](https://github.com/apache/arrow-datafusion/pull/8232) (qrilka) +- Update itertools requirement from 0.11 to 0.12 [#8233](https://github.com/apache/arrow-datafusion/pull/8233) (crepererum) +- Port tests in subqueries.rs to sqllogictest [#8231](https://github.com/apache/arrow-datafusion/pull/8231) (PsiACE) +- feat: make FixedSizeList scalar also an ArrayRef [#8221](https://github.com/apache/arrow-datafusion/pull/8221) (wjones127) +- Add versions to datafusion dependencies [#8238](https://github.com/apache/arrow-datafusion/pull/8238) (andygrove) +- feat: to_array_of_size for ScalarValue::FixedSizeList [#8225](https://github.com/apache/arrow-datafusion/pull/8225) (wjones127) +- feat:implement calcite style 'levenshtein' string function [#8168](https://github.com/apache/arrow-datafusion/pull/8168) (Syleechan) +- feat: roundtrip FixedSizeList Scalar to protobuf [#8239](https://github.com/apache/arrow-datafusion/pull/8239) (wjones127) +- Update prost-build requirement from =0.12.1 to =0.12.2 [#8244](https://github.com/apache/arrow-datafusion/pull/8244) (dependabot[bot]) +- Minor: Port tests in `displayable.rs` to sqllogictest [#8246](https://github.com/apache/arrow-datafusion/pull/8246) (Weijun-H) +- Minor: add `with_estimated_selectivity ` to Precision [#8177](https://github.com/apache/arrow-datafusion/pull/8177) (alamb) +- fix: Timestamp with timezone not considered `join on` [#8150](https://github.com/apache/arrow-datafusion/pull/8150) (ACking-you) +- Replace macro in array_array to remove duplicate codes [#8252](https://github.com/apache/arrow-datafusion/pull/8252) (Veeupup) +- Port tests in projection.rs to sqllogictest [#8240](https://github.com/apache/arrow-datafusion/pull/8240) (PsiACE) +- Introduce `array_except` function [#8135](https://github.com/apache/arrow-datafusion/pull/8135) (jayzhan211) +- Port tests in `describe.rs` to sqllogictest [#8242](https://github.com/apache/arrow-datafusion/pull/8242) (Asura7969) +- Remove FileWriterMode and ListingTableInsertMode (#7994) [#8017](https://github.com/apache/arrow-datafusion/pull/8017) (tustvold) +- Minor: clean up the code based on Clippy [#8257](https://github.com/apache/arrow-datafusion/pull/8257) (Weijun-H) +- Update arrow 49.0.0 and object_store 0.8.0 [#8029](https://github.com/apache/arrow-datafusion/pull/8029) (tustvold) +- feat: impl the basic `string_agg` function [#8148](https://github.com/apache/arrow-datafusion/pull/8148) (haohuaijin) +- Minor: Make schema of grouping set columns nullable [#8248](https://github.com/apache/arrow-datafusion/pull/8248) (markusa380) +- feat: support simplifying BinaryExpr with arbitrary guarantees in GuaranteeRewriter [#8256](https://github.com/apache/arrow-datafusion/pull/8256) (wjones127) +- Making stream joins extensible: A new Trait implementation for SHJ [#8234](https://github.com/apache/arrow-datafusion/pull/8234) (metesynnada) +- Don't Canonicalize Filesystem Paths in ListingTableUrl / support new external tables for files that do not (yet) exist [#8014](https://github.com/apache/arrow-datafusion/pull/8014) (tustvold) +- Minor: Add sql level test for inserting into non-existent directory [#8278](https://github.com/apache/arrow-datafusion/pull/8278) (alamb) +- Replace `array_has/array_has_all/array_has_any` macro to remove duplicate code [#8263](https://github.com/apache/arrow-datafusion/pull/8263) (Veeupup) +- Fix bug in field level metadata matching code [#8286](https://github.com/apache/arrow-datafusion/pull/8286) (alamb) +- Refactor Interval Arithmetic Updates [#8276](https://github.com/apache/arrow-datafusion/pull/8276) (berkaysynnada) +- [MINOR]: Remove unecessary orderings from the final plan [#8289](https://github.com/apache/arrow-datafusion/pull/8289) (mustafasrepo) +- consistent logical & physical `NTILE` return types [#8270](https://github.com/apache/arrow-datafusion/pull/8270) (korowa) +- make `array_union`/`array_except`/`array_intersect` handle empty/null arrays rightly [#8269](https://github.com/apache/arrow-datafusion/pull/8269) (Veeupup) +- improve file path validation when reading parquet [#8267](https://github.com/apache/arrow-datafusion/pull/8267) (Weijun-H) +- [Benchmarks] Make `partitions` default to number of cores instead of 2 [#8292](https://github.com/apache/arrow-datafusion/pull/8292) (andygrove) +- Update prost-build requirement from =0.12.2 to =0.12.3 [#8298](https://github.com/apache/arrow-datafusion/pull/8298) (dependabot[bot]) +- Fix Display for List [#8261](https://github.com/apache/arrow-datafusion/pull/8261) (jayzhan211) +- feat: support customizing column default values for inserting [#8283](https://github.com/apache/arrow-datafusion/pull/8283) (jonahgao) +- support `LargeList` for `arrow_cast`, support `ScalarValue::LargeList` [#8290](https://github.com/apache/arrow-datafusion/pull/8290) (Weijun-H) +- Minor: remove useless clone based on Clippy [#8300](https://github.com/apache/arrow-datafusion/pull/8300) (Weijun-H) +- Calculate ordering equivalence for expressions (rather than just columns) [#8281](https://github.com/apache/arrow-datafusion/pull/8281) (mustafasrepo) +- Fix sqllogictests link in contributor-guide/index.md [#8314](https://github.com/apache/arrow-datafusion/pull/8314) (qrilka) +- Refactor: Unify `Expr::ScalarFunction` and `Expr::ScalarUDF`, introduce unresolved functions by name [#8258](https://github.com/apache/arrow-datafusion/pull/8258) (2010YOUY01) +- Support no distinct aggregate sum/min/max in `single_distinct_to_group_by` rule [#8266](https://github.com/apache/arrow-datafusion/pull/8266) (haohuaijin) +- feat:implement sql style 'substr_index' string function [#8272](https://github.com/apache/arrow-datafusion/pull/8272) (Syleechan) +- Fixing issues with for timestamp literals [#8193](https://github.com/apache/arrow-datafusion/pull/8193) (comphead) +- Projection Pushdown over StreamingTableExec [#8299](https://github.com/apache/arrow-datafusion/pull/8299) (berkaysynnada) +- minor: fix documentation [#8323](https://github.com/apache/arrow-datafusion/pull/8323) (comphead) +- fix: wrong result of range function [#8313](https://github.com/apache/arrow-datafusion/pull/8313) (smallzhongfeng) +- Minor: rename parquet.rs to parquet/mod.rs [#8301](https://github.com/apache/arrow-datafusion/pull/8301) (alamb) +- refactor: output ordering [#8304](https://github.com/apache/arrow-datafusion/pull/8304) (QuenKar) +- Update substrait requirement from 0.19.0 to 0.20.0 [#8339](https://github.com/apache/arrow-datafusion/pull/8339) (dependabot[bot]) +- Port tests in `aggregates.rs` to sqllogictest [#8316](https://github.com/apache/arrow-datafusion/pull/8316) (edmondop) +- Library Guide: Add Using the DataFrame API [#8319](https://github.com/apache/arrow-datafusion/pull/8319) (Veeupup) +- Port tests in limit.rs to sqllogictest [#8315](https://github.com/apache/arrow-datafusion/pull/8315) (zhangxffff) +- move array function unit_tests to sqllogictest [#8332](https://github.com/apache/arrow-datafusion/pull/8332) (Veeupup) +- NTH_VALUE reverse support [#8327](https://github.com/apache/arrow-datafusion/pull/8327) (mustafasrepo) +- Optimize Projections during Logical Plan [#8340](https://github.com/apache/arrow-datafusion/pull/8340) (mustafasrepo) +- [MINOR]: Move merge projections tests to under optimize projections [#8352](https://github.com/apache/arrow-datafusion/pull/8352) (mustafasrepo) +- Add `quote` and `escape` attributes to create csv external table [#8351](https://github.com/apache/arrow-datafusion/pull/8351) (Asura7969) +- Minor: Add DataFrame test [#8341](https://github.com/apache/arrow-datafusion/pull/8341) (alamb) +- Minor: clean up the code based on Clippy [#8359](https://github.com/apache/arrow-datafusion/pull/8359) (Weijun-H) +- Minor: Make it easier to work with Expr::ScalarFunction [#8350](https://github.com/apache/arrow-datafusion/pull/8350) (alamb) +- Minor: Move some datafusion-optimizer::utils down to datafusion-expr::utils [#8354](https://github.com/apache/arrow-datafusion/pull/8354) (Jesse-Bakker) +- Minor: Make `BuiltInScalarFunction::alias` a method [#8349](https://github.com/apache/arrow-datafusion/pull/8349) (alamb) +- Extract parquet statistics to its own module, add tests [#8294](https://github.com/apache/arrow-datafusion/pull/8294) (alamb) +- feat:implement sql style 'find_in_set' string function [#8328](https://github.com/apache/arrow-datafusion/pull/8328) (Syleechan) +- Support LargeUtf8 to Temporal Coercion [#8357](https://github.com/apache/arrow-datafusion/pull/8357) (jayzhan211) +- Refactor aggregate function handling [#8358](https://github.com/apache/arrow-datafusion/pull/8358) (Weijun-H) +- Implement Aliases for ScalarUDF [#8360](https://github.com/apache/arrow-datafusion/pull/8360) (Veeupup) +- Minor: Remove unnecessary name field in `ScalarFunctionDefintion` [#8365](https://github.com/apache/arrow-datafusion/pull/8365) (alamb) +- feat: support `LargeList` in `array_empty` [#8321](https://github.com/apache/arrow-datafusion/pull/8321) (Weijun-H) +- Double type argument for to_timestamp function [#8159](https://github.com/apache/arrow-datafusion/pull/8159) (spaydar) +- Support User Defined Table Function [#8306](https://github.com/apache/arrow-datafusion/pull/8306) (Veeupup) +- Document timestamp input limits [#8369](https://github.com/apache/arrow-datafusion/pull/8369) (comphead) +- fix: make `ntile` work in some corner cases [#8371](https://github.com/apache/arrow-datafusion/pull/8371) (haohuaijin) +- Minor: Refactor array_union function to use a generic union_arrays function [#8381](https://github.com/apache/arrow-datafusion/pull/8381) (Weijun-H) +- Minor: Refactor function argument handling in `ScalarFunctionDefinition` [#8387](https://github.com/apache/arrow-datafusion/pull/8387) (Weijun-H) +- Materialize dictionaries in group keys [#8291](https://github.com/apache/arrow-datafusion/pull/8291) (qrilka) +- Rewrite `array_ndims` to fix List(Null) handling [#8320](https://github.com/apache/arrow-datafusion/pull/8320) (jayzhan211) +- Docs: Improve the documentation on `ScalarValue` [#8378](https://github.com/apache/arrow-datafusion/pull/8378) (alamb) +- Avoid concat for `array_replace` [#8337](https://github.com/apache/arrow-datafusion/pull/8337) (jayzhan211) +- add a summary table to benchmark compare output [#8399](https://github.com/apache/arrow-datafusion/pull/8399) (razeghi71) +- Refactors on TreeNode Implementations [#8395](https://github.com/apache/arrow-datafusion/pull/8395) (berkaysynnada) +- feat: support `LargeList` in `make_array` and `array_length` [#8121](https://github.com/apache/arrow-datafusion/pull/8121) (Weijun-H) +- remove `unalias` TableScan filters when create Physical Filter [#8404](https://github.com/apache/arrow-datafusion/pull/8404) (jackwener) +- Update custom-table-providers.md [#8409](https://github.com/apache/arrow-datafusion/pull/8409) (nickpoorman) +- fix transforming `LogicalPlan::Explain` use `TreeNode::transform` fails [#8400](https://github.com/apache/arrow-datafusion/pull/8400) (haohuaijin) +- Docs: Fix `array_except` documentation example error [#8407](https://github.com/apache/arrow-datafusion/pull/8407) (Asura7969) +- Support named query parameters [#8384](https://github.com/apache/arrow-datafusion/pull/8384) (Asura7969) +- Minor: Add installation link to README.md [#8389](https://github.com/apache/arrow-datafusion/pull/8389) (Weijun-H) +- Update code comment for the cases of regularized RANGE frame and add tests for ORDER BY cases with RANGE frame [#8410](https://github.com/apache/arrow-datafusion/pull/8410) (viirya) +- Minor: Add example with parameters to LogicalPlan [#8418](https://github.com/apache/arrow-datafusion/pull/8418) (alamb) +- Minor: Improve `PruningPredicate` documentation [#8394](https://github.com/apache/arrow-datafusion/pull/8394) (alamb) +- feat: ScalarValue from String [#8411](https://github.com/apache/arrow-datafusion/pull/8411) (QuenKar) +- Bump actions/labeler from 4.3.0 to 5.0.0 [#8422](https://github.com/apache/arrow-datafusion/pull/8422) (dependabot[bot]) +- Update sqlparser requirement from 0.39.0 to 0.40.0 [#8338](https://github.com/apache/arrow-datafusion/pull/8338) (dependabot[bot]) +- feat: support `LargeList` for `array_has`, `array_has_all` and `array_has_any` [#8322](https://github.com/apache/arrow-datafusion/pull/8322) (Weijun-H) +- Union `schema` can't be a subset of the child schema [#8408](https://github.com/apache/arrow-datafusion/pull/8408) (jackwener) +- Move `PartitionSearchMode` into datafusion_physical_plan, rename to `InputOrderMode` [#8364](https://github.com/apache/arrow-datafusion/pull/8364) (alamb) +- Make filter selectivity for statistics configurable [#8243](https://github.com/apache/arrow-datafusion/pull/8243) (edmondop) +- fix: Changed labeler.yml to latest format [#8431](https://github.com/apache/arrow-datafusion/pull/8431) (viirya) +- Minor: Use `ScalarValue::from` impl for strings [#8429](https://github.com/apache/arrow-datafusion/pull/8429) (alamb) +- Support crossjoin in substrait. [#8427](https://github.com/apache/arrow-datafusion/pull/8427) (my-vegetable-has-exploded) +- Fix ambiguous reference when aliasing in combination with `ORDER BY` [#8425](https://github.com/apache/arrow-datafusion/pull/8425) (Asura7969) +- Minor: convert marcro `list-slice` and `slice` to function [#8424](https://github.com/apache/arrow-datafusion/pull/8424) (Weijun-H) +- Remove macro in iter_to_array for List [#8414](https://github.com/apache/arrow-datafusion/pull/8414) (jayzhan211) +- fix: Literal in `ORDER BY` window definition should not be an ordinal referring to relation column [#8419](https://github.com/apache/arrow-datafusion/pull/8419) (viirya) +- feat: customize column default values for external tables [#8415](https://github.com/apache/arrow-datafusion/pull/8415) (jonahgao) +- feat: Support `array_sort`(`list_sort`) [#8279](https://github.com/apache/arrow-datafusion/pull/8279) (Asura7969) +- Bugfix: Remove df-cli specific SQL statment options before executing with DataFusion [#8426](https://github.com/apache/arrow-datafusion/pull/8426) (devinjdangelo) +- Detect when filters on unique constraints make subqueries scalar [#8312](https://github.com/apache/arrow-datafusion/pull/8312) (Jesse-Bakker) +- Add alias check to optimize projections merge [#8438](https://github.com/apache/arrow-datafusion/pull/8438) (mustafasrepo) +- Fix PartialOrd for ScalarValue::List/FixSizeList/LargeList [#8253](https://github.com/apache/arrow-datafusion/pull/8253) (jayzhan211) +- Support parquet_metadata for datafusion-cli [#8413](https://github.com/apache/arrow-datafusion/pull/8413) (Veeupup) +- Fix bug in optimizing a nested count [#8459](https://github.com/apache/arrow-datafusion/pull/8459) (Dandandan) +- Bump actions/setup-python from 4 to 5 [#8449](https://github.com/apache/arrow-datafusion/pull/8449) (dependabot[bot]) +- fix: ORDER BY window definition should work on null literal [#8444](https://github.com/apache/arrow-datafusion/pull/8444) (viirya) +- flx clippy warnings [#8455](https://github.com/apache/arrow-datafusion/pull/8455) (waynexia) +- fix: RANGE frame for corner cases with empty ORDER BY clause should be treated as constant sort [#8445](https://github.com/apache/arrow-datafusion/pull/8445) (viirya) +- Preserve `dict_id` on `Field` during serde roundtrip [#8457](https://github.com/apache/arrow-datafusion/pull/8457) (avantgardnerio) +- feat: support `InterleaveExecNode` in the proto [#8460](https://github.com/apache/arrow-datafusion/pull/8460) (liukun4515) +- [BUG FIX]: Proper Empty Batch handling in window execution [#8466](https://github.com/apache/arrow-datafusion/pull/8466) (mustafasrepo) +- Minor: update `cast` [#8458](https://github.com/apache/arrow-datafusion/pull/8458) (Weijun-H) +- fix: don't unifies projection if expr is non-trival [#8454](https://github.com/apache/arrow-datafusion/pull/8454) (haohuaijin) +- Minor: Add new bloom filter predicate tests [#8433](https://github.com/apache/arrow-datafusion/pull/8433) (alamb) +- Add PRIMARY KEY Aggregate support to dataframe API [#8356](https://github.com/apache/arrow-datafusion/pull/8356) (mustafasrepo) +- Minor: refactor `data_trunc` to reduce duplicated code [#8430](https://github.com/apache/arrow-datafusion/pull/8430) (Weijun-H) +- Support array_distinct function. [#8268](https://github.com/apache/arrow-datafusion/pull/8268) (my-vegetable-has-exploded) +- Add primary key support to stream table [#8467](https://github.com/apache/arrow-datafusion/pull/8467) (mustafasrepo) +- Add `evaluate_demo` and `range_analysis_demo` to Expr examples [#8377](https://github.com/apache/arrow-datafusion/pull/8377) (alamb) +- Minor: fix function name typo [#8473](https://github.com/apache/arrow-datafusion/pull/8473) (Weijun-H) +- Minor: Fix comment typo in table.rs: s/indentical/identical/ [#8469](https://github.com/apache/arrow-datafusion/pull/8469) (KeunwooLee-at) +- Remove `define_array_slice` and reuse `array_slice` for `array_pop_front/back` [#8401](https://github.com/apache/arrow-datafusion/pull/8401) (jayzhan211) +- Minor: refactor `trim` to clean up duplicated code [#8434](https://github.com/apache/arrow-datafusion/pull/8434) (Weijun-H) +- Split `EmptyExec` into `PlaceholderRowExec` [#8446](https://github.com/apache/arrow-datafusion/pull/8446) (razeghi71) +- Enable non-uniform field type for structs created in DataFusion [#8463](https://github.com/apache/arrow-datafusion/pull/8463) (dlovell) +- Minor: Add multi ordering test for array agg order [#8439](https://github.com/apache/arrow-datafusion/pull/8439) (jayzhan211) +- Sort filenames when reading parquet to ensure consistent schema [#6629](https://github.com/apache/arrow-datafusion/pull/6629) (thomas-k-cameron) +- Minor: Improve comments in EnforceDistribution tests [#8474](https://github.com/apache/arrow-datafusion/pull/8474) (alamb) +- fix: support uppercase when parsing `Interval` [#8478](https://github.com/apache/arrow-datafusion/pull/8478) (QuenKar) +- Better Equivalence (ordering and exact equivalence) Propagation through ProjectionExec [#8484](https://github.com/apache/arrow-datafusion/pull/8484) (mustafasrepo) +- Add `today` alias for `current_date` [#8423](https://github.com/apache/arrow-datafusion/pull/8423) (smallzhongfeng) +- Minor: remove useless clone in `array_expression` [#8495](https://github.com/apache/arrow-datafusion/pull/8495) (Weijun-H) +- fix: incorrect set preserve_partitioning in SortExec [#8485](https://github.com/apache/arrow-datafusion/pull/8485) (haohuaijin) +- Explicitly mark parquet for tests in datafusion-common [#8497](https://github.com/apache/arrow-datafusion/pull/8497) (Dennis40816) +- Minor/Doc: Clarify DataFrame::write_table Documentation [#8519](https://github.com/apache/arrow-datafusion/pull/8519) (devinjdangelo) +- fix: Pull stats in `IdentVisitor`/`GraphvizVisitor` only when requested [#8514](https://github.com/apache/arrow-datafusion/pull/8514) (vrongmeal) +- Change display of RepartitionExec from SortPreservingRepartitionExec to RepartitionExec preserve_order=true [#8521](https://github.com/apache/arrow-datafusion/pull/8521) (JacobOgle) +- Fix `DataFrame::cache` errors with `Plan("Mismatch between schema and batches")` [#8510](https://github.com/apache/arrow-datafusion/pull/8510) (Asura7969) +- Minor: update pbjson_dependency [#8470](https://github.com/apache/arrow-datafusion/pull/8470) (alamb) +- Minor: Update prost-derive dependency [#8471](https://github.com/apache/arrow-datafusion/pull/8471) (alamb) +- Minor/Doc: Add DataFrame::write_table to DataFrame user guide [#8527](https://github.com/apache/arrow-datafusion/pull/8527) (devinjdangelo) +- Minor: Add repartition_file.slt end to end test for repartitioning files, and supporting tweaks [#8505](https://github.com/apache/arrow-datafusion/pull/8505) (alamb) +- Prepare version 34.0.0 [#8508](https://github.com/apache/arrow-datafusion/pull/8508) (andygrove) +- refactor: use ExprBuilder to consume substrait expr and use macro to generate error [#8515](https://github.com/apache/arrow-datafusion/pull/8515) (waynexia) +- [MINOR]: Make some slt tests deterministic [#8525](https://github.com/apache/arrow-datafusion/pull/8525) (mustafasrepo) +- fix: volatile expressions should not be target of common subexpt elimination [#8520](https://github.com/apache/arrow-datafusion/pull/8520) (viirya) +- Minor: Add LakeSoul to the list of Known Users [#8536](https://github.com/apache/arrow-datafusion/pull/8536) (xuchen-plus) +- Fix regression with Incorrect results when reading parquet files with different schemas and statistics [#8533](https://github.com/apache/arrow-datafusion/pull/8533) (alamb) +- feat: improve string statistics display in datafusion-cli `parquet_metadata` function [#8535](https://github.com/apache/arrow-datafusion/pull/8535) (asimsedhain) +- Defer file creation to write [#8539](https://github.com/apache/arrow-datafusion/pull/8539) (tustvold) +- Minor: Improve error handling in sqllogictest runner [#8544](https://github.com/apache/arrow-datafusion/pull/8544) (alamb) diff --git a/dev/release/generate-changelog.py b/dev/release/generate-changelog.py index ff9e8d4754b2a..f419bdb3a1ac7 100755 --- a/dev/release/generate-changelog.py +++ b/dev/release/generate-changelog.py @@ -57,6 +57,7 @@ def generate_changelog(repo, repo_name, tag1, tag2): bugs = [] docs = [] enhancements = [] + performance = [] # categorize the pull requests based on GitHub labels print("Categorizing pull requests", file=sys.stderr) @@ -79,6 +80,8 @@ def generate_changelog(repo, repo_name, tag1, tag2): breaking.append((pull, commit)) elif 'bug' in labels or cc_type == 'fix': bugs.append((pull, commit)) + elif 'performance' in labels or cc_type == 'perf': + performance.append((pull, commit)) elif 'enhancement' in labels or cc_type == 'feat': enhancements.append((pull, commit)) elif 'documentation' in labels or cc_type == 'docs': @@ -87,6 +90,7 @@ def generate_changelog(repo, repo_name, tag1, tag2): # produce the changelog content print("Generating changelog content", file=sys.stderr) print_pulls(repo_name, "Breaking changes", breaking) + print_pulls(repo_name, "Performance related", performance) print_pulls(repo_name, "Implemented enhancements", enhancements) print_pulls(repo_name, "Fixed bugs", bugs) print_pulls(repo_name, "Documentation updates", docs) diff --git a/docs/Cargo.toml b/docs/Cargo.toml index 4d01466924f99..813335e30f777 100644 --- a/docs/Cargo.toml +++ b/docs/Cargo.toml @@ -29,4 +29,4 @@ authors = { workspace = true } rust-version = "1.70" [dependencies] -datafusion = { path = "../datafusion/core", version = "33.0.0", default-features = false } +datafusion = { path = "../datafusion/core", version = "34.0.0", default-features = false } diff --git a/docs/source/contributor-guide/communication.md b/docs/source/contributor-guide/communication.md index 11e0e4e0f0eaa..8678aa534baf0 100644 --- a/docs/source/contributor-guide/communication.md +++ b/docs/source/contributor-guide/communication.md @@ -26,15 +26,25 @@ All participation in the Apache Arrow DataFusion project is governed by the Apache Software Foundation's [code of conduct](https://www.apache.org/foundation/policies/conduct.html). +## GitHub + The vast majority of communication occurs in the open on our -[github repository](https://github.com/apache/arrow-datafusion). +[github repository](https://github.com/apache/arrow-datafusion) in the form of tickets, issues, discussions, and Pull Requests. + +## Slack and Discord -## Questions? +We use the Slack and Discord platforms for informal discussions and coordination. These are great places to +meet other contributors and get guidance on where to contribute. It is important to note that any technical designs and +decisions are made fully in the open, on GitHub. -### Mailing list +Most of us use the `#arrow-datafusion` and `#arrow-rust` channels in the [ASF Slack workspace](https://s.apache.org/slack-invite) . +Unfortunately, due to spammers, the ASF Slack workspace requires an invitation to join. To get an invitation, +request one in the `Arrow Rust` channel of the [Arrow Rust Discord server](https://discord.gg/Qw5gKqHxUM). -We use arrow.apache.org's `dev@` mailing list for project management, release -coordination and design discussions +## Mailing list + +We also use arrow.apache.org's `dev@` mailing list for release coordination and occasional design discussions. Other +than the the release process, most DataFusion mailing list traffic will link to a GitHub issue or PR for discussion. ([subscribe](mailto:dev-subscribe@arrow.apache.org), [unsubscribe](mailto:dev-unsubscribe@arrow.apache.org), [archives](https://lists.apache.org/list.html?dev@arrow.apache.org)). @@ -42,33 +52,3 @@ coordination and design discussions When emailing the dev list, please make sure to prefix the subject line with a `[DataFusion]` tag, e.g. `"[DataFusion] New API for remote data sources"`, so that the appropriate people in the Apache Arrow community notice the message. - -### Slack and Discord - -We use the official [ASF](https://s.apache.org/slack-invite) Slack workspace -for informal discussions and coordination. This is a great place to meet other -contributors and get guidance on where to contribute. Join us in the -`#arrow-rust` channel. - -We also have a backup Arrow Rust Discord -server ([invite link](https://discord.gg/Qw5gKqHxUM)) in case you are not able -to join the Slack workspace. If you need an invite to the Slack workspace, you -can also ask for one in our Discord server. - -### Sync up video calls - -We have biweekly sync calls every other Thursdays at both 04:00 UTC -and 16:00 UTC (starting September 30, 2021) depending on if there are -items on the agenda to discuss and someone being willing to host. - -Please see the [agenda](https://docs.google.com/document/d/1atCVnoff5SR4eM4Lwf2M1BBJTY6g3_HUNR6qswYJW_U/edit) -for the video call link, add topics and to see what others plan to discuss. - -The goals of these calls are: - -1. Help "put a face to the name" of some of other contributors we are working with -2. Discuss / synchronize on the goals and major initiatives from different stakeholders to identify areas where more alignment is needed - -No decisions are made on the call and anything of substance will be discussed on the mailing list or in github issues / google docs. - -We will send a summary of all sync ups to the dev@arrow.apache.org mailing list. diff --git a/docs/source/contributor-guide/index.md b/docs/source/contributor-guide/index.md index e42ab0dee07a0..8d69ade83d72e 100644 --- a/docs/source/contributor-guide/index.md +++ b/docs/source/contributor-guide/index.md @@ -151,7 +151,7 @@ Tests for code in an individual module are defined in the same source file with ### sqllogictests Tests -DataFusion's SQL implementation is tested using [sqllogictest](https://github.com/apache/arrow-datafusion/tree/main/datafusion/core/tests/sqllogictests) which are run like any other Rust test using `cargo test --test sqllogictests`. +DataFusion's SQL implementation is tested using [sqllogictest](https://github.com/apache/arrow-datafusion/tree/main/datafusion/sqllogictest) which are run like any other Rust test using `cargo test --test sqllogictests`. `sqllogictests` tests may be less convenient for new contributors who are familiar with writing `.rs` tests as they require learning another tool. However, `sqllogictest` based tests are much easier to develop and maintain as they 1) do not require a slow recompile/link cycle and 2) can be automatically updated via `cargo test --test sqllogictests -- --complete`. @@ -221,8 +221,8 @@ Below is a checklist of what you need to do to add a new scalar function to Data - a new line in `signature` with the signature of the function (number and types of its arguments) - a new line in `create_physical_expr`/`create_physical_fun` mapping the built-in to the implementation - tests to the function. -- In [core/tests/sqllogictests/test_files](../../../datafusion/core/tests/sqllogictests/test_files), add new `sqllogictest` integration tests where the function is called through SQL against well known data and returns the expected result. - - Documentation for `sqllogictest` [here](../../../datafusion/core/tests/sqllogictests/README.md) +- In [sqllogictest/test_files](../../../datafusion/sqllogictest/test_files), add new `sqllogictest` integration tests where the function is called through SQL against well known data and returns the expected result. + - Documentation for `sqllogictest` [here](../../../datafusion/sqllogictest/README.md) - In [expr/src/expr_fn.rs](../../../datafusion/expr/src/expr_fn.rs), add: - a new entry of the `unary_scalar_expr!` macro for the new function. - Add SQL reference documentation [here](../../../docs/source/user-guide/sql/scalar_functions.md) @@ -243,8 +243,8 @@ Below is a checklist of what you need to do to add a new aggregate function to D - a new line in `signature` with the signature of the function (number and types of its arguments) - a new line in `create_aggregate_expr` mapping the built-in to the implementation - tests to the function. -- In [core/tests/sqllogictests/test_files](../../../datafusion/core/tests/sqllogictests/test_files), add new `sqllogictest` integration tests where the function is called through SQL against well known data and returns the expected result. - - Documentation for `sqllogictest` [here](../../../datafusion/core/tests/sqllogictests/README.md) +- In [sqllogictest/test_files](../../../datafusion/sqllogictest/test_files), add new `sqllogictest` integration tests where the function is called through SQL against well known data and returns the expected result. + - Documentation for `sqllogictest` [here](../../../datafusion/sqllogictest/README.md) - Add SQL reference documentation [here](../../../docs/source/user-guide/sql/aggregate_functions.md) ### How to display plans graphically diff --git a/docs/source/index.rst b/docs/source/index.rst index bb8e2127f1e75..3853716617162 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -43,11 +43,12 @@ community. The `example usage`_ section in the user guide and the `datafusion-examples`_ code in the crate contain information on using DataFusion. -The `developer’s guide`_ contains information on how to contribute. +Please see the `developer’s guide`_ for contributing and `communication`_ for getting in touch with us. .. _example usage: user-guide/example-usage.html .. _datafusion-examples: https://github.com/apache/arrow-datafusion/tree/master/datafusion-examples .. _developer’s guide: contributor-guide/index.html#developer-s-guide +.. _communication: contributor-guide/communication.html .. _toc.links: .. toctree:: diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/adding-udfs.md index a4b5ed0b40f17..1f687f978f30e 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/adding-udfs.md @@ -17,17 +17,18 @@ under the License. --> -# Adding User Defined Functions: Scalar/Window/Aggregate +# Adding User Defined Functions: Scalar/Window/Aggregate/Table Functions User Defined Functions (UDFs) are functions that can be used in the context of DataFusion execution. This page covers how to add UDFs to DataFusion. In particular, it covers how to add Scalar, Window, and Aggregate UDFs. -| UDF Type | Description | Example | -| --------- | ---------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------ | -| Scalar | A function that takes a row of data and returns a single value. | [simple_udf.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udf.rs) | -| Window | A function that takes a row of data and returns a single value, but also has access to the rows around it. | [simple_udwf.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udwf.rs) | -| Aggregate | A function that takes a group of rows and returns a single value. | [simple_udaf.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs) | +| UDF Type | Description | Example | +| --------- | ---------------------------------------------------------------------------------------------------------- | ------------------- | +| Scalar | A function that takes a row of data and returns a single value. | [simple_udf.rs][1] | +| Window | A function that takes a row of data and returns a single value, but also has access to the rows around it. | [simple_udwf.rs][2] | +| Aggregate | A function that takes a group of rows and returns a single value. | [simple_udaf.rs][3] | +| Table | A function that takes parameters and returns a `TableProvider` to be used in an query plan. | [simple_udtf.rs][4] | First we'll talk about adding an Scalar UDF end-to-end, then we'll talk about the differences between the different types of UDFs. @@ -38,7 +39,7 @@ A Scalar UDF is a function that takes a row of data and returns a single value. ```rust use std::sync::Arc; -use arrow::array::{ArrayRef, Int64Array}; +use datafusion::arrow::array::{ArrayRef, Int64Array}; use datafusion::common::Result; use datafusion::common::cast::as_int64_array; @@ -75,9 +76,16 @@ The challenge however is that DataFusion doesn't know about this function. We ne ### Registering a Scalar UDF -To register a Scalar UDF, you need to wrap the function implementation in a `ScalarUDF` struct and then register it with the `SessionContext`. DataFusion provides the `create_udf` and `make_scalar_function` helper functions to make this easier. +To register a Scalar UDF, you need to wrap the function implementation in a [`ScalarUDF`] struct and then register it with the `SessionContext`. +DataFusion provides the [`create_udf`] and helper functions to make this easier. +There is a lower level API with more functionality but is more complex, that is documented in [`advanced_udf.rs`]. ```rust +use datafusion::logical_expr::{Volatility, create_udf}; +use datafusion::physical_plan::functions::make_scalar_function; +use datafusion::arrow::datatypes::DataType; +use std::sync::Arc; + let udf = create_udf( "add_one", vec![DataType::Int64], @@ -87,6 +95,11 @@ let udf = create_udf( ); ``` +[`scalarudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.ScalarUDF.html +[`create_udf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udf.html +[`make_scalar_function`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.make_scalar_function.html +[`advanced_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs + A few things to note: - The first argument is the name of the function. This is the name that will be used in SQL queries. @@ -98,6 +111,8 @@ A few things to note: That gives us a `ScalarUDF` that we can register with the `SessionContext`: ```rust +use datafusion::execution::context::SessionContext; + let mut ctx = SessionContext::new(); ctx.register_udf(udf); @@ -115,10 +130,415 @@ let df = ctx.sql(&sql).await.unwrap(); Scalar UDFs are functions that take a row of data and return a single value. Window UDFs are similar, but they also have access to the rows around them. Access to the the proximal rows is helpful, but adds some complexity to the implementation. -Body coming soon. +For example, we will declare a user defined window function that computes a moving average. + +```rust +use datafusion::arrow::{array::{ArrayRef, Float64Array, AsArray}, datatypes::Float64Type}; +use datafusion::logical_expr::{PartitionEvaluator}; +use datafusion::common::ScalarValue; +use datafusion::error::Result; +/// This implements the lowest level evaluation for a window function +/// +/// It handles calculating the value of the window function for each +/// distinct values of `PARTITION BY` +#[derive(Clone, Debug)] +struct MyPartitionEvaluator {} + +impl MyPartitionEvaluator { + fn new() -> Self { + Self {} + } +} + +/// Different evaluation methods are called depending on the various +/// settings of WindowUDF. This example uses the simplest and most +/// general, `evaluate`. See `PartitionEvaluator` for the other more +/// advanced uses. +impl PartitionEvaluator for MyPartitionEvaluator { + /// Tell DataFusion the window function varies based on the value + /// of the window frame. + fn uses_window_frame(&self) -> bool { + true + } + + /// This function is called once per input row. + /// + /// `range`specifies which indexes of `values` should be + /// considered for the calculation. + /// + /// Note this is the SLOWEST, but simplest, way to evaluate a + /// window function. It is much faster to implement + /// evaluate_all or evaluate_all_with_rank, if possible + fn evaluate( + &mut self, + values: &[ArrayRef], + range: &std::ops::Range, + ) -> Result { + // Again, the input argument is an array of floating + // point numbers to calculate a moving average + let arr: &Float64Array = values[0].as_ref().as_primitive::(); + + let range_len = range.end - range.start; + + // our smoothing function will average all the values in the + let output = if range_len > 0 { + let sum: f64 = arr.values().iter().skip(range.start).take(range_len).sum(); + Some(sum / range_len as f64) + } else { + None + }; + + Ok(ScalarValue::Float64(output)) + } +} + +/// Create a `PartitionEvalutor` to evaluate this function on a new +/// partition. +fn make_partition_evaluator() -> Result> { + Ok(Box::new(MyPartitionEvaluator::new())) +} +``` + +### Registering a Window UDF + +To register a Window UDF, you need to wrap the function implementation in a [`WindowUDF`] struct and then register it with the `SessionContext`. DataFusion provides the [`create_udwf`] helper functions to make this easier. +There is a lower level API with more functionality but is more complex, that is documented in [`advanced_udwf.rs`]. + +```rust +use datafusion::logical_expr::{Volatility, create_udwf}; +use datafusion::arrow::datatypes::DataType; +use std::sync::Arc; + +// here is where we define the UDWF. We also declare its signature: +let smooth_it = create_udwf( + "smooth_it", + DataType::Float64, + Arc::new(DataType::Float64), + Volatility::Immutable, + Arc::new(make_partition_evaluator), +); +``` + +[`windowudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.WindowUDF.html +[`create_udwf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udwf.html +[`advanced_udwf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udwf.rs + +The `create_udwf` has five arguments to check: + +- The first argument is the name of the function. This is the name that will be used in SQL queries. +- **The second argument** is the `DataType` of input array (attention: this is not a list of arrays). I.e. in this case, the function accepts `Float64` as argument. +- The third argument is the return type of the function. I.e. in this case, the function returns an `Float64`. +- The fourth argument is the volatility of the function. In short, this is used to determine if the function's performance can be optimized in some situations. In this case, the function is `Immutable` because it always returns the same value for the same input. A random number generator would be `Volatile` because it returns a different value for the same input. +- **The fifth argument** is the function implementation. This is the function that we defined above. + +That gives us a `WindowUDF` that we can register with the `SessionContext`: + +```rust +use datafusion::execution::context::SessionContext; + +let ctx = SessionContext::new(); + +ctx.register_udwf(smooth_it); +``` + +At this point, you can use the `smooth_it` function in your query: + +For example, if we have a [`cars.csv`](https://github.com/apache/arrow-datafusion/blob/main/datafusion/core/tests/data/cars.csv) whose contents like + +```csv +car,speed,time +red,20.0,1996-04-12T12:05:03.000000000 +red,20.3,1996-04-12T12:05:04.000000000 +green,10.0,1996-04-12T12:05:03.000000000 +green,10.3,1996-04-12T12:05:04.000000000 +... +``` + +Then, we can query like below: + +```rust +use datafusion::datasource::file_format::options::CsvReadOptions; +// register csv table first +let csv_path = "cars.csv".to_string(); +ctx.register_csv("cars", &csv_path, CsvReadOptions::default().has_header(true)).await?; +// do query with smooth_it +let df = ctx + .sql( + "SELECT \ + car, \ + speed, \ + smooth_it(speed) OVER (PARTITION BY car ORDER BY time) as smooth_speed,\ + time \ + from cars \ + ORDER BY \ + car", + ) + .await?; +// print the results +df.show().await?; +``` + +the output will be like: + +```csv ++-------+-------+--------------------+---------------------+ +| car | speed | smooth_speed | time | ++-------+-------+--------------------+---------------------+ +| green | 10.0 | 10.0 | 1996-04-12T12:05:03 | +| green | 10.3 | 10.15 | 1996-04-12T12:05:04 | +| green | 10.4 | 10.233333333333334 | 1996-04-12T12:05:05 | +| green | 10.5 | 10.3 | 1996-04-12T12:05:06 | +| green | 11.0 | 10.440000000000001 | 1996-04-12T12:05:07 | +| green | 12.0 | 10.700000000000001 | 1996-04-12T12:05:08 | +| green | 14.0 | 11.171428571428573 | 1996-04-12T12:05:09 | +| green | 15.0 | 11.65 | 1996-04-12T12:05:10 | +| green | 15.1 | 12.033333333333333 | 1996-04-12T12:05:11 | +| green | 15.2 | 12.35 | 1996-04-12T12:05:12 | +| green | 8.0 | 11.954545454545455 | 1996-04-12T12:05:13 | +| green | 2.0 | 11.125 | 1996-04-12T12:05:14 | +| red | 20.0 | 20.0 | 1996-04-12T12:05:03 | +| red | 20.3 | 20.15 | 1996-04-12T12:05:04 | +... +``` ## Adding an Aggregate UDF Aggregate UDFs are functions that take a group of rows and return a single value. These are akin to SQL's `SUM` or `COUNT` functions. -Body coming soon. +For example, we will declare a single-type, single return type UDAF that computes the geometric mean. + +```rust +use datafusion::arrow::array::ArrayRef; +use datafusion::scalar::ScalarValue; +use datafusion::{error::Result, physical_plan::Accumulator}; + +/// A UDAF has state across multiple rows, and thus we require a `struct` with that state. +#[derive(Debug)] +struct GeometricMean { + n: u32, + prod: f64, +} + +impl GeometricMean { + // how the struct is initialized + pub fn new() -> Self { + GeometricMean { n: 0, prod: 1.0 } + } +} + +// UDAFs are built using the trait `Accumulator`, that offers DataFusion the necessary functions +// to use them. +impl Accumulator for GeometricMean { + // This function serializes our state to `ScalarValue`, which DataFusion uses + // to pass this state between execution stages. + // Note that this can be arbitrary data. + fn state(&self) -> Result> { + Ok(vec![ + ScalarValue::from(self.prod), + ScalarValue::from(self.n), + ]) + } + + // DataFusion expects this function to return the final value of this aggregator. + // in this case, this is the formula of the geometric mean + fn evaluate(&self) -> Result { + let value = self.prod.powf(1.0 / self.n as f64); + Ok(ScalarValue::from(value)) + } + + // DataFusion calls this function to update the accumulator's state for a batch + // of inputs rows. In this case the product is updated with values from the first column + // and the count is updated based on the row count + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + let arr = &values[0]; + (0..arr.len()).try_for_each(|index| { + let v = ScalarValue::try_from_array(arr, index)?; + + if let ScalarValue::Float64(Some(value)) = v { + self.prod *= value; + self.n += 1; + } else { + unreachable!("") + } + Ok(()) + }) + } + + // Optimization hint: this trait also supports `update_batch` and `merge_batch`, + // that can be used to perform these operations on arrays instead of single values. + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + let arr = &states[0]; + (0..arr.len()).try_for_each(|index| { + let v = states + .iter() + .map(|array| ScalarValue::try_from_array(array, index)) + .collect::>>()?; + if let (ScalarValue::Float64(Some(prod)), ScalarValue::UInt32(Some(n))) = (&v[0], &v[1]) + { + self.prod *= prod; + self.n += n; + } else { + unreachable!("") + } + Ok(()) + }) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} +``` + +### registering an Aggregate UDF + +To register a Aggreate UDF, you need to wrap the function implementation in a `AggregateUDF` struct and then register it with the `SessionContext`. DataFusion provides the `create_udaf` helper functions to make this easier. + +```rust +use datafusion::logical_expr::{Volatility, create_udaf}; +use datafusion::arrow::datatypes::DataType; +use std::sync::Arc; + +// here is where we define the UDAF. We also declare its signature: +let geometric_mean = create_udaf( + // the name; used to represent it in plan descriptions and in the registry, to use in SQL. + "geo_mean", + // the input type; DataFusion guarantees that the first entry of `values` in `update` has this type. + vec![DataType::Float64], + // the return type; DataFusion expects this to match the type returned by `evaluate`. + Arc::new(DataType::Float64), + Volatility::Immutable, + // This is the accumulator factory; DataFusion uses it to create new accumulators. + Arc::new(|_| Ok(Box::new(GeometricMean::new()))), + // This is the description of the state. `state()` must match the types here. + Arc::new(vec![DataType::Float64, DataType::UInt32]), +); +``` + +The `create_udaf` has six arguments to check: + +- The first argument is the name of the function. This is the name that will be used in SQL queries. +- The second argument is a vector of `DataType`s. This is the list of argument types that the function accepts. I.e. in this case, the function accepts a single `Float64` argument. +- The third argument is the return type of the function. I.e. in this case, the function returns an `Int64`. +- The fourth argument is the volatility of the function. In short, this is used to determine if the function's performance can be optimized in some situations. In this case, the function is `Immutable` because it always returns the same value for the same input. A random number generator would be `Volatile` because it returns a different value for the same input. +- The fifth argument is the function implementation. This is the function that we defined above. +- The sixth argument is the description of the state, which will by passed between execution stages. + +That gives us a `AggregateUDF` that we can register with the `SessionContext`: + +```rust +use datafusion::execution::context::SessionContext; + +let ctx = SessionContext::new(); + +ctx.register_udaf(geometric_mean); +``` + +Then, we can query like below: + +```rust +let df = ctx.sql("SELECT geo_mean(a) FROM t").await?; +``` + +## Adding a User-Defined Table Function + +A User-Defined Table Function (UDTF) is a function that takes parameters and returns a `TableProvider`. + +Because we're returning a `TableProvider`, in this example we'll use the `MemTable` data source to represent a table. This is a simple struct that holds a set of RecordBatches in memory and treats them as a table. In your case, this would be replaced with your own struct that implements `TableProvider`. + +While this is a simple example for illustrative purposes, UDTFs have a lot of potential use cases. And can be particularly useful for reading data from external sources and interactive analysis. For example, see the [example][4] for a working example that reads from a CSV file. As another example, you could use the built-in UDTF `parquet_metadata` in the CLI to read the metadata from a Parquet file. + +```console +❯ select filename, row_group_id, row_group_num_rows, row_group_bytes, stats_min, stats_max from parquet_metadata('./benchmarks/data/hits.parquet') where column_id = 17 limit 10; ++--------------------------------+--------------+--------------------+-----------------+-----------+-----------+ +| filename | row_group_id | row_group_num_rows | row_group_bytes | stats_min | stats_max | ++--------------------------------+--------------+--------------------+-----------------+-----------+-----------+ +| ./benchmarks/data/hits.parquet | 0 | 450560 | 188921521 | 0 | 73256 | +| ./benchmarks/data/hits.parquet | 1 | 612174 | 210338885 | 0 | 109827 | +| ./benchmarks/data/hits.parquet | 2 | 344064 | 161242466 | 0 | 122484 | +| ./benchmarks/data/hits.parquet | 3 | 606208 | 235549898 | 0 | 121073 | +| ./benchmarks/data/hits.parquet | 4 | 335872 | 137103898 | 0 | 108996 | +| ./benchmarks/data/hits.parquet | 5 | 311296 | 145453612 | 0 | 108996 | +| ./benchmarks/data/hits.parquet | 6 | 303104 | 138833963 | 0 | 108996 | +| ./benchmarks/data/hits.parquet | 7 | 303104 | 191140113 | 0 | 73256 | +| ./benchmarks/data/hits.parquet | 8 | 573440 | 208038598 | 0 | 95823 | +| ./benchmarks/data/hits.parquet | 9 | 344064 | 147838157 | 0 | 73256 | ++--------------------------------+--------------+--------------------+-----------------+-----------+-----------+ +``` + +### Writing the UDTF + +The simple UDTF used here takes a single `Int64` argument and returns a table with a single column with the value of the argument. To create a function in DataFusion, you need to implement the `TableFunctionImpl` trait. This trait has a single method, `call`, that takes a slice of `Expr`s and returns a `Result>`. + +In the `call` method, you parse the input `Expr`s and return a `TableProvider`. You might also want to do some validation of the input `Expr`s, e.g. checking that the number of arguments is correct. + +```rust +use datafusion::common::plan_err; +use datafusion::datasource::function::TableFunctionImpl; +// Other imports here + +/// A table function that returns a table provider with the value as a single column +#[derive(Default)] +pub struct EchoFunction {} + +impl TableFunctionImpl for EchoFunction { + fn call(&self, exprs: &[Expr]) -> Result> { + let Some(Expr::Literal(ScalarValue::Int64(Some(value)))) = exprs.get(0) else { + return plan_err!("First argument must be an integer"); + }; + + // Create the schema for the table + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + + // Create a single RecordBatch with the value as a single column + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int64Array::from(vec![*value]))], + )?; + + // Create a MemTable plan that returns the RecordBatch + let provider = MemTable::try_new(schema, vec![vec![batch]])?; + + Ok(Arc::new(provider)) + } +} +``` + +### Registering and Using the UDTF + +With the UDTF implemented, you can register it with the `SessionContext`: + +```rust +use datafusion::execution::context::SessionContext; + +let ctx = SessionContext::new(); + +ctx.register_udtf("echo", Arc::new(EchoFunction::default())); +``` + +And if all goes well, you can use it in your query: + +```rust +use datafusion::arrow::util::pretty; + +let df = ctx.sql("SELECT * FROM echo(1)").await?; + +let results = df.collect().await?; +pretty::print_batches(&results)?; +// +---+ +// | a | +// +---+ +// | 1 | +// +---+ +``` + +[1]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udf.rs +[2]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udwf.rs +[3]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs +[4]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udtf.rs diff --git a/docs/source/library-user-guide/custom-table-providers.md b/docs/source/library-user-guide/custom-table-providers.md index ca0e9de779efa..9da207da68f32 100644 --- a/docs/source/library-user-guide/custom-table-providers.md +++ b/docs/source/library-user-guide/custom-table-providers.md @@ -25,7 +25,7 @@ This section will also touch on how to have DataFusion use the new `TableProvide ## Table Provider and Scan -The `scan` method on the `TableProvider` is likely its most important. It returns an `ExecutionPlan` that DataFusion will use to read the actual data during execution o the query. +The `scan` method on the `TableProvider` is likely its most important. It returns an `ExecutionPlan` that DataFusion will use to read the actual data during execution of the query. ### Scan diff --git a/docs/source/library-user-guide/using-the-dataframe-api.md b/docs/source/library-user-guide/using-the-dataframe-api.md index fdf309980dc2e..c4f4ecd4f1370 100644 --- a/docs/source/library-user-guide/using-the-dataframe-api.md +++ b/docs/source/library-user-guide/using-the-dataframe-api.md @@ -19,4 +19,129 @@ # Using the DataFrame API -Coming Soon +## What is a DataFrame + +`DataFrame` in `DataFrame` is modeled after the Pandas DataFrame interface, and is a thin wrapper over LogicalPlan that adds functionality for building and executing those plans. + +```rust +pub struct DataFrame { + session_state: SessionState, + plan: LogicalPlan, +} +``` + +You can build up `DataFrame`s using its methods, similarly to building `LogicalPlan`s using `LogicalPlanBuilder`: + +```rust +let df = ctx.table("users").await?; + +// Create a new DataFrame sorted by `id`, `bank_account` +let new_df = df.select(vec![col("id"), col("bank_account")])? + .sort(vec![col("id")])?; + +// Build the same plan using the LogicalPlanBuilder +let plan = LogicalPlanBuilder::from(&df.to_logical_plan()) + .project(vec![col("id"), col("bank_account")])? + .sort(vec![col("id")])? + .build()?; +``` + +You can use `collect` or `execute_stream` to execute the query. + +## How to generate a DataFrame + +You can directly use the `DataFrame` API or generate a `DataFrame` from a SQL query. + +For example, to use `sql` to construct `DataFrame`: + +```rust +let ctx = SessionContext::new(); +// Register the in-memory table containing the data +ctx.register_table("users", Arc::new(create_memtable()?))?; +let dataframe = ctx.sql("SELECT * FROM users;").await?; +``` + +To construct `DataFrame` using the API: + +```rust +let ctx = SessionContext::new(); +// Register the in-memory table containing the data +ctx.register_table("users", Arc::new(create_memtable()?))?; +let dataframe = ctx + .table("users") + .filter(col("a").lt_eq(col("b")))? + .sort(vec![col("a").sort(true, true), col("b").sort(false, false)])?; +``` + +## Collect / Streaming Exec + +DataFusion `DataFrame`s are "lazy", meaning they do not do any processing until they are executed, which allows for additional optimizations. + +When you have a `DataFrame`, you can run it in one of three ways: + +1. `collect` which executes the query and buffers all the output into a `Vec` +2. `streaming_exec`, which begins executions and returns a `SendableRecordBatchStream` which incrementally computes output on each call to `next()` +3. `cache` which executes the query and buffers the output into a new in memory DataFrame. + +You can just collect all outputs once like: + +```rust +let ctx = SessionContext::new(); +let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; +let batches = df.collect().await?; +``` + +You can also use stream output to incrementally generate output one `RecordBatch` at a time + +```rust +let ctx = SessionContext::new(); +let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; +let mut stream = df.execute_stream().await?; +while let Some(rb) = stream.next().await { + println!("{rb:?}"); +} +``` + +# Write DataFrame to Files + +You can also serialize `DataFrame` to a file. For now, `Datafusion` supports write `DataFrame` to `csv`, `json` and `parquet`. + +When writing a file, DataFusion will execute the DataFrame and stream the results to a file. + +For example, to write a csv_file + +```rust +let ctx = SessionContext::new(); +// Register the in-memory table containing the data +ctx.register_table("users", Arc::new(mem_table))?; +let dataframe = ctx.sql("SELECT * FROM users;").await?; + +dataframe + .write_csv("user_dataframe.csv", DataFrameWriteOptions::default(), None) + .await; +``` + +and the file will look like (Example Output): + +``` +id,bank_account +1,9000 +``` + +## Transform between LogicalPlan and DataFrame + +As shown above, `DataFrame` is just a very thin wrapper of `LogicalPlan`, so you can easily go back and forth between them. + +```rust +// Just combine LogicalPlan with SessionContext and you get a DataFrame +let ctx = SessionContext::new(); +// Register the in-memory table containing the data +ctx.register_table("users", Arc::new(mem_table))?; +let dataframe = ctx.sql("SELECT * FROM users;").await?; + +// get LogicalPlan in dataframe +let plan = dataframe.logical_plan().clone(); + +// construct a DataFrame with LogicalPlan +let new_df = DataFrame::new(ctx.state(), plan); +``` diff --git a/docs/source/library-user-guide/working-with-exprs.md b/docs/source/library-user-guide/working-with-exprs.md index a8baf24d5f0ae..96be8ef7f1aeb 100644 --- a/docs/source/library-user-guide/working-with-exprs.md +++ b/docs/source/library-user-guide/working-with-exprs.md @@ -17,7 +17,7 @@ under the License. --> -# Working with Exprs +# Working with `Expr`s @@ -48,12 +48,11 @@ As another example, the SQL expression `a + b * c` would be represented as an `E └────────────────────┘ └────────────────────┘ ``` -As the writer of a library, you may want to use or create `Expr`s to represent computations that you want to perform. This guide will walk you through how to make your own scalar UDF as an `Expr` and how to rewrite `Expr`s to inline the simple UDF. +As the writer of a library, you can use `Expr`s to represent computations that you want to perform. This guide will walk you through how to make your own scalar UDF as an `Expr` and how to rewrite `Expr`s to inline the simple UDF. -There are also executable examples for working with `Expr`s: +## Creating and Evaluating `Expr`s -- [rewrite_expr.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/rewrite_expr.rs) -- [expr_api.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/expr_api.rs) +Please see [expr_api.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/expr_api.rs) for well commented code for creating, evaluating, simplifying, and analyzing `Expr`s. ## A Scalar UDF Example @@ -79,7 +78,9 @@ let expr = add_one_udf.call(vec![col("my_column")]); If you'd like to learn more about `Expr`s, before we get into the details of creating and rewriting them, you can read the [expression user-guide](./../user-guide/expressions.md). -## Rewriting Exprs +## Rewriting `Expr`s + +[rewrite_expr.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/rewrite_expr.rs) contains example code for rewriting `Expr`s. Rewriting Expressions is the process of taking an `Expr` and transforming it into another `Expr`. This is useful for a number of reasons, including: diff --git a/docs/source/user-guide/cli.md b/docs/source/user-guide/cli.md index e8fdae7bb097b..525ab090ce514 100644 --- a/docs/source/user-guide/cli.md +++ b/docs/source/user-guide/cli.md @@ -31,7 +31,9 @@ The easiest way to install DataFusion CLI a spin is via `cargo install datafusio ### Install and run using Homebrew (on MacOS) -DataFusion CLI can also be installed via Homebrew (on MacOS). Install it as any other pre-built software like this: +DataFusion CLI can also be installed via Homebrew (on MacOS). If you don't have Homebrew installed, you can check how to install it [here](https://docs.brew.sh/Installation). + +Install it as any other pre-built software like this: ```bash brew install datafusion @@ -46,6 +48,34 @@ brew install datafusion datafusion-cli ``` +### Install and run using PyPI + +DataFusion CLI can also be installed via PyPI. You can check how to install PyPI [here](https://pip.pypa.io/en/latest/installation/). + +Install it as any other pre-built software like this: + +```bash +pip3 install datafusion +# Defaulting to user installation because normal site-packages is not writeable +# Collecting datafusion +# Downloading datafusion-33.0.0-cp38-abi3-macosx_11_0_arm64.whl.metadata (9.6 kB) +# Collecting pyarrow>=11.0.0 (from datafusion) +# Downloading pyarrow-14.0.1-cp39-cp39-macosx_11_0_arm64.whl.metadata (3.0 kB) +# Requirement already satisfied: numpy>=1.16.6 in /Users/Library/Python/3.9/lib/python/site-packages (from pyarrow>=11.0.0->datafusion) (1.23.4) +# Downloading datafusion-33.0.0-cp38-abi3-macosx_11_0_arm64.whl (13.5 MB) +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.5/13.5 MB 3.6 MB/s eta 0:00:00 +# Downloading pyarrow-14.0.1-cp39-cp39-macosx_11_0_arm64.whl (24.0 MB) +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 24.0/24.0 MB 36.4 MB/s eta 0:00:00 +# Installing collected packages: pyarrow, datafusion +# Attempting uninstall: pyarrow +# Found existing installation: pyarrow 10.0.1 +# Uninstalling pyarrow-10.0.1: +# Successfully uninstalled pyarrow-10.0.1 +# Successfully installed datafusion-33.0.0 pyarrow-14.0.1 + +datafusion-cli +``` + ### Run using Docker There is no officially published Docker image for the DataFusion CLI, so it is necessary to build from source diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 4cc4fd1c3a25e..0a5c221c50343 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -64,7 +64,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.parquet.statistics_enabled | NULL | Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting | | datafusion.execution.parquet.max_statistics_size | NULL | Sets max statistics size for any column. If NULL, uses default parquet writer setting | | datafusion.execution.parquet.max_row_group_size | 1048576 | Sets maximum number of rows in a row group | -| datafusion.execution.parquet.created_by | datafusion version 33.0.0 | Sets "created by" property | +| datafusion.execution.parquet.created_by | datafusion version 34.0.0 | Sets "created by" property | | datafusion.execution.parquet.column_index_truncate_length | NULL | Sets column index truncate length | | datafusion.execution.parquet.data_page_row_count_limit | 18446744073709551615 | Sets best effort maximum number of rows in data page | | datafusion.execution.parquet.encoding | NULL | Sets default encoding for any column Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting | @@ -82,6 +82,8 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.minimum_parallel_output_files | 4 | Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. | | datafusion.execution.soft_max_rows_per_output_file | 50000000 | Target number of rows in output files when writing multiple. This is a soft max, so it can be exceeded slightly. There also will be one file smaller than the limit if the total number of rows written is not roughly divisible by the soft max | | datafusion.execution.max_buffered_batches_per_output_file | 2 | This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption | +| datafusion.execution.listing_table_ignore_subdirectory | true | Should sub directories be ignored when scanning directories for data files. Defaults to true (ignores subdirectories), consistent with Hive. Note that this setting does not affect reading partitioned tables (e.g. `/table/year=2021/month=01/data.parquet`). | +| datafusion.optimizer.enable_distinct_aggregation_soft_limit | true | When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. | | datafusion.optimizer.enable_round_robin_repartition | true | When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores | | datafusion.optimizer.enable_topk_aggregation | true | When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible | | datafusion.optimizer.filter_null_join_keys | false | When set to true, the optimizer will insert filters before a join between a nullable and non-nullable column to filter out nulls on the nullable side. This filter can add additional overhead when the file format does not fully support predicate push down. | @@ -98,6 +100,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.optimizer.top_down_join_key_reordering | true | When set to true, the physical plan optimizer will run a top down process to reorder the join keys | | datafusion.optimizer.prefer_hash_join | true | When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. HashJoin can work more efficiently than SortMergeJoin but consumes more memory | | datafusion.optimizer.hash_join_single_partition_threshold | 1048576 | The maximum estimated size in bytes for one input side of a HashJoin will be collected into a single partition | +| datafusion.optimizer.default_filter_selectivity | 20 | The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). | | datafusion.explain.logical_plan_only | false | When set to true, the explain statement will only print logical plans | | datafusion.explain.physical_plan_only | false | When set to true, the explain statement will only print physical plans | | datafusion.explain.show_statistics | false | When set to true, the explain statement will print operator statistics for physical plans | diff --git a/docs/source/user-guide/dataframe.md b/docs/source/user-guide/dataframe.md index 4484b2c510197..c0210200a246f 100644 --- a/docs/source/user-guide/dataframe.md +++ b/docs/source/user-guide/dataframe.md @@ -95,6 +95,7 @@ These methods execute the logical plan represented by the DataFrame and either c | write_csv | Execute this DataFrame and write the results to disk in CSV format. | | write_json | Execute this DataFrame and write the results to disk in JSON format. | | write_parquet | Execute this DataFrame and write the results to disk in Parquet format. | +| write_table | Execute this DataFrame and write the results via the insert_into method of the registered TableProvider | ## Other DataFrame Methods diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index dbe12df335648..b8689e5567415 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -215,10 +215,12 @@ Unlike to some databases the math functions in Datafusion works the same way as | array_has_all(array, sub-array) | Returns true if all elements of sub-array exist in array `array_has_all([1,2,3], [1,3]) -> true` | | array_has_any(array, sub-array) | Returns true if any elements exist in both arrays `array_has_any([1,2,3], [1,4]) -> true` | | array_dims(array) | Returns an array of the array's dimensions. `array_dims([[1, 2, 3], [4, 5, 6]]) -> [2, 3]` | +| array_distinct(array) | Returns distinct values from the array after removing duplicates. `array_distinct([1, 3, 2, 3, 1, 2, 4]) -> [1, 2, 3, 4]` | | array_element(array, index) | Extracts the element with the index n from the array `array_element([1, 2, 3, 4], 3) -> 3` | | flatten(array) | Converts an array of arrays to a flat array `flatten([[1], [2, 3], [4, 5, 6]]) -> [1, 2, 3, 4, 5, 6]` | | array_length(array, dimension) | Returns the length of the array dimension. `array_length([1, 2, 3, 4, 5]) -> 5` | | array_ndims(array) | Returns the number of dimensions of the array. `array_ndims([[1, 2, 3], [4, 5, 6]]) -> 2` | +| array_pop_front(array) | Returns the array without the first element. `array_pop_front([1, 2, 3]) -> [2, 3]` | | array_pop_back(array) | Returns the array without the last element. `array_pop_back([1, 2, 3]) -> [1, 2]` | | array_position(array, element) | Searches for an element in the array, returns first occurrence. `array_position([1, 2, 2, 3, 4], 2) -> 2` | | array_positions(array, element) | Searches for an element in the array, returns all occurrences. `array_positions([1, 2, 2, 3, 4], 2) -> [2, 3]` | @@ -232,8 +234,12 @@ Unlike to some databases the math functions in Datafusion works the same way as | array_replace_all(array, from, to) | Replaces all occurrences of the specified element with another specified element. `array_replace_all([1, 2, 2, 3, 2, 1, 4], 2, 5) -> [1, 5, 5, 3, 5, 1, 4]` | | array_slice(array, index) | Returns a slice of the array. `array_slice([1, 2, 3, 4, 5, 6, 7, 8], 3, 6) -> [3, 4, 5, 6]` | | array_to_string(array, delimiter) | Converts each element to its text representation. `array_to_string([1, 2, 3, 4], ',') -> 1,2,3,4` | +| array_intersect(array1, array2) | Returns an array of the elements in the intersection of array1 and array2. `array_intersect([1, 2, 3, 4], [5, 6, 3, 4]) -> [3, 4]` | +| array_union(array1, array2) | Returns an array of the elements in the union of array1 and array2 without duplicates. `array_union([1, 2, 3, 4], [5, 6, 3, 4]) -> [1, 2, 3, 4, 5, 6]` | +| array_except(array1, array2) | Returns an array of the elements that appear in the first array but not in the second. `array_except([1, 2, 3, 4], [5, 6, 3, 4]) -> [3, 4]` | | cardinality(array) | Returns the total number of elements in the array. `cardinality([[1, 2, 3], [4, 5, 6]]) -> 6` | | make_array(value1, [value2 [, ...]]) | Returns an Arrow array using the specified input expressions. `make_array(1, 2, 3) -> [1, 2, 3]` | +| range(start [, stop, step]) | Returns an Arrow array between start and stop with step. `SELECT range(2, 10, 3) -> [2, 5, 8]` | | trim_array(array, n) | Deprecated | ## Regular Expressions diff --git a/docs/source/user-guide/introduction.md b/docs/source/user-guide/introduction.md index da250fbb1f9c0..b737c3bab2666 100644 --- a/docs/source/user-guide/introduction.md +++ b/docs/source/user-guide/introduction.md @@ -75,7 +75,7 @@ latency). Here are some example systems built using DataFusion: -- Specialized Analytical Database systems such as [CeresDB] and more general Apache Spark like system such a [Ballista]. +- Specialized Analytical Database systems such as [HoraeDB] and more general Apache Spark like system such a [Ballista]. - New query language engines such as [prql-query] and accelerators such as [VegaFusion] - Research platform for new Database Systems, such as [Flock] - SQL support to another library, such as [dask sql] @@ -96,7 +96,6 @@ Here are some active projects using DataFusion: - [Arroyo](https://github.com/ArroyoSystems/arroyo) Distributed stream processing engine in Rust - [Ballista](https://github.com/apache/arrow-ballista) Distributed SQL Query Engine -- [CeresDB](https://github.com/CeresDB/ceresdb) Distributed Time-Series Database - [CnosDB](https://github.com/cnosdb/cnosdb) Open Source Distributed Time Series Database - [Cube Store](https://github.com/cube-js/cube.js/tree/master/rust) - [Dask SQL](https://github.com/dask-contrib/dask-sql) Distributed SQL query engine in Python @@ -104,8 +103,10 @@ Here are some active projects using DataFusion: - [delta-rs](https://github.com/delta-io/delta-rs) Native Rust implementation of Delta Lake - [GreptimeDB](https://github.com/GreptimeTeam/greptimedb) Open Source & Cloud Native Distributed Time Series Database - [GlareDB](https://github.com/GlareDB/glaredb) Fast SQL database for querying and analyzing distributed data. +- [HoraeDB](https://github.com/apache/incubator-horaedb) Distributed Time-Series Database - [InfluxDB IOx](https://github.com/influxdata/influxdb_iox) Time Series Database - [Kamu](https://github.com/kamu-data/kamu-cli/) Planet-scale streaming data pipeline +- [LakeSoul](https://github.com/lakesoul-io/LakeSoul) Open source LakeHouse framework with native IO in Rust. - [Lance](https://github.com/lancedb/lance) Modern columnar data format for ML - [Parseable](https://github.com/parseablehq/parseable) Log storage and observability platform - [qv](https://github.com/timvw/qv) Quickly view your data @@ -127,7 +128,6 @@ Here are some less active projects that used DataFusion: [ballista]: https://github.com/apache/arrow-ballista [blaze]: https://github.com/blaze-init/blaze -[ceresdb]: https://github.com/CeresDB/ceresdb [cloudfuse buzz]: https://github.com/cloudfuse-io/buzz-rust [cnosdb]: https://github.com/cnosdb/cnosdb [cube store]: https://github.com/cube-js/cube.js/tree/master/rust @@ -137,6 +137,7 @@ Here are some less active projects that used DataFusion: [flock]: https://github.com/flock-lab/flock [kamu]: https://github.com/kamu-data/kamu-cli [greptime db]: https://github.com/GreptimeTeam/greptimedb +[horaedb]: https://github.com/apache/incubator-horaedb [influxdb iox]: https://github.com/influxdata/influxdb_iox [parseable]: https://github.com/parseablehq/parseable [prql-query]: https://github.com/prql/prql-query diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index be05084fb2491..629a5f6ecb882 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -635,6 +635,10 @@ nullif(expression1, expression2) - [trim](#trim) - [upper](#upper) - [uuid](#uuid) +- [overlay](#overlay) +- [levenshtein](#levenshtein) +- [substr_index](#substr_index) +- [find_in_set](#find_in_set) ### `ascii` @@ -1120,6 +1124,67 @@ Returns UUID v4 string value which is unique per row. uuid() ``` +### `overlay` + +Returns the string which is replaced by another string from the specified position and specified count length. +For example, `overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas` + +``` +overlay(str PLACING substr FROM pos [FOR count]) +``` + +#### Arguments + +- **str**: String expression to operate on. +- **substr**: the string to replace part of str. +- **pos**: the start position to replace of str. +- **count**: the count of characters to be replaced from start position of str. If not specified, will use substr length instead. + +### `levenshtein` + +Returns the Levenshtein distance between the two given strings. +For example, `levenshtein('kitten', 'sitting') = 3` + +``` +levenshtein(str1, str2) +``` + +#### Arguments + +- **str1**: String expression to compute Levenshtein distance with str2. +- **str2**: String expression to compute Levenshtein distance with str1. + +### `substr_index` + +Returns the substring from str before count occurrences of the delimiter delim. +If count is positive, everything to the left of the final delimiter (counting from the left) is returned. +If count is negative, everything to the right of the final delimiter (counting from the right) is returned. +For example, `substr_index('www.apache.org', '.', 1) = www`, `substr_index('www.apache.org', '.', -1) = org` + +``` +substr_index(str, delim, count) +``` + +#### Arguments + +- **str**: String expression to operate on. +- **delim**: the string to find in str to split str. +- **count**: The number of times to search for the delimiter. Can be both a positive or negative number. + +### `find_in_set` + +Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings. +For example, `find_in_set('b', 'a,b,c,d') = 2` + +``` +find_in_set(str, strlist) +``` + +#### Arguments + +- **str**: String expression to find in strlist. +- **strlist**: A string list is a string composed of substrings separated by , characters. + ## Binary String Functions - [decode](#decode) @@ -1215,6 +1280,7 @@ regexp_replace(str, regexp, replacement, flags) - [datepart](#datepart) - [extract](#extract) - [to_timestamp](#to_timestamp) +- [today](#today) - [to_timestamp_millis](#to_timestamp_millis) - [to_timestamp_micros](#to_timestamp_micros) - [to_timestamp_seconds](#to_timestamp_seconds) @@ -1243,6 +1309,14 @@ no matter when in the query plan the function executes. current_date() ``` +#### Aliases + +- today + +### `today` + +_Alias of [current_date](#current_date)._ + ### `current_time` Returns the current UTC time. @@ -1336,6 +1410,7 @@ date_part(part, expression) The following date parts are supported: - year + - quarter _(emits value in inclusive range [1, 4] based on which quartile of the year the date is in)_ - month - week _(week of the year)_ - day _(day of the month)_ @@ -1347,6 +1422,7 @@ date_part(part, expression) - nanosecond - dow _(day of the week)_ - doy _(day of the year)_ + - epoch _(seconds since Unix epoch)_ - **expression**: Time expression to operate on. Can be a constant, column, or function. @@ -1374,6 +1450,7 @@ extract(field FROM source) The following date fields are supported: - year + - quarter _(emits value in inclusive range [1, 4] based on which quartile of the year the date is in)_ - month - week _(week of the year)_ - day _(day of the month)_ @@ -1385,6 +1462,7 @@ extract(field FROM source) - nanosecond - dow _(day of the week)_ - doy _(day of the year)_ + - epoch _(seconds since Unix epoch)_ - **source**: Source time expression to operate on. Can be a constant, column, or function. @@ -1392,11 +1470,14 @@ extract(field FROM source) ### `to_timestamp` Converts a value to a timestamp (`YYYY-MM-DDT00:00:00Z`). -Supports strings, integer, and unsigned integer types as input. +Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') -Integers and unsigned integers are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`) +Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`) return the corresponding timestamp. +Note: `to_timestamp` returns `Timestamp(Nanosecond)`. The supported range for integer input is between `-9223372037` and `9223372036`. +Supported range for string input is between `1677-09-21T00:12:44.0` and `2262-04-11T23:47:16.0`. Please use `to_timestamp_seconds` for the input outside of supported bounds. + ``` to_timestamp(expression) ``` @@ -1487,6 +1568,7 @@ from_unixtime(expression) ## Array Functions - [array_append](#array_append) +- [array_sort](#array_sort) - [array_cat](#array_cat) - [array_concat](#array_concat) - [array_contains](#array_contains) @@ -1498,6 +1580,7 @@ from_unixtime(expression) - [array_length](#array_length) - [array_ndims](#array_ndims) - [array_prepend](#array_prepend) +- [array_pop_front](#array_pop_front) - [array_pop_back](#array_pop_back) - [array_position](#array_position) - [array_positions](#array_positions) @@ -1515,6 +1598,7 @@ from_unixtime(expression) - [cardinality](#cardinality) - [empty](#empty) - [list_append](#list_append) +- [list_sort](#list_sort) - [list_cat](#list_cat) - [list_concat](#list_concat) - [list_dims](#list_dims) @@ -1543,6 +1627,7 @@ from_unixtime(expression) - [string_to_array](#string_to_array) - [string_to_list](#string_to_list) - [trim_array](#trim_array) +- [range](#range) ### `array_append` @@ -1575,6 +1660,36 @@ array_append(array, element) - list_append - list_push_back +### `array_sort` + +Sort array. + +``` +array_sort(array, desc, nulls_first) +``` + +#### Arguments + +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **desc**: Whether to sort in descending order(`ASC` or `DESC`). +- **nulls_first**: Whether to sort nulls first(`NULLS FIRST` or `NULLS LAST`). + +#### Example + +``` +❯ select array_sort([3, 1, 2]); ++-----------------------------+ +| array_sort(List([3,1,2])) | ++-----------------------------+ +| [1, 2, 3] | ++-----------------------------+ +``` + +#### Aliases + +- list_sort + ### `array_cat` _Alias of [array_concat](#array_concat)._ @@ -1850,6 +1965,30 @@ array_prepend(element, array) - list_prepend - list_push_front +### `array_pop_front` + +Returns the array without the first element. + +``` +array_pop_first(array) +``` + +#### Arguments + +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. + +#### Example + +``` +❯ select array_pop_first([1, 2, 3]); ++-------------------------------+ +| array_pop_first(List([1,2,3])) | ++-------------------------------+ +| [2, 3] | ++-------------------------------+ +``` + ### `array_pop_back` Returns the array without the last element. @@ -2211,6 +2350,82 @@ array_to_string(array, delimiter) - list_join - list_to_string +### `array_union` + +Returns an array of elements that are present in both arrays (all elements from both arrays) with out duplicates. + +``` +array_union(array1, array2) +``` + +#### Arguments + +- **array1**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **array2**: Array expression. + Can be a constant, column, or function, and any combination of array operators. + +#### Example + +``` +❯ select array_union([1, 2, 3, 4], [5, 6, 3, 4]); ++----------------------------------------------------+ +| array_union([1, 2, 3, 4], [5, 6, 3, 4]); | ++----------------------------------------------------+ +| [1, 2, 3, 4, 5, 6] | ++----------------------------------------------------+ +❯ select array_union([1, 2, 3, 4], [5, 6, 7, 8]); ++----------------------------------------------------+ +| array_union([1, 2, 3, 4], [5, 6, 7, 8]); | ++----------------------------------------------------+ +| [1, 2, 3, 4, 5, 6, 7, 8] | ++----------------------------------------------------+ +``` + +--- + +#### Aliases + +- list_union + +### `array_except` + +Returns an array of the elements that appear in the first array but not in the second. + +``` +array_except(array1, array2) +``` + +#### Arguments + +- **array1**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **array2**: Array expression. + Can be a constant, column, or function, and any combination of array operators. + +#### Example + +``` +❯ select array_except([1, 2, 3, 4], [5, 6, 3, 4]); ++----------------------------------------------------+ +| array_except([1, 2, 3, 4], [5, 6, 3, 4]); | ++----------------------------------------------------+ +| [1, 2] | ++----------------------------------------------------+ +❯ select array_except([1, 2, 3, 4], [3, 4, 5, 6]); ++----------------------------------------------------+ +| array_except([1, 2, 3, 4], [3, 4, 5, 6]); | ++----------------------------------------------------+ +| [1, 2] | ++----------------------------------------------------+ +``` + +--- + +#### Aliases + +- list_except + ### `cardinality` Returns the total number of elements in the array. @@ -2263,6 +2478,10 @@ empty(array) _Alias of [array_append](#array_append)._ +### `list_sort` + +_Alias of [array_sort](#array_sort)._ + ### `list_cat` _Alias of [array_concat](#array_concat)._ @@ -2426,6 +2645,20 @@ trim_array(array, n) Can be a constant, column, or function, and any combination of array operators. - **n**: Element to trim the array. +### `range` + +Returns an Arrow array between start and stop with step. `SELECT range(2, 10, 3) -> [2, 5, 8]` + +The range start..end contains all values with start <= x < end. It is empty if start >= end. + +Step can not be 0 (then the range will be nonsense.). + +#### Arguments + +- **start**: start of the range +- **end**: end of the range (not included) +- **step**: increase by step (can not be 0) + ## Struct Functions - [struct](#struct) diff --git a/docs/source/user-guide/sql/write_options.md b/docs/source/user-guide/sql/write_options.md index 941484e84efd0..470591afafff2 100644 --- a/docs/source/user-guide/sql/write_options.md +++ b/docs/source/user-guide/sql/write_options.md @@ -42,12 +42,11 @@ WITH HEADER ROW DELIMITER ';' LOCATION '/test/location/my_csv_table/' OPTIONS( -CREATE_LOCAL_PATH 'true', NULL_VALUE 'NAN' ); ``` -When running `INSERT INTO my_table ...`, the options from the `CREATE TABLE` will be respected (gzip compression, special delimiter, and header row included). Note that compression, header, and delimiter settings can also be specified within the `OPTIONS` tuple list. Dedicated syntax within the SQL statement always takes precedence over arbitrary option tuples, so if both are specified the `OPTIONS` setting will be ignored. CREATE_LOCAL_PATH is a special option that indicates if DataFusion should create local file paths when writing new files if they do not already exist. This option is useful if you wish to create an external table from scratch, using only DataFusion SQL statements. Finally, NULL_VALUE is a CSV format specific option that determines how null values should be encoded within the CSV file. +When running `INSERT INTO my_table ...`, the options from the `CREATE TABLE` will be respected (gzip compression, special delimiter, and header row included). Note that compression, header, and delimiter settings can also be specified within the `OPTIONS` tuple list. Dedicated syntax within the SQL statement always takes precedence over arbitrary option tuples, so if both are specified the `OPTIONS` setting will be ignored. NULL_VALUE is a CSV format specific option that determines how null values should be encoded within the CSV file. Finally, options can be passed when running a `COPY` command. @@ -70,19 +69,9 @@ In this example, we write the entirety of `source_table` out to a folder of parq The following special options are specific to the `COPY` command. | Option | Description | Default Value | -| ------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------- | +| ------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------- | --- | | SINGLE_FILE_OUTPUT | If true, COPY query will write output to a single file. Otherwise, multiple files will be written to a directory in parallel. | true | -| FORMAT | Specifies the file format COPY query will write out. If single_file_output is false or the format cannot be inferred from the file extension, then FORMAT must be specified. | N/A | - -### CREATE EXTERNAL TABLE Specific Options - -The following special options are specific to creating an external table. - -| Option | Description | Default Value | -| ----------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------- | -| SINGLE_FILE | If true, indicates that this external table is backed by a single file. INSERT INTO queries will append to this file. | false | -| CREATE_LOCAL_PATH | If true, the folder or file backing this table will be created on the local file system if it does not already exist when running INSERT INTO queries. | false | -| INSERT_MODE | Determines if INSERT INTO queries should append to existing files or append new files to an existing directory. Valid values are append_to_file, append_new_files, and error. Note that "error" will block inserting data into this table. | CSV and JSON default to append_to_file. Parquet defaults to append_new_files | +| FORMAT | Specifies the file format COPY query will write out. If single_file_output is false or the format cannot be inferred from the file extension, then FORMAT must be specified. | N/A | | ### JSON Format Specific Options