diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 42134c3edfd2..7c2b7e3a5458 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -21,13 +21,29 @@ updates: directory: "/" schedule: interval: daily - open-pull-requests-limit: 10 target-branch: main labels: [auto-dependencies] ignore: - # arrow is bumped manually + # major version bumps of arrow* and parquet are handled manually - dependency-name: "arrow*" update-types: ["version-update:semver-major"] + - dependency-name: "parquet" + update-types: ["version-update:semver-major"] + groups: + # minor and patch bumps of arrow* and parquet are grouped + arrow-parquet: + applies-to: version-updates + patterns: + - "arrow*" + - "parquet" + update-types: + - "minor" + - "patch" + proto: + applies-to: version-updates + patterns: + - "prost*" + - "pbjson*" - package-ecosystem: "github-actions" directory: "/" schedule: diff --git a/.github/workflows/extended.yml b/.github/workflows/extended.yml index 7fa89ea773f8..19910957a85b 100644 --- a/.github/workflows/extended.yml +++ b/.github/workflows/extended.yml @@ -48,28 +48,34 @@ jobs: with: rust-version: stable - name: Prepare cargo build - run: cargo check --profile ci --all-targets + run: | + cargo check --profile ci --all-targets + cargo clean - # Run extended tests (with feature 'extended_tests') - linux-test-extended: - name: cargo test 'extended_tests' (amd64) - needs: linux-build-lib - runs-on: ubuntu-latest - container: - image: amd64/rust - steps: - - uses: actions/checkout@v4 - with: - submodules: true - fetch-depth: 1 - - name: Setup Rust toolchain - uses: ./.github/actions/setup-builder - with: - rust-version: stable - - name: Run tests (excluding doctests) - run: cargo test --profile ci --exclude datafusion-examples --exclude datafusion-benchmarks --workspace --lib --tests --bins --features avro,json,backtrace,extended_tests - - name: Verify Working Directory Clean - run: git diff --exit-code +# # Run extended tests (with feature 'extended_tests') +# # Disabling as it is running out of disk space +# # see https://github.com/apache/datafusion/issues/14576 +# linux-test-extended: +# name: cargo test 'extended_tests' (amd64) +# needs: linux-build-lib +# runs-on: ubuntu-latest +# container: +# image: amd64/rust +# steps: +# - uses: actions/checkout@v4 +# with: +# submodules: true +# fetch-depth: 1 +# - name: Setup Rust toolchain +# uses: ./.github/actions/setup-builder +# with: +# rust-version: stable +# - name: Run tests (excluding doctests) +# run: cargo test --profile ci --exclude datafusion-examples --exclude datafusion-benchmarks --workspace --lib --tests --bins --features avro,json,backtrace,extended_tests +# - name: Verify Working Directory Clean +# run: git diff --exit-code +# - name: Cleanup +# run: cargo clean # Check answers are correct when hash values collide hash-collisions: @@ -90,6 +96,7 @@ jobs: run: | cd datafusion cargo test --profile ci --exclude datafusion-examples --exclude datafusion-benchmarks --exclude datafusion-sqllogictest --workspace --lib --tests --features=force_hash_collisions,avro,extended_tests + cargo clean sqllogictest-sqlite: name: "Run sqllogictests with the sqlite test suite" @@ -106,4 +113,8 @@ jobs: with: rust-version: stable - name: Run sqllogictest - run: cargo test --profile release-nonlto --test sqllogictests -- --include-sqlite + run: | + cargo test --profile release-nonlto --test sqllogictests -- --include-sqlite + cargo clean + + diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index c023faa9b168..a743d0e8fd07 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -41,7 +41,7 @@ on: jobs: # Check license header license-header-check: - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest name: Check License Header steps: - uses: actions/checkout@v4 diff --git a/Cargo.lock b/Cargo.lock index cb77384cb371..b7794d731b75 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -223,8 +223,8 @@ dependencies = [ "serde_bytes", "serde_json", "snap", - "strum", - "strum_macros", + "strum 0.26.3", + "strum_macros 0.26.4", "thiserror 1.0.69", "typed-builder", "uuid", @@ -246,9 +246,9 @@ checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "arrow" -version = "54.1.0" +version = "54.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6422e12ac345a0678d7a17e316238e3a40547ae7f92052b77bd86d5e0239f3fc" +checksum = "755b6da235ac356a869393c23668c663720b8749dd6f15e52b6c214b4b964cc7" dependencies = [ "arrow-arith", "arrow-array", @@ -270,9 +270,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "54.1.0" +version = "54.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23cf34bb1f48c41d3475927bcc7be498665b8e80b379b88f62a840337f8b8248" +checksum = "64656a1e0b13ca766f8440752e9a93e11014eec7b67909986f83ed0ab1fe37b8" dependencies = [ "arrow-array", "arrow-buffer", @@ -284,9 +284,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "54.1.0" +version = "54.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb4a06d507f54b70a277be22a127c8ffe0cec6cd98c0ad8a48e77779bbda8223" +checksum = "57a4a6d2896083cfbdf84a71a863b22460d0708f8206a8373c52e326cc72ea1a" dependencies = [ "ahash 0.8.11", "arrow-buffer", @@ -301,9 +301,9 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "54.1.0" +version = "54.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d69d326d5ad1cb82dcefa9ede3fee8fdca98f9982756b16f9cb142f4aa6edc89" +checksum = "cef870583ce5e4f3b123c181706f2002fb134960f9a911900f64ba4830c7a43a" dependencies = [ "bytes", "half", @@ -312,9 +312,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "54.1.0" +version = "54.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "626e65bd42636a84a238bed49d09c8777e3d825bf81f5087a70111c2831d9870" +checksum = "1ac7eba5a987f8b4a7d9629206ba48e19a1991762795bbe5d08497b7736017ee" dependencies = [ "arrow-array", "arrow-buffer", @@ -333,9 +333,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "54.1.0" +version = "54.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71c8f959f7a1389b1dbd883cdcd37c3ed12475329c111912f7f69dad8195d8c6" +checksum = "90f12542b8164398fc9ec595ff783c4cf6044daa89622c5a7201be920e4c0d4c" dependencies = [ "arrow-array", "arrow-cast", @@ -349,9 +349,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "54.1.0" +version = "54.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1858e7c7d01c44cf71c21a85534fd1a54501e8d60d1195d0d6fbcc00f4b10754" +checksum = "b095e8a4f3c309544935d53e04c3bfe4eea4e71c3de6fe0416d1f08bb4441a83" dependencies = [ "arrow-buffer", "arrow-schema", @@ -361,9 +361,9 @@ dependencies = [ [[package]] name = "arrow-flight" -version = "54.1.0" +version = "54.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9b3aaba47ed4b6146563c8b79ad0f7aa283f794cde0c057c656291b81196746" +checksum = "cf7806ee3d229ee866013e83446e937ab3c8a9e6a664b259d41dd960b309c5d0" dependencies = [ "arrow-arith", "arrow-array", @@ -388,9 +388,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "54.1.0" +version = "54.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6bb3f727f049884c7603f0364bc9315363f356b59e9f605ea76541847e06a1e" +checksum = "65c63da4afedde2b25ef69825cd4663ca76f78f79ffe2d057695742099130ff6" dependencies = [ "arrow-array", "arrow-buffer", @@ -402,9 +402,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "54.1.0" +version = "54.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35de94f165ed8830aede72c35f238763794f0d49c69d30c44d49c9834267ff8c" +checksum = "9551d9400532f23a370cabbea1dc5a53c49230397d41f96c4c8eedf306199305" dependencies = [ "arrow-array", "arrow-buffer", @@ -422,9 +422,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "54.1.0" +version = "54.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8aa06e5f267dc53efbacb933485c79b6fc1685d3ffbe870a16ce4e696fb429da" +checksum = "6c07223476f8219d1ace8cd8d85fa18c4ebd8d945013f25ef5c72e85085ca4ee" dependencies = [ "arrow-array", "arrow-buffer", @@ -435,9 +435,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "54.1.0" +version = "54.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66f1144bb456a2f9d82677bd3abcea019217e572fc8f07de5a7bac4b2c56eb2c" +checksum = "91b194b38bfd89feabc23e798238989c6648b2506ad639be42ec8eb1658d82c4" dependencies = [ "arrow-array", "arrow-buffer", @@ -448,19 +448,18 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "54.1.0" +version = "54.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "105f01ec0090259e9a33a9263ec18ff223ab91a0ea9fbc18042f7e38005142f6" +checksum = "0f40f6be8f78af1ab610db7d9b236e21d587b7168e368a36275d2e5670096735" dependencies = [ "bitflags 2.8.0", - "serde", ] [[package]] name = "arrow-select" -version = "54.1.0" +version = "54.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f690752fdbd2dee278b5f1636fefad8f2f7134c85e20fd59c4199e15a39a6807" +checksum = "ac265273864a820c4a179fc67182ccc41ea9151b97024e1be956f0f2369c2539" dependencies = [ "ahash 0.8.11", "arrow-array", @@ -472,9 +471,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "54.1.0" +version = "54.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0fff9cd745a7039b66c47ecaf5954460f9fa12eed628f65170117ea93e64ee0" +checksum = "d44c8eed43be4ead49128370f7131f054839d3d6003e52aebf64322470b8fbd0" dependencies = [ "arrow-array", "arrow-buffer", @@ -1036,15 +1035,16 @@ dependencies = [ [[package]] name = "blake3" -version = "1.5.5" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8ee0c1824c4dea5b5f81736aff91bae041d2c07ee1192bec91054e10e3e601e" +checksum = "1230237285e3e10cde447185e8975408ae24deaa67205ce684805c25bc0c7937" dependencies = [ "arrayref", "arrayvec", "cc", "cfg-if", "constant_time_eq", + "memmap2", ] [[package]] @@ -1223,19 +1223,18 @@ dependencies = [ [[package]] name = "bzip2" -version = "0.5.0" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bafdbf26611df8c14810e268ddceda071c297570a5fb360ceddf617fe417ef58" +checksum = "75b89e7c29231c673a61a46e722602bcd138298f6b9e81e71119693534585f5c" dependencies = [ "bzip2-sys", - "libc", ] [[package]] name = "bzip2-sys" -version = "0.1.11+1.0.8" +version = "0.1.12+1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc" +checksum = "72ebc2f1a417f01e1da30ef264ee86ae31d2dcd2d603ea283d3c244a883ca2a9" dependencies = [ "cc", "libc", @@ -1347,9 +1346,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.28" +version = "4.5.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e77c3243bd94243c03672cb5154667347c457ca271254724f9f393aee1c05ff" +checksum = "92b7b18d71fad5313a1e320fa9897994228ce274b60faa4d694fe0ea89cd9e6d" dependencies = [ "clap_builder", "clap_derive", @@ -1357,9 +1356,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.27" +version = "4.5.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b26884eb4b57140e4d2d93652abfa49498b938b3c9179f9fc487b0acc3edad7" +checksum = "a35db2071778a7344791a4fb4f95308b5673d219dee3ae348b86642574ecc90c" dependencies = [ "anstream", "anstyle", @@ -1396,9 +1395,9 @@ dependencies = [ [[package]] name = "cmake" -version = "0.1.53" +version = "0.1.54" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e24a03c8b52922d68a1589ad61032f2c1aa5a8158d2aa0d93c6e9534944bbad6" +checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" dependencies = [ "cc", ] @@ -1551,7 +1550,7 @@ dependencies = [ "anes", "cast", "ciborium", - "clap 4.5.28", + "clap 4.5.30", "criterion-plot", "futures", "is-terminal", @@ -1724,25 +1723,29 @@ dependencies = [ "arrow", "arrow-ipc", "arrow-schema", - "async-compression", "async-trait", "bytes", - "bzip2 0.5.0", + "bzip2 0.5.1", "chrono", "criterion", "ctor", + "dashmap", "datafusion-catalog", "datafusion-catalog-listing", "datafusion-common", "datafusion-common-runtime", + "datafusion-datasource", + "datafusion-doc", "datafusion-execution", "datafusion-expr", + "datafusion-expr-common", "datafusion-functions", "datafusion-functions-aggregate", "datafusion-functions-nested", "datafusion-functions-table", "datafusion-functions-window", "datafusion-functions-window-common", + "datafusion-macros", "datafusion-optimizer", "datafusion-physical-expr", "datafusion-physical-expr-common", @@ -1753,7 +1756,6 @@ dependencies = [ "env_logger", "flate2", "futures", - "glob", "itertools 0.14.0", "log", "nix", @@ -1766,13 +1768,13 @@ dependencies = [ "rand_distr", "regex", "rstest", + "serde", "serde_json", "sqlparser", "sysinfo", "tempfile", "test-utils", "tokio", - "tokio-util", "url", "uuid", "xz2", @@ -1816,7 +1818,6 @@ dependencies = [ "itertools 0.14.0", "log", "parking_lot", - "sqlparser", "tokio", ] @@ -1825,25 +1826,20 @@ name = "datafusion-catalog-listing" version = "45.0.0" dependencies = [ "arrow", - "arrow-schema", - "async-compression", "async-trait", - "chrono", "datafusion-catalog", "datafusion-common", + "datafusion-datasource", "datafusion-execution", "datafusion-expr", "datafusion-physical-expr", "datafusion-physical-expr-common", "datafusion-physical-plan", "futures", - "glob", - "itertools 0.14.0", "log", "object_store", "tempfile", "tokio", - "url", ] [[package]] @@ -1855,7 +1851,7 @@ dependencies = [ "async-trait", "aws-config", "aws-credential-types", - "clap 4.5.28", + "clap 4.5.30", "ctor", "datafusion", "dirs", @@ -1881,7 +1877,6 @@ dependencies = [ "apache-avro", "arrow", "arrow-ipc", - "arrow-schema", "base64 0.22.1", "chrono", "half", @@ -1908,6 +1903,37 @@ dependencies = [ "tokio", ] +[[package]] +name = "datafusion-datasource" +version = "45.0.0" +dependencies = [ + "arrow", + "async-compression", + "async-trait", + "bytes", + "bzip2 0.5.1", + "chrono", + "datafusion-catalog", + "datafusion-common", + "datafusion-common-runtime", + "datafusion-execution", + "datafusion-expr", + "datafusion-physical-plan", + "flate2", + "futures", + "glob", + "itertools 0.14.0", + "log", + "object_store", + "rand 0.8.5", + "tempfile", + "tokio", + "tokio-util", + "url", + "xz2", + "zstd", +] + [[package]] name = "datafusion-doc" version = "45.0.0" @@ -1984,6 +2010,7 @@ version = "45.0.0" dependencies = [ "arrow", "datafusion-common", + "indexmap 2.7.1", "itertools 0.14.0", "paste", ] @@ -1994,7 +2021,6 @@ version = "45.0.0" dependencies = [ "abi_stable", "arrow", - "arrow-schema", "async-ffi", "async-trait", "datafusion", @@ -2024,7 +2050,6 @@ dependencies = [ "datafusion-expr", "datafusion-expr-common", "datafusion-macros", - "hashbrown 0.14.5", "hex", "itertools 0.14.0", "log", @@ -2043,7 +2068,6 @@ version = "45.0.0" dependencies = [ "ahash 0.8.11", "arrow", - "arrow-schema", "criterion", "datafusion-common", "datafusion-doc", @@ -2078,7 +2102,6 @@ version = "45.0.0" dependencies = [ "arrow", "arrow-ord", - "arrow-schema", "criterion", "datafusion-common", "datafusion-doc", @@ -2171,7 +2194,6 @@ version = "45.0.0" dependencies = [ "ahash 0.8.11", "arrow", - "arrow-schema", "criterion", "datafusion-common", "datafusion-expr", @@ -2185,7 +2207,7 @@ dependencies = [ "itertools 0.14.0", "log", "paste", - "petgraph 0.7.1", + "petgraph", "rand 0.8.5", "rstest", ] @@ -2207,7 +2229,6 @@ name = "datafusion-physical-optimizer" version = "45.0.0" dependencies = [ "arrow", - "arrow-schema", "datafusion-common", "datafusion-execution", "datafusion-expr", @@ -2216,13 +2237,9 @@ dependencies = [ "datafusion-physical-expr", "datafusion-physical-expr-common", "datafusion-physical-plan", - "futures", "itertools 0.14.0", "log", "recursive", - "rstest", - "tokio", - "url", ] [[package]] @@ -2278,7 +2295,7 @@ dependencies = [ "prost", "serde", "serde_json", - "strum", + "strum 0.27.1", "tokio", ] @@ -2300,7 +2317,6 @@ name = "datafusion-sql" version = "45.0.0" dependencies = [ "arrow", - "arrow-schema", "bigdecimal", "ctor", "datafusion-common", @@ -2328,11 +2344,8 @@ dependencies = [ "bigdecimal", "bytes", "chrono", - "clap 4.5.28", + "clap 4.5.30", "datafusion", - "datafusion-catalog", - "datafusion-common", - "datafusion-common-runtime", "env_logger", "futures", "half", @@ -2379,23 +2392,14 @@ dependencies = [ "chrono", "console_error_panic_hook", "datafusion", - "datafusion-catalog", "datafusion-common", - "datafusion-common-runtime", "datafusion-execution", "datafusion-expr", - "datafusion-expr-common", - "datafusion-functions", - "datafusion-functions-aggregate", - "datafusion-functions-aggregate-common", - "datafusion-functions-table", "datafusion-optimizer", "datafusion-physical-expr", - "datafusion-physical-expr-common", "datafusion-physical-plan", "datafusion-sql", "getrandom 0.2.15", - "parquet", "tokio", "wasm-bindgen", "wasm-bindgen-futures", @@ -2625,7 +2629,6 @@ version = "0.1.0" dependencies = [ "abi_stable", "arrow", - "arrow-schema", "datafusion", "datafusion-ffi", "ffi_module_interface", @@ -2662,12 +2665,6 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "fixedbitset" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" - [[package]] name = "fixedbitset" version = "0.5.7" @@ -3680,7 +3677,7 @@ checksum = "5297962ef19edda4ce33aaa484386e0a5b3d7f2f4e037cbeee00503ef6b29d33" dependencies = [ "anstream", "anstyle", - "clap 4.5.28", + "clap 4.5.30", "escape8259", ] @@ -3754,6 +3751,15 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +[[package]] +name = "memmap2" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd3f7eed9d3848f8b98834af67102b720745c4ec028fcd0aa0239277e7de374f" +dependencies = [ + "libc", +] + [[package]] name = "memoffset" version = "0.9.1" @@ -4047,9 +4053,9 @@ dependencies = [ [[package]] name = "parquet" -version = "54.1.0" +version = "54.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a01a0efa30bbd601ae85b375c728efdb211ade54390281628a7b16708beb235" +checksum = "761c44d824fe83106e0600d2510c07bf4159a4985bf0569b513ea4288dc1b4fb" dependencies = [ "ahash 0.8.11", "arrow-array", @@ -4165,23 +4171,13 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" -[[package]] -name = "petgraph" -version = "0.6.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" -dependencies = [ - "fixedbitset 0.4.2", - "indexmap 2.7.1", -] - [[package]] name = "petgraph" version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" dependencies = [ - "fixedbitset 0.5.7", + "fixedbitset", "indexmap 2.7.1", ] @@ -4437,9 +4433,9 @@ dependencies = [ [[package]] name = "prost" -version = "0.13.4" +version = "0.13.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c0fef6c4230e4ccf618a35c59d7ede15dea37de8427500f50aff708806e42ec" +checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" dependencies = [ "bytes", "prost-derive", @@ -4447,16 +4443,16 @@ dependencies = [ [[package]] name = "prost-build" -version = "0.13.4" +version = "0.13.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0f3e5beed80eb580c68e2c600937ac2c4eedabdfd5ef1e5b7ea4f3fba84497b" +checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf" dependencies = [ "heck 0.5.0", - "itertools 0.13.0", + "itertools 0.14.0", "log", "multimap", "once_cell", - "petgraph 0.6.5", + "petgraph", "prettyplease", "prost", "prost-types", @@ -4467,12 +4463,12 @@ dependencies = [ [[package]] name = "prost-derive" -version = "0.13.4" +version = "0.13.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "157c5a9d7ea5c2ed2d9fb8f495b64759f7816c7eaea54ba3978f0d63000162e3" +checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" dependencies = [ "anyhow", - "itertools 0.13.0", + "itertools 0.14.0", "proc-macro2", "quote", "syn 2.0.98", @@ -4480,9 +4476,9 @@ dependencies = [ [[package]] name = "prost-types" -version = "0.13.4" +version = "0.13.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc2f1e56baa61e93533aebc21af4d2134b70f66275e0fcdf3cbe43d77ff7e8fc" +checksum = "52c2c1bf36ddb1a1c396b3601a3cec27c2462e45f07c386894ec3ccf5332bd16" dependencies = [ "prost", ] @@ -5367,9 +5363,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.138" +version = "1.0.139" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949" +checksum = "44f86c3acccc9c65b153fe1b85a3be07fe5515274ec9f0653b4a0875731c72a6" dependencies = [ "itoa", "memchr", @@ -5577,9 +5573,9 @@ checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" [[package]] name = "sqllogictest" -version = "0.26.4" +version = "0.27.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3bc65b5525b4674a844eb6e39a5d4ef2385a3a2b96c13ef82bbe73220f24bcad" +checksum = "07a06aea5e52b0a63b9d8328b46ea2740cdab4cac13def8ef4f2e5288610f9ed" dependencies = [ "async-trait", "educe", @@ -5630,9 +5626,9 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] name = "stacker" -version = "0.1.17" +version = "0.1.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "799c883d55abdb5e98af1a7b3f23b9b6de8ecada0ecac058672d7635eb48ca7b" +checksum = "1d08feb8f695b465baed819b03c128dc23f57a694510ab1f06c77f763975685e" dependencies = [ "cc", "cfg-if", @@ -5716,8 +5712,14 @@ name = "strum" version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" + +[[package]] +name = "strum" +version = "0.27.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32" dependencies = [ - "strum_macros", + "strum_macros 0.27.1", ] [[package]] @@ -5733,6 +5735,19 @@ dependencies = [ "syn 2.0.98", ] +[[package]] +name = "strum_macros" +version = "0.27.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c77a8c5abcaf0f9ce05d62342b7d298c346515365c36b673df4ebe3ced01fde8" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.98", +] + [[package]] name = "subst" version = "0.3.7" @@ -5745,9 +5760,9 @@ dependencies = [ [[package]] name = "substrait" -version = "0.53.0" +version = "0.53.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef201e3234acdb66865840012c0b9c3d04269b74416fb6285cd480b01718c2f9" +checksum = "6fac3d70185423235f37b889764e184b81a5af4bb7c95833396ee9bd92577e1b" dependencies = [ "heck 0.5.0", "pbjson", @@ -5845,9 +5860,9 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" [[package]] name = "tempfile" -version = "3.16.0" +version = "3.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38c246215d7d24f48ae091a2902398798e05d978b24315d6efbc00ede9a8bb91" +checksum = "22e5a0acb1f3f55f65cc4a866c361b2fb2a0ff6366785ae6fbb5f85df07ba230" dependencies = [ "cfg-if", "fastrand", @@ -6169,9 +6184,9 @@ checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" [[package]] name = "toml_edit" -version = "0.22.23" +version = "0.22.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02a8b472d1a3d7c18e2d61a489aee3453fd9031c33e4f55bd533f4a7adca1bee" +checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474" dependencies = [ "indexmap 2.7.1", "toml_datetime", @@ -6497,9 +6512,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.13.1" +version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ced87ca4be083373936a67f8de945faa23b6b42384bd5b64434850802c6dccd0" +checksum = "8c1f41ffb7cf259f1ecc2876861a17e7142e63ead296f671f81f6ae85903e0d6" dependencies = [ "getrandom 0.3.1", "js-sys", @@ -6980,9 +6995,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86e376c75f4f43f44db463cf729e0d3acbf954d13e22c51e26e4c264b4ab545f" +checksum = "59690dea168f2198d1a3b0cac23b8063efcd11012f10ae4698f284808c8ef603" dependencies = [ "memchr", ] diff --git a/Cargo.toml b/Cargo.toml index 56bc218f2706..ccf3f02a2fde 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -78,15 +78,15 @@ version = "45.0.0" ahash = { version = "0.8", default-features = false, features = [ "runtime-rng", ] } -arrow = { version = "54.1.0", features = [ +arrow = { version = "54.2.0", features = [ "prettyprint", "chrono-tz", ] } arrow-buffer = { version = "54.1.0", default-features = false } -arrow-flight = { version = "54.1.0", features = [ +arrow-flight = { version = "54.2.0", features = [ "flight-sql-experimental", ] } -arrow-ipc = { version = "54.1.0", default-features = false, features = [ +arrow-ipc = { version = "54.2.0", default-features = false, features = [ "lz4", ] } arrow-ord = { version = "54.1.0", default-features = false } @@ -102,6 +102,7 @@ datafusion-catalog = { path = "datafusion/catalog", version = "45.0.0" } datafusion-catalog-listing = { path = "datafusion/catalog-listing", version = "45.0.0" } datafusion-common = { path = "datafusion/common", version = "45.0.0", default-features = false } datafusion-common-runtime = { path = "datafusion/common-runtime", version = "45.0.0" } +datafusion-datasource = { path = "datafusion/datasource", version = "45.0.0", default-features = false } datafusion-doc = { path = "datafusion/doc", version = "45.0.0" } datafusion-execution = { path = "datafusion/execution", version = "45.0.0" } datafusion-expr = { path = "datafusion/expr", version = "45.0.0" } @@ -133,7 +134,7 @@ itertools = "0.14" log = "^0.4" object_store = { version = "0.11.0", default-features = false } parking_lot = "0.12" -parquet = { version = "54.1.0", default-features = false, features = [ +parquet = { version = "54.2.0", default-features = false, features = [ "arrow", "async", "object_store", diff --git a/README.md b/README.md index 2c2febab09cc..158033d40599 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ ![Commit Activity][commit-activity-badge] [![Open Issues][open-issues-badge]][open-issues-url] [![Discord chat][discord-badge]][discord-url] +[![Linkedin][linkedin-badge]][linkedin-url] [crates-badge]: https://img.shields.io/crates/v/datafusion.svg [crates-url]: https://crates.io/crates/datafusion @@ -32,11 +33,13 @@ [license-url]: https://github.com/apache/datafusion/blob/main/LICENSE.txt [actions-badge]: https://github.com/apache/datafusion/actions/workflows/rust.yml/badge.svg [actions-url]: https://github.com/apache/datafusion/actions?query=branch%3Amain -[discord-badge]: https://img.shields.io/discord/885562378132000778.svg?logo=discord&style=flat-square +[discord-badge]: https://img.shields.io/badge/Chat-Discord-purple [discord-url]: https://discord.com/invite/Qw5gKqHxUM [commit-activity-badge]: https://img.shields.io/github/commit-activity/m/apache/datafusion [open-issues-badge]: https://img.shields.io/github/issues-raw/apache/datafusion [open-issues-url]: https://github.com/apache/datafusion/issues +[linkedin-badge]: https://img.shields.io/badge/Follow-Linkedin-blue +[linkedin-url]: https://www.linkedin.com/company/apache-datafusion/ [Website](https://datafusion.apache.org/) | [API Docs](https://docs.rs/datafusion/latest/datafusion/) | diff --git a/benchmarks/lineprotocol.py b/benchmarks/lineprotocol.py new file mode 100644 index 000000000000..75e09b662e3e --- /dev/null +++ b/benchmarks/lineprotocol.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +""" +Converts a given json to LineProtocol format that can be +visualised by grafana/other systems that support LineProtocol. + +Usage example: +$ python3 lineprotocol.py sort.json +benchmark,name=sort,version=28.0.0,datafusion_version=28.0.0,num_cpus=8 query="sort utf8",iteration=0,row_count=10838832,elapsed_ms=85626006 1691105678000000000 +benchmark,name=sort,version=28.0.0,datafusion_version=28.0.0,num_cpus=8 query="sort utf8",iteration=1,row_count=10838832,elapsed_ms=68694468 1691105678000000000 +benchmark,name=sort,version=28.0.0,datafusion_version=28.0.0,num_cpus=8 query="sort utf8",iteration=2,row_count=10838832,elapsed_ms=63392883 1691105678000000000 +benchmark,name=sort,version=28.0.0,datafusion_version=28.0.0,num_cpus=8 query="sort utf8",iteration=3,row_count=10838832,elapsed_ms=66388367 1691105678000000000 +""" + +# sort.json +""" +{ + "queries": [ + { + "iterations": [ + { + "elapsed": 85626.006132, + "row_count": 10838832 + }, + { + "elapsed": 68694.467851, + "row_count": 10838832 + }, + { + "elapsed": 63392.883406, + "row_count": 10838832 + }, + { + "elapsed": 66388.367387, + "row_count": 10838832 + }, + ], + "query": "sort utf8", + "start_time": 1691105678 + }, + ], + "context": { + "arguments": [ + "sort", + "--path", + "benchmarks/data", + "--scale-factor", + "1.0", + "--iterations", + "4", + "-o", + "sort.json" + ], + "benchmark_version": "28.0.0", + "datafusion_version": "28.0.0", + "num_cpus": 8, + "start_time": 1691105678 + } +} +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from typing import Dict, List, Any +from pathlib import Path +from argparse import ArgumentParser +import sys +print = sys.stdout.write + + +@dataclass +class QueryResult: + elapsed: float + row_count: int + + @classmethod + def load_from(cls, data: Dict[str, Any]) -> QueryResult: + return cls(elapsed=data["elapsed"], row_count=data["row_count"]) + + +@dataclass +class QueryRun: + query: int + iterations: List[QueryResult] + start_time: int + + @classmethod + def load_from(cls, data: Dict[str, Any]) -> QueryRun: + return cls( + query=data["query"], + iterations=[QueryResult(**iteration) for iteration in data["iterations"]], + start_time=data["start_time"], + ) + + @property + def execution_time(self) -> float: + assert len(self.iterations) >= 1 + + # Use minimum execution time to account for variations / other + # things the system was doing + return min(iteration.elapsed for iteration in self.iterations) + + +@dataclass +class Context: + benchmark_version: str + datafusion_version: str + num_cpus: int + start_time: int + arguments: List[str] + name: str + + @classmethod + def load_from(cls, data: Dict[str, Any]) -> Context: + return cls( + benchmark_version=data["benchmark_version"], + datafusion_version=data["datafusion_version"], + num_cpus=data["num_cpus"], + start_time=data["start_time"], + arguments=data["arguments"], + name=data["arguments"][0] + ) + + +@dataclass +class BenchmarkRun: + context: Context + queries: List[QueryRun] + + @classmethod + def load_from(cls, data: Dict[str, Any]) -> BenchmarkRun: + return cls( + context=Context.load_from(data["context"]), + queries=[QueryRun.load_from(result) for result in data["queries"]], + ) + + @classmethod + def load_from_file(cls, path: Path) -> BenchmarkRun: + with open(path, "r") as f: + return cls.load_from(json.load(f)) + + +def lineformat( + baseline: Path, +) -> None: + baseline = BenchmarkRun.load_from_file(baseline) + context = baseline.context + benchamrk_str = f"benchmark,name={context.name},version={context.benchmark_version},datafusion_version={context.datafusion_version},num_cpus={context.num_cpus}" + for query in baseline.queries: + query_str = f"query=\"{query.query}\"" + timestamp = f"{query.start_time*10**9}" + for iter_num, result in enumerate(query.iterations): + print(f"{benchamrk_str} {query_str},iteration={iter_num},row_count={result.row_count},elapsed_ms={result.elapsed*1000:.0f} {timestamp}\n") + +def main() -> None: + parser = ArgumentParser() + parser.add_argument( + "path", + type=Path, + help="Path to the benchmark file.", + ) + options = parser.parse_args() + + lineformat(options.baseline_path) + + + +if __name__ == "__main__": + main() diff --git a/benchmarks/src/bin/external_aggr.rs b/benchmarks/src/bin/external_aggr.rs index a2fb75dd1941..578f71f8275d 100644 --- a/benchmarks/src/bin/external_aggr.rs +++ b/benchmarks/src/bin/external_aggr.rs @@ -17,6 +17,8 @@ //! external_aggr binary entrypoint +use datafusion::execution::memory_pool::GreedyMemoryPool; +use datafusion::execution::memory_pool::MemoryPool; use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; @@ -41,7 +43,7 @@ use datafusion::prelude::*; use datafusion_benchmarks::util::{BenchmarkRun, CommonOpt}; use datafusion_common::instant::Instant; use datafusion_common::utils::get_available_parallelism; -use datafusion_common::{exec_datafusion_err, exec_err, DEFAULT_PARQUET_EXTENSION}; +use datafusion_common::{exec_err, DEFAULT_PARQUET_EXTENSION}; #[derive(Debug, StructOpt)] #[structopt( @@ -58,10 +60,6 @@ struct ExternalAggrConfig { #[structopt(short, long)] query: Option, - /// Memory limit (e.g. '100M', '1.5G'). If not specified, run all pre-defined memory limits for given query. - #[structopt(long)] - memory_limit: Option, - /// Common options #[structopt(flatten)] common: CommonOpt, @@ -129,10 +127,8 @@ impl ExternalAggrConfig { pub async fn run(&self) -> Result<()> { let mut benchmark_run = BenchmarkRun::new(); - let memory_limit = match &self.memory_limit { - Some(limit) => Some(Self::parse_memory_limit(limit)?), - None => None, - }; + let memory_limit = self.common.memory_limit.map(|limit| limit as u64); + let mem_pool_type = self.common.mem_pool_type.as_str(); let query_range = match self.query { Some(query_id) => query_id..=query_id, @@ -171,7 +167,9 @@ impl ExternalAggrConfig { human_readable_size(mem_limit as usize) )); - let query_results = self.benchmark_query(query_id, mem_limit).await?; + let query_results = self + .benchmark_query(query_id, mem_limit, mem_pool_type) + .await?; for iter in query_results { benchmark_run.write_iter(iter.elapsed, iter.row_count); } @@ -187,12 +185,20 @@ impl ExternalAggrConfig { &self, query_id: usize, mem_limit: u64, + mem_pool_type: &str, ) -> Result> { let query_name = format!("Q{query_id}({})", human_readable_size(mem_limit as usize)); let config = self.common.config(); + let memory_pool: Arc = match mem_pool_type { + "fair" => Arc::new(FairSpillPool::new(mem_limit as usize)), + "greedy" => Arc::new(GreedyMemoryPool::new(mem_limit as usize)), + _ => { + return exec_err!("Invalid memory pool type: {}", mem_pool_type); + } + }; let runtime_env = RuntimeEnvBuilder::new() - .with_memory_pool(Arc::new(FairSpillPool::new(mem_limit as usize))) + .with_memory_pool(memory_pool) .build_arc()?; let state = SessionStateBuilder::new() .with_config(config) @@ -331,22 +337,6 @@ impl ExternalAggrConfig { .partitions .unwrap_or(get_available_parallelism()) } - - /// Parse memory limit from string to number of bytes - /// e.g. '1.5G', '100M' -> 1572864 - fn parse_memory_limit(limit: &str) -> Result { - let (number, unit) = limit.split_at(limit.len() - 1); - let number: f64 = number.parse().map_err(|_| { - exec_datafusion_err!("Failed to parse number from memory limit '{}'", limit) - })?; - - match unit { - "K" => Ok((number * 1024.0) as u64), - "M" => Ok((number * 1024.0 * 1024.0) as u64), - "G" => Ok((number * 1024.0 * 1024.0 * 1024.0) as u64), - _ => exec_err!("Unsupported unit '{}' in memory limit '{}'", unit, limit), - } - } } #[tokio::main] @@ -359,31 +349,3 @@ pub async fn main() -> Result<()> { Ok(()) } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_parse_memory_limit_all() { - // Test valid inputs - assert_eq!( - ExternalAggrConfig::parse_memory_limit("100K").unwrap(), - 102400 - ); - assert_eq!( - ExternalAggrConfig::parse_memory_limit("1.5M").unwrap(), - 1572864 - ); - assert_eq!( - ExternalAggrConfig::parse_memory_limit("2G").unwrap(), - 2147483648 - ); - - // Test invalid unit - assert!(ExternalAggrConfig::parse_memory_limit("500X").is_err()); - - // Test invalid number - assert!(ExternalAggrConfig::parse_memory_limit("abcM").is_err()); - } -} diff --git a/benchmarks/src/clickbench.rs b/benchmarks/src/clickbench.rs index 6b7c75ed4bab..a9750d9b4b84 100644 --- a/benchmarks/src/clickbench.rs +++ b/benchmarks/src/clickbench.rs @@ -124,7 +124,8 @@ impl RunOpt { parquet_options.binary_as_string = true; } - let ctx = SessionContext::new_with_config(config); + let rt_builder = self.common.runtime_env_builder()?; + let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); self.register_hits(&ctx).await?; let iterations = self.common.iterations; diff --git a/benchmarks/src/h2o.rs b/benchmarks/src/h2o.rs index 53a516ceb56d..eae7f67f1d62 100644 --- a/benchmarks/src/h2o.rs +++ b/benchmarks/src/h2o.rs @@ -68,7 +68,8 @@ impl RunOpt { }; let config = self.common.config(); - let ctx = SessionContext::new_with_config(config); + let rt_builder = self.common.runtime_env_builder()?; + let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); // Register data self.register_data(&ctx).await?; diff --git a/benchmarks/src/imdb/run.rs b/benchmarks/src/imdb/run.rs index 8d2317c62ef1..d7d7a56d0540 100644 --- a/benchmarks/src/imdb/run.rs +++ b/benchmarks/src/imdb/run.rs @@ -306,8 +306,8 @@ impl RunOpt { .config() .with_collect_statistics(!self.disable_statistics); config.options_mut().optimizer.prefer_hash_join = self.prefer_hash_join; - - let ctx = SessionContext::new_with_config(config); + let rt_builder = self.common.runtime_env_builder()?; + let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); // register tables self.register_tables(&ctx).await?; @@ -515,6 +515,9 @@ mod tests { iterations: 1, partitions: Some(2), batch_size: 8192, + mem_pool_type: "fair".to_string(), + memory_limit: None, + sort_spill_reservation_bytes: None, debug: false, }; let opt = RunOpt { @@ -548,6 +551,9 @@ mod tests { iterations: 1, partitions: Some(2), batch_size: 8192, + mem_pool_type: "fair".to_string(), + memory_limit: None, + sort_spill_reservation_bytes: None, debug: false, }; let opt = RunOpt { diff --git a/benchmarks/src/sort_tpch.rs b/benchmarks/src/sort_tpch.rs index 566a5ea62c2d..b1997b40e09e 100644 --- a/benchmarks/src/sort_tpch.rs +++ b/benchmarks/src/sort_tpch.rs @@ -188,8 +188,10 @@ impl RunOpt { /// Benchmark query `query_id` in `SORT_QUERIES` async fn benchmark_query(&self, query_id: usize) -> Result> { let config = self.common.config(); + let rt_builder = self.common.runtime_env_builder()?; let state = SessionStateBuilder::new() .with_config(config) + .with_runtime_env(rt_builder.build_arc()?) .with_default_features() .build(); let ctx = SessionContext::from(state); diff --git a/benchmarks/src/tpch/run.rs b/benchmarks/src/tpch/run.rs index de3ee3d67db2..eb9db821db02 100644 --- a/benchmarks/src/tpch/run.rs +++ b/benchmarks/src/tpch/run.rs @@ -121,7 +121,8 @@ impl RunOpt { .config() .with_collect_statistics(!self.disable_statistics); config.options_mut().optimizer.prefer_hash_join = self.prefer_hash_join; - let ctx = SessionContext::new_with_config(config); + let rt_builder = self.common.runtime_env_builder()?; + let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); // register tables self.register_tables(&ctx).await?; @@ -342,6 +343,9 @@ mod tests { iterations: 1, partitions: Some(2), batch_size: 8192, + mem_pool_type: "fair".to_string(), + memory_limit: None, + sort_spill_reservation_bytes: None, debug: false, }; let opt = RunOpt { @@ -375,6 +379,9 @@ mod tests { iterations: 1, partitions: Some(2), batch_size: 8192, + mem_pool_type: "fair".to_string(), + memory_limit: None, + sort_spill_reservation_bytes: None, debug: false, }; let opt = RunOpt { diff --git a/benchmarks/src/util/options.rs b/benchmarks/src/util/options.rs index b1570a1d1bc1..a1cf31525dd9 100644 --- a/benchmarks/src/util/options.rs +++ b/benchmarks/src/util/options.rs @@ -15,8 +15,17 @@ // specific language governing permissions and limitations // under the License. -use datafusion::prelude::SessionConfig; -use datafusion_common::utils::get_available_parallelism; +use std::{num::NonZeroUsize, sync::Arc}; + +use datafusion::{ + execution::{ + disk_manager::DiskManagerConfig, + memory_pool::{FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool}, + runtime_env::RuntimeEnvBuilder, + }, + prelude::SessionConfig, +}; +use datafusion_common::{utils::get_available_parallelism, DataFusionError, Result}; use structopt::StructOpt; // Common benchmark options (don't use doc comments otherwise this doc @@ -35,6 +44,20 @@ pub struct CommonOpt { #[structopt(short = "s", long = "batch-size", default_value = "8192")] pub batch_size: usize, + /// The memory pool type to use, should be one of "fair" or "greedy" + #[structopt(long = "mem-pool-type", default_value = "fair")] + pub mem_pool_type: String, + + /// Memory limit (e.g. '100M', '1.5G'). If not specified, run all pre-defined memory limits for given query + /// if there's any, otherwise run with no memory limit. + #[structopt(long = "memory-limit", parse(try_from_str = parse_memory_limit))] + pub memory_limit: Option, + + /// The amount of memory to reserve for sort spill operations. DataFusion's default value will be used + /// if not specified. + #[structopt(long = "sort-spill-reservation-bytes", parse(try_from_str = parse_memory_limit))] + pub sort_spill_reservation_bytes: Option, + /// Activate debug mode to see more details #[structopt(short, long)] pub debug: bool, @@ -48,10 +71,81 @@ impl CommonOpt { /// Modify the existing config appropriately pub fn update_config(&self, config: SessionConfig) -> SessionConfig { - config + let mut config = config .with_target_partitions( self.partitions.unwrap_or(get_available_parallelism()), ) - .with_batch_size(self.batch_size) + .with_batch_size(self.batch_size); + if let Some(sort_spill_reservation_bytes) = self.sort_spill_reservation_bytes { + config = + config.with_sort_spill_reservation_bytes(sort_spill_reservation_bytes); + } + config + } + + /// Return an appropriately configured `RuntimeEnvBuilder` + pub fn runtime_env_builder(&self) -> Result { + let mut rt_builder = RuntimeEnvBuilder::new(); + const NUM_TRACKED_CONSUMERS: usize = 5; + if let Some(memory_limit) = self.memory_limit { + let pool: Arc = match self.mem_pool_type.as_str() { + "fair" => Arc::new(TrackConsumersPool::new( + FairSpillPool::new(memory_limit), + NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(), + )), + "greedy" => Arc::new(TrackConsumersPool::new( + GreedyMemoryPool::new(memory_limit), + NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(), + )), + _ => { + return Err(DataFusionError::Configuration(format!( + "Invalid memory pool type: {}", + self.mem_pool_type + ))) + } + }; + rt_builder = rt_builder + .with_memory_pool(pool) + .with_disk_manager(DiskManagerConfig::NewOs); + } + Ok(rt_builder) + } +} + +/// Parse memory limit from string to number of bytes +/// e.g. '1.5G', '100M' -> 1572864 +fn parse_memory_limit(limit: &str) -> Result { + let (number, unit) = limit.split_at(limit.len() - 1); + let number: f64 = number + .parse() + .map_err(|_| format!("Failed to parse number from memory limit '{}'", limit))?; + + match unit { + "K" => Ok((number * 1024.0) as usize), + "M" => Ok((number * 1024.0 * 1024.0) as usize), + "G" => Ok((number * 1024.0 * 1024.0 * 1024.0) as usize), + _ => Err(format!( + "Unsupported unit '{}' in memory limit '{}'", + unit, limit + )), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_memory_limit_all() { + // Test valid inputs + assert_eq!(parse_memory_limit("100K").unwrap(), 102400); + assert_eq!(parse_memory_limit("1.5M").unwrap(), 1572864); + assert_eq!(parse_memory_limit("2G").unwrap(), 2147483648); + + // Test invalid unit + assert!(parse_memory_limit("500X").is_err()); + + // Test invalid number + assert!(parse_memory_limit("abcM").is_err()); } } diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 987ac97452a9..d88f8fccb928 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -32,12 +32,13 @@ arrow = { workspace = true } async-trait = { workspace = true } aws-config = "1.5.16" aws-credential-types = "1.2.0" -clap = { version = "4.5.28", features = ["derive", "cargo"] } +clap = { version = "4.5.30", features = ["derive", "cargo"] } datafusion = { workspace = true, features = [ "avro", "crypto_expressions", "datetime_expressions", "encoding_expressions", + "nested_expressions", "parquet", "recursive_protection", "regex_expressions", diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index ec6e0ab71d50..feafa48b3954 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -74,7 +74,7 @@ test-utils = { path = "../test-utils" } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot"] } tonic = "0.12.1" url = { workspace = true } -uuid = "1.7" +uuid = "1.13" [target.'cfg(not(target_os = "windows"))'.dev-dependencies] nix = { version = "0.29.0", features = ["fs"] } diff --git a/datafusion-examples/examples/advanced_parquet_index.rs b/datafusion-examples/examples/advanced_parquet_index.rs index 43dc592b997e..bb1cf3c8f78d 100644 --- a/datafusion-examples/examples/advanced_parquet_index.rs +++ b/datafusion-examples/examples/advanced_parquet_index.rs @@ -504,7 +504,7 @@ impl TableProvider for IndexTableProvider { .with_file(partitioned_file); // Finally, put it all together into a DataSourceExec - Ok(file_scan_config.new_exec()) + Ok(file_scan_config.build()) } /// Tell DataFusion to push filters down to the scan method diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index fd65c3352bbc..9cda726db719 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -423,11 +423,11 @@ impl AggregateUDFImpl for SimplifiedGeoMeanUdaf { // In real-world scenarios, you might create UDFs from built-in expressions. Ok(Expr::AggregateFunction(AggregateFunction::new_udf( Arc::new(AggregateUDF::from(GeoMeanUdaf::new())), - aggregate_function.args, - aggregate_function.distinct, - aggregate_function.filter, - aggregate_function.order_by, - aggregate_function.null_treatment, + aggregate_function.params.args, + aggregate_function.params.distinct, + aggregate_function.params.filter, + aggregate_function.params.order_by, + aggregate_function.params.null_treatment, ))) }; Some(Box::new(simplify)) diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs index ac326be9cb04..8330e783319d 100644 --- a/datafusion-examples/examples/advanced_udwf.rs +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -26,7 +26,7 @@ use arrow::{ use datafusion::common::ScalarValue; use datafusion::error::Result; use datafusion::functions_aggregate::average::avg_udaf; -use datafusion::logical_expr::expr::WindowFunction; +use datafusion::logical_expr::expr::{WindowFunction, WindowFunctionParams}; use datafusion::logical_expr::function::{ PartitionEvaluatorArgs, WindowFunctionSimplification, WindowUDFFieldArgs, }; @@ -192,11 +192,13 @@ impl WindowUDFImpl for SimplifySmoothItUdf { let simplify = |window_function: WindowFunction, _: &dyn SimplifyInfo| { Ok(Expr::WindowFunction(WindowFunction { fun: WindowFunctionDefinition::AggregateUDF(avg_udaf()), - args: window_function.args, - partition_by: window_function.partition_by, - order_by: window_function.order_by, - window_frame: window_function.window_frame, - null_treatment: window_function.null_treatment, + params: WindowFunctionParams { + args: window_function.params.args, + partition_by: window_function.params.partition_by, + order_by: window_function.params.order_by, + window_frame: window_function.params.window_frame, + null_treatment: window_function.params.null_treatment, + }, })) }; diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index 2908edbb754d..349850df6148 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -22,8 +22,9 @@ use arrow::array::{BooleanArray, Int32Array, Int8Array}; use arrow::record_batch::RecordBatch; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; +use datafusion::common::stats::Precision; use datafusion::common::tree_node::{Transformed, TreeNode}; -use datafusion::common::DFSchema; +use datafusion::common::{ColumnStatistics, DFSchema}; use datafusion::common::{ScalarValue, ToDFSchema}; use datafusion::error::Result; use datafusion::functions_aggregate::first_last::first_value_udaf; @@ -80,6 +81,9 @@ async fn main() -> Result<()> { // See how to analyze ranges in expressions range_analysis_demo()?; + // See how to analyze boundaries in different kinds of expressions. + boundary_analysis_and_selectivity_demo()?; + // See how to determine the data types of expressions expression_type_demo()?; @@ -275,6 +279,74 @@ fn range_analysis_demo() -> Result<()> { Ok(()) } +// DataFusion's analysis can infer boundary statistics and selectivity in +// various situations which can be helpful in building more efficient +// query plans. +fn boundary_analysis_and_selectivity_demo() -> Result<()> { + // Consider the example where we want all rows with an `id` greater than + // 5000. + let id_greater_5000 = col("id").gt_eq(lit(5000i64)); + + // As in most examples we must tell DaataFusion the type of the column. + let schema = Arc::new(Schema::new(vec![make_field("id", DataType::Int64)])); + + // DataFusion is able to do cardinality estimation on various column types + // these estimates represented by the `ColumnStatistics` type describe + // properties such as the maximum and minimum value, the number of distinct + // values and the number of null values. + let column_stats = ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int64(Some(10000))), + min_value: Precision::Exact(ScalarValue::Int64(Some(1))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + }; + + // We can then build our expression boundaries from the column statistics + // allowing the analysis to be more precise. + let initial_boundaries = + vec![ExprBoundaries::try_from_column(&schema, &column_stats, 0)?]; + + // With the above we can perform the boundary analysis similar to the previous + // example. + let df_schema = DFSchema::try_from(schema.clone())?; + + // Analysis case id >= 5000 + let physical_expr1 = + SessionContext::new().create_physical_expr(id_greater_5000, &df_schema)?; + let analysis = analyze( + &physical_expr1, + AnalysisContext::new(initial_boundaries.clone()), + df_schema.as_ref(), + )?; + + // The analysis will return better bounds thanks to the column statistics. + assert_eq!( + analysis.boundaries.first().map(|boundary| boundary + .interval + .clone() + .unwrap() + .into_bounds()), + Some(( + ScalarValue::Int64(Some(5000)), + ScalarValue::Int64(Some(10000)) + )) + ); + + // We can also infer selectivity from the column statistics by assuming + // that the column is uniformly distributed and using the following + // estimation formula: + // Assuming the original range is [a, b] and the new range: [a', b'] + // + // (a' - b' + 1) / (a - b) + // (10000 - 5000 + 1) / (10000 - 1) + assert!(analysis + .selectivity + .is_some_and(|selectivity| (0.5..=0.6).contains(&selectivity))); + + Ok(()) +} + fn make_field(name: &str, data_type: DataType) -> Field { let nullable = false; Field::new(name, data_type, nullable) diff --git a/datafusion-examples/examples/ffi/ffi_example_table_provider/Cargo.toml b/datafusion-examples/examples/ffi/ffi_example_table_provider/Cargo.toml index 2d91ea2329e4..e9c0c5b43d68 100644 --- a/datafusion-examples/examples/ffi/ffi_example_table_provider/Cargo.toml +++ b/datafusion-examples/examples/ffi/ffi_example_table_provider/Cargo.toml @@ -24,7 +24,6 @@ publish = false [dependencies] abi_stable = "0.11.3" arrow = { workspace = true } -arrow-schema = { workspace = true } datafusion = { workspace = true } datafusion-ffi = { workspace = true } ffi_module_interface = { path = "../ffi_module_interface" } diff --git a/datafusion-examples/examples/optimizer_rule.rs b/datafusion-examples/examples/optimizer_rule.rs index 0206c7cd157e..63f17484809e 100644 --- a/datafusion-examples/examples/optimizer_rule.rs +++ b/datafusion-examples/examples/optimizer_rule.rs @@ -20,8 +20,8 @@ use arrow::datatypes::DataType; use datafusion::common::tree_node::{Transformed, TreeNode}; use datafusion::common::{assert_batches_eq, Result, ScalarValue}; use datafusion::logical_expr::{ - BinaryExpr, ColumnarValue, Expr, LogicalPlan, Operator, ScalarUDF, ScalarUDFImpl, - Signature, Volatility, + BinaryExpr, ColumnarValue, Expr, LogicalPlan, Operator, ScalarFunctionArgs, + ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; use datafusion::optimizer::ApplyOrder; use datafusion::optimizer::{OptimizerConfig, OptimizerRule}; @@ -205,11 +205,7 @@ impl ScalarUDFImpl for MyEq { Ok(DataType::Boolean) } - fn invoke_batch( - &self, - _args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { // this example simply returns "true" which is not what a real // implementation would do. Ok(ColumnarValue::Scalar(ScalarValue::from(true))) diff --git a/datafusion-examples/examples/parquet_index.rs b/datafusion-examples/examples/parquet_index.rs index f465699abed2..3851dca2a775 100644 --- a/datafusion-examples/examples/parquet_index.rs +++ b/datafusion-examples/examples/parquet_index.rs @@ -258,7 +258,7 @@ impl TableProvider for IndexTableProvider { file_size, )); } - Ok(file_scan_config.new_exec()) + Ok(file_scan_config.build()) } /// Tell DataFusion to push filters down to the scan method diff --git a/datafusion/catalog-listing/Cargo.toml b/datafusion/catalog-listing/Cargo.toml index 03132e7b7bb5..68d0ca3a149f 100644 --- a/datafusion/catalog-listing/Cargo.toml +++ b/datafusion/catalog-listing/Cargo.toml @@ -29,33 +29,22 @@ version.workspace = true [dependencies] arrow = { workspace = true } -arrow-schema = { workspace = true } -async-compression = { version = "0.4.0", features = [ - "bzip2", - "gzip", - "xz", - "zstd", - "tokio", -], optional = true } -chrono = { workspace = true } +async-trait = { workspace = true } datafusion-catalog = { workspace = true } datafusion-common = { workspace = true, features = ["object_store"] } +datafusion-datasource = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } datafusion-physical-plan = { workspace = true } futures = { workspace = true } -glob = "0.3.0" -itertools = { workspace = true } log = { workspace = true } object_store = { workspace = true } -url = { workspace = true } +tokio = { workspace = true } [dev-dependencies] -async-trait = { workspace = true } tempfile = { workspace = true } -tokio = { workspace = true } [lints] workspace = true diff --git a/datafusion/catalog-listing/src/helpers.rs b/datafusion/catalog-listing/src/helpers.rs index 6cb3f661e652..cf475263535a 100644 --- a/datafusion/catalog-listing/src/helpers.rs +++ b/datafusion/catalog-listing/src/helpers.rs @@ -20,20 +20,19 @@ use std::mem; use std::sync::Arc; -use super::ListingTableUrl; -use super::PartitionedFile; use datafusion_catalog::Session; use datafusion_common::internal_err; use datafusion_common::{HashMap, Result, ScalarValue}; +use datafusion_datasource::ListingTableUrl; +use datafusion_datasource::PartitionedFile; use datafusion_expr::{BinaryExpr, Operator}; use arrow::{ array::{Array, ArrayRef, AsArray, StringBuilder}, compute::{and, cast, prep_null_mask_filter}, - datatypes::{DataType, Field, Schema}, + datatypes::{DataType, Field, Fields, Schema}, record_batch::RecordBatch, }; -use arrow_schema::Fields; use datafusion_expr::execution_props::ExecutionProps; use futures::stream::FuturesUnordered; use futures::{stream::BoxStream, StreamExt, TryStreamExt}; diff --git a/datafusion/catalog-listing/src/mod.rs b/datafusion/catalog-listing/src/mod.rs index 709fa88b5867..b98790e86455 100644 --- a/datafusion/catalog-listing/src/mod.rs +++ b/datafusion/catalog-listing/src/mod.rs @@ -15,264 +15,4 @@ // specific language governing permissions and limitations // under the License. -//! A table that uses the `ObjectStore` listing capability -//! to get the list of files to process. - -pub mod file_groups; pub mod helpers; -pub mod url; -use chrono::TimeZone; -use datafusion_common::Result; -use datafusion_common::{ScalarValue, Statistics}; -use futures::Stream; -use object_store::{path::Path, ObjectMeta}; -use std::pin::Pin; -use std::sync::Arc; - -pub use self::url::ListingTableUrl; - -/// Stream of files get listed from object store -pub type PartitionedFileStream = - Pin> + Send + Sync + 'static>>; - -/// Only scan a subset of Row Groups from the Parquet file whose data "midpoint" -/// lies within the [start, end) byte offsets. This option can be used to scan non-overlapping -/// sections of a Parquet file in parallel. -#[derive(Debug, Clone, PartialEq, Hash, Eq, PartialOrd, Ord)] -pub struct FileRange { - /// Range start - pub start: i64, - /// Range end - pub end: i64, -} - -impl FileRange { - /// returns true if this file range contains the specified offset - pub fn contains(&self, offset: i64) -> bool { - offset >= self.start && offset < self.end - } -} - -#[derive(Debug, Clone)] -/// A single file or part of a file that should be read, along with its schema, statistics -/// and partition column values that need to be appended to each row. -pub struct PartitionedFile { - /// Path for the file (e.g. URL, filesystem path, etc) - pub object_meta: ObjectMeta, - /// Values of partition columns to be appended to each row. - /// - /// These MUST have the same count, order, and type than the [`table_partition_cols`]. - /// - /// You may use [`wrap_partition_value_in_dict`] to wrap them if you have used [`wrap_partition_type_in_dict`] to wrap the column type. - /// - /// - /// [`wrap_partition_type_in_dict`]: https://github.com/apache/datafusion/blob/main/datafusion/core/src/datasource/physical_plan/file_scan_config.rs#L55 - /// [`wrap_partition_value_in_dict`]: https://github.com/apache/datafusion/blob/main/datafusion/core/src/datasource/physical_plan/file_scan_config.rs#L62 - /// [`table_partition_cols`]: https://github.com/apache/datafusion/blob/main/datafusion/core/src/datasource/file_format/options.rs#L190 - pub partition_values: Vec, - /// An optional file range for a more fine-grained parallel execution - pub range: Option, - /// Optional statistics that describe the data in this file if known. - /// - /// DataFusion relies on these statistics for planning (in particular to sort file groups), - /// so if they are incorrect, incorrect answers may result. - pub statistics: Option, - /// An optional field for user defined per object metadata - pub extensions: Option>, - /// The estimated size of the parquet metadata, in bytes - pub metadata_size_hint: Option, -} - -impl PartitionedFile { - /// Create a simple file without metadata or partition - pub fn new(path: impl Into, size: u64) -> Self { - Self { - object_meta: ObjectMeta { - location: Path::from(path.into()), - last_modified: chrono::Utc.timestamp_nanos(0), - size: size as usize, - e_tag: None, - version: None, - }, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - } - } - - /// Create a file range without metadata or partition - pub fn new_with_range(path: String, size: u64, start: i64, end: i64) -> Self { - Self { - object_meta: ObjectMeta { - location: Path::from(path), - last_modified: chrono::Utc.timestamp_nanos(0), - size: size as usize, - e_tag: None, - version: None, - }, - partition_values: vec![], - range: Some(FileRange { start, end }), - statistics: None, - extensions: None, - metadata_size_hint: None, - } - .with_range(start, end) - } - - /// Provide a hint to the size of the file metadata. If a hint is provided - /// the reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. - /// Without an appropriate hint, two read may be required to fetch the metadata. - pub fn with_metadata_size_hint(mut self, metadata_size_hint: usize) -> Self { - self.metadata_size_hint = Some(metadata_size_hint); - self - } - - /// Return a file reference from the given path - pub fn from_path(path: String) -> Result { - let size = std::fs::metadata(path.clone())?.len(); - Ok(Self::new(path, size)) - } - - /// Return the path of this partitioned file - pub fn path(&self) -> &Path { - &self.object_meta.location - } - - /// Update the file to only scan the specified range (in bytes) - pub fn with_range(mut self, start: i64, end: i64) -> Self { - self.range = Some(FileRange { start, end }); - self - } - - /// Update the user defined extensions for this file. - /// - /// This can be used to pass reader specific information. - pub fn with_extensions( - mut self, - extensions: Arc, - ) -> Self { - self.extensions = Some(extensions); - self - } -} - -impl From for PartitionedFile { - fn from(object_meta: ObjectMeta) -> Self { - PartitionedFile { - object_meta, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - } - } -} - -#[cfg(test)] -mod tests { - use super::ListingTableUrl; - use datafusion_execution::object_store::{ - DefaultObjectStoreRegistry, ObjectStoreRegistry, - }; - use object_store::{local::LocalFileSystem, path::Path}; - use std::{ops::Not, sync::Arc}; - use url::Url; - - #[test] - fn test_object_store_listing_url() { - let listing = ListingTableUrl::parse("file:///").unwrap(); - let store = listing.object_store(); - assert_eq!(store.as_str(), "file:///"); - - let listing = ListingTableUrl::parse("s3://bucket/").unwrap(); - let store = listing.object_store(); - assert_eq!(store.as_str(), "s3://bucket/"); - } - - #[test] - fn test_get_store_hdfs() { - let sut = DefaultObjectStoreRegistry::default(); - let url = Url::parse("hdfs://localhost:8020").unwrap(); - sut.register_store(&url, Arc::new(LocalFileSystem::new())); - let url = ListingTableUrl::parse("hdfs://localhost:8020/key").unwrap(); - sut.get_store(url.as_ref()).unwrap(); - } - - #[test] - fn test_get_store_s3() { - let sut = DefaultObjectStoreRegistry::default(); - let url = Url::parse("s3://bucket/key").unwrap(); - sut.register_store(&url, Arc::new(LocalFileSystem::new())); - let url = ListingTableUrl::parse("s3://bucket/key").unwrap(); - sut.get_store(url.as_ref()).unwrap(); - } - - #[test] - fn test_get_store_file() { - let sut = DefaultObjectStoreRegistry::default(); - let url = ListingTableUrl::parse("file:///bucket/key").unwrap(); - sut.get_store(url.as_ref()).unwrap(); - } - - #[test] - fn test_get_store_local() { - let sut = DefaultObjectStoreRegistry::default(); - let url = ListingTableUrl::parse("../").unwrap(); - sut.get_store(url.as_ref()).unwrap(); - } - - #[test] - fn test_url_contains() { - let url = ListingTableUrl::parse("file:///var/data/mytable/").unwrap(); - - // standard case with default config - assert!(url.contains( - &Path::parse("/var/data/mytable/data.parquet").unwrap(), - true - )); - - // standard case with `ignore_subdirectory` set to false - assert!(url.contains( - &Path::parse("/var/data/mytable/data.parquet").unwrap(), - false - )); - - // as per documentation, when `ignore_subdirectory` is true, we should ignore files that aren't - // a direct child of the `url` - assert!(url - .contains( - &Path::parse("/var/data/mytable/mysubfolder/data.parquet").unwrap(), - true - ) - .not()); - - // when we set `ignore_subdirectory` to false, we should not ignore the file - assert!(url.contains( - &Path::parse("/var/data/mytable/mysubfolder/data.parquet").unwrap(), - false - )); - - // as above, `ignore_subdirectory` is false, so we include the file - assert!(url.contains( - &Path::parse("/var/data/mytable/year=2024/data.parquet").unwrap(), - false - )); - - // in this case, we include the file even when `ignore_subdirectory` is true because the - // path segment is a hive partition which doesn't count as a subdirectory for the purposes - // of `Url::contains` - assert!(url.contains( - &Path::parse("/var/data/mytable/year=2024/data.parquet").unwrap(), - true - )); - - // testing an empty path with default config - assert!(url.contains(&Path::parse("/var/data/mytable/").unwrap(), true)); - - // testing an empty path with `ignore_subdirectory` set to false - assert!(url.contains(&Path::parse("/var/data/mytable/").unwrap(), false)); - } -} diff --git a/datafusion/catalog/Cargo.toml b/datafusion/catalog/Cargo.toml index 749457855ca2..73ac44a0316e 100644 --- a/datafusion/catalog/Cargo.toml +++ b/datafusion/catalog/Cargo.toml @@ -40,7 +40,6 @@ futures = { workspace = true } itertools = { workspace = true } log = { workspace = true } parking_lot = { workspace = true } -sqlparser = { workspace = true } [dev-dependencies] tokio = { workspace = true } diff --git a/datafusion/catalog/src/information_schema.rs b/datafusion/catalog/src/information_schema.rs index e68e636989f8..7948c0299d39 100644 --- a/datafusion/catalog/src/information_schema.rs +++ b/datafusion/catalog/src/information_schema.rs @@ -405,7 +405,7 @@ fn get_udf_args_and_return_types( udf: &Arc, ) -> Result, Option)>> { let signature = udf.signature(); - let arg_types = signature.type_signature.get_possible_types(); + let arg_types = signature.type_signature.get_example_types(); if arg_types.is_empty() { Ok(vec![(vec![], None)]) } else { @@ -428,7 +428,7 @@ fn get_udaf_args_and_return_types( udaf: &Arc, ) -> Result, Option)>> { let signature = udaf.signature(); - let arg_types = signature.type_signature.get_possible_types(); + let arg_types = signature.type_signature.get_example_types(); if arg_types.is_empty() { Ok(vec![(vec![], None)]) } else { @@ -452,7 +452,7 @@ fn get_udwf_args_and_return_types( udwf: &Arc, ) -> Result, Option)>> { let signature = udwf.signature(); - let arg_types = signature.type_signature.get_possible_types(); + let arg_types = signature.type_signature.get_example_types(); if arg_types.is_empty() { Ok(vec![(vec![], None)]) } else { diff --git a/datafusion/catalog/src/table.rs b/datafusion/catalog/src/table.rs index 88d2d8bde51e..ecc792f73d30 100644 --- a/datafusion/catalog/src/table.rs +++ b/datafusion/catalog/src/table.rs @@ -33,18 +33,19 @@ use datafusion_expr::{ }; use datafusion_physical_plan::ExecutionPlan; -/// A named table which can be queried. +/// A table which can be queried and modified. /// /// Please see [`CatalogProvider`] for details of implementing a custom catalog. /// /// [`TableProvider`] represents a source of data which can provide data as -/// Apache Arrow `RecordBatch`es. Implementations of this trait provide +/// Apache Arrow [`RecordBatch`]es. Implementations of this trait provide /// important information for planning such as: /// /// 1. [`Self::schema`]: The schema (columns and their types) of the table /// 2. [`Self::supports_filters_pushdown`]: Should filters be pushed into this scan /// 2. [`Self::scan`]: An [`ExecutionPlan`] that can read data /// +/// [`RecordBatch`]: https://docs.rs/arrow/latest/arrow/record_batch/struct.RecordBatch.html /// [`CatalogProvider`]: super::CatalogProvider #[async_trait] pub trait TableProvider: Debug + Sync + Send { diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index 215a06e81c3d..3be666ce7974 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -52,7 +52,6 @@ apache-avro = { version = "0.17", default-features = false, features = [ ], optional = true } arrow = { workspace = true } arrow-ipc = { workspace = true } -arrow-schema = { workspace = true } base64 = "0.22.1" half = { workspace = true } hashbrown = { workspace = true } diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs index bc37e59c9b92..28202c6684b5 100644 --- a/datafusion/common/src/cast.rs +++ b/datafusion/common/src/cast.rs @@ -20,7 +20,7 @@ //! but provide an error message rather than a panic, as the corresponding //! kernels in arrow-rs such as `as_boolean_array` do. -use crate::{downcast_value, DataFusionError, Result}; +use crate::{downcast_value, Result}; use arrow::array::{ BinaryViewArray, Float16Array, Int16Array, Int8Array, LargeBinaryArray, LargeStringArray, StringViewArray, UInt16Array, diff --git a/datafusion/common/src/column.rs b/datafusion/common/src/column.rs index 05e2dff0bd43..50a4e257d1c9 100644 --- a/datafusion/common/src/column.rs +++ b/datafusion/common/src/column.rs @@ -17,10 +17,10 @@ //! Column -use crate::error::_schema_err; +use crate::error::{_schema_err, add_possible_columns_to_diag}; use crate::utils::{parse_identifiers_normalized, quote_identifier}; use crate::{DFSchema, Diagnostic, Result, SchemaError, Spans, TableReference}; -use arrow_schema::{Field, FieldRef}; +use arrow::datatypes::{Field, FieldRef}; use std::collections::HashSet; use std::convert::Infallible; use std::fmt; @@ -273,18 +273,11 @@ impl Column { // user which columns are candidates, or which table // they come from. For now, let's list the table names // only. - for qualified_field in qualified_fields { - let (Some(table), _) = qualified_field else { - continue; - }; - diagnostic.add_note( - format!( - "possible reference to '{}' in table '{}'", - &self.name, table - ), - None, - ); - } + add_possible_columns_to_diag( + &mut diagnostic, + &Column::new_unqualified(&self.name), + &columns, + ); err.with_diagnostic(diagnostic) }); } @@ -380,8 +373,7 @@ impl fmt::Display for Column { #[cfg(test)] mod tests { use super::*; - use arrow::datatypes::DataType; - use arrow_schema::SchemaBuilder; + use arrow::datatypes::{DataType, SchemaBuilder}; use std::sync::Arc; fn create_qualified_schema(qualifier: &str, names: Vec<&str>) -> Result { diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 15d8c04371f8..edb323eeef17 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -1233,35 +1233,72 @@ macro_rules! extensions_options { Box::new(self.clone()) } - fn set(&mut self, key: &str, value: &str) -> $crate::Result<()> { - match key { - $( - stringify!($field_name) => { - self.$field_name = value.parse().map_err(|e| { - $crate::DataFusionError::Context( - format!(concat!("Error parsing {} as ", stringify!($t),), value), - Box::new($crate::DataFusionError::External(Box::new(e))), - ) - })?; - Ok(()) - } - )* - _ => Err($crate::DataFusionError::Configuration( - format!(concat!("Config value \"{}\" not found on ", stringify!($struct_name)), key) - )) - } + fn set(&mut self, key: &str, value: &str) -> $crate::error::Result<()> { + $crate::config::ConfigField::set(self, key, value) } fn entries(&self) -> Vec<$crate::config::ConfigEntry> { - vec![ + struct Visitor(Vec<$crate::config::ConfigEntry>); + + impl $crate::config::Visit for Visitor { + fn some( + &mut self, + key: &str, + value: V, + description: &'static str, + ) { + self.0.push($crate::config::ConfigEntry { + key: key.to_string(), + value: Some(value.to_string()), + description, + }) + } + + fn none(&mut self, key: &str, description: &'static str) { + self.0.push($crate::config::ConfigEntry { + key: key.to_string(), + value: None, + description, + }) + } + } + + let mut v = Visitor(vec![]); + // The prefix is not used for extensions. + // The description is generated in ConfigField::visit. + // We can just pass empty strings here. + $crate::config::ConfigField::visit(self, &mut v, "", ""); + v.0 + } + } + + impl $crate::config::ConfigField for $struct_name { + fn set(&mut self, key: &str, value: &str) -> $crate::error::Result<()> { + let (key, rem) = key.split_once('.').unwrap_or((key, "")); + match key { $( - $crate::config::ConfigEntry { - key: stringify!($field_name).to_owned(), - value: (self.$field_name != $default).then(|| self.$field_name.to_string()), - description: concat!($($d),*).trim(), + stringify!($field_name) => { + // Safely apply deprecated attribute if present + // $(#[allow(deprecated)])? + { + #[allow(deprecated)] + self.$field_name.set(rem, value.as_ref()) + } }, )* - ] + _ => return $crate::error::_config_err!( + "Config value \"{}\" not found on {}", key, stringify!($struct_name) + ) + } + } + + fn visit(&self, v: &mut V, _key_prefix: &str, _description: &'static str) { + $( + let key = stringify!($field_name).to_string(); + let desc = concat!($($d),*).trim(); + #[allow(deprecated)] + self.$field_name.visit(v, key.as_str(), desc); + )* } } } diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 2ac629432ce9..99fb179c76a3 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -30,8 +30,9 @@ use crate::{ }; use arrow::compute::can_cast_types; -use arrow::datatypes::{DataType, Field, FieldRef, Fields, Schema, SchemaRef}; -use arrow_schema::SchemaBuilder; +use arrow::datatypes::{ + DataType, Field, FieldRef, Fields, Schema, SchemaBuilder, SchemaRef, +}; /// A reference-counted reference to a [DFSchema]. pub type DFSchemaRef = Arc; @@ -56,7 +57,7 @@ pub type DFSchemaRef = Arc; /// /// ```rust /// use datafusion_common::{DFSchema, Column}; -/// use arrow_schema::{DataType, Field, Schema}; +/// use arrow::datatypes::{DataType, Field, Schema}; /// /// let arrow_schema = Schema::new(vec![ /// Field::new("c1", DataType::Int32, false), @@ -77,7 +78,7 @@ pub type DFSchemaRef = Arc; /// /// ```rust /// use datafusion_common::{DFSchema, Column}; -/// use arrow_schema::{DataType, Field, Schema}; +/// use arrow::datatypes::{DataType, Field, Schema}; /// /// let arrow_schema = Schema::new(vec![ /// Field::new("c1", DataType::Int32, false), @@ -94,8 +95,7 @@ pub type DFSchemaRef = Arc; /// /// ```rust /// use datafusion_common::DFSchema; -/// use arrow_schema::Schema; -/// use arrow::datatypes::Field; +/// use arrow::datatypes::{Schema, Field}; /// use std::collections::HashMap; /// /// let df_schema = DFSchema::from_unqualified_fields(vec![ @@ -1002,12 +1002,14 @@ pub trait SchemaExt { /// It works the same as [`DFSchema::equivalent_names_and_types`]. fn equivalent_names_and_types(&self, other: &Self) -> bool; - /// Returns true if the two schemas have the same qualified named - /// fields with logically equivalent data types. Returns false otherwise. + /// Returns nothing if the two schemas have the same qualified named + /// fields with logically equivalent data types. Returns internal error otherwise. /// /// Use [DFSchema]::equivalent_names_and_types for stricter semantic type /// equivalence checking. - fn logically_equivalent_names_and_types(&self, other: &Self) -> bool; + /// + /// It is only used by insert into cases. + fn logically_equivalent_names_and_types(&self, other: &Self) -> Result<()>; } impl SchemaExt for Schema { @@ -1028,21 +1030,36 @@ impl SchemaExt for Schema { }) } - fn logically_equivalent_names_and_types(&self, other: &Self) -> bool { + // It is only used by insert into cases. + fn logically_equivalent_names_and_types(&self, other: &Self) -> Result<()> { + // case 1 : schema length mismatch if self.fields().len() != other.fields().len() { - return false; + _plan_err!( + "Inserting query must have the same schema length as the table. \ + Expected table schema length: {}, got: {}", + self.fields().len(), + other.fields().len() + ) + } else { + // case 2 : schema length match, but fields mismatch + // check if the fields name are the same and have the same data types + self.fields() + .iter() + .zip(other.fields().iter()) + .try_for_each(|(f1, f2)| { + if f1.name() != f2.name() || !DFSchema::datatype_is_logically_equal(f1.data_type(), f2.data_type()) { + _plan_err!( + "Inserting query schema mismatch: Expected table field '{}' with type {:?}, \ + but got '{}' with type {:?}.", + f1.name(), + f1.data_type(), + f2.name(), + f2.data_type()) + } else { + Ok(()) + } + }) } - - self.fields() - .iter() - .zip(other.fields().iter()) - .all(|(f1, f2)| { - f1.name() == f2.name() - && DFSchema::datatype_is_logically_equal( - f1.data_type(), - f2.data_type(), - ) - }) } } @@ -1069,7 +1086,7 @@ mod tests { Column names are case sensitive. \ You can use double quotes to refer to the \"\"t1.c0\"\" column \ or set the datafusion.sql_parser.enable_ident_normalization configuration. \ - Valid fields are t1.c0, t1.c1."; + Did you mean 't1.c0'?."; assert_eq!(err.strip_backtrace(), expected); Ok(()) } diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 013b1d5a2cab..c50ec64759d5 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -27,6 +27,7 @@ use std::io; use std::result; use std::sync::Arc; +use crate::utils::datafusion_strsim::normalized_levenshtein; use crate::utils::quote_identifier; use crate::{Column, DFSchema, Diagnostic, TableReference}; #[cfg(feature = "avro")] @@ -190,6 +191,11 @@ impl Display for SchemaError { .iter() .map(|column| column.flat_name().to_lowercase()) .collect::>(); + + let valid_fields_names = valid_fields + .iter() + .map(|column| column.flat_name()) + .collect::>(); if lower_valid_fields.contains(&field.flat_name().to_lowercase()) { write!( f, @@ -198,7 +204,15 @@ impl Display for SchemaError { field.quoted_flat_name() )?; } - if !valid_fields.is_empty() { + let field_name = field.name(); + if let Some(matched) = valid_fields_names + .iter() + .filter(|str| normalized_levenshtein(str, field_name) >= 0.5) + .collect::>() + .first() + { + write!(f, ". Did you mean '{matched}'?")?; + } else if !valid_fields.is_empty() { write!( f, ". Valid fields are {}", @@ -468,6 +482,11 @@ impl DataFusionError { "".to_owned() } + /// Return a [`DataFusionErrorBuilder`] to build a [`DataFusionError`] + pub fn builder() -> DataFusionErrorBuilder { + DataFusionErrorBuilder::default() + } + fn error_prefix(&self) -> &'static str { match self { DataFusionError::ArrowError(_, _) => "Arrow error: ", @@ -602,6 +621,9 @@ impl DataFusionError { DiagnosticsIterator { head: self }.next() } + /// Return an iterator over this [`DataFusionError`] and any other + /// [`DataFusionError`]s in a [`DataFusionError::Collection`]. + /// /// Sometimes DataFusion is able to collect multiple errors in a SQL query /// before terminating, e.g. across different expressions in a SELECT /// statements or different sides of a UNION. This method returns an @@ -634,17 +656,65 @@ impl DataFusionError { } } +/// A builder for [`DataFusionError`] +/// +/// This builder can be used to collect multiple errors and return them as a +/// [`DataFusionError::Collection`]. +/// +/// # Example: no errors +/// ``` +/// # use datafusion_common::DataFusionError; +/// let mut builder = DataFusionError::builder(); +/// // ok_or returns the value if no errors have been added +/// assert_eq!(builder.error_or(42).unwrap(), 42); +/// ``` +/// +/// # Example: with errors +/// ``` +/// # use datafusion_common::{assert_contains, DataFusionError}; +/// let mut builder = DataFusionError::builder(); +/// builder.add_error(DataFusionError::Internal("foo".to_owned())); +/// // ok_or returns the value if no errors have been added +/// assert_contains!(builder.error_or(42).unwrap_err().to_string(), "Internal error: foo"); +/// ``` +#[derive(Debug, Default)] pub struct DataFusionErrorBuilder(Vec); impl DataFusionErrorBuilder { + /// Create a new [`DataFusionErrorBuilder`] pub fn new() -> Self { - Self(Vec::new()) + Default::default() } + /// Add an error to the in progress list + /// + /// # Example + /// ``` + /// # use datafusion_common::{assert_contains, DataFusionError}; + /// let mut builder = DataFusionError::builder(); + /// builder.add_error(DataFusionError::Internal("foo".to_owned())); + /// assert_contains!(builder.error_or(42).unwrap_err().to_string(), "Internal error: foo"); + /// ``` pub fn add_error(&mut self, error: DataFusionError) { self.0.push(error); } + /// Add an error to the in progress list, returning the builder + /// + /// # Example + /// ``` + /// # use datafusion_common::{assert_contains, DataFusionError}; + /// let builder = DataFusionError::builder() + /// .with_error(DataFusionError::Internal("foo".to_owned())); + /// assert_contains!(builder.error_or(42).unwrap_err().to_string(), "Internal error: foo"); + /// ``` + pub fn with_error(mut self, error: DataFusionError) -> Self { + self.0.push(error); + self + } + + /// Returns `Ok(ok)` if no errors were added to the builder, + /// otherwise returns a `Result::Err` pub fn error_or(self, ok: T) -> Result { match self.0.len() { 0 => Ok(ok), @@ -654,12 +724,6 @@ impl DataFusionErrorBuilder { } } -impl Default for DataFusionErrorBuilder { - fn default() -> Self { - Self::new() - } -} - /// Unwrap an `Option` if possible. Otherwise return an `DataFusionError::Internal`. /// In normal usage of DataFusion the unwrap should always succeed. /// @@ -827,6 +891,27 @@ pub fn unqualified_field_not_found(name: &str, schema: &DFSchema) -> DataFusionE }) } +pub fn add_possible_columns_to_diag( + diagnostic: &mut Diagnostic, + field: &Column, + valid_fields: &[Column], +) { + let field_names: Vec = valid_fields + .iter() + .filter_map(|f| { + if normalized_levenshtein(f.name(), field.name()) >= 0.5 { + Some(f.flat_name()) + } else { + None + } + }) + .collect(); + + for name in field_names { + diagnostic.add_note(format!("possible column {}", name), None); + } +} + #[cfg(test)] mod test { use std::sync::Arc; diff --git a/datafusion/common/src/file_options/parquet_writer.rs b/datafusion/common/src/file_options/parquet_writer.rs index 6a717d3c0c60..8c785b84313c 100644 --- a/datafusion/common/src/file_options/parquet_writer.rs +++ b/datafusion/common/src/file_options/parquet_writer.rs @@ -25,7 +25,7 @@ use crate::{ DataFusionError, Result, _internal_datafusion_err, }; -use arrow_schema::Schema; +use arrow::datatypes::Schema; // TODO: handle once deprecated #[allow(deprecated)] use parquet::{ diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 1ad2a5f0cec3..df1ae100f581 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -104,21 +104,84 @@ pub type HashSet = hashbrown::HashSet; #[macro_export] macro_rules! downcast_value { ($Value: expr, $Type: ident) => {{ - use std::any::type_name; - $Value.as_any().downcast_ref::<$Type>().ok_or_else(|| { - DataFusionError::Internal(format!( - "could not cast value to {}", - type_name::<$Type>() - )) - })? + use $crate::__private::DowncastArrayHelper; + $Value.downcast_array_helper::<$Type>()? }}; ($Value: expr, $Type: ident, $T: tt) => {{ - use std::any::type_name; - $Value.as_any().downcast_ref::<$Type<$T>>().ok_or_else(|| { - DataFusionError::Internal(format!( - "could not cast value to {}", - type_name::<$Type<$T>>() - )) - })? + use $crate::__private::DowncastArrayHelper; + $Value.downcast_array_helper::<$Type<$T>>()? }}; } + +// Not public API. +#[doc(hidden)] +pub mod __private { + use crate::error::_internal_datafusion_err; + use crate::Result; + use arrow::array::Array; + use std::any::{type_name, Any}; + + #[doc(hidden)] + pub trait DowncastArrayHelper { + fn downcast_array_helper(&self) -> Result<&U>; + } + + impl DowncastArrayHelper for T { + fn downcast_array_helper(&self) -> Result<&U> { + self.as_any().downcast_ref().ok_or_else(|| { + _internal_datafusion_err!( + "could not cast array of type {} to {}", + self.data_type(), + type_name::() + ) + }) + } + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{ArrayRef, Int32Array, UInt64Array}; + use std::any::{type_name, type_name_of_val}; + use std::sync::Arc; + + #[test] + fn test_downcast_value() -> crate::Result<()> { + let boxed: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); + let array = downcast_value!(&boxed, Int32Array); + assert_eq!(type_name_of_val(&array), type_name::<&Int32Array>()); + + let expected: Int32Array = vec![1, 2, 3].into_iter().map(Some).collect(); + assert_eq!(array, &expected); + Ok(()) + } + + #[test] + fn test_downcast_value_err_message() { + let boxed: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); + let error: crate::DataFusionError = (|| { + downcast_value!(&boxed, UInt64Array); + Ok(()) + })() + .err() + .unwrap(); + + assert_starts_with( + error.to_string(), + "Internal error: could not cast array of type Int32 to arrow_array::array::primitive_array::PrimitiveArray" + ); + } + + // `err.to_string()` depends on backtrace being present (may have backtrace appended) + // `err.strip_backtrace()` also depends on backtrace being present (may have "This was likely caused by ..." stripped) + fn assert_starts_with(actual: impl AsRef, expected_prefix: impl AsRef) { + let actual = actual.as_ref(); + let expected_prefix = expected_prefix.as_ref(); + assert!( + actual.starts_with(expected_prefix), + "Expected '{}' to start with '{}'", + actual, + expected_prefix + ); + } +} diff --git a/datafusion/common/src/param_value.rs b/datafusion/common/src/param_value.rs index 8d61bad97b9f..d2802c096da1 100644 --- a/datafusion/common/src/param_value.rs +++ b/datafusion/common/src/param_value.rs @@ -17,7 +17,7 @@ use crate::error::{_plan_datafusion_err, _plan_err}; use crate::{Result, ScalarValue}; -use arrow_schema::DataType; +use arrow::datatypes::DataType; use std::collections::HashMap; /// The parameter value corresponding to the placeholder diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 5db0f5ed5cc0..9059ae07e648 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -55,10 +55,9 @@ use arrow::datatypes::{ Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, - UInt8Type, DECIMAL128_MAX_PRECISION, + UInt8Type, UnionFields, UnionMode, DECIMAL128_MAX_PRECISION, }; use arrow::util::display::{array_value_to_string, ArrayFormatter, FormatOptions}; -use arrow_schema::{UnionFields, UnionMode}; use crate::format::DEFAULT_CAST_OPTIONS; use half::f16; @@ -976,6 +975,129 @@ impl ScalarValue { ) } + /// Create a Null instance of ScalarValue for this datatype + /// + /// Example + /// ``` + /// use datafusion_common::ScalarValue; + /// use arrow::datatypes::DataType; + /// + /// let scalar = ScalarValue::try_new_null(&DataType::Int32).unwrap(); + /// assert_eq!(scalar.is_null(), true); + /// assert_eq!(scalar.data_type(), DataType::Int32); + /// ``` + pub fn try_new_null(data_type: &DataType) -> Result { + Ok(match data_type { + DataType::Boolean => ScalarValue::Boolean(None), + DataType::Float16 => ScalarValue::Float16(None), + DataType::Float64 => ScalarValue::Float64(None), + DataType::Float32 => ScalarValue::Float32(None), + DataType::Int8 => ScalarValue::Int8(None), + DataType::Int16 => ScalarValue::Int16(None), + DataType::Int32 => ScalarValue::Int32(None), + DataType::Int64 => ScalarValue::Int64(None), + DataType::UInt8 => ScalarValue::UInt8(None), + DataType::UInt16 => ScalarValue::UInt16(None), + DataType::UInt32 => ScalarValue::UInt32(None), + DataType::UInt64 => ScalarValue::UInt64(None), + DataType::Decimal128(precision, scale) => { + ScalarValue::Decimal128(None, *precision, *scale) + } + DataType::Decimal256(precision, scale) => { + ScalarValue::Decimal256(None, *precision, *scale) + } + DataType::Utf8 => ScalarValue::Utf8(None), + DataType::LargeUtf8 => ScalarValue::LargeUtf8(None), + DataType::Utf8View => ScalarValue::Utf8View(None), + DataType::Binary => ScalarValue::Binary(None), + DataType::BinaryView => ScalarValue::BinaryView(None), + DataType::FixedSizeBinary(len) => ScalarValue::FixedSizeBinary(*len, None), + DataType::LargeBinary => ScalarValue::LargeBinary(None), + DataType::Date32 => ScalarValue::Date32(None), + DataType::Date64 => ScalarValue::Date64(None), + DataType::Time32(TimeUnit::Second) => ScalarValue::Time32Second(None), + DataType::Time32(TimeUnit::Millisecond) => { + ScalarValue::Time32Millisecond(None) + } + DataType::Time64(TimeUnit::Microsecond) => { + ScalarValue::Time64Microsecond(None) + } + DataType::Time64(TimeUnit::Nanosecond) => ScalarValue::Time64Nanosecond(None), + DataType::Timestamp(TimeUnit::Second, tz_opt) => { + ScalarValue::TimestampSecond(None, tz_opt.clone()) + } + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { + ScalarValue::TimestampMillisecond(None, tz_opt.clone()) + } + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { + ScalarValue::TimestampMicrosecond(None, tz_opt.clone()) + } + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { + ScalarValue::TimestampNanosecond(None, tz_opt.clone()) + } + DataType::Interval(IntervalUnit::YearMonth) => { + ScalarValue::IntervalYearMonth(None) + } + DataType::Interval(IntervalUnit::DayTime) => { + ScalarValue::IntervalDayTime(None) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + ScalarValue::IntervalMonthDayNano(None) + } + DataType::Duration(TimeUnit::Second) => ScalarValue::DurationSecond(None), + DataType::Duration(TimeUnit::Millisecond) => { + ScalarValue::DurationMillisecond(None) + } + DataType::Duration(TimeUnit::Microsecond) => { + ScalarValue::DurationMicrosecond(None) + } + DataType::Duration(TimeUnit::Nanosecond) => { + ScalarValue::DurationNanosecond(None) + } + DataType::Dictionary(index_type, value_type) => ScalarValue::Dictionary( + index_type.clone(), + Box::new(value_type.as_ref().try_into()?), + ), + // `ScalaValue::List` contains single element `ListArray`. + DataType::List(field_ref) => ScalarValue::List(Arc::new( + GenericListArray::new_null(Arc::clone(field_ref), 1), + )), + // `ScalarValue::LargeList` contains single element `LargeListArray`. + DataType::LargeList(field_ref) => ScalarValue::LargeList(Arc::new( + GenericListArray::new_null(Arc::clone(field_ref), 1), + )), + // `ScalaValue::FixedSizeList` contains single element `FixedSizeList`. + DataType::FixedSizeList(field_ref, fixed_length) => { + ScalarValue::FixedSizeList(Arc::new(FixedSizeListArray::new_null( + Arc::clone(field_ref), + *fixed_length, + 1, + ))) + } + DataType::Struct(fields) => ScalarValue::Struct( + new_null_array(&DataType::Struct(fields.to_owned()), 1) + .as_struct() + .to_owned() + .into(), + ), + DataType::Map(fields, sorted) => ScalarValue::Map( + new_null_array(&DataType::Map(fields.to_owned(), sorted.to_owned()), 1) + .as_map() + .to_owned() + .into(), + ), + DataType::Union(fields, mode) => { + ScalarValue::Union(None, fields.clone(), *mode) + } + DataType::Null => ScalarValue::Null, + _ => { + return _not_impl_err!( + "Can't create a null scalar from data_type \"{data_type:?}\"" + ); + } + }) + } + /// Returns a [`ScalarValue::Utf8`] representing `val` pub fn new_utf8(val: impl Into) -> Self { ScalarValue::from(val.into()) @@ -3457,115 +3579,7 @@ impl TryFrom<&DataType> for ScalarValue { /// Create a Null instance of ScalarValue for this datatype fn try_from(data_type: &DataType) -> Result { - Ok(match data_type { - DataType::Boolean => ScalarValue::Boolean(None), - DataType::Float16 => ScalarValue::Float16(None), - DataType::Float64 => ScalarValue::Float64(None), - DataType::Float32 => ScalarValue::Float32(None), - DataType::Int8 => ScalarValue::Int8(None), - DataType::Int16 => ScalarValue::Int16(None), - DataType::Int32 => ScalarValue::Int32(None), - DataType::Int64 => ScalarValue::Int64(None), - DataType::UInt8 => ScalarValue::UInt8(None), - DataType::UInt16 => ScalarValue::UInt16(None), - DataType::UInt32 => ScalarValue::UInt32(None), - DataType::UInt64 => ScalarValue::UInt64(None), - DataType::Decimal128(precision, scale) => { - ScalarValue::Decimal128(None, *precision, *scale) - } - DataType::Decimal256(precision, scale) => { - ScalarValue::Decimal256(None, *precision, *scale) - } - DataType::Utf8 => ScalarValue::Utf8(None), - DataType::LargeUtf8 => ScalarValue::LargeUtf8(None), - DataType::Utf8View => ScalarValue::Utf8View(None), - DataType::Binary => ScalarValue::Binary(None), - DataType::BinaryView => ScalarValue::BinaryView(None), - DataType::FixedSizeBinary(len) => ScalarValue::FixedSizeBinary(*len, None), - DataType::LargeBinary => ScalarValue::LargeBinary(None), - DataType::Date32 => ScalarValue::Date32(None), - DataType::Date64 => ScalarValue::Date64(None), - DataType::Time32(TimeUnit::Second) => ScalarValue::Time32Second(None), - DataType::Time32(TimeUnit::Millisecond) => { - ScalarValue::Time32Millisecond(None) - } - DataType::Time64(TimeUnit::Microsecond) => { - ScalarValue::Time64Microsecond(None) - } - DataType::Time64(TimeUnit::Nanosecond) => ScalarValue::Time64Nanosecond(None), - DataType::Timestamp(TimeUnit::Second, tz_opt) => { - ScalarValue::TimestampSecond(None, tz_opt.clone()) - } - DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { - ScalarValue::TimestampMillisecond(None, tz_opt.clone()) - } - DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { - ScalarValue::TimestampMicrosecond(None, tz_opt.clone()) - } - DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { - ScalarValue::TimestampNanosecond(None, tz_opt.clone()) - } - DataType::Interval(IntervalUnit::YearMonth) => { - ScalarValue::IntervalYearMonth(None) - } - DataType::Interval(IntervalUnit::DayTime) => { - ScalarValue::IntervalDayTime(None) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - ScalarValue::IntervalMonthDayNano(None) - } - DataType::Duration(TimeUnit::Second) => ScalarValue::DurationSecond(None), - DataType::Duration(TimeUnit::Millisecond) => { - ScalarValue::DurationMillisecond(None) - } - DataType::Duration(TimeUnit::Microsecond) => { - ScalarValue::DurationMicrosecond(None) - } - DataType::Duration(TimeUnit::Nanosecond) => { - ScalarValue::DurationNanosecond(None) - } - DataType::Dictionary(index_type, value_type) => ScalarValue::Dictionary( - index_type.clone(), - Box::new(value_type.as_ref().try_into()?), - ), - // `ScalaValue::List` contains single element `ListArray`. - DataType::List(field_ref) => ScalarValue::List(Arc::new( - GenericListArray::new_null(Arc::clone(field_ref), 1), - )), - // `ScalarValue::LargeList` contains single element `LargeListArray`. - DataType::LargeList(field_ref) => ScalarValue::LargeList(Arc::new( - GenericListArray::new_null(Arc::clone(field_ref), 1), - )), - // `ScalaValue::FixedSizeList` contains single element `FixedSizeList`. - DataType::FixedSizeList(field_ref, fixed_length) => { - ScalarValue::FixedSizeList(Arc::new(FixedSizeListArray::new_null( - Arc::clone(field_ref), - *fixed_length, - 1, - ))) - } - DataType::Struct(fields) => ScalarValue::Struct( - new_null_array(&DataType::Struct(fields.to_owned()), 1) - .as_struct() - .to_owned() - .into(), - ), - DataType::Map(fields, sorted) => ScalarValue::Map( - new_null_array(&DataType::Map(fields.to_owned(), sorted.to_owned()), 1) - .as_map() - .to_owned() - .into(), - ), - DataType::Union(fields, mode) => { - ScalarValue::Union(None, fields.clone(), *mode) - } - DataType::Null => ScalarValue::Null, - _ => { - return _not_impl_err!( - "Can't create a scalar from data_type \"{data_type:?}\"" - ); - } - }) + Self::try_new_null(data_type) } } @@ -3964,9 +3978,9 @@ mod tests { use arrow::array::{types::Float64Type, NullBufferBuilder}; use arrow::buffer::{Buffer, OffsetBuffer}; use arrow::compute::{is_null, kernels}; + use arrow::datatypes::Fields; use arrow::error::ArrowError; use arrow::util::pretty::pretty_format_columns; - use arrow_schema::Fields; use chrono::NaiveDate; use rand::Rng; @@ -7269,4 +7283,88 @@ mod tests { let dictionary_array = dictionary_scalar.to_array().unwrap(); assert!(dictionary_array.is_null(0)); } + + #[test] + fn test_scalar_value_try_new_null() { + let scalars = vec![ + ScalarValue::try_new_null(&DataType::Boolean).unwrap(), + ScalarValue::try_new_null(&DataType::Int8).unwrap(), + ScalarValue::try_new_null(&DataType::Int16).unwrap(), + ScalarValue::try_new_null(&DataType::Int32).unwrap(), + ScalarValue::try_new_null(&DataType::Int64).unwrap(), + ScalarValue::try_new_null(&DataType::UInt8).unwrap(), + ScalarValue::try_new_null(&DataType::UInt16).unwrap(), + ScalarValue::try_new_null(&DataType::UInt32).unwrap(), + ScalarValue::try_new_null(&DataType::UInt64).unwrap(), + ScalarValue::try_new_null(&DataType::Float16).unwrap(), + ScalarValue::try_new_null(&DataType::Float32).unwrap(), + ScalarValue::try_new_null(&DataType::Float64).unwrap(), + ScalarValue::try_new_null(&DataType::Decimal128(42, 42)).unwrap(), + ScalarValue::try_new_null(&DataType::Decimal256(42, 42)).unwrap(), + ScalarValue::try_new_null(&DataType::Utf8).unwrap(), + ScalarValue::try_new_null(&DataType::LargeUtf8).unwrap(), + ScalarValue::try_new_null(&DataType::Utf8View).unwrap(), + ScalarValue::try_new_null(&DataType::Binary).unwrap(), + ScalarValue::try_new_null(&DataType::BinaryView).unwrap(), + ScalarValue::try_new_null(&DataType::FixedSizeBinary(42)).unwrap(), + ScalarValue::try_new_null(&DataType::LargeBinary).unwrap(), + ScalarValue::try_new_null(&DataType::Date32).unwrap(), + ScalarValue::try_new_null(&DataType::Date64).unwrap(), + ScalarValue::try_new_null(&DataType::Time32(TimeUnit::Second)).unwrap(), + ScalarValue::try_new_null(&DataType::Time32(TimeUnit::Millisecond)).unwrap(), + ScalarValue::try_new_null(&DataType::Time64(TimeUnit::Microsecond)).unwrap(), + ScalarValue::try_new_null(&DataType::Time64(TimeUnit::Nanosecond)).unwrap(), + ScalarValue::try_new_null(&DataType::Timestamp(TimeUnit::Second, None)) + .unwrap(), + ScalarValue::try_new_null(&DataType::Timestamp(TimeUnit::Millisecond, None)) + .unwrap(), + ScalarValue::try_new_null(&DataType::Timestamp(TimeUnit::Microsecond, None)) + .unwrap(), + ScalarValue::try_new_null(&DataType::Timestamp(TimeUnit::Nanosecond, None)) + .unwrap(), + ScalarValue::try_new_null(&DataType::Interval(IntervalUnit::YearMonth)) + .unwrap(), + ScalarValue::try_new_null(&DataType::Interval(IntervalUnit::DayTime)) + .unwrap(), + ScalarValue::try_new_null(&DataType::Interval(IntervalUnit::MonthDayNano)) + .unwrap(), + ScalarValue::try_new_null(&DataType::Duration(TimeUnit::Second)).unwrap(), + ScalarValue::try_new_null(&DataType::Duration(TimeUnit::Microsecond)) + .unwrap(), + ScalarValue::try_new_null(&DataType::Duration(TimeUnit::Nanosecond)).unwrap(), + ScalarValue::try_new_null(&DataType::Null).unwrap(), + ]; + assert!(scalars.iter().all(|s| s.is_null())); + + let field_ref = Arc::new(Field::new("foo", DataType::Int32, true)); + let map_field_ref = Arc::new(Field::new( + "foo", + DataType::Struct(Fields::from(vec![ + Field::new("bar", DataType::Utf8, true), + Field::new("baz", DataType::Int32, true), + ])), + true, + )); + let scalars = vec![ + ScalarValue::try_new_null(&DataType::List(Arc::clone(&field_ref))).unwrap(), + ScalarValue::try_new_null(&DataType::LargeList(Arc::clone(&field_ref))) + .unwrap(), + ScalarValue::try_new_null(&DataType::FixedSizeList( + Arc::clone(&field_ref), + 42, + )) + .unwrap(), + ScalarValue::try_new_null(&DataType::Struct( + vec![Arc::clone(&field_ref)].into(), + )) + .unwrap(), + ScalarValue::try_new_null(&DataType::Map(map_field_ref, false)).unwrap(), + ScalarValue::try_new_null(&DataType::Union( + UnionFields::new(vec![42], vec![field_ref]), + UnionMode::Dense, + )) + .unwrap(), + ]; + assert!(scalars.iter().all(|s| s.is_null())); + } } diff --git a/datafusion/common/src/scalar/struct_builder.rs b/datafusion/common/src/scalar/struct_builder.rs index 4a6a8f0289a7..5ed464018401 100644 --- a/datafusion/common/src/scalar/struct_builder.rs +++ b/datafusion/common/src/scalar/struct_builder.rs @@ -20,8 +20,7 @@ use crate::error::_internal_err; use crate::{Result, ScalarValue}; use arrow::array::{ArrayRef, StructArray}; -use arrow::datatypes::{DataType, FieldRef, Fields}; -use arrow_schema::Field; +use arrow::datatypes::{DataType, Field, FieldRef, Fields}; use std::sync::Arc; /// Builder for [`ScalarValue::Struct`]. diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index dd8848d24923..5b841db53c5e 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -21,7 +21,7 @@ use std::fmt::{self, Debug, Display}; use crate::{Result, ScalarValue}; -use arrow_schema::{DataType, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Schema, SchemaRef}; /// Represents a value with a degree of certainty. `Precision` is used to /// propagate information the precision of statistical values. diff --git a/datafusion/common/src/test_util.rs b/datafusion/common/src/test_util.rs index 22a7d87a8949..a1f883f20525 100644 --- a/datafusion/common/src/test_util.rs +++ b/datafusion/common/src/test_util.rs @@ -338,9 +338,9 @@ macro_rules! create_array { macro_rules! record_batch { ($(($name: expr, $type: ident, $values: expr)),*) => { { - let schema = std::sync::Arc::new(arrow_schema::Schema::new(vec![ + let schema = std::sync::Arc::new(arrow::datatypes::Schema::new(vec![ $( - arrow_schema::Field::new($name, arrow_schema::DataType::$type, true), + arrow::datatypes::Field::new($name, arrow::datatypes::DataType::$type, true), )* ])); diff --git a/datafusion/common/src/types/field.rs b/datafusion/common/src/types/field.rs index 85c7c157272a..5a880ba10a41 100644 --- a/datafusion/common/src/types/field.rs +++ b/datafusion/common/src/types/field.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow_schema::{Field, Fields, UnionFields}; +use arrow::datatypes::{Field, Fields, UnionFields}; use std::hash::{Hash, Hasher}; use std::{ops::Deref, sync::Arc}; diff --git a/datafusion/common/src/types/logical.rs b/datafusion/common/src/types/logical.rs index a65392cae344..884ce20fd9e2 100644 --- a/datafusion/common/src/types/logical.rs +++ b/datafusion/common/src/types/logical.rs @@ -17,7 +17,7 @@ use super::NativeType; use crate::error::Result; -use arrow_schema::DataType; +use arrow::datatypes::DataType; use core::fmt; use std::{cmp::Ordering, hash::Hash, sync::Arc}; diff --git a/datafusion/common/src/types/native.rs b/datafusion/common/src/types/native.rs index c5f180a15035..39c79b4b9974 100644 --- a/datafusion/common/src/types/native.rs +++ b/datafusion/common/src/types/native.rs @@ -21,7 +21,7 @@ use super::{ }; use crate::error::{Result, _internal_err}; use arrow::compute::can_cast_types; -use arrow_schema::{ +use arrow::datatypes::{ DataType, Field, FieldRef, Fields, IntervalUnit, TimeUnit, UnionFields, }; use std::{fmt::Display, sync::Arc}; @@ -126,7 +126,7 @@ pub enum NativeType { /// nevertheless correct). /// /// ``` - /// # use arrow_schema::{DataType, TimeUnit}; + /// # use arrow::datatypes::{DataType, TimeUnit}; /// DataType::Timestamp(TimeUnit::Second, None); /// DataType::Timestamp(TimeUnit::Second, Some("literal".into())); /// DataType::Timestamp(TimeUnit::Second, Some("string".to_string().into())); @@ -198,6 +198,11 @@ impl LogicalType for NativeType { TypeSignature::Native(self) } + /// Returns the default casted type for the given arrow type + /// + /// For types like String or Date, multiple arrow types mapped to the same logical type + /// If the given arrow type is one of them, we return the same type + /// Otherwise, we define the default casted type for the given arrow type fn default_cast_for(&self, origin: &DataType) -> Result { use DataType::*; @@ -226,6 +231,10 @@ impl LogicalType for NativeType { (Self::Decimal(p, s), _) if p <= &38 => Decimal128(*p, *s), (Self::Decimal(p, s), _) => Decimal256(*p, *s), (Self::Timestamp(tu, tz), _) => Timestamp(*tu, tz.clone()), + // If given type is Date, return the same type + (Self::Date, origin) if matches!(origin, Date32 | Date64) => { + origin.to_owned() + } (Self::Date, _) => Date32, (Self::Time(tu), _) => match tu { TimeUnit::Second | TimeUnit::Millisecond => Time32(*tu), diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index f2377cc5410a..ff9cdedab8b1 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -22,7 +22,7 @@ pub mod memory; pub mod proxy; pub mod string_utils; -use crate::error::{_internal_datafusion_err, _internal_err}; +use crate::error::{_exec_datafusion_err, _internal_datafusion_err, _internal_err}; use crate::{DataFusionError, Result, ScalarValue}; use arrow::array::{ cast::AsArray, Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray, @@ -30,8 +30,7 @@ use arrow::array::{ }; use arrow::buffer::OffsetBuffer; use arrow::compute::{partition, SortColumn, SortOptions}; -use arrow::datatypes::{Field, SchemaRef}; -use arrow_schema::DataType; +use arrow::datatypes::{DataType, Field, SchemaRef}; use sqlparser::ast::Ident; use sqlparser::dialect::GenericDialect; use sqlparser::parser::Parser; @@ -591,6 +590,13 @@ pub fn base_type(data_type: &DataType) -> DataType { } } +/// Information about how to coerce lists. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +pub enum ListCoercion { + /// [`DataType::FixedSizeList`] should be coerced to [`DataType::List`]. + FixedSizedListToList, +} + /// A helper function to coerce base type in List. /// /// Example @@ -601,16 +607,22 @@ pub fn base_type(data_type: &DataType) -> DataType { /// /// let data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); /// let base_type = DataType::Float64; -/// let coerced_type = coerced_type_with_base_type_only(&data_type, &base_type); +/// let coerced_type = coerced_type_with_base_type_only(&data_type, &base_type, None); /// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new_list_field(DataType::Float64, true)))); pub fn coerced_type_with_base_type_only( data_type: &DataType, base_type: &DataType, + array_coercion: Option<&ListCoercion>, ) -> DataType { - match data_type { - DataType::List(field) | DataType::FixedSizeList(field, _) => { - let field_type = - coerced_type_with_base_type_only(field.data_type(), base_type); + match (data_type, array_coercion) { + (DataType::List(field), _) + | (DataType::FixedSizeList(field, _), Some(ListCoercion::FixedSizedListToList)) => + { + let field_type = coerced_type_with_base_type_only( + field.data_type(), + base_type, + array_coercion, + ); DataType::List(Arc::new(Field::new( field.name(), @@ -618,9 +630,24 @@ pub fn coerced_type_with_base_type_only( field.is_nullable(), ))) } - DataType::LargeList(field) => { - let field_type = - coerced_type_with_base_type_only(field.data_type(), base_type); + (DataType::FixedSizeList(field, len), _) => { + let field_type = coerced_type_with_base_type_only( + field.data_type(), + base_type, + array_coercion, + ); + + DataType::FixedSizeList( + Arc::new(Field::new(field.name(), field_type, field.is_nullable())), + *len, + ) + } + (DataType::LargeList(field), _) => { + let field_type = coerced_type_with_base_type_only( + field.data_type(), + base_type, + array_coercion, + ); DataType::LargeList(Arc::new(Field::new( field.name(), @@ -735,6 +762,27 @@ pub mod datafusion_strsim { pub fn levenshtein(a: &str, b: &str) -> usize { generic_levenshtein(&StringWrapper(a), &StringWrapper(b)) } + + /// Calculates the normalized Levenshtein distance between two strings. + /// The normalized distance is a value between 0.0 and 1.0, where 1.0 indicates + /// that the strings are identical and 0.0 indicates no similarity. + /// + /// ``` + /// use datafusion_common::utils::datafusion_strsim::normalized_levenshtein; + /// + /// assert!((normalized_levenshtein("kitten", "sitting") - 0.57142).abs() < 0.00001); + /// + /// assert!(normalized_levenshtein("", "second").abs() < 0.00001); + /// + /// assert!((normalized_levenshtein("kitten", "sitten") - 0.833).abs() < 0.001); + /// ``` + pub fn normalized_levenshtein(a: &str, b: &str) -> f64 { + if a.is_empty() && b.is_empty() { + return 1.0; + } + 1.0 - (levenshtein(a, b) as f64) + / (a.chars().count().max(b.chars().count()) as f64) + } } /// Merges collections `first` and `second`, removes duplicates and sorts the @@ -885,6 +933,45 @@ pub fn get_available_parallelism() -> usize { .get() } +/// Converts a collection of function arguments into an fixed-size array of length N +/// producing a reasonable error message in case of unexpected number of arguments. +/// +/// # Example +/// ``` +/// # use datafusion_common::Result; +/// # use datafusion_common::utils::take_function_args; +/// # use datafusion_common::ScalarValue; +/// fn my_function(args: &[ScalarValue]) -> Result<()> { +/// // function expects 2 args, so create a 2-element array +/// let [arg1, arg2] = take_function_args("my_function", args)?; +/// // ... do stuff.. +/// Ok(()) +/// } +/// +/// // Calling the function with 1 argument produces an error: +/// let args = vec![ScalarValue::Int32(Some(10))]; +/// let err = my_function(&args).unwrap_err(); +/// assert_eq!(err.to_string(), "Execution error: my_function function requires 2 arguments, got 1"); +/// // Calling the function with 2 arguments works great +/// let args = vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(20))]; +/// my_function(&args).unwrap(); +/// ``` +pub fn take_function_args( + function_name: &str, + args: impl IntoIterator, +) -> Result<[T; N]> { + let args = args.into_iter().collect::>(); + args.try_into().map_err(|v: Vec| { + _exec_datafusion_err!( + "{} function requires {} {}, got {}", + function_name, + N, + if N == 1 { "argument" } else { "arguments" }, + v.len() + ) + }) +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index bbd999ffe98b..8a706ca19f4d 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -43,7 +43,7 @@ array_expressions = ["nested_expressions"] # Used to enable the avro format avro = ["apache-avro", "num-traits", "datafusion-common/avro"] backtrace = ["datafusion-common/backtrace"] -compression = ["xz2", "bzip2", "flate2", "zstd", "async-compression", "tokio-util"] +compression = ["xz2", "bzip2", "flate2", "zstd", "datafusion-datasource/compression"] crypto_expressions = ["datafusion-functions/crypto_expressions"] datetime_expressions = ["datafusion-functions/datetime_expressions"] default = [ @@ -74,7 +74,7 @@ recursive_protection = [ "datafusion-physical-optimizer/recursive_protection", "datafusion-sql/recursive_protection", ] -serde = ["arrow-schema/serde"] +serde = ["dep:serde"] string_expressions = ["datafusion-functions/string_expressions"] unicode_expressions = [ "datafusion-sql/unicode_expressions", @@ -87,28 +87,24 @@ apache-avro = { version = "0.17", optional = true } arrow = { workspace = true } arrow-ipc = { workspace = true } arrow-schema = { workspace = true } -async-compression = { version = "0.4.0", features = [ - "bzip2", - "gzip", - "xz", - "zstd", - "tokio", -], optional = true } async-trait = { workspace = true } bytes = { workspace = true } -bzip2 = { version = "0.5.0", optional = true } +bzip2 = { version = "0.5.1", optional = true } chrono = { workspace = true } datafusion-catalog = { workspace = true } datafusion-catalog-listing = { workspace = true } datafusion-common = { workspace = true, features = ["object_store"] } datafusion-common-runtime = { workspace = true } +datafusion-datasource = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } +datafusion-expr-common = { workspace = true } datafusion-functions = { workspace = true } datafusion-functions-aggregate = { workspace = true } datafusion-functions-nested = { workspace = true, optional = true } datafusion-functions-table = { workspace = true } datafusion-functions-window = { workspace = true } +datafusion-macros = { workspace = true } datafusion-optimizer = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } @@ -117,7 +113,6 @@ datafusion-physical-plan = { workspace = true } datafusion-sql = { workspace = true } flate2 = { version = "1.0.24", optional = true } futures = { workspace = true } -glob = "0.3.0" itertools = { workspace = true } log = { workspace = true } num-traits = { version = "0.2", optional = true } @@ -126,12 +121,12 @@ parking_lot = { workspace = true } parquet = { workspace = true, optional = true, default-features = true } rand = { workspace = true } regex = { workspace = true } +serde = { version = "1.0", default-features = false, features = ["derive"], optional = true } sqlparser = { workspace = true } tempfile = { workspace = true } tokio = { workspace = true } -tokio-util = { version = "0.7.4", features = ["io"], optional = true } url = { workspace = true } -uuid = { version = "1.7", features = ["v4", "js"] } +uuid = { version = "1.13", features = ["v4", "js"] } xz2 = { version = "0.1", optional = true, features = ["static"] } zstd = { version = "0.13", optional = true, default-features = false } @@ -139,6 +134,8 @@ zstd = { version = "0.13", optional = true, default-features = false } async-trait = { workspace = true } criterion = { version = "0.5", features = ["async_tokio"] } ctor = { workspace = true } +dashmap = "6.1.0" +datafusion-doc = { workspace = true } datafusion-functions-window-common = { workspace = true } datafusion-physical-optimizer = { workspace = true } doc-comment = { workspace = true } @@ -221,3 +218,7 @@ name = "topk_aggregate" harness = false name = "map_query_sql" required-features = ["nested_expressions"] + +[[bench]] +harness = false +name = "dataframe" diff --git a/datafusion/core/benches/dataframe.rs b/datafusion/core/benches/dataframe.rs new file mode 100644 index 000000000000..087764883a33 --- /dev/null +++ b/datafusion/core/benches/dataframe.rs @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate arrow; +#[macro_use] +extern crate criterion; +extern crate datafusion; + +use arrow_schema::{DataType, Field, Schema}; +use criterion::Criterion; +use datafusion::datasource::MemTable; +use datafusion::prelude::SessionContext; +use datafusion_expr::col; +use datafusion_functions::expr_fn::btrim; +use std::sync::Arc; +use tokio::runtime::Runtime; + +fn create_context(field_count: u32) -> datafusion_common::Result> { + let mut fields = vec![]; + for i in 0..field_count { + fields.push(Field::new(format!("str{}", i), DataType::Utf8, true)) + } + + let schema = Arc::new(Schema::new(fields)); + let ctx = SessionContext::new(); + let table = MemTable::try_new(Arc::clone(&schema), vec![vec![]])?; + + ctx.register_table("t", Arc::new(table))?; + + Ok(Arc::new(ctx)) +} + +fn run(column_count: u32, ctx: Arc) { + let rt = Runtime::new().unwrap(); + + criterion::black_box(rt.block_on(async { + let mut data_frame = ctx.table("t").await.unwrap(); + + for i in 0..column_count { + let field_name = &format!("str{}", i); + let new_field_name = &format!("newstr{}", i); + + data_frame = data_frame + .with_column_renamed(field_name, new_field_name) + .unwrap(); + data_frame = data_frame + .with_column(new_field_name, btrim(vec![col(new_field_name)])) + .unwrap(); + } + + Some(true) + })) + .unwrap(); +} + +fn criterion_benchmark(c: &mut Criterion) { + // 500 takes far too long right now + for column_count in [10, 100, 200 /* 500 */] { + let ctx = create_context(column_count).unwrap(); + + c.bench_function(&format!("with_column_{column_count}"), |b| { + b.iter(|| run(column_count, ctx.clone())) + }); + } +} + +criterion_group! { + name = benches; + config = Criterion::default().sample_size(10); + targets = criterion_benchmark +} +criterion_main!(benches); diff --git a/datafusion/core/benches/sql_query_with_io.rs b/datafusion/core/benches/sql_query_with_io.rs index 58f8409313aa..58d71ee5b2eb 100644 --- a/datafusion/core/benches/sql_query_with_io.rs +++ b/datafusion/core/benches/sql_query_with_io.rs @@ -18,7 +18,7 @@ use std::{fmt::Write, sync::Arc, time::Duration}; use arrow::array::{Int64Builder, RecordBatch, UInt64Builder}; -use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use bytes::Bytes; use criterion::{criterion_group, criterion_main, Criterion, SamplingMode}; use datafusion::{ diff --git a/datafusion/core/src/bin/print_functions_docs.rs b/datafusion/core/src/bin/print_functions_docs.rs index 8b453d5e9698..7afb90282a80 100644 --- a/datafusion/core/src/bin/print_functions_docs.rs +++ b/datafusion/core/src/bin/print_functions_docs.rs @@ -193,7 +193,7 @@ fn print_docs( {} -``` +```sql {} ``` "#, diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 9731b8784076..6f540fa02c75 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -25,7 +25,9 @@ use crate::arrow::util::pretty; use crate::datasource::file_format::csv::CsvFormatFactory; use crate::datasource::file_format::format_as_file_type; use crate::datasource::file_format::json::JsonFormatFactory; -use crate::datasource::{provider_as_source, MemTable, TableProvider}; +use crate::datasource::{ + provider_as_source, DefaultTableSource, MemTable, TableProvider, +}; use crate::error::Result; use crate::execution::context::{SessionState, TaskContext}; use crate::execution::FunctionRegistry; @@ -45,8 +47,7 @@ use std::sync::Arc; use arrow::array::{Array, ArrayRef, Int64Array, StringArray}; use arrow::compute::{cast, concat}; -use arrow::datatypes::{DataType, Field}; -use arrow_schema::{Schema, SchemaRef}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::config::{CsvOptions, JsonOptions}; use datafusion_common::{ exec_err, not_impl_err, plan_err, Column, DFSchema, DataFusionError, ParamValues, @@ -63,6 +64,7 @@ use datafusion_functions_aggregate::expr_fn::{ use async_trait::async_trait; use datafusion_catalog::Session; +use datafusion_sql::TableReference; /// Contains options that control how data is /// written out from a DataFrame @@ -1527,8 +1529,6 @@ impl DataFrame { table_name: &str, write_options: DataFrameWriteOptions, ) -> Result, DataFusionError> { - let arrow_schema = Schema::from(self.schema()); - let plan = if write_options.sort_by.is_empty() { self.plan } else { @@ -1537,10 +1537,19 @@ impl DataFrame { .build()? }; + let table_ref: TableReference = table_name.into(); + let table_schema = self.session_state.schema_for_ref(table_ref.clone())?; + let target = match table_schema.table(table_ref.table()).await? { + Some(ref provider) => Ok(Arc::clone(provider)), + _ => plan_err!("No table named '{table_name}'"), + }?; + + let target = Arc::new(DefaultTableSource::new(target)); + let plan = LogicalPlanBuilder::insert_into( plan, - table_name.to_owned(), - &arrow_schema, + table_ref, + target, write_options.insert_op, )? .build()?; @@ -1802,7 +1811,8 @@ impl DataFrame { .iter() .map(|(qualifier, field)| { if qualifier.eq(&qualifier_rename) && field.as_ref() == field_rename { - col(Column::from((qualifier, field))).alias(new_name) + col(Column::from((qualifier, field))) + .alias_qualified(qualifier.cloned(), new_name) } else { col(Column::from((qualifier, field))) } diff --git a/datafusion/core/src/datasource/data_source.rs b/datafusion/core/src/datasource/data_source.rs index 03bfb4175022..d31b68019e30 100644 --- a/datafusion/core/src/datasource/data_source.rs +++ b/datafusion/core/src/datasource/data_source.rs @@ -24,7 +24,7 @@ use std::sync::Arc; use crate::datasource::physical_plan::{FileOpener, FileScanConfig}; -use arrow_schema::SchemaRef; +use arrow::datatypes::SchemaRef; use datafusion_common::Statistics; use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion_physical_plan::DisplayFormatType; diff --git a/datafusion/core/src/datasource/default_table_source.rs b/datafusion/core/src/datasource/default_table_source.rs index 91c1e0ac97fc..541e0b6dfa91 100644 --- a/datafusion/core/src/datasource/default_table_source.rs +++ b/datafusion/core/src/datasource/default_table_source.rs @@ -26,12 +26,15 @@ use arrow::datatypes::SchemaRef; use datafusion_common::{internal_err, Constraints}; use datafusion_expr::{Expr, TableProviderFilterPushDown, TableSource, TableType}; -/// DataFusion default table source, wrapping TableProvider. +/// Implements [`TableSource`] for a [`TableProvider`] /// -/// This structure adapts a `TableProvider` (physical plan trait) to the `TableSource` -/// (logical plan trait) and is necessary because the logical plan is contained in -/// the `datafusion_expr` crate, and is not aware of table providers, which exist in -/// the core `datafusion` crate. +/// This structure adapts a [`TableProvider`] (a physical plan trait) to the +/// [`TableSource`] (logical plan trait). +/// +/// It is used so logical plans in the `datafusion_expr` crate do not have a +/// direct dependency on physical plans, such as [`TableProvider`]s. +/// +/// [`TableProvider`]: https://docs.rs/datafusion/latest/datafusion/datasource/provider/trait.TableProvider.html pub struct DefaultTableSource { /// table provider pub table_provider: Arc, diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index 339199310ca6..efe4cce6d8a8 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -37,11 +37,12 @@ use crate::datasource::physical_plan::{ use crate::error::Result; use crate::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; +use arrow::datatypes::{Schema, SchemaRef}; +use arrow::error::ArrowError; use arrow::ipc::convert::fb_to_schema; use arrow::ipc::reader::FileReader; use arrow::ipc::writer::IpcWriteOptions; use arrow::ipc::{root_as_message, CompressionType}; -use arrow_schema::{ArrowError, Schema, SchemaRef}; use datafusion_catalog::Session; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ @@ -170,11 +171,10 @@ impl FileFormat for ArrowFormat { async fn create_physical_plan( &self, _state: &dyn Session, - mut conf: FileScanConfig, + conf: FileScanConfig, _filters: Option<&Arc>, ) -> Result> { - conf = conf.with_source(Arc::new(ArrowSource::default())); - Ok(conf.new_exec()) + Ok(conf.with_source(Arc::new(ArrowSource::default())).build()) } async fn create_writer_physical_plan( diff --git a/datafusion/core/src/datasource/file_format/avro.rs b/datafusion/core/src/datasource/file_format/avro.rs index 100aa4fd51e2..c0c8f25722c2 100644 --- a/datafusion/core/src/datasource/file_format/avro.rs +++ b/datafusion/core/src/datasource/file_format/avro.rs @@ -148,11 +148,10 @@ impl FileFormat for AvroFormat { async fn create_physical_plan( &self, _state: &dyn Session, - mut conf: FileScanConfig, + conf: FileScanConfig, _filters: Option<&Arc>, ) -> Result> { - conf = conf.with_source(self.file_source()); - Ok(conf.new_exec()) + Ok(conf.with_source(self.file_source()).build()) } fn file_source(&self) -> Arc { diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 2e30d2a3b196..565a18515919 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -44,7 +44,7 @@ use crate::physical_plan::{ use arrow::array::RecordBatch; use arrow::csv::WriterBuilder; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; -use arrow_schema::ArrowError; +use arrow::error::ArrowError; use datafusion_catalog::Session; use datafusion_common::config::{ConfigField, ConfigFileType, CsvOptions}; use datafusion_common::file_options::csv_writer::CsvWriterOptions; @@ -434,9 +434,7 @@ impl FileFormat for CsvFormat { .with_terminator(self.options.terminator) .with_comment(self.options.comment), ); - conf = conf.with_source(source); - - Ok(conf.new_exec()) + Ok(conf.with_source(source).build()) } async fn create_writer_physical_plan( diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index 6648e48159ea..04f380441d66 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -43,11 +43,10 @@ use crate::physical_plan::{ }; use arrow::array::RecordBatch; -use arrow::datatypes::Schema; -use arrow::datatypes::SchemaRef; +use arrow::datatypes::{Schema, SchemaRef}; +use arrow::error::ArrowError; use arrow::json; use arrow::json::reader::{infer_json_schema_from_iterator, ValueIter}; -use arrow_schema::ArrowError; use datafusion_catalog::Session; use datafusion_common::config::{ConfigField, ConfigFileType, JsonOptions}; use datafusion_common::file_options::json_writer::JsonWriterOptions; @@ -255,9 +254,7 @@ impl FileFormat for JsonFormat { ) -> Result> { let source = Arc::new(JsonSource::new()); conf.file_compression_type = FileCompressionType::from(self.options.compression); - conf = conf.with_source(source); - - Ok(conf.new_exec()) + Ok(conf.with_source(source).build()) } async fn create_writer_physical_plan( @@ -438,9 +435,9 @@ mod tests { use crate::test::object_store::local_unpartitioned_file; use arrow::compute::concat_batches; + use arrow::datatypes::{DataType, Field}; use arrow::json::ReaderBuilder; use arrow::util::pretty; - use arrow_schema::{DataType, Field}; use datafusion_common::cast::as_int64_array; use datafusion_common::stats::Precision; use datafusion_common::{assert_batches_eq, internal_err}; diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index 0caf363f106d..46fee53cfb4e 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -24,12 +24,12 @@ pub const DEFAULT_SCHEMA_INFER_MAX_RECORD: usize = 1000; pub mod arrow; pub mod avro; pub mod csv; -pub mod file_compression_type; pub mod json; pub mod options; #[cfg(feature = "parquet")] pub mod parquet; -pub mod write; +pub use datafusion_datasource::file_compression_type; +pub use datafusion_datasource::write; use std::any::Any; use std::collections::{HashMap, VecDeque}; @@ -38,12 +38,12 @@ use std::sync::Arc; use std::task::Poll; use crate::arrow::array::RecordBatch; -use crate::arrow::datatypes::SchemaRef; +use crate::arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}; +use crate::arrow::error::ArrowError; use crate::datasource::physical_plan::{FileScanConfig, FileSinkConfig}; use crate::error::Result; use crate::physical_plan::{ExecutionPlan, Statistics}; -use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema}; use datafusion_catalog::Session; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{internal_err, not_impl_err, GetExt}; diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 13a57278c981..902174b6d664 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -398,7 +398,7 @@ impl FileFormat for ParquetFormat { async fn create_physical_plan( &self, _state: &dyn Session, - mut conf: FileScanConfig, + conf: FileScanConfig, filters: Option<&Arc>, ) -> Result> { let mut predicate = None; @@ -424,8 +424,7 @@ impl FileFormat for ParquetFormat { if let Some(metadata_size_hint) = metadata_size_hint { source = source.with_metadata_size_hint(metadata_size_hint) } - conf = conf.with_source(Arc::new(source)); - Ok(conf.new_exec()) + Ok(conf.with_source(Arc::new(source)).build()) } async fn create_writer_physical_plan( @@ -1313,7 +1312,7 @@ mod tests { types::Int32Type, Array, ArrayRef, DictionaryArray, Int32Array, Int64Array, StringArray, }; - use arrow_schema::{DataType, Field}; + use arrow::datatypes::{DataType, Field}; use async_trait::async_trait; use datafusion_common::cast::{ as_binary_array, as_binary_view_array, as_boolean_array, as_float32_array, diff --git a/datafusion/core/src/datasource/listing/mod.rs b/datafusion/core/src/datasource/listing/mod.rs index 39323b993d45..a58db55bccb6 100644 --- a/datafusion/core/src/datasource/listing/mod.rs +++ b/datafusion/core/src/datasource/listing/mod.rs @@ -19,5 +19,8 @@ //! to get the list of files to process. mod table; -pub use datafusion_catalog_listing::*; +pub use datafusion_catalog_listing::helpers; +pub use datafusion_datasource::{ + FileRange, ListingTableUrl, PartitionedFile, PartitionedFileStream, +}; pub use table::{ListingOptions, ListingTable, ListingTableConfig}; diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 4d898e6a24a4..e38bb6bccabc 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -998,27 +998,8 @@ impl TableProvider for ListingTable { insert_op: InsertOp, ) -> Result> { // Check that the schema of the plan matches the schema of this table. - if !self - .schema() - .logically_equivalent_names_and_types(&input.schema()) - { - // Return an error if schema of the input query does not match with the table schema. - return plan_err!( - "Inserting query must have the same schema with the table. \ - Expected: {:?}, got: {:?}", - self.schema() - .fields() - .iter() - .map(|field| field.data_type()) - .collect::>(), - input - .schema() - .fields() - .iter() - .map(|field| field.data_type()) - .collect::>() - ); - } + self.schema() + .logically_equivalent_names_and_types(&input.schema())?; let table_path = &self.table_paths()[0]; if !table_path.is_collection() { @@ -1197,7 +1178,7 @@ mod tests { use crate::datasource::file_format::json::JsonFormat; #[cfg(feature = "parquet")] use crate::datasource::file_format::parquet::ParquetFormat; - use crate::datasource::{provider_as_source, MemTable}; + use crate::datasource::{provider_as_source, DefaultTableSource, MemTable}; use crate::execution::options::ArrowReadOptions; use crate::prelude::*; use crate::{ @@ -1206,8 +1187,8 @@ mod tests { }; use datafusion_physical_plan::collect; + use arrow::compute::SortOptions; use arrow::record_batch::RecordBatch; - use arrow_schema::SortOptions; use datafusion_common::stats::Precision; use datafusion_common::{assert_contains, ScalarValue}; use datafusion_expr::{BinaryExpr, LogicalPlanBuilder, Operator}; @@ -2067,6 +2048,8 @@ mod tests { session_ctx.register_table("source", source_table.clone())?; // Convert the source table into a provider so that it can be used in a query let source = provider_as_source(source_table); + let target = session_ctx.table_provider("t").await?; + let target = Arc::new(DefaultTableSource::new(target)); // Create a table scan logical plan to read from the source table let scan_plan = LogicalPlanBuilder::scan("source", source, None)? .filter(filter_predicate)? @@ -2075,7 +2058,7 @@ mod tests { // Therefore, we will have 8 partitions in the final plan. // Create an insert plan to insert the source data into the initial table let insert_into_table = - LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, InsertOp::Append)? + LogicalPlanBuilder::insert_into(scan_plan, "t", target, InsertOp::Append)? .build()?; // Create a physical plan from the insert plan let plan = session_ctx diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index a996990105b3..94c6e45804e8 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -278,26 +278,9 @@ impl TableProvider for MemTable { // Create a physical plan from the logical plan. // Check that the schema of the plan matches the schema of this table. - if !self - .schema() - .logically_equivalent_names_and_types(&input.schema()) - { - return plan_err!( - "Inserting query must have the same schema with the table. \ - Expected: {:?}, got: {:?}", - self.schema() - .fields() - .iter() - .map(|field| field.data_type()) - .collect::>(), - input - .schema() - .fields() - .iter() - .map(|field| field.data_type()) - .collect::>() - ); - } + self.schema() + .logically_equivalent_names_and_types(&input.schema())?; + if insert_op != InsertOp::Append { return not_impl_err!("{insert_op} not implemented for MemoryTable yet"); } @@ -390,7 +373,7 @@ impl DataSink for MemSink { mod tests { use super::*; - use crate::datasource::provider_as_source; + use crate::datasource::{provider_as_source, DefaultTableSource}; use crate::physical_plan::collect; use crate::prelude::SessionContext; @@ -640,6 +623,7 @@ mod tests { // Create and register the initial table with the provided schema and data let initial_table = Arc::new(MemTable::try_new(schema.clone(), initial_data)?); session_ctx.register_table("t", initial_table.clone())?; + let target = Arc::new(DefaultTableSource::new(initial_table.clone())); // Create and register the source table with the provided schema and inserted data let source_table = Arc::new(MemTable::try_new(schema.clone(), inserted_data)?); session_ctx.register_table("source", source_table.clone())?; @@ -649,7 +633,7 @@ mod tests { let scan_plan = LogicalPlanBuilder::scan("source", source, None)?.build()?; // Create an insert plan to insert the source data into the initial table let insert_into_table = - LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, InsertOp::Append)? + LogicalPlanBuilder::insert_into(scan_plan, "t", target, InsertOp::Append)? .build()?; // Create a physical plan from the insert plan let plan = session_ctx diff --git a/datafusion/core/src/datasource/mod.rs b/datafusion/core/src/datasource/mod.rs index 55df55ae3543..12dd9d7cab38 100644 --- a/datafusion/core/src/datasource/mod.rs +++ b/datafusion/core/src/datasource/mod.rs @@ -47,7 +47,8 @@ pub use crate::logical_expr::TableType; pub use datafusion_execution::object_store; pub use statistics::get_statistics_with_limit; -use arrow_schema::{Schema, SortOptions}; +use arrow::compute::SortOptions; +use arrow::datatypes::Schema; use datafusion_common::{plan_err, Result}; use datafusion_expr::{Expr, SortExpr}; use datafusion_physical_expr::{expressions, LexOrdering, PhysicalSortExpr}; diff --git a/datafusion/core/src/datasource/physical_plan/arrow_file.rs b/datafusion/core/src/datasource/physical_plan/arrow_file.rs index 1a486a54ca39..4a7cdc192cd3 100644 --- a/datafusion/core/src/datasource/physical_plan/arrow_file.rs +++ b/datafusion/core/src/datasource/physical_plan/arrow_file.rs @@ -28,8 +28,8 @@ use crate::datasource::physical_plan::{ use crate::error::Result; use arrow::buffer::Buffer; +use arrow::datatypes::SchemaRef; use arrow_ipc::reader::FileDecoder; -use arrow_schema::SchemaRef; use datafusion_common::config::ConfigOptions; use datafusion_common::{Constraints, Statistics}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; @@ -321,7 +321,7 @@ impl FileOpener for ArrowOpener { footer_buf[..footer_len].try_into().unwrap(), ) .map_err(|err| { - arrow_schema::ArrowError::ParseError(format!( + arrow::error::ArrowError::ParseError(format!( "Unable to get root as footer: {err:?}" )) })?; diff --git a/datafusion/core/src/datasource/physical_plan/avro.rs b/datafusion/core/src/datasource/physical_plan/avro.rs index b148c412c48e..b0a1d8c8c9e2 100644 --- a/datafusion/core/src/datasource/physical_plan/avro.rs +++ b/datafusion/core/src/datasource/physical_plan/avro.rs @@ -265,8 +265,8 @@ impl FileSource for AvroSource { #[cfg(feature = "avro")] mod private { use super::*; - use crate::datasource::physical_plan::file_stream::{FileOpenFuture, FileOpener}; use crate::datasource::physical_plan::FileMeta; + use crate::datasource::physical_plan::{FileOpenFuture, FileOpener}; use bytes::Buf; use futures::StreamExt; @@ -399,7 +399,7 @@ mod tests { .with_file(meta.into()) .with_projection(Some(vec![0, 1, 2])); - let source_exec = conf.new_exec(); + let source_exec = conf.build(); assert_eq!( source_exec .properties() @@ -472,7 +472,7 @@ mod tests { .with_file(meta.into()) .with_projection(projection); - let source_exec = conf.new_exec(); + let source_exec = conf.build(); assert_eq!( source_exec .properties() @@ -546,7 +546,7 @@ mod tests { .with_file(partitioned_file) .with_table_partition_cols(vec![Field::new("date", DataType::Utf8, false)]); - let source_exec = conf.new_exec(); + let source_exec = conf.build(); assert_eq!( source_exec diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index bfc2c1df8eab..c0952229b5e0 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -28,8 +28,8 @@ use crate::datasource::data_source::FileSource; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::file_format::{deserialize_stream, DecoderDeserializer}; use crate::datasource::listing::{FileRange, ListingTableUrl, PartitionedFile}; -use crate::datasource::physical_plan::file_stream::{FileOpenFuture, FileOpener}; use crate::datasource::physical_plan::FileMeta; +use crate::datasource::physical_plan::{FileOpenFuture, FileOpener}; use crate::error::{DataFusionError, Result}; use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; @@ -425,7 +425,7 @@ impl ExecutionPlan for CsvExec { /// let file_scan_config = FileScanConfig::new(object_store_url, file_schema, source) /// .with_file(PartitionedFile::new("file1.csv", 100*1024*1024)) /// .with_newlines_in_values(true); // The file contains newlines in values; -/// let exec = file_scan_config.new_exec(); +/// let exec = file_scan_config.build(); /// ``` #[derive(Debug, Clone, Default)] pub struct CsvSource { @@ -836,14 +836,14 @@ mod tests { )?; let source = Arc::new(CsvSource::new(true, b',', b'"')); - let mut config = partitioned_csv_config(file_schema, file_groups, source) + let config = partitioned_csv_config(file_schema, file_groups, source) .with_file_compression_type(file_compression_type) - .with_newlines_in_values(false); - config.projection = Some(vec![0, 2, 4]); - - let csv = config.new_exec(); + .with_newlines_in_values(false) + .with_projection(Some(vec![0, 2, 4])); assert_eq!(13, config.file_schema.fields().len()); + let csv = config.build(); + assert_eq!(3, csv.schema().fields().len()); let mut stream = csv.execute(0, task_ctx)?; @@ -901,12 +901,12 @@ mod tests { )?; let source = Arc::new(CsvSource::new(true, b',', b'"')); - let mut config = partitioned_csv_config(file_schema, file_groups, source) + let config = partitioned_csv_config(file_schema, file_groups, source) .with_newlines_in_values(false) - .with_file_compression_type(file_compression_type.to_owned()); - config.projection = Some(vec![4, 0, 2]); - let csv = config.new_exec(); + .with_file_compression_type(file_compression_type.to_owned()) + .with_projection(Some(vec![4, 0, 2])); assert_eq!(13, config.file_schema.fields().len()); + let csv = config.build(); assert_eq!(3, csv.schema().fields().len()); let mut stream = csv.execute(0, task_ctx)?; @@ -964,12 +964,12 @@ mod tests { )?; let source = Arc::new(CsvSource::new(true, b',', b'"')); - let mut config = partitioned_csv_config(file_schema, file_groups, source) + let config = partitioned_csv_config(file_schema, file_groups, source) .with_newlines_in_values(false) - .with_file_compression_type(file_compression_type.to_owned()); - config.limit = Some(5); - let csv = config.new_exec(); + .with_file_compression_type(file_compression_type.to_owned()) + .with_limit(Some(5)); assert_eq!(13, config.file_schema.fields().len()); + let csv = config.build(); assert_eq!(13, csv.schema().fields().len()); let mut it = csv.execute(0, task_ctx)?; @@ -1024,12 +1024,12 @@ mod tests { )?; let source = Arc::new(CsvSource::new(true, b',', b'"')); - let mut config = partitioned_csv_config(file_schema, file_groups, source) + let config = partitioned_csv_config(file_schema, file_groups, source) .with_newlines_in_values(false) - .with_file_compression_type(file_compression_type.to_owned()); - config.limit = Some(5); - let csv = config.new_exec(); + .with_file_compression_type(file_compression_type.to_owned()) + .with_limit(Some(5)); assert_eq!(14, config.file_schema.fields().len()); + let csv = config.build(); assert_eq!(14, csv.schema().fields().len()); // errors due to https://github.com/apache/datafusion/issues/4918 @@ -1089,8 +1089,8 @@ mod tests { // we don't have `/date=xx/` in the path but that is ok because // partitions are resolved during scan anyway - let csv = config.new_exec(); assert_eq!(13, config.file_schema.fields().len()); + let csv = config.build(); assert_eq!(2, csv.schema().fields().len()); let mut it = csv.execute(0, task_ctx)?; @@ -1179,7 +1179,7 @@ mod tests { let config = partitioned_csv_config(file_schema, file_groups, source) .with_newlines_in_values(false) .with_file_compression_type(file_compression_type.to_owned()); - let csv = config.new_exec(); + let csv = config.build(); let it = csv.execute(0, task_ctx).unwrap(); let batches: Vec<_> = it.try_collect().await.unwrap(); diff --git a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs index dc9207da51cb..123ecc2f9582 100644 --- a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs +++ b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs @@ -27,24 +27,15 @@ use crate::datasource::{listing::PartitionedFile, object_store::ObjectStoreUrl}; use crate::{error::Result, scalar::ScalarValue}; use std::any::Any; use std::fmt::Formatter; -use std::{ - borrow::Cow, collections::HashMap, fmt, fmt::Debug, marker::PhantomData, - mem::size_of, sync::Arc, vec, -}; +use std::{fmt, sync::Arc}; -use arrow::array::{ - ArrayData, ArrayRef, BufferBuilder, DictionaryArray, RecordBatch, RecordBatchOptions, -}; -use arrow::buffer::Buffer; -use arrow::datatypes::{ArrowNativeType, UInt16Type}; -use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::stats::Precision; -use datafusion_common::{ - exec_err, ColumnStatistics, Constraints, DataFusionError, Statistics, -}; +use datafusion_common::{ColumnStatistics, Constraints, Statistics}; use datafusion_physical_expr::{EquivalenceProperties, LexOrdering, Partitioning}; use crate::datasource::data_source::FileSource; +pub use datafusion_datasource::file_scan_config::*; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_plan::display::{display_orderings, ProjectSchemaDisplay}; use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; @@ -53,7 +44,6 @@ use datafusion_physical_plan::projection::{ }; use datafusion_physical_plan::source::{DataSource, DataSourceExec}; use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; -use log::warn; /// Convert type to a type suitable for use as a [`ListingTable`] /// partition column. Returns `Dictionary(UInt16, val_type)`, which is @@ -78,21 +68,30 @@ pub fn wrap_partition_value_in_dict(val: ScalarValue) -> ScalarValue { ScalarValue::Dictionary(Box::new(DataType::UInt16), Box::new(val)) } -/// The base configurations to provide when creating a physical plan for +/// The base configurations for a [`DataSourceExec`], the a physical plan for /// any given file format. /// +/// Use [`Self::build`] to create a [`DataSourceExec`] from a ``FileScanConfig`. +/// /// # Example /// ``` /// # use std::sync::Arc; -/// # use arrow_schema::Schema; -/// use datafusion::datasource::listing::PartitionedFile; +/// # use arrow::datatypes::{Field, Fields, DataType, Schema}; +/// # use datafusion::datasource::listing::PartitionedFile; /// # use datafusion::datasource::physical_plan::FileScanConfig; /// # use datafusion_execution::object_store::ObjectStoreUrl; /// # use datafusion::datasource::physical_plan::ArrowSource; -/// # let file_schema = Arc::new(Schema::empty()); -/// // create FileScan config for reading data from file:// +/// # use datafusion_physical_plan::ExecutionPlan; +/// # let file_schema = Arc::new(Schema::new(vec![ +/// # Field::new("c1", DataType::Int32, false), +/// # Field::new("c2", DataType::Int32, false), +/// # Field::new("c3", DataType::Int32, false), +/// # Field::new("c4", DataType::Int32, false), +/// # ])); +/// // create FileScan config for reading arrow files from file:// /// let object_store_url = ObjectStoreUrl::local_filesystem(); -/// let config = FileScanConfig::new(object_store_url, file_schema, Arc::new(ArrowSource::default())) +/// let file_source = Arc::new(ArrowSource::default()); +/// let config = FileScanConfig::new(object_store_url, file_schema, file_source) /// .with_limit(Some(1000)) // read only the first 1000 records /// .with_projection(Some(vec![2, 3])) // project columns 2 and 3 /// // Read /tmp/file1.parquet with known size of 1234 bytes in a single group @@ -103,6 +102,8 @@ pub fn wrap_partition_value_in_dict(val: ScalarValue) -> ScalarValue { /// PartitionedFile::new("file2.parquet", 56), /// PartitionedFile::new("file3.parquet", 78), /// ]); +/// // create an execution plan from the config +/// let plan: Arc = config.build(); /// ``` #[derive(Clone)] pub struct FileScanConfig { @@ -262,19 +263,20 @@ impl DataSource for FileScanConfig { // If there is any non-column or alias-carrier expression, Projection should not be removed. // This process can be moved into CsvExec, but it would be an overlap of their responsibility. Ok(all_alias_free_columns(projection.expr()).then(|| { - let mut file_scan = self.clone(); + let file_scan = self.clone(); let source = Arc::clone(&file_scan.source); let new_projections = new_projections_for_columns( projection, &file_scan .projection + .clone() .unwrap_or((0..self.file_schema.fields().len()).collect()), ); - file_scan.projection = Some(new_projections); - // Assign projected statistics to source - file_scan = file_scan.with_source(source); - - file_scan.new_exec() as _ + file_scan + // Assign projected statistics to source + .with_projection(Some(new_projections)) + .with_source(source) + .build() as _ })) } } @@ -584,9 +586,9 @@ impl FileScanConfig { } // TODO: This function should be moved into DataSourceExec once FileScanConfig moved out of datafusion/core - /// Returns a new [`DataSourceExec`] from file configurations - pub fn new_exec(&self) -> Arc { - Arc::new(DataSourceExec::new(Arc::new(self.clone()))) + /// Returns a new [`DataSourceExec`] to scan the files specified by this config + pub fn build(self) -> Arc { + Arc::new(DataSourceExec::new(Arc::new(self))) } /// Write the data_type based on file_source @@ -601,261 +603,13 @@ impl FileScanConfig { } } -/// A helper that projects partition columns into the file record batches. -/// -/// One interesting trick is the usage of a cache for the key buffers of the partition column -/// dictionaries. Indeed, the partition columns are constant, so the dictionaries that represent them -/// have all their keys equal to 0. This enables us to re-use the same "all-zero" buffer across batches, -/// which makes the space consumption of the partition columns O(batch_size) instead of O(record_count). -pub struct PartitionColumnProjector { - /// An Arrow buffer initialized to zeros that represents the key array of all partition - /// columns (partition columns are materialized by dictionary arrays with only one - /// value in the dictionary, thus all the keys are equal to zero). - key_buffer_cache: ZeroBufferGenerators, - /// Mapping between the indexes in the list of partition columns and the target - /// schema. Sorted by index in the target schema so that we can iterate on it to - /// insert the partition columns in the target record batch. - projected_partition_indexes: Vec<(usize, usize)>, - /// The schema of the table once the projection was applied. - projected_schema: SchemaRef, -} - -impl PartitionColumnProjector { - // Create a projector to insert the partitioning columns into batches read from files - // - `projected_schema`: the target schema with both file and partitioning columns - // - `table_partition_cols`: all the partitioning column names - pub fn new(projected_schema: SchemaRef, table_partition_cols: &[String]) -> Self { - let mut idx_map = HashMap::new(); - for (partition_idx, partition_name) in table_partition_cols.iter().enumerate() { - if let Ok(schema_idx) = projected_schema.index_of(partition_name) { - idx_map.insert(partition_idx, schema_idx); - } - } - - let mut projected_partition_indexes: Vec<_> = idx_map.into_iter().collect(); - projected_partition_indexes.sort_by(|(_, a), (_, b)| a.cmp(b)); - - Self { - projected_partition_indexes, - key_buffer_cache: Default::default(), - projected_schema, - } - } - - // Transform the batch read from the file by inserting the partitioning columns - // to the right positions as deduced from `projected_schema` - // - `file_batch`: batch read from the file, with internal projection applied - // - `partition_values`: the list of partition values, one for each partition column - pub fn project( - &mut self, - file_batch: RecordBatch, - partition_values: &[ScalarValue], - ) -> Result { - let expected_cols = - self.projected_schema.fields().len() - self.projected_partition_indexes.len(); - - if file_batch.columns().len() != expected_cols { - return exec_err!( - "Unexpected batch schema from file, expected {} cols but got {}", - expected_cols, - file_batch.columns().len() - ); - } - - let mut cols = file_batch.columns().to_vec(); - for &(pidx, sidx) in &self.projected_partition_indexes { - let p_value = - partition_values - .get(pidx) - .ok_or(DataFusionError::Execution( - "Invalid partitioning found on disk".to_string(), - ))?; - - let mut partition_value = Cow::Borrowed(p_value); - - // check if user forgot to dict-encode the partition value - let field = self.projected_schema.field(sidx); - let expected_data_type = field.data_type(); - let actual_data_type = partition_value.data_type(); - if let DataType::Dictionary(key_type, _) = expected_data_type { - if !matches!(actual_data_type, DataType::Dictionary(_, _)) { - warn!("Partition value for column {} was not dictionary-encoded, applied auto-fix.", field.name()); - partition_value = Cow::Owned(ScalarValue::Dictionary( - key_type.clone(), - Box::new(partition_value.as_ref().clone()), - )); - } - } - - cols.insert( - sidx, - create_output_array( - &mut self.key_buffer_cache, - partition_value.as_ref(), - file_batch.num_rows(), - )?, - ) - } - - RecordBatch::try_new_with_options( - Arc::clone(&self.projected_schema), - cols, - &RecordBatchOptions::new().with_row_count(Some(file_batch.num_rows())), - ) - .map_err(Into::into) - } -} - -#[derive(Debug, Default)] -struct ZeroBufferGenerators { - gen_i8: ZeroBufferGenerator, - gen_i16: ZeroBufferGenerator, - gen_i32: ZeroBufferGenerator, - gen_i64: ZeroBufferGenerator, - gen_u8: ZeroBufferGenerator, - gen_u16: ZeroBufferGenerator, - gen_u32: ZeroBufferGenerator, - gen_u64: ZeroBufferGenerator, -} - -/// Generate a arrow [`Buffer`] that contains zero values. -#[derive(Debug, Default)] -struct ZeroBufferGenerator -where - T: ArrowNativeType, -{ - cache: Option, - _t: PhantomData, -} - -impl ZeroBufferGenerator -where - T: ArrowNativeType, -{ - const SIZE: usize = size_of::(); - - fn get_buffer(&mut self, n_vals: usize) -> Buffer { - match &mut self.cache { - Some(buf) if buf.len() >= n_vals * Self::SIZE => { - buf.slice_with_length(0, n_vals * Self::SIZE) - } - _ => { - let mut key_buffer_builder = BufferBuilder::::new(n_vals); - key_buffer_builder.advance(n_vals); // keys are all 0 - self.cache.insert(key_buffer_builder.finish()).clone() - } - } - } -} - -fn create_dict_array( - buffer_gen: &mut ZeroBufferGenerator, - dict_val: &ScalarValue, - len: usize, - data_type: DataType, -) -> Result -where - T: ArrowNativeType, -{ - let dict_vals = dict_val.to_array()?; - - let sliced_key_buffer = buffer_gen.get_buffer(len); - - // assemble pieces together - let mut builder = ArrayData::builder(data_type) - .len(len) - .add_buffer(sliced_key_buffer); - builder = builder.add_child_data(dict_vals.to_data()); - Ok(Arc::new(DictionaryArray::::from( - builder.build().unwrap(), - ))) -} - -fn create_output_array( - key_buffer_cache: &mut ZeroBufferGenerators, - val: &ScalarValue, - len: usize, -) -> Result { - if let ScalarValue::Dictionary(key_type, dict_val) = &val { - match key_type.as_ref() { - DataType::Int8 => { - return create_dict_array( - &mut key_buffer_cache.gen_i8, - dict_val, - len, - val.data_type(), - ); - } - DataType::Int16 => { - return create_dict_array( - &mut key_buffer_cache.gen_i16, - dict_val, - len, - val.data_type(), - ); - } - DataType::Int32 => { - return create_dict_array( - &mut key_buffer_cache.gen_i32, - dict_val, - len, - val.data_type(), - ); - } - DataType::Int64 => { - return create_dict_array( - &mut key_buffer_cache.gen_i64, - dict_val, - len, - val.data_type(), - ); - } - DataType::UInt8 => { - return create_dict_array( - &mut key_buffer_cache.gen_u8, - dict_val, - len, - val.data_type(), - ); - } - DataType::UInt16 => { - return create_dict_array( - &mut key_buffer_cache.gen_u16, - dict_val, - len, - val.data_type(), - ); - } - DataType::UInt32 => { - return create_dict_array( - &mut key_buffer_cache.gen_u32, - dict_val, - len, - val.data_type(), - ); - } - DataType::UInt64 => { - return create_dict_array( - &mut key_buffer_cache.gen_u64, - dict_val, - len, - val.data_type(), - ); - } - _ => {} - } - } - - val.to_array_of_size(len) -} - #[cfg(test)] mod tests { - use arrow::array::Int32Array; - use super::*; use crate::datasource::physical_plan::ArrowSource; use crate::{test::columns, test_util::aggr_test_schema}; + use arrow::array::{Int32Array, RecordBatch}; + use std::collections::HashMap; #[test] fn physical_plan_config_no_projection() { diff --git a/datafusion/core/src/datasource/physical_plan/file_stream.rs b/datafusion/core/src/datasource/physical_plan/file_stream.rs index 85b1d714548d..7944d6fa9020 100644 --- a/datafusion/core/src/datasource/physical_plan/file_stream.rs +++ b/datafusion/core/src/datasource/physical_plan/file_stream.rs @@ -31,49 +31,18 @@ use crate::datasource::listing::PartitionedFile; use crate::datasource::physical_plan::file_scan_config::PartitionColumnProjector; use crate::datasource::physical_plan::{FileMeta, FileScanConfig}; use crate::error::Result; -use crate::physical_plan::metrics::{ - BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, Time, -}; +use crate::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet}; use crate::physical_plan::RecordBatchStream; use arrow::datatypes::SchemaRef; use arrow::error::ArrowError; use arrow::record_batch::RecordBatch; -use datafusion_common::instant::Instant; use datafusion_common::ScalarValue; +pub use datafusion_datasource::file_stream::{FileOpenFuture, FileOpener, OnError}; +use datafusion_datasource::file_stream::{FileStreamMetrics, FileStreamState, NextOpen}; -use futures::future::BoxFuture; -use futures::stream::BoxStream; use futures::{ready, FutureExt, Stream, StreamExt}; -/// A fallible future that resolves to a stream of [`RecordBatch`] -pub type FileOpenFuture = - BoxFuture<'static, Result>>>; - -/// Describes the behavior of the `FileStream` if file opening or scanning fails -pub enum OnError { - /// Fail the entire stream and return the underlying error - Fail, - /// Continue scanning, ignoring the failed file - Skip, -} - -impl Default for OnError { - fn default() -> Self { - Self::Fail - } -} - -/// Generic API for opening a file using an [`ObjectStore`] and resolving to a -/// stream of [`RecordBatch`] -/// -/// [`ObjectStore`]: object_store::ObjectStore -pub trait FileOpener: Unpin + Send + Sync { - /// Asynchronously open the specified file and return a stream - /// of [`RecordBatch`] - fn open(&self, file_meta: FileMeta) -> Result; -} - /// A stream that iterates record batch by record batch, file over file. pub struct FileStream { /// An iterator over input files. @@ -98,151 +67,6 @@ pub struct FileStream { on_error: OnError, } -/// Represents the state of the next `FileOpenFuture`. Since we need to poll -/// this future while scanning the current file, we need to store the result if it -/// is ready -enum NextOpen { - Pending(FileOpenFuture), - Ready(Result>>), -} - -enum FileStreamState { - /// The idle state, no file is currently being read - Idle, - /// Currently performing asynchronous IO to obtain a stream of RecordBatch - /// for a given file - Open { - /// A [`FileOpenFuture`] returned by [`FileOpener::open`] - future: FileOpenFuture, - /// The partition values for this file - partition_values: Vec, - }, - /// Scanning the [`BoxStream`] returned by the completion of a [`FileOpenFuture`] - /// returned by [`FileOpener::open`] - Scan { - /// Partitioning column values for the current batch_iter - partition_values: Vec, - /// The reader instance - reader: BoxStream<'static, Result>, - /// A [`FileOpenFuture`] for the next file to be processed, - /// and its corresponding partition column values, if any. - /// This allows the next file to be opened in parallel while the - /// current file is read. - next: Option<(NextOpen, Vec)>, - }, - /// Encountered an error - Error, - /// Reached the row limit - Limit, -} - -/// A timer that can be started and stopped. -pub struct StartableTime { - pub(crate) metrics: Time, - // use for record each part cost time, will eventually add into 'metrics'. - pub(crate) start: Option, -} - -impl StartableTime { - pub(crate) fn start(&mut self) { - assert!(self.start.is_none()); - self.start = Some(Instant::now()); - } - - pub(crate) fn stop(&mut self) { - if let Some(start) = self.start.take() { - self.metrics.add_elapsed(start); - } - } -} - -/// Metrics for [`FileStream`] -/// -/// Note that all of these metrics are in terms of wall clock time -/// (not cpu time) so they include time spent waiting on I/O as well -/// as other operators. -struct FileStreamMetrics { - /// Wall clock time elapsed for file opening. - /// - /// Time between when [`FileOpener::open`] is called and when the - /// [`FileStream`] receives a stream for reading. - /// - /// If there are multiple files being scanned, the stream - /// will open the next file in the background while scanning the - /// current file. This metric will only capture time spent opening - /// while not also scanning. - pub time_opening: StartableTime, - /// Wall clock time elapsed for file scanning + first record batch of decompression + decoding - /// - /// Time between when the [`FileStream`] requests data from the - /// stream and when the first [`RecordBatch`] is produced. - pub time_scanning_until_data: StartableTime, - /// Total elapsed wall clock time for scanning + record batch decompression / decoding - /// - /// Sum of time between when the [`FileStream`] requests data from - /// the stream and when a [`RecordBatch`] is produced for all - /// record batches in the stream. Note that this metric also - /// includes the time of the parent operator's execution. - pub time_scanning_total: StartableTime, - /// Wall clock time elapsed for data decompression + decoding - /// - /// Time spent waiting for the FileStream's input. - pub time_processing: StartableTime, - /// Count of errors opening file. - /// - /// If using `OnError::Skip` this will provide a count of the number of files - /// which were skipped and will not be included in the scan results. - pub file_open_errors: Count, - /// Count of errors scanning file - /// - /// If using `OnError::Skip` this will provide a count of the number of files - /// which were skipped and will not be included in the scan results. - pub file_scan_errors: Count, -} - -impl FileStreamMetrics { - fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { - let time_opening = StartableTime { - metrics: MetricBuilder::new(metrics) - .subset_time("time_elapsed_opening", partition), - start: None, - }; - - let time_scanning_until_data = StartableTime { - metrics: MetricBuilder::new(metrics) - .subset_time("time_elapsed_scanning_until_data", partition), - start: None, - }; - - let time_scanning_total = StartableTime { - metrics: MetricBuilder::new(metrics) - .subset_time("time_elapsed_scanning_total", partition), - start: None, - }; - - let time_processing = StartableTime { - metrics: MetricBuilder::new(metrics) - .subset_time("time_elapsed_processing", partition), - start: None, - }; - - let file_open_errors = - MetricBuilder::new(metrics).counter("file_open_errors", partition); - - let file_scan_errors = - MetricBuilder::new(metrics).counter("file_scan_errors", partition); - - Self { - time_opening, - time_scanning_until_data, - time_scanning_total, - time_processing, - file_open_errors, - file_scan_errors, - } - } -} - impl FileStream { /// Create a new `FileStream` using the give `FileOpener` to scan underlying files pub fn new( @@ -526,7 +350,7 @@ mod tests { use crate::test::{make_partition, object_store::register_test_store}; use crate::datasource::physical_plan::CsvSource; - use arrow_schema::Schema; + use arrow::datatypes::Schema; use datafusion_common::internal_err; /// Test `FileOpener` which will simulate errors during file opening or scanning diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index 76cb657b0c5f..590b1cb88dcd 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -27,8 +27,8 @@ use crate::datasource::data_source::FileSource; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::file_format::{deserialize_stream, DecoderDeserializer}; use crate::datasource::listing::{ListingTableUrl, PartitionedFile}; -use crate::datasource::physical_plan::file_stream::{FileOpenFuture, FileOpener}; use crate::datasource::physical_plan::FileMeta; +use crate::datasource::physical_plan::{FileOpenFuture, FileOpener}; use crate::error::{DataFusionError, Result}; use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; @@ -589,7 +589,7 @@ mod tests { .with_file_groups(file_groups) .with_limit(Some(3)) .with_file_compression_type(file_compression_type.to_owned()); - let exec = conf.new_exec(); + let exec = conf.build(); // TODO: this is not where schema inference should be tested @@ -660,7 +660,7 @@ mod tests { .with_file_groups(file_groups) .with_limit(Some(3)) .with_file_compression_type(file_compression_type.to_owned()); - let exec = conf.new_exec(); + let exec = conf.build(); let mut it = exec.execute(0, task_ctx)?; let batch = it.next().await.unwrap()?; @@ -700,7 +700,7 @@ mod tests { .with_file_groups(file_groups) .with_projection(Some(vec![0, 2])) .with_file_compression_type(file_compression_type.to_owned()); - let exec = conf.new_exec(); + let exec = conf.build(); let inferred_schema = exec.schema(); assert_eq!(inferred_schema.fields().len(), 2); @@ -745,7 +745,7 @@ mod tests { .with_file_groups(file_groups) .with_projection(Some(vec![3, 0, 2])) .with_file_compression_type(file_compression_type.to_owned()); - let exec = conf.new_exec(); + let exec = conf.build(); let inferred_schema = exec.schema(); assert_eq!(inferred_schema.fields().len(), 3); diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index 873df859702a..953c99322e16 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -35,6 +35,12 @@ pub use self::parquet::source::ParquetSource; pub use self::parquet::{ ParquetExec, ParquetExecBuilder, ParquetFileMetrics, ParquetFileReaderFactory, }; +use crate::error::Result; +use crate::physical_plan::{DisplayAs, DisplayFormatType}; +use crate::{ + datasource::listing::{FileRange, PartitionedFile}, + physical_plan::display::{display_orderings, ProjectSchemaDisplay}, +}; #[allow(deprecated)] pub use arrow_file::ArrowExec; pub use arrow_file::ArrowSource; @@ -44,16 +50,19 @@ pub use avro::AvroSource; #[allow(deprecated)] pub use csv::{CsvExec, CsvExecBuilder}; pub use csv::{CsvOpener, CsvSource}; -pub use datafusion_catalog_listing::file_groups::FileGroupPartitioner; -use datafusion_expr::dml::InsertOp; +pub use datafusion_datasource::file_groups::FileGroupPartitioner; +pub use datafusion_datasource::file_meta::FileMeta; +pub use datafusion_datasource::file_sink_config::*; pub use file_scan_config::{ wrap_partition_type_in_dict, wrap_partition_value_in_dict, FileScanConfig, }; pub use file_stream::{FileOpenFuture, FileOpener, FileStream, OnError}; +use futures::StreamExt; #[allow(deprecated)] pub use json::NdJsonExec; pub use json::{JsonOpener, JsonSource}; - +use log::debug; +use object_store::{path::Path, GetOptions, GetRange, ObjectStore}; use std::{ fmt::{Debug, Formatter, Result as FmtResult}, ops::Range, @@ -61,115 +70,10 @@ use std::{ vec, }; -use super::{file_format::write::demux::start_demuxer_task, listing::ListingTableUrl}; -use crate::datasource::file_format::write::demux::DemuxedStreamReceiver; -use crate::error::Result; -use crate::physical_plan::{DisplayAs, DisplayFormatType}; -use crate::{ - datasource::{ - listing::{FileRange, PartitionedFile}, - object_store::ObjectStoreUrl, - }, - physical_plan::display::{display_orderings, ProjectSchemaDisplay}, -}; - -use arrow::datatypes::{DataType, SchemaRef}; -use datafusion_common_runtime::SpawnedTask; -use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use arrow::datatypes::SchemaRef; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::PhysicalSortExpr; use datafusion_physical_expr_common::sort_expr::LexOrdering; -use datafusion_physical_plan::insert::DataSink; - -use async_trait::async_trait; -use futures::StreamExt; -use log::debug; -use object_store::{path::Path, GetOptions, GetRange, ObjectMeta, ObjectStore}; - -/// General behaviors for files that do `DataSink` operations -#[async_trait] -pub trait FileSink: DataSink { - /// Retrieves the file sink configuration. - fn config(&self) -> &FileSinkConfig; - - /// Spawns writer tasks and joins them to perform file writing operations. - /// Is a critical part of `FileSink` trait, since it's the very last step for `write_all`. - /// - /// This function handles the process of writing data to files by: - /// 1. Spawning tasks for writing data to individual files. - /// 2. Coordinating the tasks using a demuxer to distribute data among files. - /// 3. Collecting results using `tokio::join`, ensuring that all tasks complete successfully. - /// - /// # Parameters - /// - `context`: The execution context (`TaskContext`) that provides resources - /// like memory management and runtime environment. - /// - `demux_task`: A spawned task that handles demuxing, responsible for splitting - /// an input [`SendableRecordBatchStream`] into dynamically determined partitions. - /// See `start_demuxer_task()` - /// - `file_stream_rx`: A receiver that yields streams of record batches and their - /// corresponding file paths for writing. See `start_demuxer_task()` - /// - `object_store`: A handle to the object store where the files are written. - /// - /// # Returns - /// - `Result`: Returns the total number of rows written across all files. - async fn spawn_writer_tasks_and_join( - &self, - context: &Arc, - demux_task: SpawnedTask>, - file_stream_rx: DemuxedStreamReceiver, - object_store: Arc, - ) -> Result; - - /// File sink implementation of the [`DataSink::write_all`] method. - async fn write_all( - &self, - data: SendableRecordBatchStream, - context: &Arc, - ) -> Result { - let config = self.config(); - let object_store = context - .runtime_env() - .object_store(&config.object_store_url)?; - let (demux_task, file_stream_rx) = start_demuxer_task(config, data, context); - self.spawn_writer_tasks_and_join( - context, - demux_task, - file_stream_rx, - object_store, - ) - .await - } -} - -/// The base configurations to provide when creating a physical plan for -/// writing to any given file format. -pub struct FileSinkConfig { - /// Object store URL, used to get an ObjectStore instance - pub object_store_url: ObjectStoreUrl, - /// A vector of [`PartitionedFile`] structs, each representing a file partition - pub file_groups: Vec, - /// Vector of partition paths - pub table_paths: Vec, - /// The schema of the output file - pub output_schema: SchemaRef, - /// A vector of column names and their corresponding data types, - /// representing the partitioning columns for the file - pub table_partition_cols: Vec<(String, DataType)>, - /// Controls how new data should be written to the file, determining whether - /// to append to, overwrite, or replace records in existing files. - pub insert_op: InsertOp, - /// Controls whether partition columns are kept for the file - pub keep_partition_by_columns: bool, - /// File extension without a dot(.) - pub file_extension: String, -} - -impl FileSinkConfig { - /// Get output schema - pub fn output_schema(&self) -> &SchemaRef { - &self.output_schema - } -} impl Debug for FileScanConfig { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { @@ -316,36 +220,6 @@ where Ok(()) } -/// A single file or part of a file that should be read, along with its schema, statistics -pub struct FileMeta { - /// Path for the file (e.g. URL, filesystem path, etc) - pub object_meta: ObjectMeta, - /// An optional file range for a more fine-grained parallel execution - pub range: Option, - /// An optional field for user defined per object metadata - pub extensions: Option>, - /// Size hint for the metadata of this file - pub metadata_size_hint: Option, -} - -impl FileMeta { - /// The full path to the object - pub fn location(&self) -> &Path { - &self.object_meta.location - } -} - -impl From for FileMeta { - fn from(object_meta: ObjectMeta) -> Self { - Self { - object_meta, - range: None, - extensions: None, - metadata_size_hint: None, - } - } -} - /// The various listing tables does not attempt to read all files /// concurrently, instead they will read files in sequence within a /// partition. This is an important property as it allows plans to @@ -586,7 +460,8 @@ mod tests { BinaryArray, BooleanArray, Float32Array, Int32Array, Int64Array, RecordBatch, StringArray, UInt64Array, }; - use arrow_schema::{Field, Schema}; + use arrow::datatypes::{DataType, Field, Schema}; + use object_store::ObjectMeta; use crate::datasource::schema_adapter::{ DefaultSchemaAdapterFactory, SchemaAdapterFactory, diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 4ba449e2ee82..4bd43cd1aaca 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -575,9 +575,8 @@ mod tests { ArrayRef, Date64Array, Int32Array, Int64Array, Int8Array, StringArray, StructArray, }; - use arrow::datatypes::{Field, Schema, SchemaBuilder}; + use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaBuilder}; use arrow::record_batch::RecordBatch; - use arrow_schema::{DataType, Fields}; use bytes::{BufMut, BytesMut}; use datafusion_common::{assert_contains, ScalarValue}; use datafusion_expr::{col, lit, when, Expr}; @@ -709,7 +708,7 @@ mod tests { let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); - let parquet_exec = base_config.new_exec(); + let parquet_exec = base_config.clone().build(); RoundTripResult { batches: collect(parquet_exec.clone(), task_ctx).await, parquet_exec, @@ -1355,7 +1354,7 @@ mod tests { Arc::new(ParquetSource::default()), ) .with_file_groups(file_groups) - .new_exec(); + .build(); assert_eq!( parquet_exec .properties() @@ -1469,7 +1468,7 @@ mod tests { false, ), ]) - .new_exec(); + .build(); let partition_count = parquet_exec .source() .output_partitioning() @@ -1532,7 +1531,7 @@ mod tests { Arc::new(ParquetSource::default()), ) .with_file(partitioned_file) - .new_exec(); + .build(); let mut results = parquet_exec.execute(0, state.task_ctx())?; let batch = results.next().await.unwrap(); @@ -2189,7 +2188,7 @@ mod tests { extensions: None, metadata_size_hint: None, }) - .new_exec(); + .build(); let res = collect(exec, ctx.task_ctx()).await.unwrap(); assert_eq!(res.len(), 2); diff --git a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs index 02ad9dd55100..138b44897931 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs @@ -32,7 +32,8 @@ use crate::datasource::physical_plan::{ }; use crate::datasource::schema_adapter::SchemaAdapterFactory; -use arrow_schema::{ArrowError, SchemaRef}; +use arrow::datatypes::SchemaRef; +use arrow::error::ArrowError; use datafusion_common::{exec_err, Result}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_optimizer::pruning::PruningPredicate; diff --git a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs index dcc4b0bc8150..02329effb09a 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs @@ -24,8 +24,10 @@ use super::metrics::ParquetFileMetrics; use crate::datasource::physical_plan::parquet::ParquetAccessPlan; use arrow::array::BooleanArray; -use arrow::{array::ArrayRef, datatypes::SchemaRef}; -use arrow_schema::Schema; +use arrow::{ + array::ArrayRef, + datatypes::{Schema, SchemaRef}, +}; use datafusion_common::ScalarValue; use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; use datafusion_physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs index bcd2c0af6f6f..ac6eaf2c8f63 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs @@ -588,8 +588,7 @@ mod test { DefaultSchemaAdapterFactory, SchemaAdapterFactory, }; - use arrow::datatypes::Field; - use arrow_schema::{Fields, TimeUnit::Nanosecond}; + use arrow::datatypes::{Field, Fields, TimeUnit::Nanosecond}; use datafusion_expr::{cast, col, lit, Expr}; use datafusion_physical_expr::planner::logical2physical; use datafusion_physical_plan::metrics::{Count, Time}; diff --git a/datafusion/core/src/datasource/physical_plan/parquet/source.rs b/datafusion/core/src/datasource/physical_plan/parquet/source.rs index 0705a398f4fb..21881112075d 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/source.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/source.rs @@ -31,7 +31,7 @@ use crate::datasource::schema_adapter::{ DefaultSchemaAdapterFactory, SchemaAdapterFactory, }; -use arrow_schema::{Schema, SchemaRef}; +use arrow::datatypes::{Schema, SchemaRef}; use datafusion_common::config::TableParquetOptions; use datafusion_common::Statistics; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -94,7 +94,7 @@ use object_store::ObjectStore; /// // Create a DataSourceExec for reading `file1.parquet` with a file size of 100MB /// let file_scan_config = FileScanConfig::new(object_store_url, file_schema, source) /// .with_file(PartitionedFile::new("file1.parquet", 100*1024*1024)); -/// let exec = file_scan_config.new_exec(); +/// let exec = file_scan_config.build(); /// ``` /// /// # Features @@ -176,7 +176,7 @@ use object_store::ObjectStore; /// .clone() /// .with_file_groups(vec![file_group.clone()]); /// -/// new_config.new_exec() +/// new_config.build() /// }) /// .collect::>(); /// ``` @@ -196,7 +196,7 @@ use object_store::ObjectStore; /// /// ``` /// # use std::sync::Arc; -/// # use arrow_schema::{Schema, SchemaRef}; +/// # use arrow::datatypes::{Schema, SchemaRef}; /// # use datafusion::datasource::listing::PartitionedFile; /// # use datafusion::datasource::physical_plan::parquet::ParquetAccessPlan; /// # use datafusion::datasource::physical_plan::FileScanConfig; @@ -219,7 +219,7 @@ use object_store::ObjectStore; /// .with_file(partitioned_file); /// // this parquet DataSourceExec will not even try to read row groups 2 and 4. Additional /// // pruning based on predicates may also happen -/// let exec = file_scan_config.new_exec(); +/// let exec = file_scan_config.build(); /// ``` /// /// For a complete example, see the [`advanced_parquet_index` example]). diff --git a/datafusion/core/src/datasource/physical_plan/statistics.rs b/datafusion/core/src/datasource/physical_plan/statistics.rs index 64eb2b00de94..5811c19be408 100644 --- a/datafusion/core/src/datasource/physical_plan/statistics.rs +++ b/datafusion/core/src/datasource/physical_plan/statistics.rs @@ -29,11 +29,11 @@ use std::sync::Arc; use crate::datasource::listing::PartitionedFile; use arrow::array::RecordBatch; +use arrow::datatypes::SchemaRef; use arrow::{ compute::SortColumn, row::{Row, Rows}, }; -use arrow_schema::SchemaRef; use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_physical_expr::{expressions::Column, PhysicalSortExpr}; use datafusion_physical_expr_common::sort_expr::LexOrdering; diff --git a/datafusion/core/src/datasource/schema_adapter.rs b/datafusion/core/src/datasource/schema_adapter.rs index efaae403b415..41e375cf81f8 100644 --- a/datafusion/core/src/datasource/schema_adapter.rs +++ b/datafusion/core/src/datasource/schema_adapter.rs @@ -23,7 +23,7 @@ use arrow::array::{new_null_array, RecordBatch, RecordBatchOptions}; use arrow::compute::{can_cast_types, cast}; -use arrow_schema::{Schema, SchemaRef}; +use arrow::datatypes::{Schema, SchemaRef}; use datafusion_common::plan_err; use std::fmt::Debug; use std::sync::Arc; @@ -435,9 +435,8 @@ mod tests { use crate::assert_batches_sorted_eq; use arrow::array::{Int32Array, StringArray}; - use arrow::datatypes::{Field, Schema}; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; - use arrow_schema::{DataType, SchemaRef}; use object_store::path::Path; use object_store::ObjectMeta; @@ -508,7 +507,7 @@ mod tests { FileScanConfig::new(ObjectStoreUrl::local_filesystem(), schema, source) .with_file(partitioned_file); - let parquet_exec = base_conf.new_exec(); + let parquet_exec = base_conf.build(); let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); diff --git a/datafusion/core/src/datasource/stream.rs b/datafusion/core/src/datasource/stream.rs index eab192362b27..4ca26f6213cb 100644 --- a/datafusion/core/src/datasource/stream.rs +++ b/datafusion/core/src/datasource/stream.rs @@ -29,7 +29,7 @@ use crate::catalog::{TableProvider, TableProviderFactory}; use crate::datasource::create_ordering; use arrow::array::{RecordBatch, RecordBatchReader, RecordBatchWriter}; -use arrow_schema::SchemaRef; +use arrow::datatypes::SchemaRef; use datafusion_common::{config_err, plan_err, Constraints, DataFusionError, Result}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 61ddba10b09c..c27d1e4fd46b 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -49,9 +49,8 @@ use crate::{ variable::{VarProvider, VarType}, }; -use arrow::datatypes::SchemaRef; +use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use arrow_schema::Schema; use datafusion_common::{ config::{ConfigExtension, TableOptions}, exec_datafusion_err, exec_err, not_impl_err, plan_datafusion_err, plan_err, @@ -1817,7 +1816,7 @@ mod tests { use crate::execution::memory_pool::MemoryConsumer; use crate::test; use crate::test_util::{plan_and_collect, populate_csv_partitions}; - use arrow_schema::{DataType, TimeUnit}; + use arrow::datatypes::{DataType, TimeUnit}; use std::env; use std::error::Error; use std::path::PathBuf; diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 85c2b2a0fd78..f1abf30c0c54 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -35,7 +35,7 @@ use datafusion_catalog::information_schema::{ }; use datafusion_catalog::MemoryCatalogProviderList; -use arrow_schema::{DataType, SchemaRef}; +use arrow::datatypes::{DataType, SchemaRef}; use datafusion_catalog::{Session, TableFunction, TableFunctionImpl}; use datafusion_common::alias::AliasGenerator; use datafusion_common::config::{ConfigExtension, ConfigOptions, TableOptions}; @@ -1991,7 +1991,7 @@ mod tests { use crate::datasource::MemTable; use crate::execution::context::SessionState; use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; - use arrow_schema::{DataType, Field, Schema}; + use arrow::datatypes::{DataType, Field, Schema}; use datafusion_catalog::MemoryCatalogProviderList; use datafusion_common::DFSchema; use datafusion_common::Result; diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index b256ed38039a..f4aa366500ef 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -229,9 +229,9 @@ //! 1. The query string is parsed to an Abstract Syntax Tree (AST) //! [`Statement`] using [sqlparser]. //! -//! 2. The AST is converted to a [`LogicalPlan`] and logical -//! expressions [`Expr`]s to compute the desired result by the -//! [`SqlToRel`] planner. +//! 2. The AST is converted to a [`LogicalPlan`] and logical expressions +//! [`Expr`]s to compute the desired result by [`SqlToRel`]. This phase +//! also includes name and type resolution ("binding"). //! //! [`Statement`]: https://docs.rs/sqlparser/latest/sqlparser/ast/enum.Statement.html //! @@ -737,6 +737,11 @@ pub mod logical_expr { pub use datafusion_expr::*; } +/// re-export of [`datafusion_expr_common`] crate +pub mod logical_expr_common { + pub use datafusion_expr_common::*; +} + /// re-export of [`datafusion_optimizer`] crate pub mod optimizer { pub use datafusion_optimizer::*; @@ -844,11 +849,17 @@ doc_comment::doctest!("../../../README.md", readme_example_test); // // For example, if `user_guide_expressions(line 123)` fails, // go to `docs/source/user-guide/expressions.md` to find the relevant problem. +// +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/concepts-readings-events.md", + user_guide_concepts_readings_events +); #[cfg(doctest)] doc_comment::doctest!( - "../../../docs/source/user-guide/example-usage.md", - user_guide_example_usage + "../../../docs/source/user-guide/configs.md", + user_guide_configs ); #[cfg(doctest)] @@ -859,14 +870,20 @@ doc_comment::doctest!( #[cfg(doctest)] doc_comment::doctest!( - "../../../docs/source/user-guide/configs.md", - user_guide_configs + "../../../docs/source/user-guide/dataframe.md", + user_guide_dataframe ); #[cfg(doctest)] doc_comment::doctest!( - "../../../docs/source/user-guide/dataframe.md", - user_guide_dataframe + "../../../docs/source/user-guide/example-usage.md", + user_guide_example_usage +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/explain-usage.md", + user_guide_explain_usage ); #[cfg(doctest)] @@ -875,16 +892,187 @@ doc_comment::doctest!( user_guide_expressions ); +#[cfg(doctest)] +doc_comment::doctest!("../../../docs/source/user-guide/faq.md", user_guide_faq); + #[cfg(doctest)] doc_comment::doctest!( - "../../../docs/source/library-user-guide/using-the-sql-api.md", - library_user_guide_sql_api + "../../../docs/source/user-guide/introduction.md", + user_guide_introduction +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/cli/datasources.md", + user_guide_cli_datasource +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/cli/installation.md", + user_guide_cli_installation +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/cli/overview.md", + user_guide_cli_overview +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/cli/usage.md", + user_guide_cli_usage +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/features.md", + user_guide_features +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/sql/aggregate_functions.md", + user_guide_sql_aggregate_functions +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/sql/data_types.md", + user_guide_sql_data_types +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/sql/ddl.md", + user_guide_sql_ddl +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/sql/dml.md", + user_guide_sql_dml +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/sql/explain.md", + user_guide_sql_exmplain +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/sql/information_schema.md", + user_guide_sql_information_schema +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/sql/operators.md", + user_guide_sql_operators +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/sql/prepared_statements.md", + user_guide_prepared_statements +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/sql/scalar_functions.md", + user_guide_sql_scalar_functions +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/sql/select.md", + user_guide_sql_select +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/sql/special_functions.md", + user_guide_sql_special_functions +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/sql/subqueries.md", + user_guide_sql_subqueries +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/sql/window_functions.md", + user_guide_sql_window_functions +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/sql/write_options.md", + user_guide_sql_write_options +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/adding-udfs.md", + library_user_guide_adding_udfs +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/api-health.md", + library_user_guide_api_health ); #[cfg(doctest)] doc_comment::doctest!( "../../../docs/source/library-user-guide/building-logical-plans.md", - library_user_guide_logical_plans + library_user_guide_building_logical_plans +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/catalogs.md", + library_user_guide_catalogs +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/custom-table-providers.md", + library_user_guide_custom_table_providers +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/extending-operators.md", + library_user_guide_extending_operators +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/extensions.md", + library_user_guide_extensions +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/index.md", + library_user_guide_index +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/profiling.md", + library_user_guide_profiling +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/query-optimizer.md", + library_user_guide_query_optimizer ); #[cfg(doctest)] @@ -892,3 +1080,15 @@ doc_comment::doctest!( "../../../docs/source/library-user-guide/using-the-dataframe-api.md", library_user_guide_dataframe_api ); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/using-the-sql-api.md", + library_user_guide_sql_api +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/working-with-exprs.md", + library_user_guide_working_with_exprs +); diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 9fcb9562a485..d73b7d81536a 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -24,7 +24,7 @@ use std::sync::Arc; use crate::datasource::file_format::file_type_to_format; use crate::datasource::listing::ListingTableUrl; use crate::datasource::physical_plan::FileSinkConfig; -use crate::datasource::source_as_provider; +use crate::datasource::{source_as_provider, DefaultTableSource}; use crate::error::{DataFusionError, Result}; use crate::execution::context::{ExecutionProps, SessionState}; use crate::logical_expr::utils::generate_sort_key; @@ -70,7 +70,8 @@ use datafusion_common::{ }; use datafusion_expr::dml::{CopyTo, InsertOp}; use datafusion_expr::expr::{ - physical_name, AggregateFunction, Alias, GroupingSet, WindowFunction, + physical_name, AggregateFunction, AggregateFunctionParams, Alias, GroupingSet, + WindowFunction, WindowFunctionParams, }; use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; @@ -508,7 +509,7 @@ impl DefaultPhysicalPlanner { // the column name rather than column name + explicit data type. let table_partition_cols = partition_by .iter() - .map(|s| (s.to_string(), arrow_schema::DataType::Null)) + .map(|s| (s.to_string(), arrow::datatypes::DataType::Null)) .collect::>(); let keep_partition_by_columns = match source_option_tuples @@ -541,19 +542,22 @@ impl DefaultPhysicalPlanner { .await? } LogicalPlan::Dml(DmlStatement { - table_name, + target, op: WriteOp::Insert(insert_op), .. }) => { - let name = table_name.table(); - let schema = session_state.schema_for_ref(table_name.clone())?; - if let Some(provider) = schema.table(name).await? { + if let Some(provider) = + target.as_any().downcast_ref::() + { let input_exec = children.one()?; provider + .table_provider .insert_into(session_state, input_exec, *insert_op) .await? } else { - return exec_err!("Table '{table_name}' does not exist"); + return exec_err!( + "Table source can't be downcasted to DefaultTableSource" + ); } } LogicalPlan::Window(Window { window_expr, .. }) => { @@ -565,16 +569,24 @@ impl DefaultPhysicalPlanner { let get_sort_keys = |expr: &Expr| match expr { Expr::WindowFunction(WindowFunction { - ref partition_by, - ref order_by, + params: + WindowFunctionParams { + ref partition_by, + ref order_by, + .. + }, .. }) => generate_sort_key(partition_by, order_by), Expr::Alias(Alias { expr, .. }) => { // Convert &Box to &T match &**expr { Expr::WindowFunction(WindowFunction { - ref partition_by, - ref order_by, + params: + WindowFunctionParams { + ref partition_by, + ref order_by, + .. + }, .. }) => generate_sort_key(partition_by, order_by), _ => unreachable!(), @@ -1505,11 +1517,14 @@ pub fn create_window_expr_with_name( match e { Expr::WindowFunction(WindowFunction { fun, - args, - partition_by, - order_by, - window_frame, - null_treatment, + params: + WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment, + }, }) => { let physical_args = create_physical_exprs(args, logical_schema, execution_props)?; @@ -1576,11 +1591,14 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( match e { Expr::AggregateFunction(AggregateFunction { func, - distinct, - args, - filter, - order_by, - null_treatment, + params: + AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment, + }, }) => { let name = if let Some(name) = name { name diff --git a/datafusion/core/src/schema_equivalence.rs b/datafusion/core/src/schema_equivalence.rs index f0d2acad6be9..70bee206655b 100644 --- a/datafusion/core/src/schema_equivalence.rs +++ b/datafusion/core/src/schema_equivalence.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow_schema::{DataType, Field, Fields, Schema}; +use arrow::datatypes::{DataType, Field, Fields, Schema}; /// Verifies whether the original planned schema can be satisfied with data /// adhering to the candidate schema. In practice, this is equality check on the diff --git a/datafusion/core/src/test/mod.rs b/datafusion/core/src/test/mod.rs index f2fef06c8f30..ba85f9afb6da 100644 --- a/datafusion/core/src/test/mod.rs +++ b/datafusion/core/src/test/mod.rs @@ -93,29 +93,7 @@ pub fn scan_partitioned_csv( let source = Arc::new(CsvSource::new(true, b'"', b'"')); let config = partitioned_csv_config(schema, file_groups, source) .with_file_compression_type(FileCompressionType::UNCOMPRESSED); - Ok(config.new_exec()) -} - -/// Auto finish the wrapped BzEncoder on drop -#[cfg(feature = "compression")] -struct AutoFinishBzEncoder(BzEncoder); - -#[cfg(feature = "compression")] -impl Write for AutoFinishBzEncoder { - fn write(&mut self, buf: &[u8]) -> std::io::Result { - self.0.write(buf) - } - - fn flush(&mut self) -> std::io::Result<()> { - self.0.flush() - } -} - -#[cfg(feature = "compression")] -impl Drop for AutoFinishBzEncoder { - fn drop(&mut self) { - let _ = self.0.try_finish(); - } + Ok(config.build()) } /// Returns file groups [`Vec>`] for scanning `partitions` of `filename` @@ -159,10 +137,9 @@ pub fn partitioned_file_groups( Box::new(encoder) } #[cfg(feature = "compression")] - FileCompressionType::BZIP2 => Box::new(AutoFinishBzEncoder(BzEncoder::new( - file, - BzCompression::default(), - ))), + FileCompressionType::BZIP2 => { + Box::new(BzEncoder::new(file, BzCompression::default())) + } #[cfg(not(feature = "compression"))] FileCompressionType::GZIP | FileCompressionType::BZIP2 diff --git a/datafusion/core/src/test_util/parquet.rs b/datafusion/core/src/test_util/parquet.rs index 67e0e1726917..0e0090ef028e 100644 --- a/datafusion/core/src/test_util/parquet.rs +++ b/datafusion/core/src/test_util/parquet.rs @@ -156,7 +156,7 @@ impl TestParquetFile { ) -> Result> { let parquet_options = ctx.copied_table_options().parquet; let source = Arc::new(ParquetSource::new(parquet_options.clone())); - let mut scan_config = FileScanConfig::new( + let scan_config = FileScanConfig::new( self.object_store_url.clone(), Arc::clone(&self.schema), source, @@ -185,13 +185,12 @@ impl TestParquetFile { Arc::clone(&scan_config.file_schema), Arc::clone(&physical_filter_expr), )); - scan_config = scan_config.with_source(source); - let parquet_exec = scan_config.new_exec(); + let parquet_exec = scan_config.with_source(source).build(); let exec = Arc::new(FilterExec::try_new(physical_filter_expr, parquet_exec)?); Ok(exec) } else { - Ok(scan_config.new_exec()) + Ok(scan_config.build()) } } diff --git a/datafusion/core/tests/catalog/memory.rs b/datafusion/core/tests/catalog/memory.rs index bef23fff3e96..3e45fb753226 100644 --- a/datafusion/core/tests/catalog/memory.rs +++ b/datafusion/core/tests/catalog/memory.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow_schema::Schema; +use arrow::datatypes::Schema; use datafusion::catalog::CatalogProvider; use datafusion::datasource::empty::EmptyTable; use datafusion::datasource::listing::{ diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index 89ec5a5908de..29c24948fbf0 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -16,12 +16,12 @@ // under the License. use arrow::array::{types::Int32Type, ListArray}; +use arrow::datatypes::SchemaRef; use arrow::datatypes::{DataType, Field, Schema}; use arrow::{ array::{Int32Array, StringArray}, record_batch::RecordBatch, }; -use arrow_schema::SchemaRef; use std::sync::Arc; use datafusion::error::Result; diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 772d9dbc8f46..d545157607c7 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -26,9 +26,12 @@ use arrow::array::{ StringBuilder, StructBuilder, UInt32Array, UInt32Builder, UnionArray, }; use arrow::buffer::ScalarBuffer; -use arrow::datatypes::{DataType, Field, Float32Type, Int32Type, Schema, UInt64Type}; +use arrow::datatypes::{ + DataType, Field, Float32Type, Int32Type, Schema, SchemaRef, UInt64Type, UnionFields, + UnionMode, +}; +use arrow::error::ArrowError; use arrow::util::pretty::pretty_format_batches; -use arrow_schema::{ArrowError, SchemaRef, UnionFields, UnionMode}; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::expr_fn::{ array_agg, avg, count, count_distinct, max, median, min, sum, @@ -60,7 +63,7 @@ use datafusion::{assert_batches_eq, assert_batches_sorted_eq}; use datafusion_catalog::TableProvider; use datafusion_common::{ assert_contains, Constraint, Constraints, DataFusionError, ParamValues, ScalarValue, - UnnestOptions, + TableReference, UnnestOptions, }; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::config::SessionConfig; @@ -1614,9 +1617,25 @@ async fn with_column_renamed() -> Result<()> { // accepts table qualifier .with_column_renamed("aggregate_test_100.c2", "two")? // no-op for missing column - .with_column_renamed("c4", "boom")? - .collect() - .await?; + .with_column_renamed("c4", "boom")?; + + let references: Vec<_> = df_sum_renamed + .schema() + .iter() + .map(|(a, _)| a.cloned()) + .collect(); + + assert_eq!( + references, + vec![ + Some(TableReference::bare("aggregate_test_100")), // table name is preserved + Some(TableReference::bare("aggregate_test_100")), + Some(TableReference::bare("aggregate_test_100")), + None // total column + ] + ); + + let batches = &df_sum_renamed.collect().await?; assert_batches_sorted_eq!( [ @@ -1626,7 +1645,7 @@ async fn with_column_renamed() -> Result<()> { "| a | 3 | -72 | -69 |", "+-----+-----+-----+-------+", ], - &df_sum_renamed + batches ); Ok(()) @@ -5271,3 +5290,55 @@ async fn register_non_parquet_file() { "1.json' does not match the expected extension '.parquet'" ); } + +// Test inserting into checking. +#[tokio::test] +async fn test_insert_into_checking() -> Result<()> { + // Create a new schema with one field called "a" of type Int64, and setting nullable to false + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + + let session_ctx = SessionContext::new(); + + // Create and register the initial table with the provided schema and data + let initial_table = Arc::new(MemTable::try_new(schema.clone(), vec![vec![]])?); + session_ctx.register_table("t", initial_table.clone())?; + + // There are two cases we need to check + // 1. The len of the schema of the plan and the schema of the table should be the same + // 2. The datatype of the schema of the plan and the schema of the table should be the same + + // Test case 1: + let write_df = session_ctx.sql("values (1, 2), (3, 4)").await.unwrap(); + + let e = write_df + .write_table("t", DataFrameWriteOptions::new()) + .await + .unwrap_err(); + + assert_contains!( + e.to_string(), + "Inserting query must have the same schema length as the table." + ); + + // Setting nullable to true + // Make sure the nullable check go through + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)])); + + let session_ctx = SessionContext::new(); + + // Create and register the initial table with the provided schema and data + let initial_table = Arc::new(MemTable::try_new(schema.clone(), vec![vec![]])?); + session_ctx.register_table("t", initial_table.clone())?; + + // Test case 2: + let write_df = session_ctx.sql("values ('a123'), ('b456')").await.unwrap(); + + let e = write_df + .write_table("t", DataFrameWriteOptions::new()) + .await + .unwrap_err(); + + assert_contains!(e.to_string(), "Inserting query schema mismatch: Expected table field 'a' with type Int64, but got 'column1' with type Utf8"); + + Ok(()) +} diff --git a/datafusion/core/tests/execution/logical_plan.rs b/datafusion/core/tests/execution/logical_plan.rs index 86acbe16474c..a17bb5eec8a3 100644 --- a/datafusion/core/tests/execution/logical_plan.rs +++ b/datafusion/core/tests/execution/logical_plan.rs @@ -16,11 +16,11 @@ // under the License. use arrow::array::Int64Array; -use arrow_schema::{DataType, Field}; +use arrow::datatypes::{DataType, Field}; use datafusion::execution::session_state::SessionStateBuilder; use datafusion_common::{Column, DFSchema, Result, ScalarValue, Spans}; use datafusion_execution::TaskContext; -use datafusion_expr::expr::AggregateFunction; +use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams}; use datafusion_expr::logical_plan::{LogicalPlan, Values}; use datafusion_expr::{Aggregate, AggregateUDF, Expr}; use datafusion_functions_aggregate::count::Count; @@ -60,11 +60,13 @@ async fn count_only_nulls() -> Result<()> { vec![], vec![Expr::AggregateFunction(AggregateFunction { func: Arc::new(AggregateUDF::new_from_impl(Count::new())), - args: vec![input_col_ref], - distinct: false, - filter: None, - order_by: None, - null_treatment: None, + params: AggregateFunctionParams { + args: vec![input_col_ref], + distinct: false, + filter: None, + order_by: None, + null_treatment: None, + }, })], )?); diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index 8f8ca21c206d..7c0119e8ae83 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -19,8 +19,8 @@ use arrow::array::{ builder::{ListBuilder, StringBuilder}, ArrayRef, Int64Array, RecordBatch, StringArray, StructArray, }; +use arrow::datatypes::{DataType, Field}; use arrow::util::pretty::{pretty_format_batches, pretty_format_columns}; -use arrow_schema::{DataType, Field}; use datafusion::prelude::*; use datafusion_common::{DFSchema, ScalarValue}; use datafusion_expr::ExprFunctionExt; diff --git a/datafusion/core/tests/expr_api/parse_sql_expr.rs b/datafusion/core/tests/expr_api/parse_sql_expr.rs index cc049f0004d9..92c18204324f 100644 --- a/datafusion/core/tests/expr_api/parse_sql_expr.rs +++ b/datafusion/core/tests/expr_api/parse_sql_expr.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow_schema::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, Field, Schema}; use datafusion::prelude::{CsvReadOptions, SessionContext}; use datafusion_common::DFSchema; use datafusion_common::{DFSchemaRef, Result, ToDFSchema}; diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index 83e96bffdf48..7bb21725ef40 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -365,6 +365,33 @@ fn test_const_evaluator() { ); } +#[test] +fn test_const_evaluator_alias() { + // true --> true + test_evaluate(lit(true).alias("a"), lit(true)); + // true or true --> true + test_evaluate(lit(true).alias("a").or(lit(true).alias("b")), lit(true)); + // "foo" == "foo" --> true + test_evaluate(lit("foo").alias("a").eq(lit("foo").alias("b")), lit(true)); + // c = 1 + 2 --> c + 3 + test_evaluate( + col("c") + .alias("a") + .eq(lit(1).alias("b") + lit(2).alias("c")), + col("c").alias("a").eq(lit(3)), + ); + // (foo != foo) OR (c = 1) --> false OR (c = 1) + test_evaluate( + lit("foo") + .alias("a") + .not_eq(lit("foo").alias("b")) + .alias("c") + .or(col("c").alias("d").eq(lit(1).alias("e"))) + .alias("f"), + col("c").alias("d").eq(lit(1)).alias("f"), + ); +} + #[test] fn test_const_evaluator_scalar_functions() { // concat("foo", "bar") --> "foobar" diff --git a/datafusion/core/tests/fifo/mod.rs b/datafusion/core/tests/fifo/mod.rs index cb587e3510c2..141a3f3b7558 100644 --- a/datafusion/core/tests/fifo/mod.rs +++ b/datafusion/core/tests/fifo/mod.rs @@ -28,8 +28,7 @@ mod unix_test { use arrow::array::Array; use arrow::csv::ReaderBuilder; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow_schema::SchemaRef; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; use datafusion::datasource::TableProvider; use datafusion::{ diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 0257850ffc50..5e1f263b4c76 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -24,12 +24,11 @@ use crate::fuzz_cases::aggregation_fuzzer::{ use arrow::array::{types::Int64Type, Array, ArrayRef, AsArray, Int64Array, RecordBatch}; use arrow::compute::{concat_batches, SortOptions}; -use arrow::datatypes::DataType; -use arrow::util::pretty::pretty_format_batches; -use arrow_schema::{ - IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, +use arrow::datatypes::{ + DataType, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; +use arrow::util::pretty::pretty_format_batches; use datafusion::common::Result; use datafusion::datasource::MemTable; use datafusion::physical_expr::aggregate::AggregateExprBuilder; diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs index 9c8f83f75ccb..8a8aa180b3c4 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs @@ -254,7 +254,7 @@ impl SkipPartialParams { #[cfg(test)] mod test { use arrow::array::{RecordBatch, StringArray, UInt32Array}; - use arrow_schema::{DataType, Field, Schema}; + use arrow::datatypes::{DataType, Field, Schema}; use crate::fuzz_cases::aggregation_fuzzer::check_equality_of_batches; diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs index 3ebd899f4e15..4d4c6aa79357 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs @@ -19,15 +19,15 @@ use std::sync::Arc; use arrow::array::{ArrayRef, RecordBatch}; use arrow::datatypes::{ - BinaryType, BinaryViewType, BooleanType, ByteArrayType, ByteViewType, Date32Type, - Date64Type, Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int16Type, - Int32Type, Int64Type, Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, - IntervalYearMonthType, LargeBinaryType, LargeUtf8Type, StringViewType, - Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, - TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, - TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, Utf8Type, + BinaryType, BinaryViewType, BooleanType, ByteArrayType, ByteViewType, DataType, + Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, Float32Type, + Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType, + IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, LargeBinaryType, + LargeUtf8Type, Schema, StringViewType, Time32MillisecondType, Time32SecondType, + Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, Utf8Type, }; -use arrow_schema::{DataType, Field, IntervalUnit, Schema, TimeUnit}; use datafusion_common::{arrow_datafusion_err, DataFusionError, Result}; use datafusion_physical_expr::{expressions::col, PhysicalSortExpr}; use datafusion_physical_expr_common::sort_expr::LexOrdering; diff --git a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs index cd9897d43baa..769deef1187d 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs @@ -20,7 +20,7 @@ use crate::fuzz_cases::equivalence::utils::{ generate_table_for_eq_properties, generate_table_for_orderings, is_table_same_after_sort, TestScalarUDF, }; -use arrow_schema::SortOptions; +use arrow::compute::SortOptions; use datafusion_common::Result; use datafusion_expr::{Operator, ScalarUDF}; use datafusion_physical_expr::expressions::{col, BinaryExpr}; diff --git a/datafusion/core/tests/fuzz_cases/equivalence/projection.rs b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs index 78fbda16c0a0..a3fa1157b38f 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/projection.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs @@ -19,7 +19,7 @@ use crate::fuzz_cases::equivalence::utils::{ apply_projection, create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort, TestScalarUDF, }; -use arrow_schema::SortOptions; +use arrow::compute::SortOptions; use datafusion_common::Result; use datafusion_expr::{Operator, ScalarUDF}; use datafusion_physical_expr::equivalence::ProjectionMapping; diff --git a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs index b66b7b2aca43..d4b41b686631 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs @@ -23,13 +23,15 @@ use std::cmp::Ordering; use std::sync::Arc; use arrow::array::{ArrayRef, Float32Array, Float64Array, RecordBatch, UInt32Array}; +use arrow::compute::SortOptions; use arrow::compute::{lexsort_to_indices, take_record_batch, SortColumn}; -use arrow::datatypes::{DataType, Field, Schema}; -use arrow_schema::{SchemaRef, SortOptions}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::utils::{compare_rows, get_row_at_idx}; use datafusion_common::{exec_err, plan_datafusion_err, DataFusionError, Result}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; use datafusion_physical_expr::equivalence::{EquivalenceClass, ProjectionMapping}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::LexOrdering; @@ -581,12 +583,8 @@ impl ScalarUDFImpl for TestScalarUDF { Ok(input[0].sort_properties) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = ColumnarValue::values_to_arrays(&args.args)?; let arr: ArrayRef = match args[0].data_type() { DataType::Float64 => Arc::new({ diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 8e8178e55d87..5dd29f90ef83 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -22,9 +22,9 @@ use crate::fuzz_cases::join_fuzz::JoinTestType::{HjSmj, NljHj}; use arrow::array::{ArrayRef, Int32Array}; use arrow::compute::SortOptions; +use arrow::datatypes::Schema; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; -use arrow_schema::Schema; use datafusion::common::JoinSide; use datafusion::logical_expr::{JoinType, Operator}; use datafusion::physical_expr::expressions::BinaryExpr; diff --git a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs index a73845c56a0f..987a732eb294 100644 --- a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs @@ -19,8 +19,8 @@ use arrow::array::{Float64Array, Int32Array, Int64Array, RecordBatch, StringArray}; use arrow::compute::concat_batches; +use arrow::datatypes::SchemaRef; use arrow::util::pretty::pretty_format_batches; -use arrow_schema::SchemaRef; use datafusion::datasource::MemTable; use datafusion::prelude::SessionContext; use datafusion_common::assert_contains; diff --git a/datafusion/core/tests/fuzz_cases/pruning.rs b/datafusion/core/tests/fuzz_cases/pruning.rs index fef009fa911c..c6876d4a7e96 100644 --- a/datafusion/core/tests/fuzz_cases/pruning.rs +++ b/datafusion/core/tests/fuzz_cases/pruning.rs @@ -18,7 +18,7 @@ use std::sync::{Arc, OnceLock}; use arrow::array::{Array, RecordBatch, StringArray}; -use arrow_schema::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, Field, Schema}; use bytes::{BufMut, Bytes, BytesMut}; use datafusion::{ datasource::{ @@ -110,6 +110,13 @@ async fn test_utf8_not_like_prefix() { .await; } +#[tokio::test] +async fn test_utf8_not_like_ecsape() { + Utf8Test::new(|value| col("a").not_like(lit(format!("\\%{}%", value)))) + .run() + .await; +} + #[tokio::test] async fn test_utf8_not_like_suffix() { Utf8Test::new(|value| col("a").not_like(lit(format!("{}%", value)))) @@ -117,6 +124,13 @@ async fn test_utf8_not_like_suffix() { .await; } +#[tokio::test] +async fn test_utf8_not_like_suffix_one() { + Utf8Test::new(|value| col("a").not_like(lit(format!("{}_", value)))) + .run() + .await; +} + /// Fuzz testing for UTF8 predicate pruning /// The basic idea is that query results should always be the same with or without stats/pruning /// If we get this right we at least guarantee that there are no incorrect results @@ -321,7 +335,7 @@ async fn execute_with_predicate( }) .collect(), ); - let exec = scan.new_exec(); + let exec = scan.build(); let exec = Arc::new(FilterExec::try_new(predicate, exec).unwrap()) as Arc; diff --git a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs index ecc077261acc..51a5bc87efd9 100644 --- a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use arrow::{ - array::{ArrayRef, Int32Array}, + array::{as_string_array, ArrayRef, Int32Array, StringArray}, compute::SortOptions, record_batch::RecordBatch, }; @@ -29,6 +29,7 @@ use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::{collect, ExecutionPlan}; use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_common::cast::as_int32_array; use datafusion_execution::memory_pool::GreedyMemoryPool; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr_common::sort_expr::LexOrdering; @@ -42,42 +43,139 @@ const KB: usize = 1 << 10; #[cfg_attr(tarpaulin, ignore)] async fn test_sort_10k_mem() { for (batch_size, should_spill) in [(5, false), (20000, true), (500000, true)] { - SortTest::new() + let (input, collected) = SortTest::new() .with_int32_batches(batch_size) + .with_sort_columns(vec!["x"]) .with_pool_size(10 * KB) .with_should_spill(should_spill) .run() .await; + + let expected = partitions_to_sorted_vec(&input); + let actual = batches_to_vec(&collected); + assert_eq!(expected, actual, "failure in @ batch_size {batch_size:?}"); } } #[tokio::test] #[cfg_attr(tarpaulin, ignore)] async fn test_sort_100k_mem() { - for (batch_size, should_spill) in [(5, false), (20000, false), (1000000, true)] { - SortTest::new() + for (batch_size, should_spill) in + [(5, false), (10000, false), (20000, true), (1000000, true)] + { + let (input, collected) = SortTest::new() .with_int32_batches(batch_size) + .with_sort_columns(vec!["x"]) .with_pool_size(100 * KB) .with_should_spill(should_spill) .run() .await; + + let expected = partitions_to_sorted_vec(&input); + let actual = batches_to_vec(&collected); + assert_eq!(expected, actual, "failure in @ batch_size {batch_size:?}"); + } +} + +#[tokio::test] +#[cfg_attr(tarpaulin, ignore)] +async fn test_sort_strings_100k_mem() { + for (batch_size, should_spill) in + [(5, false), (1000, false), (10000, true), (20000, true)] + { + let (input, collected) = SortTest::new() + .with_utf8_batches(batch_size) + .with_sort_columns(vec!["x"]) + .with_pool_size(100 * KB) + .with_should_spill(should_spill) + .run() + .await; + + let mut input = input + .iter() + .flat_map(|p| p.iter()) + .flat_map(|b| { + let array = b.column(0); + as_string_array(array) + .iter() + .map(|s| s.unwrap().to_string()) + }) + .collect::>(); + input.sort_unstable(); + let actual = collected + .iter() + .flat_map(|b| { + let array = b.column(0); + as_string_array(array) + .iter() + .map(|s| s.unwrap().to_string()) + }) + .collect::>(); + assert_eq!(input, actual); + } +} + +#[tokio::test] +#[cfg_attr(tarpaulin, ignore)] +async fn test_sort_multi_columns_100k_mem() { + for (batch_size, should_spill) in + [(5, false), (1000, false), (10000, true), (20000, true)] + { + let (input, collected) = SortTest::new() + .with_int32_utf8_batches(batch_size) + .with_sort_columns(vec!["x", "y"]) + .with_pool_size(100 * KB) + .with_should_spill(should_spill) + .run() + .await; + + fn record_batch_to_vec(b: &RecordBatch) -> Vec<(i32, String)> { + let mut rows: Vec<_> = Vec::new(); + let i32_array = as_int32_array(b.column(0)).unwrap(); + let string_array = as_string_array(b.column(1)); + for i in 0..b.num_rows() { + let str = string_array.value(i).to_string(); + let i32 = i32_array.value(i); + rows.push((i32, str)); + } + rows + } + let mut input = input + .iter() + .flat_map(|p| p.iter()) + .flat_map(record_batch_to_vec) + .collect::>(); + input.sort_unstable(); + let actual = collected + .iter() + .flat_map(record_batch_to_vec) + .collect::>(); + assert_eq!(input, actual); } } #[tokio::test] async fn test_sort_unlimited_mem() { for (batch_size, should_spill) in [(5, false), (20000, false), (1000000, false)] { - SortTest::new() + let (input, collected) = SortTest::new() .with_int32_batches(batch_size) + .with_sort_columns(vec!["x"]) .with_pool_size(usize::MAX) .with_should_spill(should_spill) .run() .await; + + let expected = partitions_to_sorted_vec(&input); + let actual = batches_to_vec(&collected); + assert_eq!(expected, actual, "failure in @ batch_size {batch_size:?}"); } } + #[derive(Debug, Default)] struct SortTest { input: Vec>, + /// The names of the columns to sort by + sort_columns: Vec, /// GreedyMemoryPool size, if specified pool_size: Option, /// If true, expect the sort to spill @@ -89,12 +187,29 @@ impl SortTest { Default::default() } + fn with_sort_columns(mut self, sort_columns: Vec<&str>) -> Self { + self.sort_columns = sort_columns.iter().map(|s| s.to_string()).collect(); + self + } + /// Create batches of int32 values of rows fn with_int32_batches(mut self, rows: usize) -> Self { self.input = vec![make_staggered_i32_batches(rows)]; self } + /// Create batches of utf8 values of rows + fn with_utf8_batches(mut self, rows: usize) -> Self { + self.input = vec![make_staggered_utf8_batches(rows)]; + self + } + + /// Create batches of int32 and utf8 values of rows + fn with_int32_utf8_batches(mut self, rows: usize) -> Self { + self.input = vec![make_staggered_i32_utf8_batches(rows)]; + self + } + /// specify that this test should use a memory pool of the specified size fn with_pool_size(mut self, pool_size: usize) -> Self { self.pool_size = Some(pool_size); @@ -108,7 +223,7 @@ impl SortTest { /// Sort the input using SortExec and ensure the results are /// correct according to `Vec::sort` both with and without spilling - async fn run(&self) { + async fn run(&self) -> (Vec>, Vec) { let input = self.input.clone(); let first_batch = input .iter() @@ -117,16 +232,21 @@ impl SortTest { .expect("at least one batch"); let schema = first_batch.schema(); - let sort = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("x", &schema).unwrap(), - options: SortOptions { - descending: false, - nulls_first: true, - }, - }]); + let sort_ordering = LexOrdering::new( + self.sort_columns + .iter() + .map(|c| PhysicalSortExpr { + expr: col(c, &schema).unwrap(), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }) + .collect(), + ); let exec = MemorySourceConfig::try_new_exec(&input, schema, None).unwrap(); - let sort = Arc::new(SortExec::new(sort, exec)); + let sort = Arc::new(SortExec::new(sort_ordering, exec)); let session_config = SessionConfig::new(); let session_ctx = if let Some(pool_size) = self.pool_size { @@ -151,9 +271,6 @@ impl SortTest { let task_ctx = session_ctx.task_ctx(); let collected = collect(sort.clone(), task_ctx).await.unwrap(); - let expected = partitions_to_sorted_vec(&input); - let actual = batches_to_vec(&collected); - if self.should_spill { assert_ne!( sort.metrics().unwrap().spill_count().unwrap(), @@ -173,7 +290,8 @@ impl SortTest { 0, "The sort should have returned all memory used back to the memory pool" ); - assert_eq!(expected, actual, "failure in @ pool_size {self:?}"); + + (input, collected) } } @@ -201,3 +319,63 @@ fn make_staggered_i32_batches(len: usize) -> Vec { } batches } + +/// Return randomly sized record batches in a field named 'x' of type `Utf8` +/// with randomized content +fn make_staggered_utf8_batches(len: usize) -> Vec { + let mut rng = rand::thread_rng(); + let max_batch = 1024; + + let mut batches = vec![]; + let mut remaining = len; + while remaining != 0 { + let to_read = rng.gen_range(0..=remaining.min(max_batch)); + remaining -= to_read; + + batches.push( + RecordBatch::try_from_iter(vec![( + "x", + Arc::new(StringArray::from_iter_values( + (0..to_read).map(|_| format!("test_string_{}", rng.gen::())), + )) as ArrayRef, + )]) + .unwrap(), + ) + } + batches +} + +/// Return randomly sized record batches in a field named 'x' of type `Int32` +/// with randomized i32 content and a field named 'y' of type `Utf8` +/// with randomized content +fn make_staggered_i32_utf8_batches(len: usize) -> Vec { + let mut rng = rand::thread_rng(); + let max_batch = 1024; + + let mut batches = vec![]; + let mut remaining = len; + while remaining != 0 { + let to_read = rng.gen_range(0..=remaining.min(max_batch)); + remaining -= to_read; + + batches.push( + RecordBatch::try_from_iter(vec![ + ( + "x", + Arc::new(Int32Array::from_iter_values( + (0..to_read).map(|_| rng.gen()), + )) as ArrayRef, + ), + ( + "y", + Arc::new(StringArray::from_iter_values( + (0..to_read).map(|_| format!("test_string_{}", rng.gen::())), + )) as ArrayRef, + ), + ]) + .unwrap(), + ) + } + + batches +} diff --git a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs index 8ffc78a9f59d..d23408743f9f 100644 --- a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs @@ -20,8 +20,8 @@ mod sp_repartition_fuzz_tests { use std::sync::Arc; use arrow::array::{ArrayRef, Int64Array, RecordBatch, UInt64Array}; - use arrow::compute::{concat_batches, lexsort, SortColumn}; - use arrow_schema::{DataType, Field, Schema, SchemaRef, SortOptions}; + use arrow::compute::{concat_batches, lexsort, SortColumn, SortOptions}; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::physical_plan::{ collect, diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index b7c656627187..669294d38af1 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -24,8 +24,8 @@ use std::sync::{Arc, LazyLock}; #[cfg(feature = "extended_tests")] mod memory_limit_validation; use arrow::array::{ArrayRef, DictionaryArray, RecordBatch}; +use arrow::compute::SortOptions; use arrow::datatypes::{Int32Type, SchemaRef}; -use arrow_schema::SortOptions; use datafusion::assert_batches_eq; use datafusion::datasource::{MemTable, TableProvider}; use datafusion::execution::disk_manager::DiskManagerConfig; @@ -69,7 +69,7 @@ async fn oom_sort() { .with_expected_errors(vec![ "Resources exhausted: Memory Exhausted while Sorting (DiskManager is disabled)", ]) - .with_memory_limit(200_000) + .with_memory_limit(500_000) .run() .await } @@ -271,7 +271,8 @@ async fn sort_spill_reservation() { // Merge operation needs extra memory to do row conversion, so make the // memory limit larger. - let mem_limit = partition_size * 2; + let mem_limit = + ((partition_size * 2 + 1024) as f64 / MEMORY_FRACTION).ceil() as usize; let test = TestCase::new() // This query uses a different order than the input table to // force a sort. It also needs to have multiple columns to @@ -308,7 +309,8 @@ async fn sort_spill_reservation() { test.clone() .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: ExternalSorterMerge", + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:", + "bytes for ExternalSorterMerge", ]) .with_config(config) .run() diff --git a/datafusion/core/tests/optimizer/mod.rs b/datafusion/core/tests/optimizer/mod.rs index 37a6ca7f5934..585540bd5875 100644 --- a/datafusion/core/tests/optimizer/mod.rs +++ b/datafusion/core/tests/optimizer/mod.rs @@ -22,8 +22,9 @@ use std::any::Any; use std::collections::HashMap; use std::sync::Arc; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; -use arrow_schema::{Fields, SchemaBuilder}; +use arrow::datatypes::{ + DataType, Field, Fields, Schema, SchemaBuilder, SchemaRef, TimeUnit, +}; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{plan_err, DFSchema, Result, ScalarValue, TableReference}; diff --git a/datafusion/core/tests/parquet/custom_reader.rs b/datafusion/core/tests/parquet/custom_reader.rs index 928b650e0300..b12b3be2d435 100644 --- a/datafusion/core/tests/parquet/custom_reader.rs +++ b/datafusion/core/tests/parquet/custom_reader.rs @@ -90,7 +90,7 @@ async fn route_data_access_ops_to_parquet_file_reader_factory() { ) .with_file_group(file_group); - let parquet_exec = base_config.new_exec(); + let parquet_exec = base_config.build(); let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); diff --git a/datafusion/core/tests/parquet/external_access_plan.rs b/datafusion/core/tests/parquet/external_access_plan.rs index 216f03aac746..1eacbe42c525 100644 --- a/datafusion/core/tests/parquet/external_access_plan.rs +++ b/datafusion/core/tests/parquet/external_access_plan.rs @@ -23,8 +23,8 @@ use std::sync::Arc; use crate::parquet::utils::MetricsFinder; use crate::parquet::{create_data_batch, Scenario}; +use arrow::datatypes::SchemaRef; use arrow::util::pretty::pretty_format_batches; -use arrow_schema::SchemaRef; use datafusion::common::Result; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::physical_plan::parquet::{ParquetAccessPlan, RowGroupAccess}; @@ -351,7 +351,7 @@ impl TestFull { let config = FileScanConfig::new(object_store_url, schema.clone(), source) .with_file(partitioned_file); - let plan: Arc = config.new_exec(); + let plan: Arc = config.build(); // run the DataSourceExec and collect the results let results = diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index d8ce2970bdf7..5a85f47c015a 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -1506,9 +1506,6 @@ async fn test_bloom_filter_binary_dict() { .await; } -// Makes sense to enable (or at least try to) after -// https://github.com/apache/datafusion/issues/13821 -#[ignore] #[tokio::test] async fn test_bloom_filter_decimal_dict() { RowGroupPruningTest::new() diff --git a/datafusion/core/tests/parquet/schema_coercion.rs b/datafusion/core/tests/parquet/schema_coercion.rs index 9175a6e91e91..4cbbcf12f32b 100644 --- a/datafusion/core/tests/parquet/schema_coercion.rs +++ b/datafusion/core/tests/parquet/schema_coercion.rs @@ -21,8 +21,7 @@ use arrow::array::{ types::Int32Type, ArrayRef, DictionaryArray, Float32Array, Int64Array, RecordBatch, StringArray, }; -use arrow::datatypes::{Field, Schema}; -use arrow_schema::DataType; +use arrow::datatypes::{DataType, Field, Schema}; use datafusion::assert_batches_sorted_eq; use datafusion::datasource::physical_plan::{FileScanConfig, ParquetSource}; use datafusion::physical_plan::collect; @@ -65,7 +64,7 @@ async fn multi_parquet_coercion() { FileScanConfig::new(ObjectStoreUrl::local_filesystem(), file_schema, source) .with_file_group(file_group); - let parquet_exec = conf.new_exec(); + let parquet_exec = conf.build(); let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); @@ -122,7 +121,7 @@ async fn multi_parquet_coercion_projection() { ) .with_file_group(file_group) .with_projection(Some(vec![1, 0, 2])) - .new_exec(); + .build(); let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); diff --git a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs index 0fdb09b0d079..4e87e908251c 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs @@ -182,7 +182,7 @@ fn parquet_exec_multiple_sorted( vec![PartitionedFile::new("y".to_string(), 100)], ]) .with_output_ordering(output_ordering) - .new_exec() + .build() } fn csv_exec() -> Arc { @@ -197,7 +197,7 @@ fn csv_exec_with_sort(output_ordering: Vec) -> Arc ) .with_file(PartitionedFile::new("x".to_string(), 100)) .with_output_ordering(output_ordering) - .new_exec() + .build() } fn csv_exec_multiple() -> Arc { @@ -216,7 +216,7 @@ fn csv_exec_multiple_sorted(output_ordering: Vec) -> Arc Result<()> { ) .with_file(PartitionedFile::new("x".to_string(), 100)) .with_file_compression_type(compression_type) - .new_exec(), + .build(), vec![("a".to_string(), "a".to_string())], ); assert_optimized!(expected, plan, true, 2, true, 10, false); diff --git a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs index 920d1ecd3ee5..d8b9df633277 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs @@ -24,13 +24,13 @@ use crate::physical_optimizer::test_utils::{ create_test_schema3, create_test_schema4, filter_exec, global_limit_exec, hash_join_exec, limit_exec, local_limit_exec, memory_exec, parquet_exec, repartition_exec, sort_exec, sort_expr, sort_expr_options, sort_merge_join_exec, - sort_preserving_merge_exec, spr_repartition_exec, stream_exec_ordered, union_exec, - RequirementsTestExec, + sort_preserving_merge_exec, sort_preserving_merge_exec_with_fetch, + spr_repartition_exec, stream_exec_ordered, union_exec, RequirementsTestExec, }; use datafusion_physical_plan::displayable; use arrow::compute::SortOptions; -use arrow_schema::SchemaRef; +use arrow::datatypes::SchemaRef; use datafusion_common::Result; use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::{col, Column, NotExpr}; @@ -69,7 +69,7 @@ fn csv_exec_ordered( ) .with_file(PartitionedFile::new("file_path".to_string(), 100)) .with_output_ordering(vec![sort_exprs]) - .new_exec() + .build() } /// Created a sorted parquet exec @@ -87,7 +87,7 @@ pub fn parquet_exec_sorted( ) .with_file(PartitionedFile::new("x".to_string(), 100)) .with_output_ordering(vec![sort_exprs]) - .new_exec() + .build() } /// Create a sorted Csv exec @@ -104,7 +104,7 @@ fn csv_exec_sorted( ) .with_file(PartitionedFile::new("x".to_string(), 100)) .with_output_ordering(vec![sort_exprs]) - .new_exec() + .build() } /// Runs the sort enforcement optimizer and asserts the plan @@ -1941,6 +1941,30 @@ async fn test_remove_unnecessary_spm1() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_remove_unnecessary_spm2() -> Result<()> { + let schema = create_test_schema()?; + let source = memory_exec(&schema); + let input = sort_preserving_merge_exec_with_fetch( + vec![sort_expr("non_nullable_col", &schema)], + source, + 100, + ); + + let expected_input = [ + "SortPreservingMergeExec: [non_nullable_col@1 ASC], fetch=100", + " DataSourceExec: partitions=1, partition_sizes=[0]", + ]; + let expected_optimized = [ + "LocalLimitExec: fetch=100", + " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", + " DataSourceExec: partitions=1, partition_sizes=[0]", + ]; + assert_optimized!(expected_input, expected_optimized, input, true); + + Ok(()) +} + #[tokio::test] async fn test_change_wrong_sorting() -> Result<()> { let schema = create_test_schema()?; diff --git a/datafusion/core/tests/physical_optimizer/join_selection.rs b/datafusion/core/tests/physical_optimizer/join_selection.rs index ae7adacadb19..375af94acaf4 100644 --- a/datafusion/core/tests/physical_optimizer/join_selection.rs +++ b/datafusion/core/tests/physical_optimizer/join_selection.rs @@ -22,9 +22,8 @@ use std::{ task::{Context, Poll}, }; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use arrow_schema::SchemaRef; use datafusion_common::config::ConfigOptions; use datafusion_common::JoinSide; use datafusion_common::{stats::Precision, ColumnStatistics, JoinType, ScalarValue}; diff --git a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs index 6e5c677541c5..f9810eab8f59 100644 --- a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs @@ -24,8 +24,8 @@ use crate::physical_optimizer::test_utils::{ schema, TestAggregate, }; +use arrow::datatypes::DataType; use arrow::{compute::SortOptions, util::pretty::pretty_format_batches}; -use arrow_schema::DataType; use datafusion::prelude::SessionContext; use datafusion_common::Result; use datafusion_execution::config::SessionConfig; diff --git a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs index cc005c0aa889..92a329290ae1 100644 --- a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs @@ -18,7 +18,8 @@ use std::any::Any; use std::sync::Arc; -use arrow_schema::{DataType, Field, Schema, SchemaRef, SortOptions}; +use arrow::compute::SortOptions; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::physical_plan::{CsvSource, FileScanConfig}; use datafusion_common::config::ConfigOptions; @@ -26,9 +27,7 @@ use datafusion_common::Result; use datafusion_common::{JoinSide, JoinType, ScalarValue}; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_expr::{ - ColumnarValue, Operator, ScalarUDF, ScalarUDFImpl, Signature, Volatility, -}; +use datafusion_expr::{Operator, ScalarUDF, ScalarUDFImpl, Signature, Volatility}; use datafusion_physical_expr::expressions::{ binary, col, BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr, }; @@ -91,14 +90,6 @@ impl ScalarUDFImpl for DummyUDF { fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Int32) } - - fn invoke_batch( - &self, - _args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - unimplemented!("DummyUDF::invoke") - } } #[test] @@ -382,7 +373,7 @@ fn create_simple_csv_exec() -> Arc { ) .with_file(PartitionedFile::new("x".to_string(), 100)) .with_projection(Some(vec![0, 1, 2, 3, 4])) - .new_exec() + .build() } fn create_projecting_csv_exec() -> Arc { @@ -399,7 +390,7 @@ fn create_projecting_csv_exec() -> Arc { ) .with_file(PartitionedFile::new("x".to_string(), 100)) .with_projection(Some(vec![3, 2, 1])) - .new_exec() + .build() } fn create_projecting_memory_exec() -> Arc { diff --git a/datafusion/core/tests/physical_optimizer/sanity_checker.rs b/datafusion/core/tests/physical_optimizer/sanity_checker.rs index ccfec1fcb10e..a73d084a081f 100644 --- a/datafusion/core/tests/physical_optimizer/sanity_checker.rs +++ b/datafusion/core/tests/physical_optimizer/sanity_checker.rs @@ -22,7 +22,8 @@ use crate::physical_optimizer::test_utils::{ repartition_exec, sort_exec, sort_expr_options, sort_merge_join_exec, }; -use arrow_schema::{DataType, Field, Schema, SchemaRef, SortOptions}; +use arrow::compute::SortOptions; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; use datafusion::prelude::{CsvReadOptions, SessionContext}; use datafusion_common::config::ConfigOptions; diff --git a/datafusion/core/tests/physical_optimizer/test_utils.rs b/datafusion/core/tests/physical_optimizer/test_utils.rs index fe25f8e19f61..1c840c85cc58 100644 --- a/datafusion/core/tests/physical_optimizer/test_utils.rs +++ b/datafusion/core/tests/physical_optimizer/test_utils.rs @@ -22,8 +22,9 @@ use std::fmt::Formatter; use std::sync::Arc; use arrow::array::Int32Array; +use arrow::compute::SortOptions; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use arrow_schema::{DataType, Field, Schema, SchemaRef, SortOptions}; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::physical_plan::{FileScanConfig, ParquetSource}; use datafusion_common::config::ConfigOptions; @@ -74,7 +75,7 @@ pub fn parquet_exec(schema: &SchemaRef) -> Arc { Arc::new(ParquetSource::default()), ) .with_file(PartitionedFile::new("x".to_string(), 100)) - .new_exec() + .build() } /// Create a single parquet file that is sorted @@ -88,7 +89,7 @@ pub(crate) fn parquet_exec_with_sort( ) .with_file(PartitionedFile::new("x".to_string(), 100)) .with_output_ordering(output_ordering) - .new_exec() + .build() } pub fn schema() -> SchemaRef { @@ -288,6 +289,15 @@ pub fn sort_preserving_merge_exec( Arc::new(SortPreservingMergeExec::new(sort_exprs, input)) } +pub fn sort_preserving_merge_exec_with_fetch( + sort_exprs: impl IntoIterator, + input: Arc, + fetch: usize, +) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + Arc::new(SortPreservingMergeExec::new(sort_exprs, input).with_fetch(Some(fetch))) +} + pub fn union_exec(input: Vec>) -> Arc { Arc::new(UnionExec::new(input)) } diff --git a/datafusion/core/tests/user_defined/insert_operation.rs b/datafusion/core/tests/user_defined/insert_operation.rs index aa531632c60b..12f700ce572b 100644 --- a/datafusion/core/tests/user_defined/insert_operation.rs +++ b/datafusion/core/tests/user_defined/insert_operation.rs @@ -17,7 +17,7 @@ use std::{any::Any, sync::Arc}; -use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; use datafusion::{ error::Result, diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index aa0f6c8fed8d..7cda6d410f4e 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -28,8 +28,7 @@ use std::sync::{ use arrow::array::{ types::UInt64Type, AsArray, Int32Array, PrimitiveArray, StringArray, StructArray, }; -use arrow::datatypes::Fields; -use arrow_schema::Schema; +use arrow::datatypes::{Fields, Schema}; use datafusion::dataframe::DataFrame; use datafusion::datasource::MemTable; diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index ea83bd16b468..43e7ec9e45e4 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -25,21 +25,22 @@ use arrow::array::{ Int32Array, RecordBatch, StringArray, }; use arrow::compute::kernels::numeric::add; -use arrow_schema::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, Field, Schema}; use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState}; use datafusion::prelude::*; use datafusion::{execution::registry::FunctionRegistry, test_util}; use datafusion_common::cast::{as_float64_array, as_int32_array}; use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::utils::take_function_args; use datafusion_common::{ - assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, internal_err, - not_impl_err, plan_err, DFSchema, DataFusionError, HashMap, Result, ScalarValue, + assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, not_impl_err, + plan_err, DFSchema, DataFusionError, HashMap, Result, ScalarValue, }; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, LogicalPlanBuilder, - OperateFunctionArg, ReturnInfo, ReturnTypeArgs, ScalarUDF, ScalarUDFImpl, Signature, - Volatility, + OperateFunctionArg, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF, + ScalarUDFImpl, Signature, Volatility, }; use datafusion_functions_nested::range::range_udf; use parking_lot::Mutex; @@ -207,11 +208,7 @@ impl ScalarUDFImpl for Simple0ArgsScalarUDF { Ok(self.return_type.clone()) } - fn invoke_batch( - &self, - _args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(100)))) } } @@ -518,16 +515,13 @@ impl ScalarUDFImpl for AddIndexToStringVolatileScalarUDF { Ok(self.return_type.clone()) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - number_rows: usize, - ) -> Result { - let answer = match &args[0] { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [arg] = take_function_args(self.name(), &args.args)?; + let answer = match arg { // When called with static arguments, the result is returned as an array. ColumnarValue::Scalar(ScalarValue::Utf8(Some(value))) => { let mut answer = vec![]; - for index in 1..=number_rows { + for index in 1..=args.number_rows { // When calling a function with immutable arguments, the result is returned with ")". // Example: SELECT add_index_to_string('const_value') FROM table; answer.push(index.to_string() + ") " + value); @@ -713,14 +707,6 @@ impl ScalarUDFImpl for CastToI64UDF { // return the newly written argument to DataFusion Ok(ExprSimplifyResult::Simplified(new_expr)) } - - fn invoke_batch( - &self, - _args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - unimplemented!("Function should have been simplified prior to evaluation") - } } #[tokio::test] @@ -850,17 +836,14 @@ impl ScalarUDFImpl for TakeUDF { } // The actual implementation - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - let take_idx = match &args[2] { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [_arg0, _arg1, arg2] = take_function_args(self.name(), &args.args)?; + let take_idx = match arg2 { ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) if v == "0" => 0, ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) if v == "1" => 1, _ => unreachable!(), }; - match &args[take_idx] { + match &args.args[take_idx] { ColumnarValue::Array(array) => Ok(ColumnarValue::Array(array.clone())), ColumnarValue::Scalar(_) => unimplemented!(), } @@ -963,14 +946,6 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { Ok(self.return_type.clone()) } - fn invoke_batch( - &self, - _args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - internal_err!("This function should not get invoked!") - } - fn simplify( &self, args: Vec, diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index 204d786994f8..9acd17493da4 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -19,7 +19,7 @@ //! user defined window functions use arrow::array::{ArrayRef, AsArray, Int64Array, RecordBatch, StringArray}; -use arrow_schema::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, Field, Schema}; use datafusion::{assert_batches_eq, prelude::SessionContext}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ diff --git a/datafusion/datasource/Cargo.toml b/datafusion/datasource/Cargo.toml new file mode 100644 index 000000000000..caf1c60a785d --- /dev/null +++ b/datafusion/datasource/Cargo.toml @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-datasource" +description = "datafusion-datasource" +authors.workspace = true +edition.workspace = true +homepage.workspace = true +license.workspace = true +readme.workspace = true +repository.workspace = true +rust-version.workspace = true +version.workspace = true + +[features] +compression = ["async-compression", "xz2", "bzip2", "flate2", "zstd", "tokio-util"] +default = ["compression"] + +[dependencies] +arrow = { workspace = true } +async-compression = { version = "0.4.0", features = [ + "bzip2", + "gzip", + "xz", + "zstd", + "tokio", +], optional = true } +async-trait = { workspace = true } +bytes = { workspace = true } +bzip2 = { version = "0.5.1", optional = true } +chrono = { workspace = true } +datafusion-catalog = { workspace = true } +datafusion-common = { workspace = true, features = ["object_store"] } +datafusion-common-runtime = { workspace = true } +datafusion-execution = { workspace = true } +datafusion-expr = { workspace = true } +datafusion-physical-plan = { workspace = true } +flate2 = { version = "1.0.24", optional = true } +futures = { workspace = true } +glob = "0.3.0" +itertools = { workspace = true } +log = { workspace = true } +object_store = { workspace = true } +rand = { workspace = true } +tokio = { workspace = true } +tokio-util = { version = "0.7.4", features = ["io"], optional = true } +url = { workspace = true } +xz2 = { version = "0.1", optional = true, features = ["static"] } +zstd = { version = "0.13", optional = true, default-features = false } + +[dev-dependencies] +tempfile = { workspace = true } + +[lints] +workspace = true + +[lib] +name = "datafusion_datasource" +path = "src/mod.rs" diff --git a/datafusion/datasource/LICENSE.txt b/datafusion/datasource/LICENSE.txt new file mode 120000 index 000000000000..1ef648f64b34 --- /dev/null +++ b/datafusion/datasource/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/datasource/NOTICE.txt b/datafusion/datasource/NOTICE.txt new file mode 120000 index 000000000000..fb051c92b10b --- /dev/null +++ b/datafusion/datasource/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/datasource/README.md b/datafusion/datasource/README.md new file mode 100644 index 000000000000..2479a28ae68d --- /dev/null +++ b/datafusion/datasource/README.md @@ -0,0 +1,24 @@ + + +# DataFusion datasource + +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. + +This crate is a submodule of DataFusion that defines common DataSource related components like FileScanConfig, FileCompression etc. diff --git a/datafusion/core/src/datasource/file_format/file_compression_type.rs b/datafusion/datasource/src/file_compression_type.rs similarity index 98% rename from datafusion/core/src/datasource/file_format/file_compression_type.rs rename to datafusion/datasource/src/file_compression_type.rs index 6612de077988..7cc3142564e9 100644 --- a/datafusion/core/src/datasource/file_format/file_compression_type.rs +++ b/datafusion/datasource/src/file_compression_type.rs @@ -19,7 +19,7 @@ use std::str::FromStr; -use crate::error::{DataFusionError, Result}; +use datafusion_common::error::{DataFusionError, Result}; use datafusion_common::parsers::CompressionTypeVariant::{self, *}; use datafusion_common::GetExt; @@ -254,8 +254,8 @@ pub trait FileTypeExt { mod tests { use std::str::FromStr; - use crate::datasource::file_format::file_compression_type::FileCompressionType; - use crate::error::DataFusionError; + use super::FileCompressionType; + use datafusion_common::error::DataFusionError; use bytes::Bytes; use futures::StreamExt; diff --git a/datafusion/catalog-listing/src/file_groups.rs b/datafusion/datasource/src/file_groups.rs similarity index 100% rename from datafusion/catalog-listing/src/file_groups.rs rename to datafusion/datasource/src/file_groups.rs diff --git a/datafusion/datasource/src/file_meta.rs b/datafusion/datasource/src/file_meta.rs new file mode 100644 index 000000000000..098a15eeb38a --- /dev/null +++ b/datafusion/datasource/src/file_meta.rs @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use object_store::{path::Path, ObjectMeta}; + +use crate::FileRange; + +/// A single file or part of a file that should be read, along with its schema, statistics +pub struct FileMeta { + /// Path for the file (e.g. URL, filesystem path, etc) + pub object_meta: ObjectMeta, + /// An optional file range for a more fine-grained parallel execution + pub range: Option, + /// An optional field for user defined per object metadata + pub extensions: Option>, + /// Size hint for the metadata of this file + pub metadata_size_hint: Option, +} + +impl FileMeta { + /// The full path to the object + pub fn location(&self) -> &Path { + &self.object_meta.location + } +} + +impl From for FileMeta { + fn from(object_meta: ObjectMeta) -> Self { + Self { + object_meta, + range: None, + extensions: None, + metadata_size_hint: None, + } + } +} diff --git a/datafusion/datasource/src/file_scan_config.rs b/datafusion/datasource/src/file_scan_config.rs new file mode 100644 index 000000000000..bfddbc3a1fc4 --- /dev/null +++ b/datafusion/datasource/src/file_scan_config.rs @@ -0,0 +1,278 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{borrow::Cow, collections::HashMap, marker::PhantomData, sync::Arc}; + +use arrow::{ + array::{ + ArrayData, ArrayRef, BufferBuilder, DictionaryArray, RecordBatch, + RecordBatchOptions, + }, + buffer::Buffer, + datatypes::{ArrowNativeType, DataType, SchemaRef, UInt16Type}, +}; +use datafusion_common::{exec_err, Result}; +use datafusion_common::{DataFusionError, ScalarValue}; +use log::warn; + +/// A helper that projects partition columns into the file record batches. +/// +/// One interesting trick is the usage of a cache for the key buffers of the partition column +/// dictionaries. Indeed, the partition columns are constant, so the dictionaries that represent them +/// have all their keys equal to 0. This enables us to re-use the same "all-zero" buffer across batches, +/// which makes the space consumption of the partition columns O(batch_size) instead of O(record_count). +pub struct PartitionColumnProjector { + /// An Arrow buffer initialized to zeros that represents the key array of all partition + /// columns (partition columns are materialized by dictionary arrays with only one + /// value in the dictionary, thus all the keys are equal to zero). + key_buffer_cache: ZeroBufferGenerators, + /// Mapping between the indexes in the list of partition columns and the target + /// schema. Sorted by index in the target schema so that we can iterate on it to + /// insert the partition columns in the target record batch. + projected_partition_indexes: Vec<(usize, usize)>, + /// The schema of the table once the projection was applied. + projected_schema: SchemaRef, +} + +impl PartitionColumnProjector { + // Create a projector to insert the partitioning columns into batches read from files + // - `projected_schema`: the target schema with both file and partitioning columns + // - `table_partition_cols`: all the partitioning column names + pub fn new(projected_schema: SchemaRef, table_partition_cols: &[String]) -> Self { + let mut idx_map = HashMap::new(); + for (partition_idx, partition_name) in table_partition_cols.iter().enumerate() { + if let Ok(schema_idx) = projected_schema.index_of(partition_name) { + idx_map.insert(partition_idx, schema_idx); + } + } + + let mut projected_partition_indexes: Vec<_> = idx_map.into_iter().collect(); + projected_partition_indexes.sort_by(|(_, a), (_, b)| a.cmp(b)); + + Self { + projected_partition_indexes, + key_buffer_cache: Default::default(), + projected_schema, + } + } + + // Transform the batch read from the file by inserting the partitioning columns + // to the right positions as deduced from `projected_schema` + // - `file_batch`: batch read from the file, with internal projection applied + // - `partition_values`: the list of partition values, one for each partition column + pub fn project( + &mut self, + file_batch: RecordBatch, + partition_values: &[ScalarValue], + ) -> Result { + let expected_cols = + self.projected_schema.fields().len() - self.projected_partition_indexes.len(); + + if file_batch.columns().len() != expected_cols { + return exec_err!( + "Unexpected batch schema from file, expected {} cols but got {}", + expected_cols, + file_batch.columns().len() + ); + } + + let mut cols = file_batch.columns().to_vec(); + for &(pidx, sidx) in &self.projected_partition_indexes { + let p_value = + partition_values + .get(pidx) + .ok_or(DataFusionError::Execution( + "Invalid partitioning found on disk".to_string(), + ))?; + + let mut partition_value = Cow::Borrowed(p_value); + + // check if user forgot to dict-encode the partition value + let field = self.projected_schema.field(sidx); + let expected_data_type = field.data_type(); + let actual_data_type = partition_value.data_type(); + if let DataType::Dictionary(key_type, _) = expected_data_type { + if !matches!(actual_data_type, DataType::Dictionary(_, _)) { + warn!("Partition value for column {} was not dictionary-encoded, applied auto-fix.", field.name()); + partition_value = Cow::Owned(ScalarValue::Dictionary( + key_type.clone(), + Box::new(partition_value.as_ref().clone()), + )); + } + } + + cols.insert( + sidx, + create_output_array( + &mut self.key_buffer_cache, + partition_value.as_ref(), + file_batch.num_rows(), + )?, + ) + } + + RecordBatch::try_new_with_options( + Arc::clone(&self.projected_schema), + cols, + &RecordBatchOptions::new().with_row_count(Some(file_batch.num_rows())), + ) + .map_err(Into::into) + } +} + +#[derive(Debug, Default)] +struct ZeroBufferGenerators { + gen_i8: ZeroBufferGenerator, + gen_i16: ZeroBufferGenerator, + gen_i32: ZeroBufferGenerator, + gen_i64: ZeroBufferGenerator, + gen_u8: ZeroBufferGenerator, + gen_u16: ZeroBufferGenerator, + gen_u32: ZeroBufferGenerator, + gen_u64: ZeroBufferGenerator, +} + +/// Generate a arrow [`Buffer`] that contains zero values. +#[derive(Debug, Default)] +struct ZeroBufferGenerator +where + T: ArrowNativeType, +{ + cache: Option, + _t: PhantomData, +} + +impl ZeroBufferGenerator +where + T: ArrowNativeType, +{ + const SIZE: usize = size_of::(); + + fn get_buffer(&mut self, n_vals: usize) -> Buffer { + match &mut self.cache { + Some(buf) if buf.len() >= n_vals * Self::SIZE => { + buf.slice_with_length(0, n_vals * Self::SIZE) + } + _ => { + let mut key_buffer_builder = BufferBuilder::::new(n_vals); + key_buffer_builder.advance(n_vals); // keys are all 0 + self.cache.insert(key_buffer_builder.finish()).clone() + } + } + } +} + +fn create_dict_array( + buffer_gen: &mut ZeroBufferGenerator, + dict_val: &ScalarValue, + len: usize, + data_type: DataType, +) -> Result +where + T: ArrowNativeType, +{ + let dict_vals = dict_val.to_array()?; + + let sliced_key_buffer = buffer_gen.get_buffer(len); + + // assemble pieces together + let mut builder = ArrayData::builder(data_type) + .len(len) + .add_buffer(sliced_key_buffer); + builder = builder.add_child_data(dict_vals.to_data()); + Ok(Arc::new(DictionaryArray::::from( + builder.build().unwrap(), + ))) +} + +fn create_output_array( + key_buffer_cache: &mut ZeroBufferGenerators, + val: &ScalarValue, + len: usize, +) -> Result { + if let ScalarValue::Dictionary(key_type, dict_val) = &val { + match key_type.as_ref() { + DataType::Int8 => { + return create_dict_array( + &mut key_buffer_cache.gen_i8, + dict_val, + len, + val.data_type(), + ); + } + DataType::Int16 => { + return create_dict_array( + &mut key_buffer_cache.gen_i16, + dict_val, + len, + val.data_type(), + ); + } + DataType::Int32 => { + return create_dict_array( + &mut key_buffer_cache.gen_i32, + dict_val, + len, + val.data_type(), + ); + } + DataType::Int64 => { + return create_dict_array( + &mut key_buffer_cache.gen_i64, + dict_val, + len, + val.data_type(), + ); + } + DataType::UInt8 => { + return create_dict_array( + &mut key_buffer_cache.gen_u8, + dict_val, + len, + val.data_type(), + ); + } + DataType::UInt16 => { + return create_dict_array( + &mut key_buffer_cache.gen_u16, + dict_val, + len, + val.data_type(), + ); + } + DataType::UInt32 => { + return create_dict_array( + &mut key_buffer_cache.gen_u32, + dict_val, + len, + val.data_type(), + ); + } + DataType::UInt64 => { + return create_dict_array( + &mut key_buffer_cache.gen_u64, + dict_val, + len, + val.data_type(), + ); + } + _ => {} + } + } + + val.to_array_of_size(len) +} diff --git a/datafusion/datasource/src/file_sink_config.rs b/datafusion/datasource/src/file_sink_config.rs new file mode 100644 index 000000000000..6087f930d3fe --- /dev/null +++ b/datafusion/datasource/src/file_sink_config.rs @@ -0,0 +1,114 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::write::demux::{start_demuxer_task, DemuxedStreamReceiver}; +use crate::{ListingTableUrl, PartitionedFile}; +use arrow::datatypes::{DataType, SchemaRef}; +use async_trait::async_trait; +use datafusion_common::Result; +use datafusion_common_runtime::SpawnedTask; +use datafusion_execution::object_store::ObjectStoreUrl; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_expr::dml::InsertOp; +use datafusion_physical_plan::insert::DataSink; +use object_store::ObjectStore; +use std::sync::Arc; + +/// General behaviors for files that do `DataSink` operations +#[async_trait] +pub trait FileSink: DataSink { + /// Retrieves the file sink configuration. + fn config(&self) -> &FileSinkConfig; + + /// Spawns writer tasks and joins them to perform file writing operations. + /// Is a critical part of `FileSink` trait, since it's the very last step for `write_all`. + /// + /// This function handles the process of writing data to files by: + /// 1. Spawning tasks for writing data to individual files. + /// 2. Coordinating the tasks using a demuxer to distribute data among files. + /// 3. Collecting results using `tokio::join`, ensuring that all tasks complete successfully. + /// + /// # Parameters + /// - `context`: The execution context (`TaskContext`) that provides resources + /// like memory management and runtime environment. + /// - `demux_task`: A spawned task that handles demuxing, responsible for splitting + /// an input [`SendableRecordBatchStream`] into dynamically determined partitions. + /// See `start_demuxer_task()` + /// - `file_stream_rx`: A receiver that yields streams of record batches and their + /// corresponding file paths for writing. See `start_demuxer_task()` + /// - `object_store`: A handle to the object store where the files are written. + /// + /// # Returns + /// - `Result`: Returns the total number of rows written across all files. + async fn spawn_writer_tasks_and_join( + &self, + context: &Arc, + demux_task: SpawnedTask>, + file_stream_rx: DemuxedStreamReceiver, + object_store: Arc, + ) -> Result; + + /// File sink implementation of the [`DataSink::write_all`] method. + async fn write_all( + &self, + data: SendableRecordBatchStream, + context: &Arc, + ) -> Result { + let config = self.config(); + let object_store = context + .runtime_env() + .object_store(&config.object_store_url)?; + let (demux_task, file_stream_rx) = start_demuxer_task(config, data, context); + self.spawn_writer_tasks_and_join( + context, + demux_task, + file_stream_rx, + object_store, + ) + .await + } +} + +/// The base configurations to provide when creating a physical plan for +/// writing to any given file format. +pub struct FileSinkConfig { + /// Object store URL, used to get an ObjectStore instance + pub object_store_url: ObjectStoreUrl, + /// A vector of [`PartitionedFile`] structs, each representing a file partition + pub file_groups: Vec, + /// Vector of partition paths + pub table_paths: Vec, + /// The schema of the output file + pub output_schema: SchemaRef, + /// A vector of column names and their corresponding data types, + /// representing the partitioning columns for the file + pub table_partition_cols: Vec<(String, DataType)>, + /// Controls how new data should be written to the file, determining whether + /// to append to, overwrite, or replace records in existing files. + pub insert_op: InsertOp, + /// Controls whether partition columns are kept for the file + pub keep_partition_by_columns: bool, + /// File extension without a dot(.) + pub file_extension: String, +} + +impl FileSinkConfig { + /// Get output schema + pub fn output_schema(&self) -> &SchemaRef { + &self.output_schema + } +} diff --git a/datafusion/datasource/src/file_stream.rs b/datafusion/datasource/src/file_stream.rs new file mode 100644 index 000000000000..570ca6678538 --- /dev/null +++ b/datafusion/datasource/src/file_stream.rs @@ -0,0 +1,214 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A generic stream over file format readers that can be used by +//! any file format that read its files from start to end. +//! +//! Note: Most traits here need to be marked `Sync + Send` to be +//! compliant with the `SendableRecordBatchStream` trait. + +use crate::file_meta::FileMeta; +use datafusion_common::error::Result; +use datafusion_physical_plan::metrics::{ + Count, ExecutionPlanMetricsSet, MetricBuilder, Time, +}; + +use arrow::error::ArrowError; +use arrow::record_batch::RecordBatch; +use datafusion_common::instant::Instant; +use datafusion_common::ScalarValue; + +use futures::future::BoxFuture; +use futures::stream::BoxStream; + +/// A fallible future that resolves to a stream of [`RecordBatch`] +pub type FileOpenFuture = + BoxFuture<'static, Result>>>; + +/// Describes the behavior of the `FileStream` if file opening or scanning fails +pub enum OnError { + /// Fail the entire stream and return the underlying error + Fail, + /// Continue scanning, ignoring the failed file + Skip, +} + +impl Default for OnError { + fn default() -> Self { + Self::Fail + } +} + +/// Generic API for opening a file using an [`ObjectStore`] and resolving to a +/// stream of [`RecordBatch`] +/// +/// [`ObjectStore`]: object_store::ObjectStore +pub trait FileOpener: Unpin + Send + Sync { + /// Asynchronously open the specified file and return a stream + /// of [`RecordBatch`] + fn open(&self, file_meta: FileMeta) -> Result; +} + +/// Represents the state of the next `FileOpenFuture`. Since we need to poll +/// this future while scanning the current file, we need to store the result if it +/// is ready +pub enum NextOpen { + Pending(FileOpenFuture), + Ready(Result>>), +} + +pub enum FileStreamState { + /// The idle state, no file is currently being read + Idle, + /// Currently performing asynchronous IO to obtain a stream of RecordBatch + /// for a given file + Open { + /// A [`FileOpenFuture`] returned by [`FileOpener::open`] + future: FileOpenFuture, + /// The partition values for this file + partition_values: Vec, + }, + /// Scanning the [`BoxStream`] returned by the completion of a [`FileOpenFuture`] + /// returned by [`FileOpener::open`] + Scan { + /// Partitioning column values for the current batch_iter + partition_values: Vec, + /// The reader instance + reader: BoxStream<'static, Result>, + /// A [`FileOpenFuture`] for the next file to be processed, + /// and its corresponding partition column values, if any. + /// This allows the next file to be opened in parallel while the + /// current file is read. + next: Option<(NextOpen, Vec)>, + }, + /// Encountered an error + Error, + /// Reached the row limit + Limit, +} + +/// A timer that can be started and stopped. +pub struct StartableTime { + pub metrics: Time, + // use for record each part cost time, will eventually add into 'metrics'. + pub start: Option, +} + +impl StartableTime { + pub fn start(&mut self) { + assert!(self.start.is_none()); + self.start = Some(Instant::now()); + } + + pub fn stop(&mut self) { + if let Some(start) = self.start.take() { + self.metrics.add_elapsed(start); + } + } +} + +#[allow(rustdoc::broken_intra_doc_links)] +/// Metrics for [`FileStream`] +/// +/// Note that all of these metrics are in terms of wall clock time +/// (not cpu time) so they include time spent waiting on I/O as well +/// as other operators. +/// +/// [`FileStream`]: +pub struct FileStreamMetrics { + /// Wall clock time elapsed for file opening. + /// + /// Time between when [`FileOpener::open`] is called and when the + /// [`FileStream`] receives a stream for reading. + /// + /// If there are multiple files being scanned, the stream + /// will open the next file in the background while scanning the + /// current file. This metric will only capture time spent opening + /// while not also scanning. + /// [`FileStream`]: + pub time_opening: StartableTime, + /// Wall clock time elapsed for file scanning + first record batch of decompression + decoding + /// + /// Time between when the [`FileStream`] requests data from the + /// stream and when the first [`RecordBatch`] is produced. + /// [`FileStream`]: + pub time_scanning_until_data: StartableTime, + /// Total elapsed wall clock time for scanning + record batch decompression / decoding + /// + /// Sum of time between when the [`FileStream`] requests data from + /// the stream and when a [`RecordBatch`] is produced for all + /// record batches in the stream. Note that this metric also + /// includes the time of the parent operator's execution. + pub time_scanning_total: StartableTime, + /// Wall clock time elapsed for data decompression + decoding + /// + /// Time spent waiting for the FileStream's input. + pub time_processing: StartableTime, + /// Count of errors opening file. + /// + /// If using `OnError::Skip` this will provide a count of the number of files + /// which were skipped and will not be included in the scan results. + pub file_open_errors: Count, + /// Count of errors scanning file + /// + /// If using `OnError::Skip` this will provide a count of the number of files + /// which were skipped and will not be included in the scan results. + pub file_scan_errors: Count, +} + +impl FileStreamMetrics { + pub fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { + let time_opening = StartableTime { + metrics: MetricBuilder::new(metrics) + .subset_time("time_elapsed_opening", partition), + start: None, + }; + + let time_scanning_until_data = StartableTime { + metrics: MetricBuilder::new(metrics) + .subset_time("time_elapsed_scanning_until_data", partition), + start: None, + }; + + let time_scanning_total = StartableTime { + metrics: MetricBuilder::new(metrics) + .subset_time("time_elapsed_scanning_total", partition), + start: None, + }; + + let time_processing = StartableTime { + metrics: MetricBuilder::new(metrics) + .subset_time("time_elapsed_processing", partition), + start: None, + }; + + let file_open_errors = + MetricBuilder::new(metrics).counter("file_open_errors", partition); + + let file_scan_errors = + MetricBuilder::new(metrics).counter("file_scan_errors", partition); + + Self { + time_opening, + time_scanning_until_data, + time_scanning_total, + time_processing, + file_open_errors, + file_scan_errors, + } + } +} diff --git a/datafusion/datasource/src/mod.rs b/datafusion/datasource/src/mod.rs new file mode 100644 index 000000000000..c735c3108b3d --- /dev/null +++ b/datafusion/datasource/src/mod.rs @@ -0,0 +1,283 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A table that uses the `ObjectStore` listing capability +//! to get the list of files to process. + +pub mod file_compression_type; +pub mod file_groups; +pub mod file_meta; +pub mod file_scan_config; +pub mod file_sink_config; +pub mod file_stream; +pub mod url; +pub mod write; +use chrono::TimeZone; +use datafusion_common::Result; +use datafusion_common::{ScalarValue, Statistics}; +use futures::Stream; +use object_store::{path::Path, ObjectMeta}; +use std::pin::Pin; +use std::sync::Arc; + +pub use self::url::ListingTableUrl; + +/// Stream of files get listed from object store +pub type PartitionedFileStream = + Pin> + Send + Sync + 'static>>; + +/// Only scan a subset of Row Groups from the Parquet file whose data "midpoint" +/// lies within the [start, end) byte offsets. This option can be used to scan non-overlapping +/// sections of a Parquet file in parallel. +#[derive(Debug, Clone, PartialEq, Hash, Eq, PartialOrd, Ord)] +pub struct FileRange { + /// Range start + pub start: i64, + /// Range end + pub end: i64, +} + +impl FileRange { + /// returns true if this file range contains the specified offset + pub fn contains(&self, offset: i64) -> bool { + offset >= self.start && offset < self.end + } +} + +#[derive(Debug, Clone)] +/// A single file or part of a file that should be read, along with its schema, statistics +/// and partition column values that need to be appended to each row. +pub struct PartitionedFile { + /// Path for the file (e.g. URL, filesystem path, etc) + pub object_meta: ObjectMeta, + /// Values of partition columns to be appended to each row. + /// + /// These MUST have the same count, order, and type than the [`table_partition_cols`]. + /// + /// You may use [`wrap_partition_value_in_dict`] to wrap them if you have used [`wrap_partition_type_in_dict`] to wrap the column type. + /// + /// + /// [`wrap_partition_type_in_dict`]: https://github.com/apache/datafusion/blob/main/datafusion/core/src/datasource/physical_plan/file_scan_config.rs#L55 + /// [`wrap_partition_value_in_dict`]: https://github.com/apache/datafusion/blob/main/datafusion/core/src/datasource/physical_plan/file_scan_config.rs#L62 + /// [`table_partition_cols`]: https://github.com/apache/datafusion/blob/main/datafusion/core/src/datasource/file_format/options.rs#L190 + pub partition_values: Vec, + /// An optional file range for a more fine-grained parallel execution + pub range: Option, + /// Optional statistics that describe the data in this file if known. + /// + /// DataFusion relies on these statistics for planning (in particular to sort file groups), + /// so if they are incorrect, incorrect answers may result. + pub statistics: Option, + /// An optional field for user defined per object metadata + pub extensions: Option>, + /// The estimated size of the parquet metadata, in bytes + pub metadata_size_hint: Option, +} + +impl PartitionedFile { + /// Create a simple file without metadata or partition + pub fn new(path: impl Into, size: u64) -> Self { + Self { + object_meta: ObjectMeta { + location: Path::from(path.into()), + last_modified: chrono::Utc.timestamp_nanos(0), + size: size as usize, + e_tag: None, + version: None, + }, + partition_values: vec![], + range: None, + statistics: None, + extensions: None, + metadata_size_hint: None, + } + } + + /// Create a file range without metadata or partition + pub fn new_with_range(path: String, size: u64, start: i64, end: i64) -> Self { + Self { + object_meta: ObjectMeta { + location: Path::from(path), + last_modified: chrono::Utc.timestamp_nanos(0), + size: size as usize, + e_tag: None, + version: None, + }, + partition_values: vec![], + range: Some(FileRange { start, end }), + statistics: None, + extensions: None, + metadata_size_hint: None, + } + .with_range(start, end) + } + + /// Provide a hint to the size of the file metadata. If a hint is provided + /// the reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. + /// Without an appropriate hint, two read may be required to fetch the metadata. + pub fn with_metadata_size_hint(mut self, metadata_size_hint: usize) -> Self { + self.metadata_size_hint = Some(metadata_size_hint); + self + } + + /// Return a file reference from the given path + pub fn from_path(path: String) -> Result { + let size = std::fs::metadata(path.clone())?.len(); + Ok(Self::new(path, size)) + } + + /// Return the path of this partitioned file + pub fn path(&self) -> &Path { + &self.object_meta.location + } + + /// Update the file to only scan the specified range (in bytes) + pub fn with_range(mut self, start: i64, end: i64) -> Self { + self.range = Some(FileRange { start, end }); + self + } + + /// Update the user defined extensions for this file. + /// + /// This can be used to pass reader specific information. + pub fn with_extensions( + mut self, + extensions: Arc, + ) -> Self { + self.extensions = Some(extensions); + self + } +} + +impl From for PartitionedFile { + fn from(object_meta: ObjectMeta) -> Self { + PartitionedFile { + object_meta, + partition_values: vec![], + range: None, + statistics: None, + extensions: None, + metadata_size_hint: None, + } + } +} + +#[cfg(test)] +mod tests { + use super::ListingTableUrl; + use datafusion_execution::object_store::{ + DefaultObjectStoreRegistry, ObjectStoreRegistry, + }; + use object_store::{local::LocalFileSystem, path::Path}; + use std::{ops::Not, sync::Arc}; + use url::Url; + + #[test] + fn test_object_store_listing_url() { + let listing = ListingTableUrl::parse("file:///").unwrap(); + let store = listing.object_store(); + assert_eq!(store.as_str(), "file:///"); + + let listing = ListingTableUrl::parse("s3://bucket/").unwrap(); + let store = listing.object_store(); + assert_eq!(store.as_str(), "s3://bucket/"); + } + + #[test] + fn test_get_store_hdfs() { + let sut = DefaultObjectStoreRegistry::default(); + let url = Url::parse("hdfs://localhost:8020").unwrap(); + sut.register_store(&url, Arc::new(LocalFileSystem::new())); + let url = ListingTableUrl::parse("hdfs://localhost:8020/key").unwrap(); + sut.get_store(url.as_ref()).unwrap(); + } + + #[test] + fn test_get_store_s3() { + let sut = DefaultObjectStoreRegistry::default(); + let url = Url::parse("s3://bucket/key").unwrap(); + sut.register_store(&url, Arc::new(LocalFileSystem::new())); + let url = ListingTableUrl::parse("s3://bucket/key").unwrap(); + sut.get_store(url.as_ref()).unwrap(); + } + + #[test] + fn test_get_store_file() { + let sut = DefaultObjectStoreRegistry::default(); + let url = ListingTableUrl::parse("file:///bucket/key").unwrap(); + sut.get_store(url.as_ref()).unwrap(); + } + + #[test] + fn test_get_store_local() { + let sut = DefaultObjectStoreRegistry::default(); + let url = ListingTableUrl::parse("../").unwrap(); + sut.get_store(url.as_ref()).unwrap(); + } + + #[test] + fn test_url_contains() { + let url = ListingTableUrl::parse("file:///var/data/mytable/").unwrap(); + + // standard case with default config + assert!(url.contains( + &Path::parse("/var/data/mytable/data.parquet").unwrap(), + true + )); + + // standard case with `ignore_subdirectory` set to false + assert!(url.contains( + &Path::parse("/var/data/mytable/data.parquet").unwrap(), + false + )); + + // as per documentation, when `ignore_subdirectory` is true, we should ignore files that aren't + // a direct child of the `url` + assert!(url + .contains( + &Path::parse("/var/data/mytable/mysubfolder/data.parquet").unwrap(), + true + ) + .not()); + + // when we set `ignore_subdirectory` to false, we should not ignore the file + assert!(url.contains( + &Path::parse("/var/data/mytable/mysubfolder/data.parquet").unwrap(), + false + )); + + // as above, `ignore_subdirectory` is false, so we include the file + assert!(url.contains( + &Path::parse("/var/data/mytable/year=2024/data.parquet").unwrap(), + false + )); + + // in this case, we include the file even when `ignore_subdirectory` is true because the + // path segment is a hive partition which doesn't count as a subdirectory for the purposes + // of `Url::contains` + assert!(url.contains( + &Path::parse("/var/data/mytable/year=2024/data.parquet").unwrap(), + true + )); + + // testing an empty path with default config + assert!(url.contains(&Path::parse("/var/data/mytable/").unwrap(), true)); + + // testing an empty path with `ignore_subdirectory` set to false + assert!(url.contains(&Path::parse("/var/data/mytable/").unwrap(), false)); + } +} diff --git a/datafusion/catalog-listing/src/url.rs b/datafusion/datasource/src/url.rs similarity index 99% rename from datafusion/catalog-listing/src/url.rs rename to datafusion/datasource/src/url.rs index 2e6415ba3b2b..89e73a8a2b26 100644 --- a/datafusion/catalog-listing/src/url.rs +++ b/datafusion/datasource/src/url.rs @@ -193,7 +193,7 @@ impl ListingTableUrl { /// /// Examples: /// ```rust - /// use datafusion_catalog_listing::ListingTableUrl; + /// use datafusion_datasource::ListingTableUrl; /// let url = ListingTableUrl::parse("file:///foo/bar.csv").unwrap(); /// assert_eq!(url.file_extension(), Some("csv")); /// let url = ListingTableUrl::parse("file:///foo/bar").unwrap(); diff --git a/datafusion/core/src/datasource/file_format/write/demux.rs b/datafusion/datasource/src/write/demux.rs similarity index 98% rename from datafusion/core/src/datasource/file_format/write/demux.rs rename to datafusion/datasource/src/write/demux.rs index 454666003254..111d22060c0d 100644 --- a/datafusion/core/src/datasource/file_format/write/demux.rs +++ b/datafusion/datasource/src/write/demux.rs @@ -22,16 +22,16 @@ use std::borrow::Cow; use std::collections::HashMap; use std::sync::Arc; -use crate::datasource::listing::ListingTableUrl; -use crate::datasource::physical_plan::FileSinkConfig; -use crate::error::Result; -use crate::physical_plan::SendableRecordBatchStream; +use crate::url::ListingTableUrl; +use crate::write::FileSinkConfig; +use datafusion_common::error::Result; +use datafusion_physical_plan::SendableRecordBatchStream; use arrow::array::{ builder::UInt64Builder, cast::AsArray, downcast_dictionary_array, RecordBatch, StringArray, StructArray, }; -use arrow_schema::{DataType, Schema}; +use arrow::datatypes::{DataType, Schema}; use datafusion_common::cast::{ as_boolean_array, as_date32_array, as_date64_array, as_int32_array, as_int64_array, as_string_array, as_string_view_array, diff --git a/datafusion/core/src/datasource/file_format/write/mod.rs b/datafusion/datasource/src/write/mod.rs similarity index 89% rename from datafusion/core/src/datasource/file_format/write/mod.rs rename to datafusion/datasource/src/write/mod.rs index 81ecf3f0f88c..f581126095a7 100644 --- a/datafusion/core/src/datasource/file_format/write/mod.rs +++ b/datafusion/datasource/src/write/mod.rs @@ -21,30 +21,30 @@ use std::io::Write; use std::sync::Arc; -use crate::datasource::file_format::file_compression_type::FileCompressionType; -use crate::datasource::physical_plan::FileSinkConfig; -use crate::error::Result; +use crate::file_compression_type::FileCompressionType; +use crate::file_sink_config::FileSinkConfig; +use datafusion_common::error::Result; use arrow::array::RecordBatch; -use arrow_schema::Schema; +use arrow::datatypes::Schema; use bytes::Bytes; use object_store::buffered::BufWriter; use object_store::path::Path; use object_store::ObjectStore; use tokio::io::AsyncWrite; -pub(crate) mod demux; -pub(crate) mod orchestration; +pub mod demux; +pub mod orchestration; /// A buffer with interior mutability shared by the SerializedFileWriter and /// ObjectStore writer #[derive(Clone)] -pub(crate) struct SharedBuffer { +pub struct SharedBuffer { /// The inner buffer for reading and writing /// /// The lock is used to obtain internal mutability, so no worry about the /// lock contention. - pub(crate) buffer: Arc>>, + pub buffer: Arc>>, } impl SharedBuffer { @@ -79,7 +79,7 @@ pub trait BatchSerializer: Sync + Send { /// with the specified compression. /// We drop the `AbortableWrite` struct and the writer will not try to cleanup on failure. /// Users can configure automatic cleanup with their cloud provider. -pub(crate) async fn create_writer( +pub async fn create_writer( file_compression_type: FileCompressionType, location: &Path, object_store: Arc, @@ -91,7 +91,7 @@ pub(crate) async fn create_writer( /// Converts table schema to writer schema, which may differ in the case /// of hive style partitioning where some columns are removed from the /// underlying files. -pub(crate) fn get_writer_schema(config: &FileSinkConfig) -> Arc { +pub fn get_writer_schema(config: &FileSinkConfig) -> Arc { if !config.table_partition_cols.is_empty() && !config.keep_partition_by_columns { let schema = config.output_schema(); let partition_names: Vec<_> = diff --git a/datafusion/core/src/datasource/file_format/write/orchestration.rs b/datafusion/datasource/src/write/orchestration.rs similarity index 98% rename from datafusion/core/src/datasource/file_format/write/orchestration.rs rename to datafusion/datasource/src/write/orchestration.rs index 75836d1b48b0..1364e7d9f236 100644 --- a/datafusion/core/src/datasource/file_format/write/orchestration.rs +++ b/datafusion/datasource/src/write/orchestration.rs @@ -23,8 +23,8 @@ use std::sync::Arc; use super::demux::DemuxedStreamReceiver; use super::{create_writer, BatchSerializer}; -use crate::datasource::file_format::file_compression_type::FileCompressionType; -use crate::error::Result; +use crate::file_compression_type::FileCompressionType; +use datafusion_common::error::Result; use arrow::array::RecordBatch; use datafusion_common::{internal_datafusion_err, internal_err, DataFusionError}; @@ -237,7 +237,7 @@ pub(crate) async fn stateless_serialize_and_write_files( /// Orchestrates multipart put of a dynamic number of output files from a single input stream /// for any statelessly serialized file type. That is, any file type for which each [RecordBatch] /// can be serialized independently of all other [RecordBatch]s. -pub(crate) async fn spawn_writer_tasks_and_join( +pub async fn spawn_writer_tasks_and_join( context: &Arc, serializer: Arc, compression: FileCompressionType, diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs index 7cdb53c90d0e..b11596c4a30f 100644 --- a/datafusion/execution/src/task.rs +++ b/datafusion/execution/src/task.rs @@ -214,6 +214,7 @@ mod tests { extensions_options! { struct TestExtension { value: usize, default = 42 + option_value: Option, default = None } } @@ -229,6 +230,7 @@ mod tests { let mut config = ConfigOptions::new().with_extensions(extensions); config.set("test.value", "24")?; + config.set("test.option_value", "42")?; let session_config = SessionConfig::from(config); let task_context = TaskContext::new( @@ -249,6 +251,39 @@ mod tests { assert!(test.is_some()); assert_eq!(test.unwrap().value, 24); + assert_eq!(test.unwrap().option_value, Some(42)); + + Ok(()) + } + + #[test] + fn task_context_extensions_default() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); + let mut extensions = Extensions::new(); + extensions.insert(TestExtension::default()); + + let config = ConfigOptions::new().with_extensions(extensions); + let session_config = SessionConfig::from(config); + + let task_context = TaskContext::new( + Some("task_id".to_string()), + "session_id".to_string(), + session_config, + HashMap::default(), + HashMap::default(), + HashMap::default(), + runtime, + ); + + let test = task_context + .session_config() + .options() + .extensions + .get::(); + assert!(test.is_some()); + + assert_eq!(test.unwrap().value, 42); + assert_eq!(test.unwrap().option_value, None); Ok(()) } diff --git a/datafusion/expr-common/Cargo.toml b/datafusion/expr-common/Cargo.toml index 109d8e0b89a6..abc78a9f084b 100644 --- a/datafusion/expr-common/Cargo.toml +++ b/datafusion/expr-common/Cargo.toml @@ -39,5 +39,6 @@ path = "src/lib.rs" [dependencies] arrow = { workspace = true } datafusion-common = { workspace = true } +indexmap = { workspace = true } itertools = { workspace = true } paste = "^1.0" diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 1bfae28af840..ba6fadbf7235 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -19,11 +19,14 @@ //! and return types of functions in DataFusion. use std::fmt::Display; -use std::num::NonZeroUsize; +use std::hash::Hash; use crate::type_coercion::aggregates::NUMERICS; use arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; -use datafusion_common::types::{LogicalTypeRef, NativeType}; +use datafusion_common::internal_err; +use datafusion_common::types::{LogicalType, LogicalTypeRef, NativeType}; +use datafusion_common::utils::ListCoercion; +use indexmap::IndexSet; use itertools::Itertools; /// Constant that is used as a placeholder for any valid timezone. @@ -127,12 +130,11 @@ pub enum TypeSignature { Exact(Vec), /// One or more arguments belonging to the [`TypeSignatureClass`], in order. /// - /// For example, `Coercible(vec![logical_float64()])` accepts - /// arguments like `vec![Int32]` or `vec![Float32]` - /// since i32 and f32 can be cast to f64 + /// [`Coercion`] contains not only the desired type but also the allowed casts. + /// For example, if you expect a function has string type, but you also allow it to be casted from binary type. /// /// For functions that take no arguments (e.g. `random()`) see [`TypeSignature::Nullary`]. - Coercible(Vec), + Coercible(Vec), /// One or more arguments coercible to a single, comparable type. /// /// Each argument will be coerced to a single type using the @@ -209,14 +211,13 @@ impl TypeSignature { #[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Hash)] pub enum TypeSignatureClass { Timestamp, - Date, Time, Interval, Duration, Native(LogicalTypeRef), // TODO: // Numeric - // Integer + Integer, } impl Display for TypeSignatureClass { @@ -225,27 +226,98 @@ impl Display for TypeSignatureClass { } } +impl TypeSignatureClass { + /// Get example acceptable types for this `TypeSignatureClass` + /// + /// This is used for `information_schema` and can be used to generate + /// documentation or error messages. + fn get_example_types(&self) -> Vec { + match self { + TypeSignatureClass::Native(l) => get_data_types(l.native()), + TypeSignatureClass::Timestamp => { + vec![ + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Timestamp( + TimeUnit::Nanosecond, + Some(TIMEZONE_WILDCARD.into()), + ), + ] + } + TypeSignatureClass::Time => { + vec![DataType::Time64(TimeUnit::Nanosecond)] + } + TypeSignatureClass::Interval => { + vec![DataType::Interval(IntervalUnit::DayTime)] + } + TypeSignatureClass::Duration => { + vec![DataType::Duration(TimeUnit::Nanosecond)] + } + TypeSignatureClass::Integer => { + vec![DataType::Int64] + } + } + } + + /// Does the specified `NativeType` match this type signature class? + pub fn matches_native_type( + self: &TypeSignatureClass, + logical_type: &NativeType, + ) -> bool { + if logical_type == &NativeType::Null { + return true; + } + + match self { + TypeSignatureClass::Native(t) if t.native() == logical_type => true, + TypeSignatureClass::Timestamp if logical_type.is_timestamp() => true, + TypeSignatureClass::Time if logical_type.is_time() => true, + TypeSignatureClass::Interval if logical_type.is_interval() => true, + TypeSignatureClass::Duration if logical_type.is_duration() => true, + TypeSignatureClass::Integer if logical_type.is_integer() => true, + _ => false, + } + } + + /// What type would `origin_type` be casted to when casting to the specified native type? + pub fn default_casted_type( + &self, + native_type: &NativeType, + origin_type: &DataType, + ) -> datafusion_common::Result { + match self { + TypeSignatureClass::Native(logical_type) => { + logical_type.native().default_cast_for(origin_type) + } + // If the given type is already a timestamp, we don't change the unit and timezone + TypeSignatureClass::Timestamp if native_type.is_timestamp() => { + Ok(origin_type.to_owned()) + } + TypeSignatureClass::Time if native_type.is_time() => { + Ok(origin_type.to_owned()) + } + TypeSignatureClass::Interval if native_type.is_interval() => { + Ok(origin_type.to_owned()) + } + TypeSignatureClass::Duration if native_type.is_duration() => { + Ok(origin_type.to_owned()) + } + TypeSignatureClass::Integer if native_type.is_integer() => { + Ok(origin_type.to_owned()) + } + _ => internal_err!("May miss the matching logic in `matches_native_type`"), + } + } +} + #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum ArrayFunctionSignature { - /// Specialized Signature for ArrayAppend and similar functions - /// The first argument should be List/LargeList/FixedSizedList, and the second argument should be non-list or list. - /// The second argument's list dimension should be one dimension less than the first argument's list dimension. - /// List dimension of the List/LargeList is equivalent to the number of List. - /// List dimension of the non-list is 0. - ArrayAndElement, - /// Specialized Signature for ArrayPrepend and similar functions - /// The first argument should be non-list or list, and the second argument should be List/LargeList. - /// The first argument's list dimension should be one dimension less than the second argument's list dimension. - ElementAndArray, - /// Specialized Signature for Array functions of the form (List/LargeList, Index+) - /// The first argument should be List/LargeList/FixedSizedList, and the next n arguments should be Int64. - ArrayAndIndexes(NonZeroUsize), - /// Specialized Signature for Array functions of the form (List/LargeList, Element, Optional Index) - ArrayAndElementAndOptionalIndex, - /// Specialized Signature for ArrayEmpty and similar functions - /// The function takes a single argument that must be a List/LargeList/FixedSizeList - /// or something that can be coerced to one of those types. - Array, + /// A function takes at least one List/LargeList/FixedSizeList argument. + Array { + /// A full list of the arguments accepted by this function. + arguments: Vec, + /// Additional information about how array arguments should be coerced. + array_coercion: Option, + }, /// A function takes a single argument that must be a List/LargeList/FixedSizeList /// which gets coerced to List, with element type recursively coerced to List too if it is list-like. RecursiveArray, @@ -257,25 +329,15 @@ pub enum ArrayFunctionSignature { impl Display for ArrayFunctionSignature { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - ArrayFunctionSignature::ArrayAndElement => { - write!(f, "array, element") - } - ArrayFunctionSignature::ArrayAndElementAndOptionalIndex => { - write!(f, "array, element, [index]") - } - ArrayFunctionSignature::ElementAndArray => { - write!(f, "element, array") - } - ArrayFunctionSignature::ArrayAndIndexes(count) => { - write!(f, "array")?; - for _ in 0..count.get() { - write!(f, ", index")?; + ArrayFunctionSignature::Array { arguments, .. } => { + for (idx, argument) in arguments.iter().enumerate() { + write!(f, "{argument}")?; + if idx != arguments.len() - 1 { + write!(f, ", ")?; + } } Ok(()) } - ArrayFunctionSignature::Array => { - write!(f, "array") - } ArrayFunctionSignature::RecursiveArray => { write!(f, "recursive_array") } @@ -286,6 +348,34 @@ impl Display for ArrayFunctionSignature { } } +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +pub enum ArrayFunctionArgument { + /// A non-list or list argument. The list dimensions should be one less than the Array's list + /// dimensions. + Element, + /// An Int64 index argument. + Index, + /// An argument of type List/LargeList/FixedSizeList. All Array arguments must be coercible + /// to the same type. + Array, +} + +impl Display for ArrayFunctionArgument { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ArrayFunctionArgument::Element => { + write!(f, "element") + } + ArrayFunctionArgument::Index => { + write!(f, "index") + } + ArrayFunctionArgument::Array => { + write!(f, "array") + } + } + } +} + impl TypeSignature { pub fn to_string_repr(&self) -> Vec { match self { @@ -310,8 +400,8 @@ impl TypeSignature { TypeSignature::Comparable(num) => { vec![format!("Comparable({num})")] } - TypeSignature::Coercible(types) => { - vec![Self::join_types(types, ", ")] + TypeSignature::Coercible(coercions) => { + vec![Self::join_types(coercions, ", ")] } TypeSignature::Exact(types) => { vec![Self::join_types(types, ", ")] @@ -365,44 +455,45 @@ impl TypeSignature { } } - /// get all possible types for the given `TypeSignature` + #[deprecated(since = "46.0.0", note = "See get_example_types instead")] pub fn get_possible_types(&self) -> Vec> { + self.get_example_types() + } + + /// Return example acceptable types for this `TypeSignature`' + /// + /// Returns a `Vec` for each argument to the function + /// + /// This is used for `information_schema` and can be used to generate + /// documentation or error messages. + pub fn get_example_types(&self) -> Vec> { match self { TypeSignature::Exact(types) => vec![types.clone()], TypeSignature::OneOf(types) => types .iter() - .flat_map(|type_sig| type_sig.get_possible_types()) + .flat_map(|type_sig| type_sig.get_example_types()) .collect(), TypeSignature::Uniform(arg_count, types) => types .iter() .cloned() .map(|data_type| vec![data_type; *arg_count]) .collect(), - TypeSignature::Coercible(types) => types + TypeSignature::Coercible(coercions) => coercions .iter() - .map(|logical_type| match logical_type { - TypeSignatureClass::Native(l) => get_data_types(l.native()), - TypeSignatureClass::Timestamp => { - vec![ - DataType::Timestamp(TimeUnit::Nanosecond, None), - DataType::Timestamp( - TimeUnit::Nanosecond, - Some(TIMEZONE_WILDCARD.into()), - ), - ] - } - TypeSignatureClass::Date => { - vec![DataType::Date64] - } - TypeSignatureClass::Time => { - vec![DataType::Time64(TimeUnit::Nanosecond)] - } - TypeSignatureClass::Interval => { - vec![DataType::Interval(IntervalUnit::DayTime)] - } - TypeSignatureClass::Duration => { - vec![DataType::Duration(TimeUnit::Nanosecond)] + .map(|c| { + let mut all_types: IndexSet = + c.desired_type().get_example_types().into_iter().collect(); + + if let Some(implicit_coercion) = c.implicit_coercion() { + let allowed_casts: Vec = implicit_coercion + .allowed_source_types + .iter() + .flat_map(|t| t.get_example_types()) + .collect(); + all_types.extend(allowed_casts); } + + all_types.into_iter().collect::>() }) .multi_cartesian_product() .collect(), @@ -460,6 +551,186 @@ fn get_data_types(native_type: &NativeType) -> Vec { } } +/// Represents type coercion rules for function arguments, specifying both the desired type +/// and optional implicit coercion rules for source types. +/// +/// # Examples +/// +/// ``` +/// use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; +/// use datafusion_common::types::{NativeType, logical_binary, logical_string}; +/// +/// // Exact coercion that only accepts timestamp types +/// let exact = Coercion::new_exact(TypeSignatureClass::Timestamp); +/// +/// // Implicit coercion that accepts string types but can coerce from binary types +/// let implicit = Coercion::new_implicit( +/// TypeSignatureClass::Native(logical_string()), +/// vec![TypeSignatureClass::Native(logical_binary())], +/// NativeType::String +/// ); +/// ``` +/// +/// There are two variants: +/// +/// * `Exact` - Only accepts arguments that exactly match the desired type +/// * `Implicit` - Accepts the desired type and can coerce from specified source types +#[derive(Debug, Clone, Eq, PartialOrd)] +pub enum Coercion { + /// Coercion that only accepts arguments exactly matching the desired type. + Exact { + /// The required type for the argument + desired_type: TypeSignatureClass, + }, + + /// Coercion that accepts the desired type and can implicitly coerce from other types. + Implicit { + /// The primary desired type for the argument + desired_type: TypeSignatureClass, + /// Rules for implicit coercion from other types + implicit_coercion: ImplicitCoercion, + }, +} + +impl Coercion { + pub fn new_exact(desired_type: TypeSignatureClass) -> Self { + Self::Exact { desired_type } + } + + /// Create a new coercion with implicit coercion rules. + /// + /// `allowed_source_types` defines the possible types that can be coerced to `desired_type`. + /// `default_casted_type` is the default type to be used for coercion if we cast from other types via `allowed_source_types`. + pub fn new_implicit( + desired_type: TypeSignatureClass, + allowed_source_types: Vec, + default_casted_type: NativeType, + ) -> Self { + Self::Implicit { + desired_type, + implicit_coercion: ImplicitCoercion { + allowed_source_types, + default_casted_type, + }, + } + } + + pub fn allowed_source_types(&self) -> &[TypeSignatureClass] { + match self { + Coercion::Exact { .. } => &[], + Coercion::Implicit { + implicit_coercion, .. + } => implicit_coercion.allowed_source_types.as_slice(), + } + } + + pub fn default_casted_type(&self) -> Option<&NativeType> { + match self { + Coercion::Exact { .. } => None, + Coercion::Implicit { + implicit_coercion, .. + } => Some(&implicit_coercion.default_casted_type), + } + } + + pub fn desired_type(&self) -> &TypeSignatureClass { + match self { + Coercion::Exact { desired_type } => desired_type, + Coercion::Implicit { desired_type, .. } => desired_type, + } + } + + pub fn implicit_coercion(&self) -> Option<&ImplicitCoercion> { + match self { + Coercion::Exact { .. } => None, + Coercion::Implicit { + implicit_coercion, .. + } => Some(implicit_coercion), + } + } +} + +impl Display for Coercion { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Coercion({}", self.desired_type())?; + if let Some(implicit_coercion) = self.implicit_coercion() { + write!(f, ", implicit_coercion={implicit_coercion}",) + } else { + write!(f, ")") + } + } +} + +impl PartialEq for Coercion { + fn eq(&self, other: &Self) -> bool { + self.desired_type() == other.desired_type() + && self.implicit_coercion() == other.implicit_coercion() + } +} + +impl Hash for Coercion { + fn hash(&self, state: &mut H) { + self.desired_type().hash(state); + self.implicit_coercion().hash(state); + } +} + +/// Defines rules for implicit type coercion, specifying which source types can be +/// coerced and the default type to use when coercing. +/// +/// This is used by functions to specify which types they can accept via implicit +/// coercion in addition to their primary desired type. +/// +/// # Examples +/// +/// ``` +/// use arrow::datatypes::TimeUnit; +/// +/// use datafusion_expr_common::signature::{Coercion, ImplicitCoercion, TypeSignatureClass}; +/// use datafusion_common::types::{NativeType, logical_binary}; +/// +/// // Allow coercing from binary types to timestamp, coerce to specific timestamp unit and timezone +/// let implicit = Coercion::new_implicit( +/// TypeSignatureClass::Timestamp, +/// vec![TypeSignatureClass::Native(logical_binary())], +/// NativeType::Timestamp(TimeUnit::Second, None), +/// ); +/// ``` +#[derive(Debug, Clone, Eq, PartialOrd)] +pub struct ImplicitCoercion { + /// The types that can be coerced from via implicit casting + allowed_source_types: Vec, + + /// The default type to use when coercing from allowed source types. + /// This is particularly important for types like Timestamp that have multiple + /// possible configurations (different time units and timezones). + default_casted_type: NativeType, +} + +impl Display for ImplicitCoercion { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "ImplicitCoercion({:?}, default_type={:?})", + self.allowed_source_types, self.default_casted_type + ) + } +} + +impl PartialEq for ImplicitCoercion { + fn eq(&self, other: &Self) -> bool { + self.allowed_source_types == other.allowed_source_types + && self.default_casted_type == other.default_casted_type + } +} + +impl Hash for ImplicitCoercion { + fn hash(&self, state: &mut H) { + self.allowed_source_types.hash(state); + self.default_casted_type.hash(state); + } +} + /// Defines the supported argument types ([`TypeSignature`]) and [`Volatility`] for a function. /// /// DataFusion will automatically coerce (cast) argument types to one of the supported @@ -536,11 +807,9 @@ impl Signature { volatility, } } + /// Target coerce types in order - pub fn coercible( - target_types: Vec, - volatility: Volatility, - ) -> Self { + pub fn coercible(target_types: Vec, volatility: Volatility) -> Self { Self { type_signature: TypeSignature::Coercible(target_types), volatility, @@ -580,7 +849,13 @@ impl Signature { pub fn array_and_element(volatility: Volatility) -> Self { Signature { type_signature: TypeSignature::ArraySignature( - ArrayFunctionSignature::ArrayAndElement, + ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Element, + ], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }, ), volatility, } @@ -588,30 +863,38 @@ impl Signature { /// Specialized Signature for Array functions with an optional index pub fn array_and_element_and_optional_index(volatility: Volatility) -> Self { Signature { - type_signature: TypeSignature::ArraySignature( - ArrayFunctionSignature::ArrayAndElementAndOptionalIndex, - ), - volatility, - } - } - /// Specialized Signature for ArrayPrepend and similar functions - pub fn element_and_array(volatility: Volatility) -> Self { - Signature { - type_signature: TypeSignature::ArraySignature( - ArrayFunctionSignature::ElementAndArray, - ), + type_signature: TypeSignature::OneOf(vec![ + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Element, + ], + array_coercion: None, + }), + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Element, + ArrayFunctionArgument::Index, + ], + array_coercion: None, + }), + ]), volatility, } } + /// Specialized Signature for ArrayElement and similar functions pub fn array_and_index(volatility: Volatility) -> Self { - Self::array_and_indexes(volatility, NonZeroUsize::new(1).expect("1 is non-zero")) - } - /// Specialized Signature for ArraySlice and similar functions - pub fn array_and_indexes(volatility: Volatility, count: NonZeroUsize) -> Self { Signature { type_signature: TypeSignature::ArraySignature( - ArrayFunctionSignature::ArrayAndIndexes(count), + ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Index, + ], + array_coercion: None, + }, ), volatility, } @@ -619,7 +902,12 @@ impl Signature { /// Specialized Signature for ArrayEmpty and similar functions pub fn array(volatility: Volatility) -> Self { Signature { - type_signature: TypeSignature::ArraySignature(ArrayFunctionSignature::Array), + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array], + array_coercion: None, + }, + ), volatility, } } @@ -696,14 +984,14 @@ mod tests { #[test] fn test_get_possible_types() { let type_signature = TypeSignature::Exact(vec![DataType::Int32, DataType::Int64]); - let possible_types = type_signature.get_possible_types(); + let possible_types = type_signature.get_example_types(); assert_eq!(possible_types, vec![vec![DataType::Int32, DataType::Int64]]); let type_signature = TypeSignature::OneOf(vec![ TypeSignature::Exact(vec![DataType::Int32, DataType::Int64]), TypeSignature::Exact(vec![DataType::Float32, DataType::Float64]), ]); - let possible_types = type_signature.get_possible_types(); + let possible_types = type_signature.get_example_types(); assert_eq!( possible_types, vec![ @@ -717,7 +1005,7 @@ mod tests { TypeSignature::Exact(vec![DataType::Float32, DataType::Float64]), TypeSignature::Exact(vec![DataType::Utf8]), ]); - let possible_types = type_signature.get_possible_types(); + let possible_types = type_signature.get_example_types(); assert_eq!( possible_types, vec![ @@ -729,7 +1017,7 @@ mod tests { let type_signature = TypeSignature::Uniform(2, vec![DataType::Float32, DataType::Int64]); - let possible_types = type_signature.get_possible_types(); + let possible_types = type_signature.get_example_types(); assert_eq!( possible_types, vec![ @@ -739,10 +1027,10 @@ mod tests { ); let type_signature = TypeSignature::Coercible(vec![ - TypeSignatureClass::Native(logical_string()), - TypeSignatureClass::Native(logical_int64()), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Native(logical_int64())), ]); - let possible_types = type_signature.get_possible_types(); + let possible_types = type_signature.get_example_types(); assert_eq!( possible_types, vec![ @@ -754,14 +1042,14 @@ mod tests { let type_signature = TypeSignature::Variadic(vec![DataType::Int32, DataType::Int64]); - let possible_types = type_signature.get_possible_types(); + let possible_types = type_signature.get_example_types(); assert_eq!( possible_types, vec![vec![DataType::Int32], vec![DataType::Int64]] ); let type_signature = TypeSignature::Numeric(2); - let possible_types = type_signature.get_possible_types(); + let possible_types = type_signature.get_example_types(); assert_eq!( possible_types, vec![ @@ -779,7 +1067,7 @@ mod tests { ); let type_signature = TypeSignature::String(2); - let possible_types = type_signature.get_possible_types(); + let possible_types = type_signature.get_example_types(); assert_eq!( possible_types, vec![ diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index 3be35490a4d0..64c26192ae0f 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -537,8 +537,16 @@ fn type_union_resolution_coercion( } (DataType::Dictionary(index_type, value_type), other_type) | (other_type, DataType::Dictionary(index_type, value_type)) => { - let new_value_type = type_union_resolution_coercion(value_type, other_type); - new_value_type.map(|t| DataType::Dictionary(index_type.clone(), Box::new(t))) + match type_union_resolution_coercion(value_type, other_type) { + // Dict with View type is redundant, use value type instead + // TODO: Add binary view, list view with tests + Some(DataType::Utf8View) => Some(DataType::Utf8View), + Some(new_value_type) => Some(DataType::Dictionary( + index_type.clone(), + Box::new(new_value_type), + )), + None => None, + } } (DataType::Struct(lhs), DataType::Struct(rhs)) => { if lhs.len() != rhs.len() { @@ -589,6 +597,7 @@ fn type_union_resolution_coercion( .or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type)) .or_else(|| string_coercion(lhs_type, rhs_type)) .or_else(|| numeric_string_coercion(lhs_type, rhs_type)) + .or_else(|| binary_coercion(lhs_type, rhs_type)) } } } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 305519a1f4b4..df79b3568ce6 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -696,7 +696,11 @@ impl<'a> TreeNodeContainer<'a, Expr> for Sort { pub struct AggregateFunction { /// Name of the function pub func: Arc, - /// List of expressions to feed to the functions as arguments + pub params: AggregateFunctionParams, +} + +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub struct AggregateFunctionParams { pub args: Vec, /// Whether this is a DISTINCT aggregation or not pub distinct: bool, @@ -719,11 +723,13 @@ impl AggregateFunction { ) -> Self { Self { func, - args, - distinct, - filter, - order_by, - null_treatment, + params: AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment, + }, } } } @@ -813,6 +819,11 @@ impl From> for WindowFunctionDefinition { pub struct WindowFunction { /// Name of the function pub fun: WindowFunctionDefinition, + pub params: WindowFunctionParams, +} + +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub struct WindowFunctionParams { /// List of expressions to feed to the functions as arguments pub args: Vec, /// List of partition by expressions @@ -831,11 +842,13 @@ impl WindowFunction { pub fn new(fun: impl Into, args: Vec) -> Self { Self { fun: fun.into(), - args, - partition_by: Vec::default(), - order_by: Vec::default(), - window_frame: WindowFrame::new(None), - null_treatment: None, + params: WindowFunctionParams { + args, + partition_by: Vec::default(), + order_by: Vec::default(), + window_frame: WindowFrame::new(None), + null_treatment: None, + }, } } } @@ -1864,19 +1877,25 @@ impl NormalizeEq for Expr { ( Expr::AggregateFunction(AggregateFunction { func: self_func, - args: self_args, - distinct: self_distinct, - filter: self_filter, - order_by: self_order_by, - null_treatment: self_null_treatment, + params: + AggregateFunctionParams { + args: self_args, + distinct: self_distinct, + filter: self_filter, + order_by: self_order_by, + null_treatment: self_null_treatment, + }, }), Expr::AggregateFunction(AggregateFunction { func: other_func, - args: other_args, - distinct: other_distinct, - filter: other_filter, - order_by: other_order_by, - null_treatment: other_null_treatment, + params: + AggregateFunctionParams { + args: other_args, + distinct: other_distinct, + filter: other_filter, + order_by: other_order_by, + null_treatment: other_null_treatment, + }, }), ) => { self_func.name() == other_func.name() @@ -1910,21 +1929,30 @@ impl NormalizeEq for Expr { ( Expr::WindowFunction(WindowFunction { fun: self_fun, - args: self_args, - partition_by: self_partition_by, - order_by: self_order_by, - window_frame: self_window_frame, - null_treatment: self_null_treatment, + params: self_params, }), Expr::WindowFunction(WindowFunction { fun: other_fun, - args: other_args, - partition_by: other_partition_by, - order_by: other_order_by, - window_frame: other_window_frame, - null_treatment: other_null_treatment, + params: other_params, }), ) => { + let ( + WindowFunctionParams { + args: self_args, + window_frame: self_window_frame, + partition_by: self_partition_by, + order_by: self_order_by, + null_treatment: self_null_treatment, + }, + WindowFunctionParams { + args: other_args, + window_frame: other_window_frame, + partition_by: other_partition_by, + order_by: other_order_by, + null_treatment: other_null_treatment, + }, + ) = (self_params, other_params); + self_fun.name() == other_fun.name() && self_window_frame == other_window_frame && self_null_treatment == other_null_treatment @@ -2154,24 +2182,27 @@ impl HashNode for Expr { } Expr::AggregateFunction(AggregateFunction { func, - args: _args, - distinct, - filter: _filter, - order_by: _order_by, - null_treatment, + params: + AggregateFunctionParams { + args: _args, + distinct, + filter: _, + order_by: _, + null_treatment, + }, }) => { func.hash(state); distinct.hash(state); null_treatment.hash(state); } - Expr::WindowFunction(WindowFunction { - fun, - args: _args, - partition_by: _partition_by, - order_by: _order_by, - window_frame, - null_treatment, - }) => { + Expr::WindowFunction(WindowFunction { fun, params }) => { + let WindowFunctionParams { + args: _args, + partition_by: _, + order_by: _, + window_frame, + null_treatment, + } = params; fun.hash(state); window_frame.hash(state); null_treatment.hash(state); @@ -2264,35 +2295,15 @@ impl Display for SchemaDisplay<'_> { | Expr::Placeholder(_) | Expr::Wildcard { .. } => write!(f, "{}", self.0), - Expr::AggregateFunction(AggregateFunction { - func, - args, - distinct, - filter, - order_by, - null_treatment, - }) => { - write!( - f, - "{}({}{})", - func.name(), - if *distinct { "DISTINCT " } else { "" }, - schema_name_from_exprs_comma_separated_without_space(args)? - )?; - - if let Some(null_treatment) = null_treatment { - write!(f, " {}", null_treatment)?; + Expr::AggregateFunction(AggregateFunction { func, params }) => { + match func.schema_name(params) { + Ok(name) => { + write!(f, "{name}") + } + Err(e) => { + write!(f, "got error from schema_name {}", e) + } } - - if let Some(filter) = filter { - write!(f, " FILTER (WHERE {filter})")?; - }; - - if let Some(order_by) = order_by { - write!(f, " ORDER BY [{}]", schema_name_from_sorts(order_by)?)?; - }; - - Ok(()) } // Expr is not shown since it is aliased Expr::Alias(Alias { @@ -2472,39 +2483,52 @@ impl Display for SchemaDisplay<'_> { Ok(()) } - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment, - }) => { - write!( - f, - "{}({})", - fun, - schema_name_from_exprs_comma_separated_without_space(args)? - )?; - - if let Some(null_treatment) = null_treatment { - write!(f, " {}", null_treatment)?; + Expr::WindowFunction(WindowFunction { fun, params }) => match fun { + WindowFunctionDefinition::AggregateUDF(fun) => { + match fun.window_function_schema_name(params) { + Ok(name) => { + write!(f, "{name}") + } + Err(e) => { + write!(f, "got error from window_function_schema_name {}", e) + } + } } + _ => { + let WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment, + } = params; - if !partition_by.is_empty() { write!( f, - " PARTITION BY [{}]", - schema_name_from_exprs(partition_by)? + "{}({})", + fun, + schema_name_from_exprs_comma_separated_without_space(args)? )?; - } - if !order_by.is_empty() { - write!(f, " ORDER BY [{}]", schema_name_from_sorts(order_by)?)?; - }; + if let Some(null_treatment) = null_treatment { + write!(f, " {}", null_treatment)?; + } - write!(f, " {window_frame}") - } + if !partition_by.is_empty() { + write!( + f, + " PARTITION BY [{}]", + schema_name_from_exprs(partition_by)? + )?; + } + + if !order_by.is_empty() { + write!(f, " ORDER BY [{}]", schema_name_from_sorts(order_by)?)?; + }; + + write!(f, " {window_frame}") + } + }, } } } @@ -2626,53 +2650,56 @@ impl Display for Expr { // Expr::ScalarFunction(ScalarFunction { func, args }) => { // write!(f, "{}", func.display_name(args).unwrap()) // } - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment, - }) => { - fmt_function(f, &fun.to_string(), false, args, true)?; - - if let Some(nt) = null_treatment { - write!(f, "{}", nt)?; + Expr::WindowFunction(WindowFunction { fun, params }) => match fun { + WindowFunctionDefinition::AggregateUDF(fun) => { + match fun.window_function_display_name(params) { + Ok(name) => { + write!(f, "{}", name) + } + Err(e) => { + write!(f, "got error from window_function_display_name {}", e) + } + } } + WindowFunctionDefinition::WindowUDF(fun) => { + let WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment, + } = params; + + fmt_function(f, &fun.to_string(), false, args, true)?; + + if let Some(nt) = null_treatment { + write!(f, "{}", nt)?; + } - if !partition_by.is_empty() { - write!(f, " PARTITION BY [{}]", expr_vec_fmt!(partition_by))?; - } - if !order_by.is_empty() { - write!(f, " ORDER BY [{}]", expr_vec_fmt!(order_by))?; - } - write!( - f, - " {} BETWEEN {} AND {}", - window_frame.units, window_frame.start_bound, window_frame.end_bound - )?; - Ok(()) - } - Expr::AggregateFunction(AggregateFunction { - func, - distinct, - ref args, - filter, - order_by, - null_treatment, - .. - }) => { - fmt_function(f, func.name(), *distinct, args, true)?; - if let Some(nt) = null_treatment { - write!(f, " {}", nt)?; - } - if let Some(fe) = filter { - write!(f, " FILTER (WHERE {fe})")?; + if !partition_by.is_empty() { + write!(f, " PARTITION BY [{}]", expr_vec_fmt!(partition_by))?; + } + if !order_by.is_empty() { + write!(f, " ORDER BY [{}]", expr_vec_fmt!(order_by))?; + } + write!( + f, + " {} BETWEEN {} AND {}", + window_frame.units, + window_frame.start_bound, + window_frame.end_bound + ) } - if let Some(ob) = order_by { - write!(f, " ORDER BY [{}]", expr_vec_fmt!(ob))?; + }, + Expr::AggregateFunction(AggregateFunction { func, params }) => { + match func.display_name(params) { + Ok(name) => { + write!(f, "{}", name) + } + Err(e) => { + write!(f, "got error from display_name {}", e) + } } - Ok(()) } Expr::Between(Between { expr, diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index a2de5e7b259f..f47de4a8178f 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -19,7 +19,7 @@ use crate::expr::{ AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, - Placeholder, TryCast, Unnest, WildcardOptions, WindowFunction, + Placeholder, TryCast, Unnest, WildcardOptions, WindowFunction, WindowFunctionParams, }; use crate::function::{ AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory, @@ -27,7 +27,7 @@ use crate::function::{ }; use crate::{ conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery, - AggregateUDF, Expr, LogicalPlan, Operator, PartitionEvaluator, + AggregateUDF, Expr, LogicalPlan, Operator, PartitionEvaluator, ScalarFunctionArgs, ScalarFunctionImplementation, ScalarUDF, Signature, Volatility, }; use crate::{ @@ -477,12 +477,8 @@ impl ScalarUDFImpl for SimpleScalarUDF { Ok(self.return_type.clone()) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - (self.fun)(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + (self.fun)(&args.args) } } @@ -830,20 +826,28 @@ impl ExprFuncBuilder { let fun_expr = match fun { ExprFuncKind::Aggregate(mut udaf) => { - udaf.order_by = order_by; - udaf.filter = filter.map(Box::new); - udaf.distinct = distinct; - udaf.null_treatment = null_treatment; + udaf.params.order_by = order_by; + udaf.params.filter = filter.map(Box::new); + udaf.params.distinct = distinct; + udaf.params.null_treatment = null_treatment; Expr::AggregateFunction(udaf) } - ExprFuncKind::Window(mut udwf) => { + ExprFuncKind::Window(WindowFunction { + fun, + params: WindowFunctionParams { args, .. }, + }) => { let has_order_by = order_by.as_ref().map(|o| !o.is_empty()); - udwf.order_by = order_by.unwrap_or_default(); - udwf.partition_by = partition_by.unwrap_or_default(); - udwf.window_frame = - window_frame.unwrap_or(WindowFrame::new(has_order_by)); - udwf.null_treatment = null_treatment; - Expr::WindowFunction(udwf) + Expr::WindowFunction(WindowFunction { + fun, + params: WindowFunctionParams { + args, + partition_by: partition_by.unwrap_or_default(), + order_by: order_by.unwrap_or_default(), + window_frame: window_frame + .unwrap_or(WindowFrame::new(has_order_by)), + null_treatment, + }, + }) } }; diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 49791427131f..ce1dd2f34c05 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -17,8 +17,9 @@ use super::{Between, Expr, Like}; use crate::expr::{ - AggregateFunction, Alias, BinaryExpr, Cast, InList, InSubquery, Placeholder, - ScalarFunction, TryCast, Unnest, WindowFunction, + AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, InList, + InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, + WindowFunctionParams, }; use crate::type_coercion::functions::{ data_types_with_aggregate_udf, data_types_with_scalar_udf, data_types_with_window_udf, @@ -153,7 +154,10 @@ impl ExprSchemable for Expr { Expr::WindowFunction(window_function) => self .data_type_and_nullable_with_window_function(schema, window_function) .map(|(return_type, _)| return_type), - Expr::AggregateFunction(AggregateFunction { func, args, .. }) => { + Expr::AggregateFunction(AggregateFunction { + func, + params: AggregateFunctionParams { args, .. }, + }) => { let data_types = args .iter() .map(|e| e.get_type(schema)) @@ -507,7 +511,11 @@ impl Expr { schema: &dyn ExprSchema, window_function: &WindowFunction, ) -> Result<(DataType, bool)> { - let WindowFunction { fun, args, .. } = window_function; + let WindowFunction { + fun, + params: WindowFunctionParams { args, .. }, + .. + } = window_function; let data_types = args .iter() diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index aaa65c676a42..2f04f234eb1d 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -71,8 +71,8 @@ pub use datafusion_expr_common::columnar_value::ColumnarValue; pub use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; pub use datafusion_expr_common::operator::Operator; pub use datafusion_expr_common::signature::{ - ArrayFunctionSignature, Signature, TypeSignature, TypeSignatureClass, Volatility, - TIMEZONE_WILDCARD, + ArrayFunctionArgument, ArrayFunctionSignature, Signature, TypeSignature, + TypeSignatureClass, Volatility, TIMEZONE_WILDCARD, }; pub use datafusion_expr_common::type_coercion::binary; pub use expr::{ diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 4fdfb84aea42..da30f2d7a712 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -380,18 +380,49 @@ impl LogicalPlanBuilder { }))) } - /// Create a [DmlStatement] for inserting the contents of this builder into the named table + /// Create a [`DmlStatement`] for inserting the contents of this builder into the named table. + /// + /// Note, use a [`DefaultTableSource`] to insert into a [`TableProvider`] + /// + /// [`DefaultTableSource`]: https://docs.rs/datafusion/latest/datafusion/datasource/default_table_source/struct.DefaultTableSource.html + /// [`TableProvider`]: https://docs.rs/datafusion/latest/datafusion/catalog/trait.TableProvider.html + /// + /// # Example: + /// ``` + /// # use datafusion_expr::{lit, LogicalPlanBuilder, + /// # logical_plan::builder::LogicalTableSource, + /// # }; + /// # use std::sync::Arc; + /// # use arrow::datatypes::{Schema, DataType, Field}; + /// # use datafusion_expr::dml::InsertOp; + /// # + /// # fn test() -> datafusion_common::Result<()> { + /// # let employee_schema = Arc::new(Schema::new(vec![ + /// # Field::new("id", DataType::Int32, false), + /// # ])) as _; + /// # let table_source = Arc::new(LogicalTableSource::new(employee_schema)); + /// // VALUES (1), (2) + /// let input = LogicalPlanBuilder::values(vec![vec![lit(1)], vec![lit(2)]])? + /// .build()?; + /// // INSERT INTO MyTable VALUES (1), (2) + /// let insert_plan = LogicalPlanBuilder::insert_into( + /// input, + /// "MyTable", + /// table_source, + /// InsertOp::Append, + /// )?; + /// # Ok(()) + /// # } + /// ``` pub fn insert_into( input: LogicalPlan, table_name: impl Into, - table_schema: &Schema, + target: Arc, insert_op: InsertOp, ) -> Result { - let table_schema = table_schema.clone().to_dfschema_ref()?; - Ok(Self::new(LogicalPlan::Dml(DmlStatement::new( table_name.into(), - table_schema, + target, WriteOp::Insert(insert_op), Arc::new(input), )))) @@ -722,6 +753,21 @@ impl LogicalPlanBuilder { union(Arc::unwrap_or_clone(self.plan), plan).map(Self::new) } + /// Apply a union by name, preserving duplicate rows + pub fn union_by_name(self, plan: LogicalPlan) -> Result { + union_by_name(Arc::unwrap_or_clone(self.plan), plan).map(Self::new) + } + + /// Apply a union by name, removing duplicate rows + pub fn union_by_name_distinct(self, plan: LogicalPlan) -> Result { + let left_plan: LogicalPlan = Arc::unwrap_or_clone(self.plan); + let right_plan: LogicalPlan = plan; + + Ok(Self::new(LogicalPlan::Distinct(Distinct::All(Arc::new( + union_by_name(left_plan, right_plan)?, + ))))) + } + /// Apply a union, removing duplicate rows pub fn union_distinct(self, plan: LogicalPlan) -> Result { let left_plan: LogicalPlan = Arc::unwrap_or_clone(self.plan); @@ -834,10 +880,16 @@ impl LogicalPlanBuilder { plan: &LogicalPlan, column: impl Into, ) -> Result { + let column = column.into(); + if column.relation.is_some() { + // column is already normalized + return Ok(column); + } + let schema = plan.schema(); let fallback_schemas = plan.fallback_normalize_schemas(); let using_columns = plan.using_columns()?; - column.into().normalize_with_schemas_and_ambiguity_check( + column.normalize_with_schemas_and_ambiguity_check( &[&[schema], &fallback_schemas], &using_columns, ) @@ -1540,6 +1592,18 @@ pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result Result { + Ok(LogicalPlan::Union(Union::try_new_by_name(vec![ + Arc::new(left_plan), + Arc::new(right_plan), + ])?)) +} + /// Create Projection /// # Errors /// This function errors under any of the following conditions: diff --git a/datafusion/expr/src/logical_plan/dml.rs b/datafusion/expr/src/logical_plan/dml.rs index 669bc8e8a7d3..d4d50ac4eae4 100644 --- a/datafusion/expr/src/logical_plan/dml.rs +++ b/datafusion/expr/src/logical_plan/dml.rs @@ -25,7 +25,7 @@ use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{DFSchemaRef, TableReference}; -use crate::LogicalPlan; +use crate::{LogicalPlan, TableSource}; /// Operator that copies the contents of a database to file(s) #[derive(Clone)] @@ -91,12 +91,12 @@ impl Hash for CopyTo { /// The operator that modifies the content of a database (adapted from /// substrait WriteRel) -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Clone)] pub struct DmlStatement { /// The table name pub table_name: TableReference, - /// The schema of the table (must align with Rel input) - pub table_schema: DFSchemaRef, + /// this is target table to insert into + pub target: Arc, /// The type of operation to perform pub op: WriteOp, /// The relation that determines the tuples to add/remove/modify the schema must match with table_schema @@ -104,18 +104,51 @@ pub struct DmlStatement { /// The schema of the output relation pub output_schema: DFSchemaRef, } +impl Eq for DmlStatement {} +impl Hash for DmlStatement { + fn hash(&self, state: &mut H) { + self.table_name.hash(state); + self.target.schema().hash(state); + self.op.hash(state); + self.input.hash(state); + self.output_schema.hash(state); + } +} + +impl PartialEq for DmlStatement { + fn eq(&self, other: &Self) -> bool { + self.table_name == other.table_name + && self.target.schema() == other.target.schema() + && self.op == other.op + && self.input == other.input + && self.output_schema == other.output_schema + } +} + +impl Debug for DmlStatement { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("DmlStatement") + .field("table_name", &self.table_name) + .field("target", &"...") + .field("target_schema", &self.target.schema()) + .field("op", &self.op) + .field("input", &self.input) + .field("output_schema", &self.output_schema) + .finish() + } +} impl DmlStatement { /// Creates a new DML statement with the output schema set to a single `count` column. pub fn new( table_name: TableReference, - table_schema: DFSchemaRef, + target: Arc, op: WriteOp, input: Arc, ) -> Self { Self { table_name, - table_schema, + target, op, input, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index daf1a1375eac..870b0751c923 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -18,7 +18,7 @@ //! Logical plan types use std::cmp::Ordering; -use std::collections::{HashMap, HashSet}; +use std::collections::{BTreeMap, HashMap, HashSet}; use std::fmt::{self, Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; use std::sync::{Arc, LazyLock}; @@ -30,7 +30,7 @@ use super::invariants::{ }; use super::DdlStatement; use crate::builder::{change_redundant_column, unnest_with_options}; -use crate::expr::{Placeholder, Sort as SortExpr, WindowFunction}; +use crate::expr::{Placeholder, Sort as SortExpr, WindowFunction, WindowFunctionParams}; use crate::expr_rewriter::{ create_col_from_scalar_expr, normalize_cols, normalize_sorts, NamePreserver, }; @@ -705,6 +705,13 @@ impl LogicalPlan { // If inputs are not pruned do not change schema Ok(LogicalPlan::Union(Union { inputs, schema })) } else { + // A note on `Union`s constructed via `try_new_by_name`: + // + // At this point, the schema for each input should have + // the same width. Thus, we do not need to save whether a + // `Union` was created `BY NAME`, and can safely rely on the + // `try_new` initializer to derive the new schema based on + // column positions. Ok(LogicalPlan::Union(Union::try_new(inputs)?)) } } @@ -784,7 +791,7 @@ impl LogicalPlan { } LogicalPlan::Dml(DmlStatement { table_name, - table_schema, + target, op, .. }) => { @@ -792,7 +799,7 @@ impl LogicalPlan { let input = self.only_input(inputs)?; Ok(LogicalPlan::Dml(DmlStatement::new( table_name.clone(), - Arc::clone(table_schema), + Arc::clone(target), op.clone(), Arc::new(input), ))) @@ -2422,8 +2429,7 @@ impl Window { .filter_map(|(idx, expr)| { if let Expr::WindowFunction(WindowFunction { fun: WindowFunctionDefinition::WindowUDF(udwf), - partition_by, - .. + params: WindowFunctionParams { partition_by, .. }, }) = expr { // When there is no PARTITION BY, row number will be unique @@ -2648,7 +2654,7 @@ pub struct Union { impl Union { /// Constructs new Union instance deriving schema from inputs. fn try_new(inputs: Vec>) -> Result { - let schema = Self::derive_schema_from_inputs(&inputs, false)?; + let schema = Self::derive_schema_from_inputs(&inputs, false, false)?; Ok(Union { inputs, schema }) } @@ -2657,21 +2663,143 @@ impl Union { /// take type from the first input. // TODO (https://github.com/apache/datafusion/issues/14380): Avoid creating uncoerced union at all. pub fn try_new_with_loose_types(inputs: Vec>) -> Result { - let schema = Self::derive_schema_from_inputs(&inputs, true)?; + let schema = Self::derive_schema_from_inputs(&inputs, true, false)?; Ok(Union { inputs, schema }) } + /// Constructs a new Union instance that combines rows from different tables by name, + /// instead of by position. This means that the specified inputs need not have schemas + /// that are all the same width. + pub fn try_new_by_name(inputs: Vec>) -> Result { + let schema = Self::derive_schema_from_inputs(&inputs, true, true)?; + let inputs = Self::rewrite_inputs_from_schema(&schema, inputs)?; + + Ok(Union { inputs, schema }) + } + + /// When constructing a `UNION BY NAME`, we may need to wrap inputs + /// in an additional `Projection` to account for absence of columns + /// in input schemas. + fn rewrite_inputs_from_schema( + schema: &DFSchema, + inputs: Vec>, + ) -> Result>> { + let schema_width = schema.iter().count(); + let mut wrapped_inputs = Vec::with_capacity(inputs.len()); + for input in inputs { + // If the input plan's schema contains the same number of fields + // as the derived schema, then it does not to be wrapped in an + // additional `Projection`. + if input.schema().iter().count() == schema_width { + wrapped_inputs.push(input); + continue; + } + + // Any columns that exist within the derived schema but do not exist + // within an input's schema should be replaced with `NULL` aliased + // to the appropriate column in the derived schema. + let mut expr = Vec::with_capacity(schema_width); + for column in schema.columns() { + if input + .schema() + .has_column_with_unqualified_name(column.name()) + { + expr.push(Expr::Column(column)); + } else { + expr.push(Expr::Literal(ScalarValue::Null).alias(column.name())); + } + } + wrapped_inputs.push(Arc::new(LogicalPlan::Projection(Projection::try_new( + expr, input, + )?))); + } + + Ok(wrapped_inputs) + } + /// Constructs new Union instance deriving schema from inputs. /// - /// `loose_types` if true, inputs do not have to have matching types and produced schema will - /// take type from the first input. TODO () this is not necessarily reasonable behavior. + /// If `loose_types` is true, inputs do not need to have matching types and + /// the produced schema will use the type from the first input. + /// TODO (): This is not necessarily reasonable behavior. + /// + /// If `by_name` is `true`, input schemas need not be the same width. That is, + /// the constructed schema follows `UNION BY NAME` semantics. fn derive_schema_from_inputs( inputs: &[Arc], loose_types: bool, + by_name: bool, ) -> Result { if inputs.len() < 2 { return plan_err!("UNION requires at least two inputs"); } + + if by_name { + Self::derive_schema_from_inputs_by_name(inputs, loose_types) + } else { + Self::derive_schema_from_inputs_by_position(inputs, loose_types) + } + } + + fn derive_schema_from_inputs_by_name( + inputs: &[Arc], + loose_types: bool, + ) -> Result { + type FieldData<'a> = (&'a DataType, bool, Vec<&'a HashMap>); + // Prefer `BTreeMap` as it produces items in order by key when iterated over + let mut cols: BTreeMap<&str, FieldData> = BTreeMap::new(); + for input in inputs.iter() { + for field in input.schema().fields() { + match cols.entry(field.name()) { + std::collections::btree_map::Entry::Occupied(mut occupied) => { + let (data_type, is_nullable, metadata) = occupied.get_mut(); + if !loose_types && *data_type != field.data_type() { + return plan_err!( + "Found different types for field {}", + field.name() + ); + } + + metadata.push(field.metadata()); + // If the field is nullable in any one of the inputs, + // then the field in the final schema is also nullable. + *is_nullable |= field.is_nullable(); + } + std::collections::btree_map::Entry::Vacant(vacant) => { + vacant.insert(( + field.data_type(), + field.is_nullable(), + vec![field.metadata()], + )); + } + } + } + } + + let union_fields = cols + .into_iter() + .map(|(name, (data_type, is_nullable, unmerged_metadata))| { + let mut field = Field::new(name, data_type.clone(), is_nullable); + field.set_metadata(intersect_maps(unmerged_metadata)); + + (None, Arc::new(field)) + }) + .collect::, _)>>(); + + let union_schema_metadata = + intersect_maps(inputs.iter().map(|input| input.schema().metadata())); + + // Functional Dependencies are not preserved after UNION operation + let schema = DFSchema::new_with_metadata(union_fields, union_schema_metadata)?; + let schema = Arc::new(schema); + + Ok(schema) + } + + fn derive_schema_from_inputs_by_position( + inputs: &[Arc], + loose_types: bool, + ) -> Result { let first_schema = inputs[0].schema(); let fields_count = first_schema.fields().len(); for input in inputs.iter().skip(1) { @@ -2727,7 +2855,7 @@ impl Union { let union_schema_metadata = intersect_maps(inputs.iter().map(|input| input.schema().metadata())); - // Functional Dependencies doesn't preserve after UNION operation + // Functional Dependencies are not preserved after UNION operation let schema = DFSchema::new_with_metadata(union_fields, union_schema_metadata)?; let schema = Arc::new(schema); diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 9a6103afd4b4..dfc18c74c70a 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -228,14 +228,14 @@ impl TreeNode for LogicalPlan { }), LogicalPlan::Dml(DmlStatement { table_name, - table_schema, + target, op, input, output_schema, }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Dml(DmlStatement { table_name, - table_schema, + target, op, input, output_schema, diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 42047e8e6caa..04cc26c910cb 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -29,13 +29,18 @@ use sqlparser::ast; use crate::{AggregateUDF, Expr, GetFieldAccess, ScalarUDF, TableSource, WindowUDF}; -/// Provides the `SQL` query planner meta-data about tables and -/// functions referenced in SQL statements, without a direct dependency on other -/// DataFusion structures +/// Provides the `SQL` query planner meta-data about tables and +/// functions referenced in SQL statements, without a direct dependency on the +/// `datafusion` Catalog structures such as [`TableProvider`] +/// +/// [`TableProvider`]: https://docs.rs/datafusion/latest/datafusion/catalog/trait.TableProvider.html pub trait ContextProvider { - /// Getter for a datasource + /// Returns a table by reference, if it exists fn get_table_source(&self, name: TableReference) -> Result>; + /// Return the type of a file based on its extension (e.g. `.parquet`) + /// + /// This is used to plan `COPY` statements fn get_file_type(&self, _ext: &str) -> Result> { not_impl_err!("Registered file types are not supported") } @@ -49,11 +54,20 @@ pub trait ContextProvider { not_impl_err!("Table Functions are not supported") } - /// This provides a worktable (an intermediate table that is used to store the results of a CTE during execution) - /// We don't directly implement this in the logical plan's ['SqlToRel`] - /// because the sql code needs access to a table that contains execution-related types that can't be a direct dependency - /// of the sql crate (namely, the `CteWorktable`). + /// Provides an intermediate table that is used to store the results of a CTE during execution + /// + /// CTE stands for "Common Table Expression" + /// + /// # Notes + /// We don't directly implement this in [`SqlToRel`] as implementing this function + /// often requires access to a table that contains + /// execution-related types that can't be a direct dependency + /// of the sql crate (for example [`CteWorkTable`]). + /// /// The [`ContextProvider`] provides a way to "hide" this dependency. + /// + /// [`SqlToRel`]: https://docs.rs/datafusion/latest/datafusion/sql/planner/struct.SqlToRel.html + /// [`CteWorkTable`]: https://docs.rs/datafusion/latest/datafusion/datasource/cte_worktable/struct.CteWorkTable.html fn create_cte_work_table( &self, _name: &str, @@ -62,39 +76,44 @@ pub trait ContextProvider { not_impl_err!("Recursive CTE is not implemented") } - /// Getter for expr planners + /// Return [`ExprPlanner`] extensions for planning expressions fn get_expr_planners(&self) -> &[Arc] { &[] } - /// Getter for the data type planner + /// Return [`TypePlanner`] extensions for planning data types fn get_type_planner(&self) -> Option> { None } - /// Getter for a UDF description + /// Return the scalar function with a given name, if any fn get_function_meta(&self, name: &str) -> Option>; - /// Getter for a UDAF description + + /// Return the aggregate function with a given name, if any fn get_aggregate_meta(&self, name: &str) -> Option>; - /// Getter for a UDWF + + /// Return the window function with a given name, if any fn get_window_meta(&self, name: &str) -> Option>; - /// Getter for system/user-defined variable type + + /// Return the system/user-defined variable type, if any + /// + /// A user defined variable is typically accessed via `@var_name` fn get_variable_type(&self, variable_names: &[String]) -> Option; - /// Get configuration options + /// Return overall configuration options fn options(&self) -> &ConfigOptions; - /// Get all user defined scalar function names + /// Return all scalar function names fn udf_names(&self) -> Vec; - /// Get all user defined aggregate function names + /// Return all aggregate function names fn udaf_names(&self) -> Vec; - /// Get all user defined window function names + /// Return all window function names fn udwf_names(&self) -> Vec; } -/// This trait allows users to customize the behavior of the SQL planner +/// Customize planning of SQL AST expressions to [`Expr`]s pub trait ExprPlanner: Debug + Send + Sync { /// Plan the binary operation between two expressions, returns original /// BinaryExpr if not possible @@ -106,9 +125,9 @@ pub trait ExprPlanner: Debug + Send + Sync { Ok(PlannerResult::Original(expr)) } - /// Plan the field access expression + /// Plan the field access expression, such as `foo.bar` /// - /// returns original FieldAccessExpr if not possible + /// returns original [`RawFieldAccessExpr`] if not possible fn plan_field_access( &self, expr: RawFieldAccessExpr, @@ -117,7 +136,7 @@ pub trait ExprPlanner: Debug + Send + Sync { Ok(PlannerResult::Original(expr)) } - /// Plan the array literal, returns OriginalArray if not possible + /// Plan an array literal, such as `[1, 2, 3]` /// /// Returns origin expression arguments if not possible fn plan_array_literal( @@ -128,13 +147,14 @@ pub trait ExprPlanner: Debug + Send + Sync { Ok(PlannerResult::Original(exprs)) } - // Plan the POSITION expression, e.g., POSITION( in ) - // returns origin expression arguments if not possible + /// Plan a `POSITION` expression, such as `POSITION( in )` + /// + /// returns origin expression arguments if not possible fn plan_position(&self, args: Vec) -> Result>> { Ok(PlannerResult::Original(args)) } - /// Plan the dictionary literal `{ key: value, ...}` + /// Plan a dictionary literal, such as `{ key: value, ...}` /// /// Returns origin expression arguments if not possible fn plan_dictionary_literal( @@ -145,27 +165,26 @@ pub trait ExprPlanner: Debug + Send + Sync { Ok(PlannerResult::Original(expr)) } - /// Plan an extract expression, e.g., `EXTRACT(month FROM foo)` + /// Plan an extract expression, such as`EXTRACT(month FROM foo)` /// /// Returns origin expression arguments if not possible fn plan_extract(&self, args: Vec) -> Result>> { Ok(PlannerResult::Original(args)) } - /// Plan an substring expression, e.g., `SUBSTRING( [FROM ] [FOR ])` + /// Plan an substring expression, such as `SUBSTRING( [FROM ] [FOR ])` /// /// Returns origin expression arguments if not possible fn plan_substring(&self, args: Vec) -> Result>> { Ok(PlannerResult::Original(args)) } - /// Plans a struct `struct(expression1[, ..., expression_n])` - /// literal based on the given input expressions. - /// This function takes a vector of expressions and a boolean flag indicating whether - /// the struct uses the optional name + /// Plans a struct literal, such as `{'field1' : expr1, 'field2' : expr2, ...}` + /// + /// This function takes a vector of expressions and a boolean flag + /// indicating whether the struct uses the optional name /// - /// Returns a `PlannerResult` containing either the planned struct expressions or the original - /// input expressions if planning is not possible. + /// Returns the original input expressions if planning is not possible. fn plan_struct_literal( &self, args: Vec, @@ -174,26 +193,26 @@ pub trait ExprPlanner: Debug + Send + Sync { Ok(PlannerResult::Original(args)) } - /// Plans an overlay expression eg `overlay(str PLACING substr FROM pos [FOR count])` + /// Plans an overlay expression, such as `overlay(str PLACING substr FROM pos [FOR count])` /// /// Returns origin expression arguments if not possible fn plan_overlay(&self, args: Vec) -> Result>> { Ok(PlannerResult::Original(args)) } - /// Plan a make_map expression, e.g., `make_map(key1, value1, key2, value2, ...)` + /// Plans a `make_map` expression, such as `make_map(key1, value1, key2, value2, ...)` /// /// Returns origin expression arguments if not possible fn plan_make_map(&self, args: Vec) -> Result>> { Ok(PlannerResult::Original(args)) } - /// Plans compound identifier eg `db.schema.table` for non-empty nested names + /// Plans compound identifier such as `db.schema.table` for non-empty nested names /// - /// Note: + /// # Note: /// Currently compound identifier for outer query schema is not supported. /// - /// Returns planned expression + /// Returns original expression if not possible fn plan_compound_identifier( &self, _field: &Field, @@ -205,7 +224,7 @@ pub trait ExprPlanner: Debug + Send + Sync { ) } - /// Plans `ANY` expression, e.g., `expr = ANY(array_expr)` + /// Plans `ANY` expression, such as `expr = ANY(array_expr)` /// /// Returns origin binary expression if not possible fn plan_any(&self, expr: RawBinaryExpr) -> Result> { @@ -256,9 +275,9 @@ pub enum PlannerResult { Original(T), } -/// This trait allows users to customize the behavior of the data type planning +/// Customize planning SQL types to DataFusion (Arrow) types. pub trait TypePlanner: Debug + Send + Sync { - /// Plan SQL type to DataFusion data type + /// Plan SQL [`ast::DataType`] to DataFusion [`DataType`] /// /// Returns None if not possible fn plan_type(&self, _sql_type: &ast::DataType) -> Result> { diff --git a/datafusion/expr/src/table_source.rs b/datafusion/expr/src/table_source.rs index d62484153f53..d6155cfb5dc0 100644 --- a/datafusion/expr/src/table_source.rs +++ b/datafusion/expr/src/table_source.rs @@ -71,24 +71,33 @@ impl std::fmt::Display for TableType { } } -/// Access schema information and filter push-down capabilities. +/// Planning time information about a table. /// -/// The TableSource trait is used during logical query planning and -/// optimizations and provides a subset of the functionality of the -/// `TableProvider` trait in the (core) `datafusion` crate. The `TableProvider` -/// trait provides additional capabilities needed for physical query execution -/// (such as the ability to perform a scan). +/// This trait is used during logical query planning and optimizations, and +/// provides a subset of the [`TableProvider`] trait, such as schema information +/// and filter push-down capabilities. The [`TableProvider`] trait provides +/// additional information needed for physical query execution, such as the +/// ability to perform a scan or insert data. +/// +/// # See Also: +/// +/// [`DefaultTableSource`] to go from [`TableProvider`], to `TableSource` +/// +/// # Rationale /// /// The reason for having two separate traits is to avoid having the logical /// plan code be dependent on the DataFusion execution engine. Some projects use /// DataFusion's logical plans and have their own execution engine. +/// +/// [`TableProvider`]: https://docs.rs/datafusion/latest/datafusion/datasource/provider/trait.TableProvider.html +/// [`DefaultTableSource`]: https://docs.rs/datafusion/latest/datafusion/datasource/default_table_source/struct.DefaultTableSource.html pub trait TableSource: Sync + Send { fn as_any(&self) -> &dyn Any; /// Get a reference to the schema for this table fn schema(&self) -> SchemaRef; - /// Get primary key indices, if one exists. + /// Get primary key indices, if any fn constraints(&self) -> Option<&Constraints> { None } @@ -110,6 +119,8 @@ pub trait TableSource: Sync + Send { } /// Get the Logical plan of this table provider, if available. + /// + /// For example, a view may have a logical plan, but a CSV file does not. fn get_logical_plan(&self) -> Option> { None } diff --git a/datafusion/expr/src/test/function_stub.rs b/datafusion/expr/src/test/function_stub.rs index 71ab1ad6ef9b..a753f4c376c6 100644 --- a/datafusion/expr/src/test/function_stub.rs +++ b/datafusion/expr/src/test/function_stub.rs @@ -25,7 +25,7 @@ use arrow::datatypes::{ DataType, Field, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; -use datafusion_common::{exec_err, not_impl_err, Result}; +use datafusion_common::{exec_err, not_impl_err, utils::take_function_args, Result}; use crate::type_coercion::aggregates::{avg_return_type, coerce_avg_type, NUMERICS}; use crate::Volatility::Immutable; @@ -125,9 +125,7 @@ impl AggregateUDFImpl for Sum { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() != 1 { - return exec_err!("SUM expects exactly one argument"); - } + let [array] = take_function_args(self.name(), arg_types)?; // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc // smallint, int, bigint, real, double precision, decimal, or interval. @@ -147,7 +145,7 @@ impl AggregateUDFImpl for Sum { } } - Ok(vec![coerced_type(&arg_types[0])?]) + Ok(vec![coerced_type(array)?]) } fn return_type(&self, arg_types: &[DataType]) -> Result { diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index eacace5ed046..50af62060346 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -18,8 +18,9 @@ //! Tree node implementation for Logical Expressions use crate::expr::{ - AggregateFunction, Alias, Between, BinaryExpr, Case, Cast, GroupingSet, InList, - InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, + AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Cast, + GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest, + WindowFunction, WindowFunctionParams, }; use crate::{Expr, ExprFunctionExt}; @@ -87,14 +88,14 @@ impl TreeNode for Expr { }) => (expr, low, high).apply_ref_elements(f), Expr::Case(Case { expr, when_then_expr, else_expr }) => (expr, when_then_expr, else_expr).apply_ref_elements(f), - Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) => + Expr::AggregateFunction(AggregateFunction { params: AggregateFunctionParams { args, filter, order_by, ..}, .. }) => (args, filter, order_by).apply_ref_elements(f), Expr::WindowFunction(WindowFunction { - args, - partition_by, - order_by, - .. - }) => { + params : WindowFunctionParams { + args, + partition_by, + order_by, + ..}, ..}) => { (args, partition_by, order_by).apply_ref_elements(f) } Expr::InList(InList { expr, list, .. }) => { @@ -223,12 +224,15 @@ impl TreeNode for Expr { })? } Expr::WindowFunction(WindowFunction { - args, fun, - partition_by, - order_by, - window_frame, - null_treatment, + params: + WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment, + }, }) => (args, partition_by, order_by).map_elements(f)?.update_data( |(new_args, new_partition_by, new_order_by)| { Expr::WindowFunction(WindowFunction::new(fun, new_args)) @@ -241,12 +245,15 @@ impl TreeNode for Expr { }, ), Expr::AggregateFunction(AggregateFunction { - args, func, - distinct, - filter, - order_by, - null_treatment, + params: + AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment, + }, }) => (args, filter, order_by).map_elements(f)?.map_data( |(new_args, new_filter, new_order_by)| { Ok(Expr::AggregateFunction(AggregateFunction::new_udf( diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 7ac836ef3aeb..b471feca043f 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -21,18 +21,15 @@ use arrow::{ compute::can_cast_types, datatypes::{DataType, TimeUnit}, }; -use datafusion_common::utils::coerced_fixed_size_list_to_list; +use datafusion_common::types::LogicalType; +use datafusion_common::utils::{coerced_fixed_size_list_to_list, ListCoercion}; use datafusion_common::{ - exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, - types::{LogicalType, NativeType}, - utils::list_ndims, - Result, + exec_err, internal_datafusion_err, internal_err, plan_err, types::NativeType, + utils::list_ndims, Result, }; +use datafusion_expr_common::signature::ArrayFunctionArgument; use datafusion_expr_common::{ - signature::{ - ArrayFunctionSignature, TypeSignatureClass, FIXED_SIZE_LIST_WILDCARD, - TIMEZONE_WILDCARD, - }, + signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD}, type_coercion::binary::comparison_coercion_numeric, type_coercion::binary::string_coercion, }; @@ -357,88 +354,81 @@ fn get_valid_types( signature: &TypeSignature, current_types: &[DataType], ) -> Result>> { - fn array_element_and_optional_index( + fn array_valid_types( function_name: &str, current_types: &[DataType], + arguments: &[ArrayFunctionArgument], + array_coercion: Option<&ListCoercion>, ) -> Result>> { - // make sure there's 2 or 3 arguments - if !(current_types.len() == 2 || current_types.len() == 3) { + if current_types.len() != arguments.len() { return Ok(vec![vec![]]); } - let first_two_types = ¤t_types[0..2]; - let mut valid_types = - array_append_or_prepend_valid_types(function_name, first_two_types, true)?; - - // Early return if there are only 2 arguments - if current_types.len() == 2 { - return Ok(valid_types); - } - - let valid_types_with_index = valid_types - .iter() - .map(|t| { - let mut t = t.clone(); - t.push(DataType::Int64); - t - }) - .collect::>(); - - valid_types.extend(valid_types_with_index); - - Ok(valid_types) - } - - fn array_append_or_prepend_valid_types( - function_name: &str, - current_types: &[DataType], - is_append: bool, - ) -> Result>> { - if current_types.len() != 2 { - return Ok(vec![vec![]]); - } - - let (array_type, elem_type) = if is_append { - (¤t_types[0], ¤t_types[1]) - } else { - (¤t_types[1], ¤t_types[0]) + let array_idx = arguments.iter().enumerate().find_map(|(idx, arg)| { + if *arg == ArrayFunctionArgument::Array { + Some(idx) + } else { + None + } + }); + let Some(array_idx) = array_idx else { + return Err(internal_datafusion_err!("Function '{function_name}' expected at least one argument array argument")); }; - - // We follow Postgres on `array_append(Null, T)`, which is not valid. - if array_type.eq(&DataType::Null) { + let Some(array_type) = array(¤t_types[array_idx]) else { return Ok(vec![vec![]]); - } + }; // We need to find the coerced base type, mainly for cases like: // `array_append(List(null), i64)` -> `List(i64)` - let array_base_type = datafusion_common::utils::base_type(array_type); - let elem_base_type = datafusion_common::utils::base_type(elem_type); - let new_base_type = comparison_coercion(&array_base_type, &elem_base_type); - - let new_base_type = new_base_type.ok_or_else(|| { - internal_datafusion_err!( - "Function '{function_name}' does not support coercion from {array_base_type:?} to {elem_base_type:?}" - ) - })?; - + let mut new_base_type = datafusion_common::utils::base_type(&array_type); + for (current_type, argument_type) in current_types.iter().zip(arguments.iter()) { + match argument_type { + ArrayFunctionArgument::Element | ArrayFunctionArgument::Array => { + new_base_type = + coerce_array_types(function_name, current_type, &new_base_type)?; + } + ArrayFunctionArgument::Index => {} + } + } let new_array_type = datafusion_common::utils::coerced_type_with_base_type_only( - array_type, + &array_type, &new_base_type, + array_coercion, ); - match new_array_type { + let new_elem_type = match new_array_type { DataType::List(ref field) | DataType::LargeList(ref field) - | DataType::FixedSizeList(ref field, _) => { - let new_elem_type = field.data_type(); - if is_append { - Ok(vec![vec![new_array_type.clone(), new_elem_type.clone()]]) - } else { - Ok(vec![vec![new_elem_type.to_owned(), new_array_type.clone()]]) + | DataType::FixedSizeList(ref field, _) => field.data_type(), + _ => return Ok(vec![vec![]]), + }; + + let mut valid_types = Vec::with_capacity(arguments.len()); + for (current_type, argument_type) in current_types.iter().zip(arguments.iter()) { + let valid_type = match argument_type { + ArrayFunctionArgument::Element => new_elem_type.clone(), + ArrayFunctionArgument::Index => DataType::Int64, + ArrayFunctionArgument::Array => { + let Some(current_type) = array(current_type) else { + return Ok(vec![vec![]]); + }; + let new_type = + datafusion_common::utils::coerced_type_with_base_type_only( + ¤t_type, + &new_base_type, + array_coercion, + ); + // All array arguments must be coercible to the same type + if new_type != new_array_type { + return Ok(vec![vec![]]); + } + new_type } - } - _ => Ok(vec![vec![]]), + }; + valid_types.push(valid_type); } + + Ok(vec![valid_types]) } fn array(array_type: &DataType) -> Option { @@ -449,6 +439,20 @@ fn get_valid_types( } } + fn coerce_array_types( + function_name: &str, + current_type: &DataType, + base_type: &DataType, + ) -> Result { + let current_base_type = datafusion_common::utils::base_type(current_type); + let new_base_type = comparison_coercion(base_type, ¤t_base_type); + new_base_type.ok_or_else(|| { + internal_datafusion_err!( + "Function '{function_name}' does not support coercion from {base_type:?} to {current_base_type:?}" + ) + }) + } + fn recursive_array(array_type: &DataType) -> Option { match array_type { DataType::List(_) @@ -596,75 +600,36 @@ fn get_valid_types( vec![vec![target_type; *num]] } } - TypeSignature::Coercible(target_types) => { - function_length_check( - function_name, - current_types.len(), - target_types.len(), - )?; - - // Aim to keep this logic as SIMPLE as possible! - // Make sure the corresponding test is covered - // If this function becomes COMPLEX, create another new signature! - fn can_coerce_to( - function_name: &str, - current_type: &DataType, - target_type_class: &TypeSignatureClass, - ) -> Result { - let logical_type: NativeType = current_type.into(); - - match target_type_class { - TypeSignatureClass::Native(native_type) => { - let target_type = native_type.native(); - if &logical_type == target_type { - return target_type.default_cast_for(current_type); - } - - if logical_type == NativeType::Null { - return target_type.default_cast_for(current_type); - } - - if target_type.is_integer() && logical_type.is_integer() { - return target_type.default_cast_for(current_type); - } - - internal_err!( - "Function '{function_name}' expects {target_type_class} but received {current_type}" - ) - } - // Not consistent with Postgres and DuckDB but to avoid regression we implicit cast string to timestamp - TypeSignatureClass::Timestamp - if logical_type == NativeType::String => - { - Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) - } - TypeSignatureClass::Timestamp if logical_type.is_timestamp() => { - Ok(current_type.to_owned()) - } - TypeSignatureClass::Date if logical_type.is_date() => { - Ok(current_type.to_owned()) - } - TypeSignatureClass::Time if logical_type.is_time() => { - Ok(current_type.to_owned()) - } - TypeSignatureClass::Interval if logical_type.is_interval() => { - Ok(current_type.to_owned()) - } - TypeSignatureClass::Duration if logical_type.is_duration() => { - Ok(current_type.to_owned()) - } - _ => { - not_impl_err!("Function '{function_name}' got logical_type: {logical_type} with target_type_class: {target_type_class}") - } - } - } + TypeSignature::Coercible(param_types) => { + function_length_check(function_name, current_types.len(), param_types.len())?; let mut new_types = Vec::with_capacity(current_types.len()); - for (current_type, target_type_class) in - current_types.iter().zip(target_types.iter()) - { - let target_type = can_coerce_to(function_name, current_type, target_type_class)?; - new_types.push(target_type); + for (current_type, param) in current_types.iter().zip(param_types.iter()) { + let current_native_type: NativeType = current_type.into(); + + if param.desired_type().matches_native_type(¤t_native_type) { + let casted_type = param.desired_type().default_casted_type( + ¤t_native_type, + current_type, + )?; + + new_types.push(casted_type); + } else if param + .allowed_source_types() + .iter() + .any(|t| t.matches_native_type(¤t_native_type)) { + // If the condition is met which means `implicit coercion`` is provided so we can safely unwrap + let default_casted_type = param.default_casted_type().unwrap(); + let casted_type = default_casted_type.default_cast_for(current_type)?; + new_types.push(casted_type); + } else { + return internal_err!( + "Expect {} but received {}, DataType: {}", + param.desired_type(), + current_native_type, + current_type + ); + } } vec![new_types] @@ -693,40 +658,9 @@ fn get_valid_types( vec![current_types.to_vec()] } TypeSignature::Exact(valid_types) => vec![valid_types.clone()], - TypeSignature::ArraySignature(ref function_signature) => match function_signature - { - ArrayFunctionSignature::ArrayAndElement => { - array_append_or_prepend_valid_types(function_name, current_types, true)? - } - ArrayFunctionSignature::ElementAndArray => { - array_append_or_prepend_valid_types(function_name, current_types, false)? - } - ArrayFunctionSignature::ArrayAndIndexes(count) => { - if current_types.len() != count.get() + 1 { - return Ok(vec![vec![]]); - } - array(¤t_types[0]).map_or_else( - || vec![vec![]], - |array_type| { - let mut inner = Vec::with_capacity(count.get() + 1); - inner.push(array_type); - for _ in 0..count.get() { - inner.push(DataType::Int64); - } - vec![inner] - }, - ) - } - ArrayFunctionSignature::ArrayAndElementAndOptionalIndex => { - array_element_and_optional_index(function_name, current_types)? - } - ArrayFunctionSignature::Array => { - if current_types.len() != 1 { - return Ok(vec![vec![]]); - } - - array(¤t_types[0]) - .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]]) + TypeSignature::ArraySignature(ref function_signature) => match function_signature { + ArrayFunctionSignature::Array { arguments, array_coercion, } => { + array_valid_types(function_name, current_types, arguments, array_coercion.as_ref())? } ArrayFunctionSignature::RecursiveArray => { if current_types.len() != 1 { diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 7ffc6623ea92..2b9e2bddd184 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -19,7 +19,7 @@ use std::any::Any; use std::cmp::Ordering; -use std::fmt::{self, Debug, Formatter}; +use std::fmt::{self, Debug, Formatter, Write}; use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; use std::vec; @@ -29,14 +29,18 @@ use arrow::datatypes::{DataType, Field}; use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue, Statistics}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use crate::expr::AggregateFunction; +use crate::expr::{ + schema_name_from_exprs, schema_name_from_exprs_comma_separated_without_space, + schema_name_from_sorts, AggregateFunction, AggregateFunctionParams, + WindowFunctionParams, +}; use crate::function::{ AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs, }; use crate::groups_accumulator::GroupsAccumulator; use crate::utils::format_state_name; use crate::utils::AggregateOrderSensitivity; -use crate::{Accumulator, Expr}; +use crate::{expr_vec_fmt, Accumulator, Expr}; use crate::{Documentation, Signature}; /// Logical representation of a user-defined [aggregate function] (UDAF). @@ -165,6 +169,30 @@ impl AggregateUDF { self.inner.name() } + /// See [`AggregateUDFImpl::schema_name`] for more details. + pub fn schema_name(&self, params: &AggregateFunctionParams) -> Result { + self.inner.schema_name(params) + } + + pub fn window_function_schema_name( + &self, + params: &WindowFunctionParams, + ) -> Result { + self.inner.window_function_schema_name(params) + } + + /// See [`AggregateUDFImpl::display_name`] for more details. + pub fn display_name(&self, params: &AggregateFunctionParams) -> Result { + self.inner.display_name(params) + } + + pub fn window_function_display_name( + &self, + params: &WindowFunctionParams, + ) -> Result { + self.inner.window_function_display_name(params) + } + pub fn is_nullable(&self) -> bool { self.inner.is_nullable() } @@ -382,6 +410,186 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// Returns this function's name fn name(&self) -> &str; + /// Returns the name of the column this expression would create + /// + /// See [`Expr::schema_name`] for details + /// + /// Example of schema_name: count(DISTINCT column1) FILTER (WHERE column2 > 10) ORDER BY [..] + fn schema_name(&self, params: &AggregateFunctionParams) -> Result { + let AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment, + } = params; + + let mut schema_name = String::new(); + + schema_name.write_fmt(format_args!( + "{}({}{})", + self.name(), + if *distinct { "DISTINCT " } else { "" }, + schema_name_from_exprs_comma_separated_without_space(args)? + ))?; + + if let Some(null_treatment) = null_treatment { + schema_name.write_fmt(format_args!(" {}", null_treatment))?; + } + + if let Some(filter) = filter { + schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?; + }; + + if let Some(order_by) = order_by { + schema_name.write_fmt(format_args!( + " ORDER BY [{}]", + schema_name_from_sorts(order_by)? + ))?; + }; + + Ok(schema_name) + } + + /// Returns the name of the column this expression would create + /// + /// See [`Expr::schema_name`] for details + /// + /// Different from `schema_name` in that it is used for window aggregate function + /// + /// Example of schema_name: count(DISTINCT column1) FILTER (WHERE column2 > 10) [PARTITION BY [..]] [ORDER BY [..]] + fn window_function_schema_name( + &self, + params: &WindowFunctionParams, + ) -> Result { + let WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment, + } = params; + + let mut schema_name = String::new(); + schema_name.write_fmt(format_args!( + "{}({})", + self.name(), + schema_name_from_exprs(args)? + ))?; + + if let Some(null_treatment) = null_treatment { + schema_name.write_fmt(format_args!(" {}", null_treatment))?; + } + + if !partition_by.is_empty() { + schema_name.write_fmt(format_args!( + " PARTITION BY [{}]", + schema_name_from_exprs(partition_by)? + ))?; + } + + if !order_by.is_empty() { + schema_name.write_fmt(format_args!( + " ORDER BY [{}]", + schema_name_from_sorts(order_by)? + ))?; + }; + + schema_name.write_fmt(format_args!(" {window_frame}"))?; + + Ok(schema_name) + } + + /// Returns the user-defined display name of function, given the arguments + /// + /// This can be used to customize the output column name generated by this + /// function. + /// + /// Defaults to `function_name([DISTINCT] column1, column2, ..) [null_treatment] [filter] [order_by [..]]` + fn display_name(&self, params: &AggregateFunctionParams) -> Result { + let AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment, + } = params; + + let mut schema_name = String::new(); + + schema_name.write_fmt(format_args!( + "{}({}{})", + self.name(), + if *distinct { "DISTINCT " } else { "" }, + expr_vec_fmt!(args) + ))?; + + if let Some(nt) = null_treatment { + schema_name.write_fmt(format_args!(" {}", nt))?; + } + if let Some(fe) = filter { + schema_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?; + } + if let Some(order_by) = order_by { + schema_name + .write_fmt(format_args!(" ORDER BY [{}]", expr_vec_fmt!(order_by)))?; + } + + Ok(schema_name) + } + + /// Returns the user-defined display name of function, given the arguments + /// + /// This can be used to customize the output column name generated by this + /// function. + /// + /// Different from `display_name` in that it is used for window aggregate function + /// + /// Defaults to `function_name([DISTINCT] column1, column2, ..) [null_treatment] [partition by [..]] [order_by [..]]` + fn window_function_display_name( + &self, + params: &WindowFunctionParams, + ) -> Result { + let WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment, + } = params; + + let mut display_name = String::new(); + + display_name.write_fmt(format_args!( + "{}({})", + self.name(), + expr_vec_fmt!(args) + ))?; + + if let Some(null_treatment) = null_treatment { + display_name.write_fmt(format_args!(" {}", null_treatment))?; + } + + if !partition_by.is_empty() { + display_name.write_fmt(format_args!( + " PARTITION BY [{}]", + expr_vec_fmt!(partition_by) + ))?; + } + + if !order_by.is_empty() { + display_name + .write_fmt(format_args!(" ORDER BY [{}]", expr_vec_fmt!(order_by)))?; + }; + + display_name.write_fmt(format_args!( + " {} BETWEEN {} AND {}", + window_frame.units, window_frame.start_bound, window_frame.end_bound + ))?; + + Ok(display_name) + } + /// Returns the function's [`Signature`] for information about what input /// types are accepted and the function's Volatility. fn signature(&self) -> &Signature; diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index b41d97520362..74c3c2775c1c 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -899,13 +899,8 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.inner.return_type_from_args(args) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - number_rows: usize, - ) -> Result { - #[allow(deprecated)] - self.inner.invoke_batch(args, number_rows) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + self.inner.invoke_with_args(args) } fn simplify( @@ -980,6 +975,7 @@ pub mod scalar_doc_sections { DOC_SECTION_STRUCT, DOC_SECTION_MAP, DOC_SECTION_HASHING, + DOC_SECTION_UNION, DOC_SECTION_OTHER, ] } @@ -996,6 +992,7 @@ pub mod scalar_doc_sections { DOC_SECTION_STRUCT, DOC_SECTION_MAP, DOC_SECTION_HASHING, + DOC_SECTION_UNION, DOC_SECTION_OTHER, ] } @@ -1070,4 +1067,10 @@ The following regular expression functions are supported:"#, label: "Other Functions", description: None, }; + + pub const DOC_SECTION_UNION: DocSection = DocSection { + include: true, + label: "Union Functions", + description: Some("Functions to work with the union data type, also know as tagged unions, variant types, enums or sum types. Note: Not related to the SQL UNION operator"), + }; } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 049926fb0bcd..86c0f9ad637c 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -22,7 +22,7 @@ use std::collections::{BTreeSet, HashSet}; use std::ops::Deref; use std::sync::Arc; -use crate::expr::{Alias, Sort, WildcardOptions, WindowFunction}; +use crate::expr::{Alias, Sort, WildcardOptions, WindowFunction, WindowFunctionParams}; use crate::expr_rewriter::strip_outer_reference; use crate::{ and, BinaryExpr, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator, @@ -588,7 +588,7 @@ pub fn group_window_expr_by_sort_keys( ) -> Result)>> { let mut result = vec![]; window_expr.into_iter().try_for_each(|expr| match &expr { - Expr::WindowFunction( WindowFunction{ partition_by, order_by, .. }) => { + Expr::WindowFunction( WindowFunction{ params: WindowFunctionParams { partition_by, order_by, ..}, .. }) => { let sort_key = generate_sort_key(partition_by, order_by)?; if let Some((_, values)) = result.iter_mut().find( |group: &&mut (WindowSortKey, Vec)| matches!(group, (key, _) if *key == sort_key), diff --git a/datafusion/expr/src/var_provider.rs b/datafusion/expr/src/var_provider.rs index e00cf7407237..708cd576c3ff 100644 --- a/datafusion/expr/src/var_provider.rs +++ b/datafusion/expr/src/var_provider.rs @@ -38,6 +38,12 @@ pub trait VarProvider: std::fmt::Debug { fn get_type(&self, var_names: &[String]) -> Option; } +/// Returns true if the specified string is a "system" variable such as +/// `@@version` +/// +/// See [`SessionContext::register_variable`] for more details +/// +/// [`SessionContext::register_variable`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.register_variable pub fn is_system_variables(variable_names: &[String]) -> bool { !variable_names.is_empty() && variable_names[0].get(0..2) == Some("@@") } diff --git a/datafusion/ffi/Cargo.toml b/datafusion/ffi/Cargo.toml index c33c87786de8..4c396144347c 100644 --- a/datafusion/ffi/Cargo.toml +++ b/datafusion/ffi/Cargo.toml @@ -38,7 +38,6 @@ crate-type = ["cdylib", "rlib"] [dependencies] abi_stable = "0.11.3" arrow = { workspace = true, features = ["ffi"] } -arrow-schema = { workspace = true } async-ffi = { version = "0.5.0", features = ["abi_stable"] } async-trait = { workspace = true } datafusion = { workspace = true, default-features = false } diff --git a/datafusion/ffi/src/arrow_wrappers.rs b/datafusion/ffi/src/arrow_wrappers.rs index c5add8782c51..a18e6df59bf1 100644 --- a/datafusion/ffi/src/arrow_wrappers.rs +++ b/datafusion/ffi/src/arrow_wrappers.rs @@ -19,8 +19,9 @@ use std::sync::Arc; use abi_stable::StableAbi; use arrow::{ + array::{make_array, ArrayRef}, datatypes::{Schema, SchemaRef}, - ffi::{FFI_ArrowArray, FFI_ArrowSchema}, + ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}, }; use log::error; @@ -68,3 +69,13 @@ pub struct WrappedArray { pub schema: WrappedSchema, } + +impl TryFrom for ArrayRef { + type Error = arrow::error::ArrowError; + + fn try_from(value: WrappedArray) -> Result { + let data = unsafe { from_ffi(value.array, &value.schema.0)? }; + + Ok(make_array(data)) + } +} diff --git a/datafusion/ffi/src/execution_plan.rs b/datafusion/ffi/src/execution_plan.rs index 6c5db1218563..8087acfa33c8 100644 --- a/datafusion/ffi/src/execution_plan.rs +++ b/datafusion/ffi/src/execution_plan.rs @@ -30,7 +30,8 @@ use datafusion::{ use tokio::runtime::Handle; use crate::{ - plan_properties::FFI_PlanProperties, record_batch_stream::FFI_RecordBatchStream, + df_result, plan_properties::FFI_PlanProperties, + record_batch_stream::FFI_RecordBatchStream, rresult, }; /// A stable struct for sharing a [`ExecutionPlan`] across FFI boundaries. @@ -112,13 +113,11 @@ unsafe extern "C" fn execute_fn_wrapper( let ctx = &(*private_data).context; let runtime = (*private_data).runtime.clone(); - match plan.execute(partition, Arc::clone(ctx)) { - Ok(rbs) => RResult::ROk(FFI_RecordBatchStream::new(rbs, runtime)), - Err(e) => RResult::RErr( - format!("Error occurred during FFI_ExecutionPlan execute: {}", e).into(), - ), - } + rresult!(plan + .execute(partition, Arc::clone(ctx)) + .map(|rbs| FFI_RecordBatchStream::new(rbs, runtime))) } + unsafe extern "C" fn name_fn_wrapper(plan: &FFI_ExecutionPlan) -> RString { let private_data = plan.private_data as *const ExecutionPlanPrivateData; let plan = &(*private_data).plan; @@ -274,16 +273,8 @@ impl ExecutionPlan for ForeignExecutionPlan { _context: Arc, ) -> Result { unsafe { - match (self.plan.execute)(&self.plan, partition) { - RResult::ROk(stream) => { - let stream = Pin::new(Box::new(stream)) as SendableRecordBatchStream; - Ok(stream) - } - RResult::RErr(e) => Err(DataFusionError::Execution(format!( - "Error occurred during FFI call to FFI_ExecutionPlan execute. {}", - e - ))), - } + df_result!((self.plan.execute)(&self.plan, partition)) + .map(|stream| Pin::new(Box::new(stream)) as SendableRecordBatchStream) } } } diff --git a/datafusion/ffi/src/lib.rs b/datafusion/ffi/src/lib.rs index b25528234773..bbcdd85ff80a 100644 --- a/datafusion/ffi/src/lib.rs +++ b/datafusion/ffi/src/lib.rs @@ -26,6 +26,9 @@ pub mod record_batch_stream; pub mod session_config; pub mod table_provider; pub mod table_source; +pub mod udf; +pub mod util; +pub mod volatility; #[cfg(feature = "integration-tests")] pub mod tests; diff --git a/datafusion/ffi/src/plan_properties.rs b/datafusion/ffi/src/plan_properties.rs index 3c7bc886aede..3592c16b8fab 100644 --- a/datafusion/ffi/src/plan_properties.rs +++ b/datafusion/ffi/src/plan_properties.rs @@ -19,8 +19,8 @@ use std::{ffi::c_void, sync::Arc}; use abi_stable::{ std_types::{ - RResult::{self, RErr, ROk}, - RStr, RVec, + RResult::{self, ROk}, + RString, RVec, }, StableAbi, }; @@ -44,7 +44,7 @@ use datafusion_proto::{ }; use prost::Message; -use crate::arrow_wrappers::WrappedSchema; +use crate::{arrow_wrappers::WrappedSchema, df_result, rresult_return}; /// A stable struct for sharing [`PlanProperties`] across FFI boundaries. #[repr(C)] @@ -54,7 +54,7 @@ pub struct FFI_PlanProperties { /// The output partitioning is a [`Partitioning`] protobuf message serialized /// into bytes to pass across the FFI boundary. pub output_partitioning: - unsafe extern "C" fn(plan: &Self) -> RResult, RStr<'static>>, + unsafe extern "C" fn(plan: &Self) -> RResult, RString>, /// Return the emission type of the plan. pub emission_type: unsafe extern "C" fn(plan: &Self) -> FFI_EmissionType, @@ -64,8 +64,7 @@ pub struct FFI_PlanProperties { /// The output ordering is a [`PhysicalSortExprNodeCollection`] protobuf message /// serialized into bytes to pass across the FFI boundary. - pub output_ordering: - unsafe extern "C" fn(plan: &Self) -> RResult, RStr<'static>>, + pub output_ordering: unsafe extern "C" fn(plan: &Self) -> RResult, RString>, /// Return the schema of the plan. pub schema: unsafe extern "C" fn(plan: &Self) -> WrappedSchema, @@ -84,21 +83,13 @@ struct PlanPropertiesPrivateData { unsafe extern "C" fn output_partitioning_fn_wrapper( properties: &FFI_PlanProperties, -) -> RResult, RStr<'static>> { +) -> RResult, RString> { let private_data = properties.private_data as *const PlanPropertiesPrivateData; let props = &(*private_data).props; let codec = DefaultPhysicalExtensionCodec {}; let partitioning_data = - match serialize_partitioning(props.output_partitioning(), &codec) { - Ok(p) => p, - Err(_) => { - return RErr( - "unable to serialize output_partitioning in FFI_PlanProperties" - .into(), - ) - } - }; + rresult_return!(serialize_partitioning(props.output_partitioning(), &codec)); let output_partitioning = partitioning_data.encode_to_vec(); ROk(output_partitioning.into()) @@ -122,31 +113,24 @@ unsafe extern "C" fn boundedness_fn_wrapper( unsafe extern "C" fn output_ordering_fn_wrapper( properties: &FFI_PlanProperties, -) -> RResult, RStr<'static>> { +) -> RResult, RString> { let private_data = properties.private_data as *const PlanPropertiesPrivateData; let props = &(*private_data).props; let codec = DefaultPhysicalExtensionCodec {}; - let output_ordering = - match props.output_ordering() { - Some(ordering) => { - let physical_sort_expr_nodes = - match serialize_physical_sort_exprs(ordering.to_owned(), &codec) { - Ok(v) => v, - Err(_) => return RErr( - "unable to serialize output_ordering in FFI_PlanProperties" - .into(), - ), - }; - - let ordering_data = PhysicalSortExprNodeCollection { - physical_sort_expr_nodes, - }; - - ordering_data.encode_to_vec() - } - None => Vec::default(), - }; + let output_ordering = match props.output_ordering() { + Some(ordering) => { + let physical_sort_expr_nodes = rresult_return!( + serialize_physical_sort_exprs(ordering.to_owned(), &codec) + ); + let ordering_data = PhysicalSortExprNodeCollection { + physical_sort_expr_nodes, + }; + + ordering_data.encode_to_vec() + } + None => Vec::default(), + }; ROk(output_ordering.into()) } @@ -200,40 +184,32 @@ impl TryFrom for PlanProperties { let codex = DefaultPhysicalExtensionCodec {}; let ffi_orderings = unsafe { (ffi_props.output_ordering)(&ffi_props) }; - let orderings = match ffi_orderings { - ROk(ordering_vec) => { - let proto_output_ordering = - PhysicalSortExprNodeCollection::decode(ordering_vec.as_ref()) - .map_err(|e| DataFusionError::External(Box::new(e)))?; - Some(parse_physical_sort_exprs( - &proto_output_ordering.physical_sort_expr_nodes, - &default_ctx, - &schema, - &codex, - )?) - } - RErr(e) => return Err(DataFusionError::Plan(e.to_string())), - }; - let ffi_partitioning = unsafe { (ffi_props.output_partitioning)(&ffi_props) }; - let partitioning = match ffi_partitioning { - ROk(partitioning_vec) => { - let proto_output_partitioning = - Partitioning::decode(partitioning_vec.as_ref()) - .map_err(|e| DataFusionError::External(Box::new(e)))?; - parse_protobuf_partitioning( - Some(&proto_output_partitioning), - &default_ctx, - &schema, - &codex, - )? - .ok_or(DataFusionError::Plan( - "Unable to deserialize partitioning protobuf in FFI_PlanProperties" - .to_string(), - )) - } - RErr(e) => Err(DataFusionError::Plan(e.to_string())), - }?; + let proto_output_ordering = + PhysicalSortExprNodeCollection::decode(df_result!(ffi_orderings)?.as_ref()) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + let orderings = Some(parse_physical_sort_exprs( + &proto_output_ordering.physical_sort_expr_nodes, + &default_ctx, + &schema, + &codex, + )?); + + let partitioning_vec = + unsafe { df_result!((ffi_props.output_partitioning)(&ffi_props))? }; + let proto_output_partitioning = + Partitioning::decode(partitioning_vec.as_ref()) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + let partitioning = parse_protobuf_partitioning( + Some(&proto_output_partitioning), + &default_ctx, + &schema, + &codex, + )? + .ok_or(DataFusionError::Plan( + "Unable to deserialize partitioning protobuf in FFI_PlanProperties" + .to_string(), + ))?; let eq_properties = match orderings { Some(ordering) => { diff --git a/datafusion/ffi/src/record_batch_stream.rs b/datafusion/ffi/src/record_batch_stream.rs index 466ce247678a..939c4050028c 100644 --- a/datafusion/ffi/src/record_batch_stream.rs +++ b/datafusion/ffi/src/record_batch_stream.rs @@ -35,7 +35,10 @@ use datafusion::{ use futures::{Stream, TryStreamExt}; use tokio::runtime::Handle; -use crate::arrow_wrappers::{WrappedArray, WrappedSchema}; +use crate::{ + arrow_wrappers::{WrappedArray, WrappedSchema}, + rresult, +}; /// A stable struct for sharing [`RecordBatchStream`] across FFI boundaries. /// We use the async-ffi crate for handling async calls across libraries. @@ -97,13 +100,12 @@ fn record_batch_to_wrapped_array( record_batch: RecordBatch, ) -> RResult { let struct_array = StructArray::from(record_batch); - match to_ffi(&struct_array.to_data()) { - Ok((array, schema)) => RResult::ROk(WrappedArray { + rresult!( + to_ffi(&struct_array.to_data()).map(|(array, schema)| WrappedArray { array, - schema: WrappedSchema(schema), - }), - Err(e) => RResult::RErr(e.to_string().into()), - } + schema: WrappedSchema(schema) + }) + ) } // probably want to use pub unsafe fn from_ffi(array: FFI_ArrowArray, schema: &FFI_ArrowSchema) -> Result { diff --git a/datafusion/ffi/src/table_provider.rs b/datafusion/ffi/src/table_provider.rs index 978ac10206bd..0b4080abcb55 100644 --- a/datafusion/ffi/src/table_provider.rs +++ b/datafusion/ffi/src/table_provider.rs @@ -44,6 +44,7 @@ use tokio::runtime::Handle; use crate::{ arrow_wrappers::WrappedSchema, + df_result, rresult_return, session_config::ForeignSessionConfig, table_source::{FFI_TableProviderFilterPushDown, FFI_TableType}, }; @@ -233,10 +234,7 @@ unsafe extern "C" fn scan_fn_wrapper( let runtime = &(*private_data).runtime; async move { - let config = match ForeignSessionConfig::try_from(&session_config) { - Ok(c) => c, - Err(e) => return RResult::RErr(e.to_string().into()), - }; + let config = rresult_return!(ForeignSessionConfig::try_from(&session_config)); let session = SessionStateBuilder::new() .with_default_features() .with_config(config.0) @@ -250,15 +248,13 @@ unsafe extern "C" fn scan_fn_wrapper( let codec = DefaultLogicalExtensionCodec {}; let proto_filters = - match LogicalExprList::decode(filters_serialized.as_ref()) { - Ok(f) => f, - Err(e) => return RResult::RErr(e.to_string().into()), - }; - - match parse_exprs(proto_filters.expr.iter(), &default_ctx, &codec) { - Ok(f) => f, - Err(e) => return RResult::RErr(e.to_string().into()), - } + rresult_return!(LogicalExprList::decode(filters_serialized.as_ref())); + + rresult_return!(parse_exprs( + proto_filters.expr.iter(), + &default_ctx, + &codec + )) } }; @@ -268,13 +264,11 @@ unsafe extern "C" fn scan_fn_wrapper( false => Some(&projections), }; - let plan = match internal_provider - .scan(&ctx.state(), maybe_projections, &filters, limit.into()) - .await - { - Ok(p) => p, - Err(e) => return RResult::RErr(e.to_string().into()), - }; + let plan = rresult_return!( + internal_provider + .scan(&ctx.state(), maybe_projections, &filters, limit.into()) + .await + ); RResult::ROk(FFI_ExecutionPlan::new( plan, @@ -298,30 +292,22 @@ unsafe extern "C" fn insert_into_fn_wrapper( let runtime = &(*private_data).runtime; async move { - let config = match ForeignSessionConfig::try_from(&session_config) { - Ok(c) => c, - Err(e) => return RResult::RErr(e.to_string().into()), - }; + let config = rresult_return!(ForeignSessionConfig::try_from(&session_config)); let session = SessionStateBuilder::new() .with_default_features() .with_config(config.0) .build(); let ctx = SessionContext::new_with_state(session); - let input = match ForeignExecutionPlan::try_from(&input) { - Ok(input) => Arc::new(input), - Err(e) => return RResult::RErr(e.to_string().into()), - }; + let input = rresult_return!(ForeignExecutionPlan::try_from(&input).map(Arc::new)); let insert_op = InsertOp::from(insert_op); - let plan = match internal_provider - .insert_into(&ctx.state(), input, insert_op) - .await - { - Ok(p) => p, - Err(e) => return RResult::RErr(e.to_string().into()), - }; + let plan = rresult_return!( + internal_provider + .insert_into(&ctx.state(), input, insert_op) + .await + ); RResult::ROk(FFI_ExecutionPlan::new( plan, @@ -456,14 +442,7 @@ impl TableProvider for ForeignTableProvider { ) .await; - match maybe_plan { - RResult::ROk(p) => ForeignExecutionPlan::try_from(&p)?, - RResult::RErr(_) => { - return Err(DataFusionError::Internal( - "Unable to perform scan via FFI".to_string(), - )) - } - } + ForeignExecutionPlan::try_from(&df_result!(maybe_plan)?)? }; Ok(Arc::new(plan)) @@ -493,12 +472,9 @@ impl TableProvider for ForeignTableProvider { }; let serialized_filters = expr_list.encode_to_vec(); - let pushdowns = pushdown_fn(&self.0, serialized_filters.into()); + let pushdowns = df_result!(pushdown_fn(&self.0, serialized_filters.into()))?; - match pushdowns { - RResult::ROk(p) => Ok(p.iter().map(|v| v.into()).collect()), - RResult::RErr(e) => Err(DataFusionError::Plan(e.to_string())), - } + Ok(pushdowns.iter().map(|v| v.into()).collect()) } } @@ -519,15 +495,7 @@ impl TableProvider for ForeignTableProvider { let maybe_plan = (self.0.insert_into)(&self.0, &session_config, &input, insert_op).await; - match maybe_plan { - RResult::ROk(p) => ForeignExecutionPlan::try_from(&p)?, - RResult::RErr(e) => { - return Err(DataFusionError::Internal(format!( - "Unable to perform insert_into via FFI: {}", - e - ))) - } - } + ForeignExecutionPlan::try_from(&df_result!(maybe_plan)?)? }; Ok(Arc::new(plan)) diff --git a/datafusion/ffi/src/tests/async_provider.rs b/datafusion/ffi/src/tests/async_provider.rs index eff3ed61d739..cf05d596308f 100644 --- a/datafusion/ffi/src/tests/async_provider.rs +++ b/datafusion/ffi/src/tests/async_provider.rs @@ -29,7 +29,7 @@ use std::{any::Any, fmt::Debug, sync::Arc}; use crate::table_provider::FFI_TableProvider; use arrow::array::RecordBatch; -use arrow_schema::Schema; +use arrow::datatypes::Schema; use async_trait::async_trait; use datafusion::{ catalog::{Session, TableProvider}, @@ -238,7 +238,7 @@ struct AsyncTestRecordBatchStream { } impl RecordBatchStream for AsyncTestRecordBatchStream { - fn schema(&self) -> arrow_schema::SchemaRef { + fn schema(&self) -> arrow::datatypes::SchemaRef { super::create_test_schema() } } diff --git a/datafusion/ffi/src/tests/mod.rs b/datafusion/ffi/src/tests/mod.rs index a5fc74b840d1..5a471cb8fe43 100644 --- a/datafusion/ffi/src/tests/mod.rs +++ b/datafusion/ffi/src/tests/mod.rs @@ -26,7 +26,7 @@ use abi_stable::{ StableAbi, }; -use super::table_provider::FFI_TableProvider; +use super::{table_provider::FFI_TableProvider, udf::FFI_ScalarUDF}; use arrow::array::RecordBatch; use async_provider::create_async_table_provider; use datafusion::{ @@ -34,27 +34,30 @@ use datafusion::{ common::record_batch, }; use sync_provider::create_sync_table_provider; +use udf_udaf_udwf::create_ffi_abs_func; mod async_provider; mod sync_provider; +mod udf_udaf_udwf; #[repr(C)] #[derive(StableAbi)] -#[sabi(kind(Prefix(prefix_ref = TableProviderModuleRef)))] +#[sabi(kind(Prefix(prefix_ref = ForeignLibraryModuleRef)))] /// This struct defines the module interfaces. It is to be shared by /// both the module loading program and library that implements the -/// module. It is possible to move this definition into the loading -/// program and reference it in the modules, but this example shows -/// how a user may wish to separate these concerns. -pub struct TableProviderModule { +/// module. +pub struct ForeignLibraryModule { /// Constructs the table provider pub create_table: extern "C" fn(synchronous: bool) -> FFI_TableProvider, + /// Create a scalar UDF + pub create_scalar_udf: extern "C" fn() -> FFI_ScalarUDF, + pub version: extern "C" fn() -> u64, } -impl RootModule for TableProviderModuleRef { - declare_root_module_statics! {TableProviderModuleRef} +impl RootModule for ForeignLibraryModuleRef { + declare_root_module_statics! {ForeignLibraryModuleRef} const BASE_NAME: &'static str = "datafusion_ffi"; const NAME: &'static str = "datafusion_ffi"; const VERSION_STRINGS: VersionStrings = package_version_strings!(); @@ -64,7 +67,7 @@ impl RootModule for TableProviderModuleRef { } } -fn create_test_schema() -> Arc { +pub fn create_test_schema() -> Arc { Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, true), Field::new("b", DataType::Float64, true), @@ -90,9 +93,10 @@ extern "C" fn construct_table_provider(synchronous: bool) -> FFI_TableProvider { #[export_root_module] /// This defines the entry point for using the module. -pub fn get_simple_memory_table() -> TableProviderModuleRef { - TableProviderModule { +pub fn get_foreign_library_module() -> ForeignLibraryModuleRef { + ForeignLibraryModule { create_table: construct_table_provider, + create_scalar_udf: create_ffi_abs_func, version: super::version, } .leak_into_prefix() diff --git a/datafusion/ffi/src/tests/udf_udaf_udwf.rs b/datafusion/ffi/src/tests/udf_udaf_udwf.rs new file mode 100644 index 000000000000..e8a13aac1308 --- /dev/null +++ b/datafusion/ffi/src/tests/udf_udaf_udwf.rs @@ -0,0 +1,27 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::udf::FFI_ScalarUDF; +use datafusion::{functions::math::abs::AbsFunc, logical_expr::ScalarUDF}; + +use std::sync::Arc; + +pub(crate) extern "C" fn create_ffi_abs_func() -> FFI_ScalarUDF { + let udf: Arc = Arc::new(AbsFunc::new().into()); + + udf.into() +} diff --git a/datafusion/ffi/src/udf.rs b/datafusion/ffi/src/udf.rs new file mode 100644 index 000000000000..bbc9cf936cee --- /dev/null +++ b/datafusion/ffi/src/udf.rs @@ -0,0 +1,351 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, sync::Arc}; + +use abi_stable::{ + std_types::{RResult, RString, RVec}, + StableAbi, +}; +use arrow::datatypes::DataType; +use arrow::{ + array::ArrayRef, + error::ArrowError, + ffi::{from_ffi, to_ffi, FFI_ArrowSchema}, +}; +use datafusion::{ + error::DataFusionError, + logical_expr::type_coercion::functions::data_types_with_scalar_udf, +}; +use datafusion::{ + error::Result, + logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, + }, +}; + +use crate::{ + arrow_wrappers::{WrappedArray, WrappedSchema}, + df_result, rresult, rresult_return, + util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, + volatility::FFI_Volatility, +}; + +/// A stable struct for sharing a [`ScalarUDF`] across FFI boundaries. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_ScalarUDF { + /// FFI equivalent to the `name` of a [`ScalarUDF`] + pub name: RString, + + /// FFI equivalent to the `aliases` of a [`ScalarUDF`] + pub aliases: RVec, + + /// FFI equivalent to the `volatility` of a [`ScalarUDF`] + pub volatility: FFI_Volatility, + + /// Determines the return type of the underlying [`ScalarUDF`] based on the + /// argument types. + pub return_type: unsafe extern "C" fn( + udf: &Self, + arg_types: RVec, + ) -> RResult, + + /// Execute the underlying [`ScalarUDF`] and return the result as a `FFI_ArrowArray` + /// within an AbiStable wrapper. + pub invoke_with_args: unsafe extern "C" fn( + udf: &Self, + args: RVec, + num_rows: usize, + return_type: WrappedSchema, + ) -> RResult, + + /// See [`ScalarUDFImpl`] for details on short_circuits + pub short_circuits: bool, + + /// Performs type coersion. To simply this interface, all UDFs are treated as having + /// user defined signatures, which will in turn call coerce_types to be called. This + /// call should be transparent to most users as the internal function performs the + /// appropriate calls on the underlying [`ScalarUDF`] + pub coerce_types: unsafe extern "C" fn( + udf: &Self, + arg_types: RVec, + ) -> RResult, RString>, + + /// Used to create a clone on the provider of the udf. This should + /// only need to be called by the receiver of the udf. + pub clone: unsafe extern "C" fn(udf: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(udf: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the udf. + /// A [`ForeignScalarUDF`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_ScalarUDF {} +unsafe impl Sync for FFI_ScalarUDF {} + +pub struct ScalarUDFPrivateData { + pub udf: Arc, +} + +unsafe extern "C" fn return_type_fn_wrapper( + udf: &FFI_ScalarUDF, + arg_types: RVec, +) -> RResult { + let private_data = udf.private_data as *const ScalarUDFPrivateData; + let udf = &(*private_data).udf; + + let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types)); + + let return_type = udf + .return_type(&arg_types) + .and_then(|v| FFI_ArrowSchema::try_from(v).map_err(DataFusionError::from)) + .map(WrappedSchema); + + rresult!(return_type) +} + +unsafe extern "C" fn coerce_types_fn_wrapper( + udf: &FFI_ScalarUDF, + arg_types: RVec, +) -> RResult, RString> { + let private_data = udf.private_data as *const ScalarUDFPrivateData; + let udf = &(*private_data).udf; + + let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types)); + + let return_types = rresult_return!(data_types_with_scalar_udf(&arg_types, udf)); + + rresult!(vec_datatype_to_rvec_wrapped(&return_types)) +} + +unsafe extern "C" fn invoke_with_args_fn_wrapper( + udf: &FFI_ScalarUDF, + args: RVec, + number_rows: usize, + return_type: WrappedSchema, +) -> RResult { + let private_data = udf.private_data as *const ScalarUDFPrivateData; + let udf = &(*private_data).udf; + + let args = args + .into_iter() + .map(|arr| { + from_ffi(arr.array, &arr.schema.0) + .map(|v| ColumnarValue::Array(arrow::array::make_array(v))) + }) + .collect::>(); + + let args = rresult_return!(args); + let return_type = rresult_return!(DataType::try_from(&return_type.0)); + + let args = ScalarFunctionArgs { + args, + number_rows, + return_type: &return_type, + }; + + let result = rresult_return!(udf + .invoke_with_args(args) + .and_then(|r| r.to_array(number_rows))); + + let (result_array, result_schema) = rresult_return!(to_ffi(&result.to_data())); + + RResult::ROk(WrappedArray { + array: result_array, + schema: WrappedSchema(result_schema), + }) +} + +unsafe extern "C" fn release_fn_wrapper(udf: &mut FFI_ScalarUDF) { + let private_data = Box::from_raw(udf.private_data as *mut ScalarUDFPrivateData); + drop(private_data); +} + +unsafe extern "C" fn clone_fn_wrapper(udf: &FFI_ScalarUDF) -> FFI_ScalarUDF { + let private_data = udf.private_data as *const ScalarUDFPrivateData; + let udf_data = &(*private_data); + + Arc::clone(&udf_data.udf).into() +} + +impl Clone for FFI_ScalarUDF { + fn clone(&self) -> Self { + unsafe { (self.clone)(self) } + } +} + +impl From> for FFI_ScalarUDF { + fn from(udf: Arc) -> Self { + let name = udf.name().into(); + let aliases = udf.aliases().iter().map(|a| a.to_owned().into()).collect(); + let volatility = udf.signature().volatility.into(); + let short_circuits = udf.short_circuits(); + + let private_data = Box::new(ScalarUDFPrivateData { udf }); + + Self { + name, + aliases, + volatility, + short_circuits, + invoke_with_args: invoke_with_args_fn_wrapper, + return_type: return_type_fn_wrapper, + coerce_types: coerce_types_fn_wrapper, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + } + } +} + +impl Drop for FFI_ScalarUDF { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an UDF provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignScalarUDF is to be used by the caller of the UDF, so it has +/// no knowledge or access to the private data. All interaction with the UDF +/// must occur through the functions defined in FFI_ScalarUDF. +#[derive(Debug)] +pub struct ForeignScalarUDF { + name: String, + aliases: Vec, + udf: FFI_ScalarUDF, + signature: Signature, +} + +unsafe impl Send for ForeignScalarUDF {} +unsafe impl Sync for ForeignScalarUDF {} + +impl TryFrom<&FFI_ScalarUDF> for ForeignScalarUDF { + type Error = DataFusionError; + + fn try_from(udf: &FFI_ScalarUDF) -> Result { + let name = udf.name.to_owned().into(); + let signature = Signature::user_defined((&udf.volatility).into()); + + let aliases = udf.aliases.iter().map(|s| s.to_string()).collect(); + + Ok(Self { + name, + udf: udf.clone(), + aliases, + signature, + }) + } +} + +impl ScalarUDFImpl for ForeignScalarUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?; + + let result = unsafe { (self.udf.return_type)(&self.udf, arg_types) }; + + let result = df_result!(result); + + result.and_then(|r| (&r.0).try_into().map_err(DataFusionError::from)) + } + + fn invoke_with_args(&self, invoke_args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { + args, + number_rows, + return_type, + } = invoke_args; + + let args = args + .into_iter() + .map(|v| v.to_array(number_rows)) + .collect::>>()? + .into_iter() + .map(|v| { + to_ffi(&v.to_data()).map(|(ffi_array, ffi_schema)| WrappedArray { + array: ffi_array, + schema: WrappedSchema(ffi_schema), + }) + }) + .collect::, ArrowError>>()? + .into(); + + let return_type = WrappedSchema(FFI_ArrowSchema::try_from(return_type)?); + + let result = unsafe { + (self.udf.invoke_with_args)(&self.udf, args, number_rows, return_type) + }; + + let result = df_result!(result)?; + let result_array: ArrayRef = result.try_into()?; + + Ok(ColumnarValue::Array(result_array)) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn short_circuits(&self) -> bool { + self.udf.short_circuits + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + unsafe { + let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?; + let result_types = df_result!((self.udf.coerce_types)(&self.udf, arg_types))?; + Ok(rvec_wrapped_to_vec_datatype(&result_types)?) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_round_trip_scalar_udf() -> Result<()> { + let original_udf = datafusion::functions::math::abs::AbsFunc::new(); + let original_udf = Arc::new(ScalarUDF::from(original_udf)); + + let local_udf: FFI_ScalarUDF = Arc::clone(&original_udf).into(); + + let foreign_udf: ForeignScalarUDF = (&local_udf).try_into()?; + + assert!(original_udf.name() == foreign_udf.name()); + + Ok(()) + } +} diff --git a/datafusion/ffi/src/util.rs b/datafusion/ffi/src/util.rs new file mode 100644 index 000000000000..9d5f2aefe324 --- /dev/null +++ b/datafusion/ffi/src/util.rs @@ -0,0 +1,135 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use abi_stable::std_types::RVec; +use arrow::{datatypes::DataType, ffi::FFI_ArrowSchema}; + +use crate::arrow_wrappers::WrappedSchema; + +/// This macro is a helpful conversion utility to conver from an abi_stable::RResult to a +/// DataFusion result. +#[macro_export] +macro_rules! df_result { + ( $x:expr ) => { + match $x { + abi_stable::std_types::RResult::ROk(v) => Ok(v), + abi_stable::std_types::RResult::RErr(e) => { + Err(datafusion::error::DataFusionError::Execution(e.to_string())) + } + } + }; +} + +/// This macro is a helpful conversion utility to conver from a DataFusion Result to an abi_stable::RResult +#[macro_export] +macro_rules! rresult { + ( $x:expr ) => { + match $x { + Ok(v) => abi_stable::std_types::RResult::ROk(v), + Err(e) => abi_stable::std_types::RResult::RErr( + abi_stable::std_types::RString::from(e.to_string()), + ), + } + }; +} + +/// This macro is a helpful conversion utility to conver from a DataFusion Result to an abi_stable::RResult +/// and to also call return when it is an error. Since you cannot use `?` on an RResult, this is designed +/// to mimic the pattern. +#[macro_export] +macro_rules! rresult_return { + ( $x:expr ) => { + match $x { + Ok(v) => v, + Err(e) => { + return abi_stable::std_types::RResult::RErr( + abi_stable::std_types::RString::from(e.to_string()), + ) + } + } + }; +} + +/// This is a utility function to convert a slice of [`DataType`] to its equivalent +/// FFI friendly counterpart, [`WrappedSchema`] +pub fn vec_datatype_to_rvec_wrapped( + data_types: &[DataType], +) -> Result, arrow::error::ArrowError> { + Ok(data_types + .iter() + .map(FFI_ArrowSchema::try_from) + .collect::, arrow::error::ArrowError>>()? + .into_iter() + .map(WrappedSchema) + .collect()) +} + +/// This is a utility function to convert an FFI friendly vector of [`WrappedSchema`] +/// to their equivalent [`DataType`]. +pub fn rvec_wrapped_to_vec_datatype( + data_types: &RVec, +) -> Result, arrow::error::ArrowError> { + data_types + .iter() + .map(|d| DataType::try_from(&d.0)) + .collect() +} + +#[cfg(test)] +mod tests { + use abi_stable::std_types::{RResult, RString}; + use datafusion::error::DataFusionError; + + fn wrap_result(result: Result) -> RResult { + RResult::ROk(rresult_return!(result)) + } + + #[test] + fn test_conversion() { + const VALID_VALUE: &str = "valid_value"; + const ERROR_VALUE: &str = "error_value"; + + let ok_r_result: RResult = + RResult::ROk(VALID_VALUE.to_string().into()); + let err_r_result: RResult = + RResult::RErr(ERROR_VALUE.to_string().into()); + + let returned_ok_result = df_result!(ok_r_result); + assert!(returned_ok_result.is_ok()); + assert!(returned_ok_result.unwrap().to_string() == VALID_VALUE); + + let returned_err_result = df_result!(err_r_result); + assert!(returned_err_result.is_err()); + assert!( + returned_err_result.unwrap_err().to_string() + == format!("Execution error: {}", ERROR_VALUE) + ); + + let ok_result: Result = Ok(VALID_VALUE.to_string()); + let err_result: Result = + Err(DataFusionError::Execution(ERROR_VALUE.to_string())); + + let returned_ok_r_result = wrap_result(ok_result); + assert!(returned_ok_r_result == RResult::ROk(VALID_VALUE.into())); + + let returned_err_r_result = wrap_result(err_result); + assert!( + returned_err_r_result + == RResult::RErr(format!("Execution error: {}", ERROR_VALUE).into()) + ); + } +} diff --git a/datafusion/ffi/src/volatility.rs b/datafusion/ffi/src/volatility.rs new file mode 100644 index 000000000000..8b565b91b76d --- /dev/null +++ b/datafusion/ffi/src/volatility.rs @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use abi_stable::StableAbi; +use datafusion::logical_expr::Volatility; + +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub enum FFI_Volatility { + Immutable, + Stable, + Volatile, +} + +impl From for FFI_Volatility { + fn from(value: Volatility) -> Self { + match value { + Volatility::Immutable => Self::Immutable, + Volatility::Stable => Self::Stable, + Volatility::Volatile => Self::Volatile, + } + } +} + +impl From<&FFI_Volatility> for Volatility { + fn from(value: &FFI_Volatility) -> Self { + match value { + FFI_Volatility::Immutable => Self::Immutable, + FFI_Volatility::Stable => Self::Stable, + FFI_Volatility::Volatile => Self::Volatile, + } + } +} diff --git a/datafusion/ffi/tests/table_provider.rs b/datafusion/ffi/tests/ffi_integration.rs similarity index 68% rename from datafusion/ffi/tests/table_provider.rs rename to datafusion/ffi/tests/ffi_integration.rs index 9169c9f4221c..84e120df4299 100644 --- a/datafusion/ffi/tests/table_provider.rs +++ b/datafusion/ffi/tests/ffi_integration.rs @@ -21,10 +21,13 @@ mod tests { use abi_stable::library::RootModule; + use datafusion::common::record_batch; use datafusion::error::{DataFusionError, Result}; - use datafusion::prelude::SessionContext; + use datafusion::logical_expr::ScalarUDF; + use datafusion::prelude::{col, SessionContext}; use datafusion_ffi::table_provider::ForeignTableProvider; - use datafusion_ffi::tests::TableProviderModuleRef; + use datafusion_ffi::tests::{create_record_batch, ForeignLibraryModuleRef}; + use datafusion_ffi::udf::ForeignScalarUDF; use std::path::Path; use std::sync::Arc; @@ -61,11 +64,7 @@ mod tests { Ok(best_path) } - /// It is important that this test is in the `tests` directory and not in the - /// library directory so we can verify we are building a dynamic library and - /// testing it via a different executable. - #[cfg(feature = "integration-tests")] - async fn test_table_provider(synchronous: bool) -> Result<()> { + fn get_module() -> Result { let expected_version = datafusion_ffi::version(); let crate_root = Path::new(env!("CARGO_MANIFEST_DIR")); @@ -80,22 +79,30 @@ mod tests { // so you will need to change the approach here based on your use case. // let target: &std::path::Path = "../../../../target/".as_ref(); let library_path = - compute_library_path::(target_dir.as_path()) + compute_library_path::(target_dir.as_path()) .map_err(|e| DataFusionError::External(Box::new(e)))? .join("deps"); // Load the module - let table_provider_module = - TableProviderModuleRef::load_from_directory(&library_path) - .map_err(|e| DataFusionError::External(Box::new(e)))?; + let module = ForeignLibraryModuleRef::load_from_directory(&library_path) + .map_err(|e| DataFusionError::External(Box::new(e)))?; assert_eq!( - table_provider_module + module .version() .expect("Unable to call version on FFI module")(), expected_version ); + Ok(module) + } + + /// It is important that this test is in the `tests` directory and not in the + /// library directory so we can verify we are building a dynamic library and + /// testing it via a different executable. + async fn test_table_provider(synchronous: bool) -> Result<()> { + let table_provider_module = get_module()?; + // By calling the code below, the table provided will be created within // the module's code. let ffi_table_provider = table_provider_module.create_table().ok_or( @@ -116,9 +123,9 @@ mod tests { let results = df.collect().await?; assert_eq!(results.len(), 3); - assert_eq!(results[0], datafusion_ffi::tests::create_record_batch(1, 5)); - assert_eq!(results[1], datafusion_ffi::tests::create_record_batch(6, 1)); - assert_eq!(results[2], datafusion_ffi::tests::create_record_batch(7, 5)); + assert_eq!(results[0], create_record_batch(1, 5)); + assert_eq!(results[1], create_record_batch(6, 1)); + assert_eq!(results[2], create_record_batch(7, 5)); Ok(()) } @@ -132,4 +139,44 @@ mod tests { async fn sync_test_table_provider() -> Result<()> { test_table_provider(true).await } + + /// This test validates that we can load an external module and use a scalar + /// udf defined in it via the foreign function interface. In this case we are + /// using the abs() function as our scalar UDF. + #[tokio::test] + async fn test_scalar_udf() -> Result<()> { + let module = get_module()?; + + let ffi_abs_func = + module + .create_scalar_udf() + .ok_or(DataFusionError::NotImplemented( + "External table provider failed to implement create_scalar_udf" + .to_string(), + ))?(); + let foreign_abs_func: ForeignScalarUDF = (&ffi_abs_func).try_into()?; + + let udf: ScalarUDF = foreign_abs_func.into(); + + let ctx = SessionContext::default(); + let df = ctx.read_batch(create_record_batch(-5, 5))?; + + let df = df + .with_column("abs_a", udf.call(vec![col("a")]))? + .with_column("abs_b", udf.call(vec![col("b")]))?; + + let result = df.collect().await?; + + let expected = record_batch!( + ("a", Int32, vec![-5, -4, -3, -2, -1]), + ("b", Float64, vec![-5., -4., -3., -2., -1.]), + ("abs_a", Int32, vec![5, 4, 3, 2, 1]), + ("abs_b", Float64, vec![5., 4., 3., 2., 1.]) + )?; + + assert!(result.len() == 1); + assert!(result[0] == expected); + + Ok(()) + } } diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index 78e22011b61a..007e1e76a3be 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -40,7 +40,6 @@ path = "src/lib.rs" [dependencies] ahash = { workspace = true } arrow = { workspace = true } -arrow-schema = { workspace = true } datafusion-common = { workspace = true } datafusion-doc = { workspace = true } datafusion-execution = { workspace = true } diff --git a/datafusion/functions-aggregate/benches/array_agg.rs b/datafusion/functions-aggregate/benches/array_agg.rs index c9792d541a4f..fb605e87ed0c 100644 --- a/datafusion/functions-aggregate/benches/array_agg.rs +++ b/datafusion/functions-aggregate/benches/array_agg.rs @@ -20,9 +20,8 @@ use std::sync::Arc; use arrow::array::{ Array, ArrayRef, ArrowPrimitiveType, AsArray, ListArray, NullBufferBuilder, }; -use arrow::datatypes::Int64Type; +use arrow::datatypes::{Field, Int64Type}; use arrow::util::bench_util::create_primitive_array; -use arrow_schema::Field; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::Accumulator; use datafusion_functions_aggregate::array_agg::ArrayAggAccumulator; diff --git a/datafusion/functions-aggregate/benches/count.rs b/datafusion/functions-aggregate/benches/count.rs index e6b62e6e1856..8bde7d04c44d 100644 --- a/datafusion/functions-aggregate/benches/count.rs +++ b/datafusion/functions-aggregate/benches/count.rs @@ -16,9 +16,8 @@ // under the License. use arrow::array::{ArrayRef, BooleanArray}; -use arrow::datatypes::Int32Type; +use arrow::datatypes::{DataType, Field, Int32Type, Schema}; use arrow::util::bench_util::{create_boolean_array, create_primitive_array}; -use arrow_schema::{DataType, Field, Schema}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator}; use datafusion_functions_aggregate::count::Count; diff --git a/datafusion/functions-aggregate/benches/sum.rs b/datafusion/functions-aggregate/benches/sum.rs index 1c180126a313..fab53ae94b25 100644 --- a/datafusion/functions-aggregate/benches/sum.rs +++ b/datafusion/functions-aggregate/benches/sum.rs @@ -16,9 +16,8 @@ // under the License. use arrow::array::{ArrayRef, BooleanArray}; -use arrow::datatypes::Int64Type; +use arrow::datatypes::{DataType, Field, Int64Type, Schema}; use arrow::util::bench_util::{create_boolean_array, create_primitive_array}; -use arrow_schema::{DataType, Field, Schema}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator}; use datafusion_functions_aggregate::sum::Sum; diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index 5d174a752296..787e08bae286 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -20,8 +20,8 @@ use std::any::Any; use std::fmt::Debug; -use arrow::{datatypes::DataType, datatypes::Field}; -use arrow_schema::DataType::{Float64, UInt64}; +use arrow::datatypes::DataType::{Float64, UInt64}; +use arrow::datatypes::{DataType, Field}; use datafusion_common::{not_impl_err, plan_err, Result}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 000c69d9f331..1fad5f73703c 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -27,13 +27,12 @@ use arrow::{ ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }, - datatypes::DataType, + datatypes::{DataType, Field, Schema}, }; -use arrow_schema::{Field, Schema}; use datafusion_common::{ downcast_value, internal_err, not_impl_datafusion_err, not_impl_err, plan_err, - DataFusionError, Result, ScalarValue, + Result, ScalarValue, }; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; @@ -491,7 +490,7 @@ impl Accumulator for ApproxPercentileAccumulator { #[cfg(test)] mod tests { - use arrow_schema::DataType; + use arrow::datatypes::DataType; use datafusion_functions_aggregate_common::tdigest::TDigest; diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index 9fff05999122..0f12ac34bfd2 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -18,9 +18,8 @@ //! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`] use arrow::array::{new_empty_array, Array, ArrayRef, AsArray, ListArray, StructArray}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field, Fields}; -use arrow_schema::{Field, Fields}; use datafusion_common::cast::as_list_array; use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder}; use datafusion_common::{exec_err, ScalarValue}; @@ -604,7 +603,7 @@ mod tests { use std::sync::Arc; use arrow::array::Int64Array; - use arrow_schema::SortOptions; + use arrow::compute::SortOptions; use datafusion_common::utils::get_row_at_idx; use datafusion_common::{Result, ScalarValue}; diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 18874f831e9d..141771b0412f 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -27,7 +27,9 @@ use arrow::datatypes::{ i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, Field, Float64Type, UInt64Type, }; -use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; +use datafusion_common::{ + exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue, +}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::{avg_return_type, coerce_avg_type}; use datafusion_expr::utils::format_state_name; @@ -247,10 +249,8 @@ impl AggregateUDFImpl for Avg { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() != 1 { - return exec_err!("{} expects exactly one argument.", self.name()); - } - coerce_avg_type(self.name(), arg_types) + let [args] = take_function_args(self.name(), arg_types)?; + coerce_avg_type(self.name(), std::slice::from_ref(args)) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions-aggregate/src/bit_and_or_xor.rs b/datafusion/functions-aggregate/src/bit_and_or_xor.rs index 6298071a223b..6319a9c07dd2 100644 --- a/datafusion/functions-aggregate/src/bit_and_or_xor.rs +++ b/datafusion/functions-aggregate/src/bit_and_or_xor.rs @@ -25,10 +25,9 @@ use std::mem::{size_of, size_of_val}; use ahash::RandomState; use arrow::array::{downcast_integer, Array, ArrayRef, AsArray}; use arrow::datatypes::{ - ArrowNativeType, ArrowNumericType, DataType, Int16Type, Int32Type, Int64Type, + ArrowNativeType, ArrowNumericType, DataType, Field, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; -use arrow_schema::Field; use datafusion_common::cast::as_list_array; use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; diff --git a/datafusion/functions-aggregate/src/bool_and_or.rs b/datafusion/functions-aggregate/src/bool_and_or.rs index 29dfc68e0576..1b33a7900c00 100644 --- a/datafusion/functions-aggregate/src/bool_and_or.rs +++ b/datafusion/functions-aggregate/src/bool_and_or.rs @@ -29,7 +29,7 @@ use arrow::datatypes::Field; use datafusion_common::internal_err; use datafusion_common::{downcast_value, not_impl_err}; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index fa04e1aca2c9..cb59042ef468 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -44,7 +44,7 @@ use arrow::{ buffer::BooleanBuffer, }; use datafusion_common::{ - downcast_value, internal_err, not_impl_err, DataFusionError, Result, ScalarValue, + downcast_value, internal_err, not_impl_err, Result, ScalarValue, }; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::{ diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index da5ec739ad8d..90fb46883de6 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -33,10 +33,9 @@ use arrow::array::{ use arrow::compute; use arrow::datatypes::{ DataType, Decimal128Type, Decimal256Type, Float16Type, Float32Type, Float64Type, - Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, - UInt8Type, + Int16Type, Int32Type, Int64Type, Int8Type, IntervalUnit, UInt16Type, UInt32Type, + UInt64Type, UInt8Type, }; -use arrow_schema::IntervalUnit; use datafusion_common::stats::Precision; use datafusion_common::{ downcast_value, exec_err, internal_err, ColumnStatistics, DataFusionError, Result, diff --git a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs index 725b7a29bd47..05321c2ff52d 100644 --- a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs +++ b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs @@ -19,7 +19,7 @@ use arrow::array::{ Array, ArrayRef, AsArray, BinaryBuilder, BinaryViewBuilder, BooleanArray, LargeBinaryBuilder, LargeStringBuilder, StringBuilder, StringViewBuilder, }; -use arrow_schema::DataType; +use arrow::datatypes::DataType; use datafusion_common::{internal_err, Result}; use datafusion_expr::{EmitTo, GroupsAccumulator}; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::apply_filter_as_nulls; diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index 8252fd6baaa3..d84bd02a6baf 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -24,7 +24,7 @@ use std::mem::{size_of, size_of_val}; use std::sync::Arc; use arrow::array::{new_empty_array, ArrayRef, AsArray, StructArray}; -use arrow_schema::{DataType, Field, Fields}; +use arrow::datatypes::{DataType, Field, Fields}; use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder}; use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue}; diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index 0cd403cff428..64314ef6df68 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -18,7 +18,7 @@ //! [`StringAgg`] accumulator for the `string_agg` function use arrow::array::ArrayRef; -use arrow_schema::DataType; +use arrow::datatypes::DataType; use datafusion_common::cast::as_generic_string_array; use datafusion_common::Result; use datafusion_common::{not_impl_err, ScalarValue}; diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index 9615ca33a5f3..76a1315c2d88 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -33,7 +33,9 @@ use arrow::datatypes::{ DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; use arrow::{array::ArrayRef, datatypes::Field}; -use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; +use datafusion_common::{ + exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue, +}; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::utils::format_state_name; @@ -125,9 +127,7 @@ impl AggregateUDFImpl for Sum { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() != 1 { - return exec_err!("SUM expects exactly one argument"); - } + let [args] = take_function_args(self.name(), arg_types)?; // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc // smallint, int, bigint, real, double precision, decimal, or interval. @@ -147,7 +147,7 @@ impl AggregateUDFImpl for Sum { } } - Ok(vec![coerced_type(&arg_types[0])?]) + Ok(vec![coerced_type(args)?]) } fn return_type(&self, arg_types: &[DataType]) -> Result { diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 8aa7a40ce320..53e3e0cc56cd 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -27,9 +27,7 @@ use arrow::{ use std::mem::{size_of, size_of_val}; use std::{fmt::Debug, sync::Arc}; -use datafusion_common::{ - downcast_value, not_impl_err, plan_err, DataFusionError, Result, ScalarValue, -}; +use datafusion_common::{downcast_value, not_impl_err, plan_err, Result, ScalarValue}; use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, utils::format_state_name, diff --git a/datafusion/functions-nested/Cargo.toml b/datafusion/functions-nested/Cargo.toml index 7835985b297f..a63175b36e21 100644 --- a/datafusion/functions-nested/Cargo.toml +++ b/datafusion/functions-nested/Cargo.toml @@ -42,7 +42,6 @@ path = "src/lib.rs" [dependencies] arrow = { workspace = true } arrow-ord = { workspace = true } -arrow-schema = { workspace = true } datafusion-common = { workspace = true } datafusion-doc = { workspace = true } datafusion-execution = { workspace = true } diff --git a/datafusion/functions-nested/benches/map.rs b/datafusion/functions-nested/benches/map.rs index e60f7f388ac1..3726cac0752e 100644 --- a/datafusion/functions-nested/benches/map.rs +++ b/datafusion/functions-nested/benches/map.rs @@ -19,7 +19,7 @@ extern crate criterion; use arrow::array::{Int32Array, ListArray, StringArray}; use arrow::buffer::{OffsetBuffer, ScalarBuffer}; -use arrow_schema::{DataType, Field}; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use rand::prelude::ThreadRng; use rand::Rng; diff --git a/datafusion/functions-nested/src/array_has.rs b/datafusion/functions-nested/src/array_has.rs index 5c694600b822..5a29cf962817 100644 --- a/datafusion/functions-nested/src/array_has.rs +++ b/datafusion/functions-nested/src/array_has.rs @@ -25,6 +25,7 @@ use arrow::datatypes::DataType; use arrow::row::{RowConverter, Rows, SortField}; use datafusion_common::cast::as_generic_list_array; use datafusion_common::utils::string_utils::string_array_to_vec; +use datafusion_common::utils::take_function_args; use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, @@ -120,15 +121,15 @@ impl ScalarUDFImpl for ArrayHas { Ok(DataType::Boolean) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - match &args[1] { + let [first_arg, second_arg] = take_function_args(self.name(), &args.args)?; + match &second_arg { ColumnarValue::Array(array_needle) => { // the needle is already an array, convert the haystack to an array of the same length - let haystack = args[0].to_array(array_needle.len())?; + let haystack = first_arg.to_array(array_needle.len())?; let array = array_has_inner_for_array(&haystack, array_needle)?; Ok(ColumnarValue::Array(array)) } @@ -140,11 +141,11 @@ impl ScalarUDFImpl for ArrayHas { } // since the needle is a scalar, convert it to an array of size 1 - let haystack = args[0].to_array(1)?; + let haystack = first_arg.to_array(1)?; let needle = scalar_needle.to_array_of_size(1)?; let needle = Scalar::new(needle); let array = array_has_inner_for_scalar(&haystack, &needle)?; - if let ColumnarValue::Scalar(_) = &args[0] { + if let ColumnarValue::Scalar(_) = &first_arg { // If both inputs are scalar, keeps output as scalar let scalar_value = ScalarValue::try_from_array(&array, 0)?; Ok(ColumnarValue::Scalar(scalar_value)) @@ -332,12 +333,11 @@ impl ScalarUDFImpl for ArrayHasAll { Ok(DataType::Boolean) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_has_all_inner)(args) + make_scalar_function(array_has_all_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -407,12 +407,11 @@ impl ScalarUDFImpl for ArrayHasAny { Ok(DataType::Boolean) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_has_any_inner)(args) + make_scalar_function(array_has_any_inner)(&args.args) } fn aliases(&self) -> &[String] { diff --git a/datafusion/functions-nested/src/cardinality.rs b/datafusion/functions-nested/src/cardinality.rs index 21ab9fb35982..f2f23841586c 100644 --- a/datafusion/functions-nested/src/cardinality.rs +++ b/datafusion/functions-nested/src/cardinality.rs @@ -21,14 +21,17 @@ use crate::utils::make_scalar_function; use arrow::array::{ Array, ArrayRef, GenericListArray, MapArray, OffsetSizeTrait, UInt64Array, }; -use arrow_schema::DataType; -use arrow_schema::DataType::{FixedSizeList, LargeList, List, Map, UInt64}; +use arrow::datatypes::{ + DataType, + DataType::{FixedSizeList, LargeList, List, Map, UInt64}, +}; use datafusion_common::cast::{as_large_list_array, as_list_array, as_map_array}; +use datafusion_common::utils::take_function_args; use datafusion_common::Result; use datafusion_common::{exec_err, plan_err}; use datafusion_expr::{ - ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, - TypeSignature, Volatility, + ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, + ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_macros::user_doc; use std::any::Any; @@ -47,7 +50,10 @@ impl Cardinality { Self { signature: Signature::one_of( vec![ - TypeSignature::ArraySignature(ArrayFunctionSignature::Array), + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array], + array_coercion: None, + }), TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray), ], Volatility::Immutable, @@ -106,12 +112,11 @@ impl ScalarUDFImpl for Cardinality { }) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(cardinality_inner)(args) + make_scalar_function(cardinality_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -125,21 +130,18 @@ impl ScalarUDFImpl for Cardinality { /// Cardinality SQL function pub fn cardinality_inner(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("cardinality expects one argument"); - } - - match &args[0].data_type() { + let [array] = take_function_args("cardinality", args)?; + match &array.data_type() { List(_) => { - let list_array = as_list_array(&args[0])?; + let list_array = as_list_array(&array)?; generic_list_cardinality::(list_array) } LargeList(_) => { - let list_array = as_large_list_array(&args[0])?; + let list_array = as_large_list_array(&array)?; generic_list_cardinality::(list_array) } Map(_, _) => { - let map_array = as_map_array(&args[0])?; + let map_array = as_map_array(&array)?; generic_map_cardinality(map_array) } other => { diff --git a/datafusion/functions-nested/src/concat.rs b/datafusion/functions-nested/src/concat.rs index 723dab9a76b7..f4b9208e5c83 100644 --- a/datafusion/functions-nested/src/concat.rs +++ b/datafusion/functions-nested/src/concat.rs @@ -25,13 +25,17 @@ use arrow::array::{ OffsetSizeTrait, }; use arrow::buffer::OffsetBuffer; -use arrow_schema::{DataType, Field}; +use arrow::datatypes::{DataType, Field}; +use datafusion_common::utils::ListCoercion; use datafusion_common::Result; use datafusion_common::{ - cast::as_generic_list_array, exec_err, not_impl_err, plan_err, utils::list_ndims, + cast::as_generic_list_array, + exec_err, not_impl_err, plan_err, + utils::{list_ndims, take_function_args}, }; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, + ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_macros::user_doc; @@ -105,12 +109,11 @@ impl ScalarUDFImpl for ArrayAppend { Ok(arg_types[0].clone()) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_append_inner)(args) + make_scalar_function(array_append_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -163,7 +166,18 @@ impl Default for ArrayPrepend { impl ArrayPrepend { pub fn new() -> Self { Self { - signature: Signature::element_and_array(Volatility::Immutable), + signature: Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Element, + ArrayFunctionArgument::Array, + ], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }, + ), + volatility: Volatility::Immutable, + }, aliases: vec![ String::from("list_prepend"), String::from("array_push_front"), @@ -190,12 +204,11 @@ impl ScalarUDFImpl for ArrayPrepend { Ok(arg_types[1].clone()) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_prepend_inner)(args) + make_scalar_function(array_prepend_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -309,12 +322,11 @@ impl ScalarUDFImpl for ArrayConcat { Ok(expr_type) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_concat_inner)(args) + make_scalar_function(array_concat_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -415,11 +427,9 @@ fn concat_internal(args: &[ArrayRef]) -> Result { /// Array_append SQL function pub(crate) fn array_append_inner(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_append expects two arguments"); - } + let [array, _] = take_function_args("array_append", args)?; - match args[0].data_type() { + match array.data_type() { DataType::LargeList(_) => general_append_and_prepend::(args, true), _ => general_append_and_prepend::(args, true), } @@ -427,11 +437,9 @@ pub(crate) fn array_append_inner(args: &[ArrayRef]) -> Result { /// Array_prepend SQL function pub(crate) fn array_prepend_inner(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_prepend expects two arguments"); - } + let [_, array] = take_function_args("array_prepend", args)?; - match args[1].data_type() { + match array.data_type() { DataType::LargeList(_) => general_append_and_prepend::(args, false), _ => general_append_and_prepend::(args, false), } @@ -457,8 +465,8 @@ where }; let res = match list_array.value_type() { - DataType::List(_) => concat_internal::(args)?, - DataType::LargeList(_) => concat_internal::(args)?, + DataType::List(_) => concat_internal::(args)?, + DataType::LargeList(_) => concat_internal::(args)?, data_type => { return generic_append_and_prepend::( list_array, diff --git a/datafusion/functions-nested/src/dimension.rs b/datafusion/functions-nested/src/dimension.rs index 702d0fc3a77d..a7d033641413 100644 --- a/datafusion/functions-nested/src/dimension.rs +++ b/datafusion/functions-nested/src/dimension.rs @@ -20,15 +20,17 @@ use arrow::array::{ Array, ArrayRef, GenericListArray, ListArray, OffsetSizeTrait, UInt64Array, }; -use arrow::datatypes::{DataType, UInt64Type}; +use arrow::datatypes::{ + DataType, + DataType::{FixedSizeList, LargeList, List, UInt64}, + Field, UInt64Type, +}; use std::any::Any; use datafusion_common::cast::{as_large_list_array, as_list_array}; -use datafusion_common::{exec_err, plan_err, Result}; +use datafusion_common::{exec_err, plan_err, utils::take_function_args, Result}; use crate::utils::{compute_array_dims, make_scalar_function}; -use arrow_schema::DataType::{FixedSizeList, LargeList, List, UInt64}; -use arrow_schema::Field; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -104,12 +106,11 @@ impl ScalarUDFImpl for ArrayDims { }) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_dims_inner)(args) + make_scalar_function(array_dims_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -182,12 +183,11 @@ impl ScalarUDFImpl for ArrayNdims { }) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_ndims_inner)(args) + make_scalar_function(array_ndims_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -201,20 +201,18 @@ impl ScalarUDFImpl for ArrayNdims { /// Array_dims SQL function pub fn array_dims_inner(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("array_dims needs one argument"); - } + let [array] = take_function_args("array_dims", args)?; - let data = match args[0].data_type() { + let data = match array.data_type() { List(_) => { - let array = as_list_array(&args[0])?; + let array = as_list_array(&array)?; array .iter() .map(compute_array_dims) .collect::>>()? } LargeList(_) => { - let array = as_large_list_array(&args[0])?; + let array = as_large_list_array(&array)?; array .iter() .map(compute_array_dims) @@ -232,9 +230,7 @@ pub fn array_dims_inner(args: &[ArrayRef]) -> Result { /// Array_ndims SQL function pub fn array_ndims_inner(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("array_ndims needs one argument"); - } + let [array_dim] = take_function_args("array_ndims", args)?; fn general_list_ndims( array: &GenericListArray, @@ -252,13 +248,13 @@ pub fn array_ndims_inner(args: &[ArrayRef]) -> Result { Ok(Arc::new(UInt64Array::from(data)) as ArrayRef) } - match args[0].data_type() { + match array_dim.data_type() { List(_) => { - let array = as_list_array(&args[0])?; + let array = as_list_array(&array_dim)?; general_list_ndims::(array) } LargeList(_) => { - let array = as_large_list_array(&args[0])?; + let array = as_large_list_array(&array_dim)?; general_list_ndims::(array) } array_type => exec_err!("array_ndims does not support type {array_type:?}"), diff --git a/datafusion/functions-nested/src/distance.rs b/datafusion/functions-nested/src/distance.rs index 6a5d6b4fa0ff..cfc7fccdd70c 100644 --- a/datafusion/functions-nested/src/distance.rs +++ b/datafusion/functions-nested/src/distance.rs @@ -21,14 +21,18 @@ use crate::utils::make_scalar_function; use arrow::array::{ Array, ArrayRef, Float64Array, LargeListArray, ListArray, OffsetSizeTrait, }; -use arrow_schema::DataType; -use arrow_schema::DataType::{FixedSizeList, Float64, LargeList, List}; +use arrow::datatypes::{ + DataType, + DataType::{FixedSizeList, Float64, LargeList, List}, +}; use datafusion_common::cast::{ as_float32_array, as_float64_array, as_generic_list_array, as_int32_array, as_int64_array, }; use datafusion_common::utils::coerced_fixed_size_list_to_list; -use datafusion_common::{exec_err, internal_datafusion_err, Result}; +use datafusion_common::{ + exec_err, internal_datafusion_err, utils::take_function_args, Result, +}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -108,9 +112,7 @@ impl ScalarUDFImpl for ArrayDistance { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() != 2 { - return exec_err!("array_distance expects exactly two arguments"); - } + let [_, _] = take_function_args(self.name(), arg_types)?; let mut result = Vec::new(); for arg_type in arg_types { match arg_type { @@ -122,12 +124,11 @@ impl ScalarUDFImpl for ArrayDistance { Ok(result) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_distance_inner)(args) + make_scalar_function(array_distance_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -140,11 +141,9 @@ impl ScalarUDFImpl for ArrayDistance { } pub fn array_distance_inner(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_distance expects exactly two arguments"); - } + let [array1, array2] = take_function_args("array_distance", args)?; - match (&args[0].data_type(), &args[1].data_type()) { + match (&array1.data_type(), &array2.data_type()) { (List(_), List(_)) => general_array_distance::(args), (LargeList(_), LargeList(_)) => general_array_distance::(args), (array_type1, array_type2) => { diff --git a/datafusion/functions-nested/src/empty.rs b/datafusion/functions-nested/src/empty.rs index b5e2df6f8952..dcefd583e937 100644 --- a/datafusion/functions-nested/src/empty.rs +++ b/datafusion/functions-nested/src/empty.rs @@ -19,10 +19,12 @@ use crate::utils::make_scalar_function; use arrow::array::{ArrayRef, BooleanArray, OffsetSizeTrait}; -use arrow_schema::DataType; -use arrow_schema::DataType::{Boolean, FixedSizeList, LargeList, List}; +use arrow::datatypes::{ + DataType, + DataType::{Boolean, FixedSizeList, LargeList, List}, +}; use datafusion_common::cast::as_generic_list_array; -use datafusion_common::{exec_err, plan_err, Result}; +use datafusion_common::{exec_err, plan_err, utils::take_function_args, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -96,12 +98,11 @@ impl ScalarUDFImpl for ArrayEmpty { }) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_empty_inner)(args) + make_scalar_function(array_empty_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -115,14 +116,12 @@ impl ScalarUDFImpl for ArrayEmpty { /// Array_empty SQL function pub fn array_empty_inner(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("array_empty expects one argument"); - } + let [array] = take_function_args("array_empty", args)?; - let array_type = args[0].data_type(); + let array_type = array.data_type(); match array_type { - List(_) => general_array_empty::(&args[0]), - LargeList(_) => general_array_empty::(&args[0]), + List(_) => general_array_empty::(array), + LargeList(_) => general_array_empty::(array), _ => exec_err!("array_empty does not support type '{array_type:?}'."), } } diff --git a/datafusion/functions-nested/src/except.rs b/datafusion/functions-nested/src/except.rs index 79e2c0f23ce3..2385f6d12d43 100644 --- a/datafusion/functions-nested/src/except.rs +++ b/datafusion/functions-nested/src/except.rs @@ -20,9 +20,10 @@ use crate::utils::{check_datatypes, make_scalar_function}; use arrow::array::{cast::AsArray, Array, ArrayRef, GenericListArray, OffsetSizeTrait}; use arrow::buffer::OffsetBuffer; +use arrow::datatypes::{DataType, FieldRef}; use arrow::row::{RowConverter, SortField}; -use arrow_schema::{DataType, FieldRef}; -use datafusion_common::{exec_err, internal_err, HashSet, Result}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{internal_err, HashSet, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -105,12 +106,11 @@ impl ScalarUDFImpl for ArrayExcept { } } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_except_inner)(args) + make_scalar_function(array_except_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -124,12 +124,7 @@ impl ScalarUDFImpl for ArrayExcept { /// Array_except SQL function pub fn array_except_inner(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_except needs two arguments"); - } - - let array1 = &args[0]; - let array2 = &args[1]; + let [array1, array2] = take_function_args("array_except", args)?; match (array1.data_type(), array2.data_type()) { (DataType::Null, _) | (_, DataType::Null) => Ok(array1.to_owned()), diff --git a/datafusion/functions-nested/src/extract.rs b/datafusion/functions-nested/src/extract.rs index 2f59dccad94a..422b1b612850 100644 --- a/datafusion/functions-nested/src/extract.rs +++ b/datafusion/functions-nested/src/extract.rs @@ -23,21 +23,26 @@ use arrow::array::{ }; use arrow::buffer::OffsetBuffer; use arrow::datatypes::DataType; -use arrow_schema::DataType::{FixedSizeList, LargeList, List}; -use arrow_schema::Field; +use arrow::datatypes::{ + DataType::{FixedSizeList, LargeList, List}, + Field, +}; use datafusion_common::cast::as_int64_array; use datafusion_common::cast::as_large_list_array; use datafusion_common::cast::as_list_array; +use datafusion_common::utils::ListCoercion; use datafusion_common::{ - exec_err, internal_datafusion_err, plan_err, DataFusionError, Result, + exec_err, internal_datafusion_err, plan_err, utils::take_function_args, + DataFusionError, Result, +}; +use datafusion_expr::{ + ArrayFunctionArgument, ArrayFunctionSignature, Expr, TypeSignature, }; -use datafusion_expr::{ArrayFunctionSignature, Expr, TypeSignature}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; use std::any::Any; -use std::num::NonZeroUsize; use std::sync::Arc; use crate::utils::make_scalar_function; @@ -167,12 +172,11 @@ impl ScalarUDFImpl for ArrayElement { } } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_element_inner)(args) + make_scalar_function(array_element_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -192,24 +196,22 @@ impl ScalarUDFImpl for ArrayElement { /// For example: /// > array_element(\[1, 2, 3], 2) -> 2 fn array_element_inner(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_element needs two arguments"); - } + let [array, indexes] = take_function_args("array_element", args)?; - match &args[0].data_type() { + match &array.data_type() { List(_) => { - let array = as_list_array(&args[0])?; - let indexes = as_int64_array(&args[1])?; + let array = as_list_array(&array)?; + let indexes = as_int64_array(&indexes)?; general_array_element::(array, indexes) } LargeList(_) => { - let array = as_large_list_array(&args[0])?; - let indexes = as_int64_array(&args[1])?; + let array = as_large_list_array(&array)?; + let indexes = as_int64_array(&indexes)?; general_array_element::(array, indexes) } _ => exec_err!( "array_element does not support type: {:?}", - args[0].data_type() + array.data_type() ), } } @@ -329,16 +331,23 @@ impl ArraySlice { Self { signature: Signature::one_of( vec![ - TypeSignature::ArraySignature( - ArrayFunctionSignature::ArrayAndIndexes( - NonZeroUsize::new(2).expect("2 is non-zero"), - ), - ), - TypeSignature::ArraySignature( - ArrayFunctionSignature::ArrayAndIndexes( - NonZeroUsize::new(3).expect("3 is non-zero"), - ), - ), + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Index, + ArrayFunctionArgument::Index, + ], + array_coercion: None, + }), + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Index, + ArrayFunctionArgument::Index, + ArrayFunctionArgument::Index, + ], + array_coercion: None, + }), ], Volatility::Immutable, ), @@ -385,12 +394,11 @@ impl ScalarUDFImpl for ArraySlice { Ok(arg_types[0].clone()) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_slice_inner)(args) + make_scalar_function(array_slice_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -664,7 +672,15 @@ pub(super) struct ArrayPopFront { impl ArrayPopFront { pub fn new() -> Self { Self { - signature: Signature::array(Volatility::Immutable), + signature: Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }, + ), + volatility: Volatility::Immutable, + }, aliases: vec![String::from("list_pop_front")], } } @@ -686,12 +702,11 @@ impl ScalarUDFImpl for ArrayPopFront { Ok(arg_types[0].clone()) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_pop_front_inner)(args) + make_scalar_function(array_pop_front_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -764,7 +779,15 @@ pub(super) struct ArrayPopBack { impl ArrayPopBack { pub fn new() -> Self { Self { - signature: Signature::array(Volatility::Immutable), + signature: Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }, + ), + volatility: Volatility::Immutable, + }, aliases: vec![String::from("list_pop_back")], } } @@ -786,12 +809,11 @@ impl ScalarUDFImpl for ArrayPopBack { Ok(arg_types[0].clone()) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_pop_back_inner)(args) + make_scalar_function(array_pop_back_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -805,23 +827,20 @@ impl ScalarUDFImpl for ArrayPopBack { /// array_pop_back SQL function fn array_pop_back_inner(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("array_pop_back needs one argument"); - } + let [array] = take_function_args("array_pop_back", args)?; - let array_data_type = args[0].data_type(); - match array_data_type { + match array.data_type() { List(_) => { - let array = as_list_array(&args[0])?; + let array = as_list_array(&array)?; general_pop_back_list::(array) } LargeList(_) => { - let array = as_large_list_array(&args[0])?; + let array = as_large_list_array(&array)?; general_pop_back_list::(array) } _ => exec_err!( "array_pop_back does not support type: {:?}", - array_data_type + array.data_type() ), } } @@ -895,13 +914,13 @@ impl ScalarUDFImpl for ArrayAnyValue { } } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_any_value_inner)(args) + make_scalar_function(array_any_value_inner)(&args.args) } + fn aliases(&self) -> &[String] { &self.aliases } @@ -912,17 +931,15 @@ impl ScalarUDFImpl for ArrayAnyValue { } fn array_any_value_inner(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("array_any_value expects one argument"); - } + let [array] = take_function_args("array_any_value", args)?; - match &args[0].data_type() { + match &array.data_type() { List(_) => { - let array = as_list_array(&args[0])?; + let array = as_list_array(&array)?; general_array_any_value::(array) } LargeList(_) => { - let array = as_large_list_array(&args[0])?; + let array = as_large_list_array(&array)?; general_array_any_value::(array) } data_type => exec_err!("array_any_value does not support type: {:?}", data_type), @@ -982,7 +999,7 @@ where #[cfg(test)] mod tests { use super::array_element_udf; - use arrow_schema::{DataType, Field}; + use arrow::datatypes::{DataType, Field}; use datafusion_common::{Column, DFSchema, ScalarValue}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{cast, Expr, ExprSchemable}; diff --git a/datafusion/functions-nested/src/flatten.rs b/datafusion/functions-nested/src/flatten.rs index 4fd14c79644b..f288035948dc 100644 --- a/datafusion/functions-nested/src/flatten.rs +++ b/datafusion/functions-nested/src/flatten.rs @@ -20,12 +20,14 @@ use crate::utils::make_scalar_function; use arrow::array::{ArrayRef, GenericListArray, OffsetSizeTrait}; use arrow::buffer::OffsetBuffer; -use arrow_schema::DataType; -use arrow_schema::DataType::{FixedSizeList, LargeList, List, Null}; +use arrow::datatypes::{ + DataType, + DataType::{FixedSizeList, LargeList, List, Null}, +}; use datafusion_common::cast::{ as_generic_list_array, as_large_list_array, as_list_array, }; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{exec_err, utils::take_function_args, Result}; use datafusion_expr::{ ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, Volatility, @@ -122,12 +124,11 @@ impl ScalarUDFImpl for Flatten { Ok(data_type) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(flatten_inner)(args) + make_scalar_function(flatten_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -141,25 +142,22 @@ impl ScalarUDFImpl for Flatten { /// Flatten SQL function pub fn flatten_inner(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("flatten expects one argument"); - } + let [array] = take_function_args("flatten", args)?; - let array_type = args[0].data_type(); - match array_type { + match array.data_type() { List(_) => { - let list_arr = as_list_array(&args[0])?; + let list_arr = as_list_array(&array)?; let flattened_array = flatten_internal::(list_arr.clone(), None)?; Ok(Arc::new(flattened_array) as ArrayRef) } LargeList(_) => { - let list_arr = as_large_list_array(&args[0])?; + let list_arr = as_large_list_array(&array)?; let flattened_array = flatten_internal::(list_arr.clone(), None)?; Ok(Arc::new(flattened_array) as ArrayRef) } - Null => Ok(Arc::clone(&args[0])), + Null => Ok(Arc::clone(array)), _ => { - exec_err!("flatten does not support type '{array_type:?}'") + exec_err!("flatten does not support type '{:?}'", array.data_type()) } } } diff --git a/datafusion/functions-nested/src/length.rs b/datafusion/functions-nested/src/length.rs index 1081a682897f..3c3a42da0d69 100644 --- a/datafusion/functions-nested/src/length.rs +++ b/datafusion/functions-nested/src/length.rs @@ -21,8 +21,10 @@ use crate::utils::make_scalar_function; use arrow::array::{ Array, ArrayRef, Int64Array, LargeListArray, ListArray, OffsetSizeTrait, UInt64Array, }; -use arrow_schema::DataType; -use arrow_schema::DataType::{FixedSizeList, LargeList, List, UInt64}; +use arrow::datatypes::{ + DataType, + DataType::{FixedSizeList, LargeList, List, UInt64}, +}; use datafusion_common::cast::{as_generic_list_array, as_int64_array}; use datafusion_common::{exec_err, internal_datafusion_err, plan_err, Result}; use datafusion_expr::{ @@ -101,12 +103,11 @@ impl ScalarUDFImpl for ArrayLength { }) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_length_inner)(args) + make_scalar_function(array_length_inner)(&args.args) } fn aliases(&self) -> &[String] { diff --git a/datafusion/functions-nested/src/make_array.rs b/datafusion/functions-nested/src/make_array.rs index 8bf5f37b8add..4daaafc5a888 100644 --- a/datafusion/functions-nested/src/make_array.rs +++ b/datafusion/functions-nested/src/make_array.rs @@ -27,8 +27,11 @@ use arrow::array::{ MutableArrayData, NullArray, OffsetSizeTrait, }; use arrow::buffer::OffsetBuffer; -use arrow_schema::DataType::{List, Null}; -use arrow_schema::{DataType, Field}; +use arrow::datatypes::DataType; +use arrow::datatypes::{ + DataType::{List, Null}, + Field, +}; use datafusion_common::utils::SingleRowListArrayBuilder; use datafusion_common::{plan_err, Result}; use datafusion_expr::binary::{ @@ -114,12 +117,11 @@ impl ScalarUDFImpl for MakeArray { } } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(make_array_inner)(args) + make_scalar_function(make_array_inner)(&args.args) } fn aliases(&self) -> &[String] { diff --git a/datafusion/functions-nested/src/map.rs b/datafusion/functions-nested/src/map.rs index 77e06b28a8d6..828f2e244112 100644 --- a/datafusion/functions-nested/src/map.rs +++ b/datafusion/functions-nested/src/map.rs @@ -21,11 +21,12 @@ use std::sync::Arc; use arrow::array::{Array, ArrayData, ArrayRef, MapArray, OffsetSizeTrait, StructArray}; use arrow::buffer::Buffer; -use arrow::datatypes::ToByteSlice; -use arrow_schema::{DataType, Field, SchemaBuilder}; +use arrow::datatypes::{DataType, Field, SchemaBuilder, ToByteSlice}; use datafusion_common::utils::{fixed_size_list_to_arrays, list_to_arrays}; -use datafusion_common::{exec_err, HashSet, Result, ScalarValue}; +use datafusion_common::{ + exec_err, utils::take_function_args, HashSet, Result, ScalarValue, +}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, @@ -56,23 +57,18 @@ fn can_evaluate_to_const(args: &[ColumnarValue]) -> bool { } fn make_map_batch(args: &[ColumnarValue]) -> Result { - if args.len() != 2 { - return exec_err!( - "make_map requires exactly 2 arguments, got {} instead", - args.len() - ); - } + let [keys_arg, values_arg] = take_function_args("make_map", args)?; let can_evaluate_to_const = can_evaluate_to_const(args); // check the keys array is unique - let keys = get_first_array_ref(&args[0])?; + let keys = get_first_array_ref(keys_arg)?; if keys.null_count() > 0 { return exec_err!("map key cannot be null"); } let key_array = keys.as_ref(); - match &args[0] { + match keys_arg { ColumnarValue::Array(_) => { let row_keys = match key_array.data_type() { DataType::List(_) => list_to_arrays::(&keys), @@ -95,8 +91,8 @@ fn make_map_batch(args: &[ColumnarValue]) -> Result { } } - let values = get_first_array_ref(&args[1])?; - make_map_batch_internal(keys, values, can_evaluate_to_const, args[0].data_type()) + let values = get_first_array_ref(values_arg)?; + make_map_batch_internal(keys, values, can_evaluate_to_const, keys_arg.data_type()) } fn check_unique_keys(array: &dyn Array) -> Result<()> { @@ -258,21 +254,16 @@ impl ScalarUDFImpl for MapFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types.len() != 2 { - return exec_err!( - "map requires exactly 2 arguments, got {} instead", - arg_types.len() - ); - } + let [keys_arg, values_arg] = take_function_args(self.name(), arg_types)?; let mut builder = SchemaBuilder::new(); builder.push(Field::new( "key", - get_element_type(&arg_types[0])?.clone(), + get_element_type(keys_arg)?.clone(), false, )); builder.push(Field::new( "value", - get_element_type(&arg_types[1])?.clone(), + get_element_type(values_arg)?.clone(), true, )); let fields = builder.finish().fields; @@ -282,12 +273,11 @@ impl ScalarUDFImpl for MapFunc { )) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_map_batch(args) + make_map_batch(&args.args) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions-nested/src/map_extract.rs b/datafusion/functions-nested/src/map_extract.rs index 47d977a8c01c..55ab8447c54f 100644 --- a/datafusion/functions-nested/src/map_extract.rs +++ b/datafusion/functions-nested/src/map_extract.rs @@ -17,13 +17,13 @@ //! [`ScalarUDFImpl`] definitions for map_extract functions. +use crate::utils::{get_map_entry_field, make_scalar_function}; use arrow::array::{ make_array, Array, ArrayRef, Capacities, ListArray, MapArray, MutableArrayData, }; use arrow::buffer::OffsetBuffer; -use arrow::datatypes::DataType; -use arrow_schema::Field; - +use arrow::datatypes::{DataType, Field}; +use datafusion_common::utils::take_function_args; use datafusion_common::{cast::as_map_array, exec_err, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, @@ -33,8 +33,6 @@ use std::any::Any; use std::sync::Arc; use std::vec; -use crate::utils::{get_map_entry_field, make_scalar_function}; - // Create static instances of ScalarUDFs for each function make_udf_expr_and_func!( MapExtract, @@ -104,10 +102,7 @@ impl ScalarUDFImpl for MapExtract { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types.len() != 2 { - return exec_err!("map_extract expects two arguments"); - } - let map_type = &arg_types[0]; + let [map_type, _] = take_function_args(self.name(), arg_types)?; let map_fields = get_map_entry_field(map_type)?; Ok(DataType::List(Arc::new(Field::new_list_field( map_fields.last().unwrap().data_type().clone(), @@ -115,12 +110,11 @@ impl ScalarUDFImpl for MapExtract { )))) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(map_extract_inner)(args) + make_scalar_function(map_extract_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -128,13 +122,11 @@ impl ScalarUDFImpl for MapExtract { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() != 2 { - return exec_err!("map_extract expects two arguments"); - } + let [map_type, _] = take_function_args(self.name(), arg_types)?; - let field = get_map_entry_field(&arg_types[0])?; + let field = get_map_entry_field(map_type)?; Ok(vec![ - arg_types[0].clone(), + map_type.clone(), field.first().unwrap().data_type().clone(), ]) } @@ -190,24 +182,22 @@ fn general_map_extract_inner( } fn map_extract_inner(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("map_extract expects two arguments"); - } + let [map_arg, key_arg] = take_function_args("map_extract", args)?; - let map_array = match args[0].data_type() { - DataType::Map(_, _) => as_map_array(&args[0])?, + let map_array = match map_arg.data_type() { + DataType::Map(_, _) => as_map_array(&map_arg)?, _ => return exec_err!("The first argument in map_extract must be a map"), }; let key_type = map_array.key_type(); - if key_type != args[1].data_type() { + if key_type != key_arg.data_type() { return exec_err!( "The key type {} does not match the map key type {}", - args[1].data_type(), + key_arg.data_type(), key_type ); } - general_map_extract_inner(map_array, &args[1]) + general_map_extract_inner(map_array, key_arg) } diff --git a/datafusion/functions-nested/src/map_keys.rs b/datafusion/functions-nested/src/map_keys.rs index 60039865daae..0f15c06d86d1 100644 --- a/datafusion/functions-nested/src/map_keys.rs +++ b/datafusion/functions-nested/src/map_keys.rs @@ -19,7 +19,8 @@ use crate::utils::{get_map_entry_field, make_scalar_function}; use arrow::array::{Array, ArrayRef, ListArray}; -use arrow_schema::{DataType, Field}; +use arrow::datatypes::{DataType, Field}; +use datafusion_common::utils::take_function_args; use datafusion_common::{cast::as_map_array, exec_err, Result}; use datafusion_expr::{ ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, @@ -91,10 +92,7 @@ impl ScalarUDFImpl for MapKeysFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types.len() != 1 { - return exec_err!("map_keys expects single argument"); - } - let map_type = &arg_types[0]; + let [map_type] = take_function_args(self.name(), arg_types)?; let map_fields = get_map_entry_field(map_type)?; Ok(DataType::List(Arc::new(Field::new_list_field( map_fields.first().unwrap().data_type().clone(), @@ -102,12 +100,11 @@ impl ScalarUDFImpl for MapKeysFunc { )))) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(map_keys_inner)(args) + make_scalar_function(map_keys_inner)(&args.args) } fn documentation(&self) -> Option<&Documentation> { @@ -116,12 +113,10 @@ impl ScalarUDFImpl for MapKeysFunc { } fn map_keys_inner(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("map_keys expects single argument"); - } + let [map_arg] = take_function_args("map_keys", args)?; - let map_array = match args[0].data_type() { - DataType::Map(_, _) => as_map_array(&args[0])?, + let map_array = match map_arg.data_type() { + DataType::Map(_, _) => as_map_array(&map_arg)?, _ => return exec_err!("Argument for map_keys should be a map"), }; diff --git a/datafusion/functions-nested/src/map_values.rs b/datafusion/functions-nested/src/map_values.rs index c6d31f3d9067..f82e4bfa1a89 100644 --- a/datafusion/functions-nested/src/map_values.rs +++ b/datafusion/functions-nested/src/map_values.rs @@ -19,7 +19,8 @@ use crate::utils::{get_map_entry_field, make_scalar_function}; use arrow::array::{Array, ArrayRef, ListArray}; -use arrow_schema::{DataType, Field}; +use arrow::datatypes::{DataType, Field}; +use datafusion_common::utils::take_function_args; use datafusion_common::{cast::as_map_array, exec_err, Result}; use datafusion_expr::{ ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, @@ -91,10 +92,7 @@ impl ScalarUDFImpl for MapValuesFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types.len() != 1 { - return exec_err!("map_values expects single argument"); - } - let map_type = &arg_types[0]; + let [map_type] = take_function_args(self.name(), arg_types)?; let map_fields = get_map_entry_field(map_type)?; Ok(DataType::List(Arc::new(Field::new_list_field( map_fields.last().unwrap().data_type().clone(), @@ -102,12 +100,11 @@ impl ScalarUDFImpl for MapValuesFunc { )))) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(map_values_inner)(args) + make_scalar_function(map_values_inner)(&args.args) } fn documentation(&self) -> Option<&Documentation> { @@ -116,12 +113,10 @@ impl ScalarUDFImpl for MapValuesFunc { } fn map_values_inner(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("map_values expects single argument"); - } + let [map_arg] = take_function_args("map_values", args)?; - let map_array = match args[0].data_type() { - DataType::Map(_, _) => as_map_array(&args[0])?, + let map_array = match map_arg.data_type() { + DataType::Map(_, _) => as_map_array(&map_arg)?, _ => return exec_err!("Argument for map_values should be a map"), }; diff --git a/datafusion/functions-nested/src/planner.rs b/datafusion/functions-nested/src/planner.rs index 5ca51ac20f1e..369eaecb1905 100644 --- a/datafusion/functions-nested/src/planner.rs +++ b/datafusion/functions-nested/src/planner.rs @@ -17,14 +17,20 @@ //! SQL planning extensions like [`NestedFunctionPlanner`] and [`FieldAccessPlanner`] +use arrow::datatypes::DataType; +use datafusion_common::ExprSchema; use datafusion_common::{plan_err, utils::list_ndims, DFSchema, Result}; use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams}; +use datafusion_expr::AggregateUDF; use datafusion_expr::{ planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr}, sqlparser, Expr, ExprSchemable, GetFieldAccess, }; +use datafusion_functions::core::get_field as get_field_inner; use datafusion_functions::expr_fn::get_field; use datafusion_functions_aggregate::nth_value::nth_value_udaf; +use std::sync::Arc; use crate::map::map_udf; use crate::{ @@ -137,7 +143,7 @@ impl ExprPlanner for FieldAccessPlanner { fn plan_field_access( &self, expr: RawFieldAccessExpr, - _schema: &DFSchema, + schema: &DFSchema, ) -> Result> { let RawFieldAccessExpr { expr, field_access } = expr; @@ -150,19 +156,34 @@ impl ExprPlanner for FieldAccessPlanner { GetFieldAccess::ListIndex { key: index } => { match expr { // Special case for array_agg(expr)[index] to NTH_VALUE(expr, index) - Expr::AggregateFunction(agg_func) if is_array_agg(&agg_func) => { - Ok(PlannerResult::Planned(Expr::AggregateFunction( - datafusion_expr::expr::AggregateFunction::new_udf( - nth_value_udaf(), - agg_func - .args - .into_iter() - .chain(std::iter::once(*index)) - .collect(), - agg_func.distinct, - agg_func.filter, - agg_func.order_by, - agg_func.null_treatment, + Expr::AggregateFunction(AggregateFunction { + func, + params: + AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment, + }, + }) if is_array_agg(&func) => Ok(PlannerResult::Planned( + Expr::AggregateFunction(AggregateFunction::new_udf( + nth_value_udaf(), + args.into_iter().chain(std::iter::once(*index)).collect(), + distinct, + filter, + order_by, + null_treatment, + )), + )), + // special case for map access with + Expr::Column(ref c) + if matches!(schema.data_type(c)?, DataType::Map(_, _)) => + { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf( + get_field_inner(), + vec![expr, *index], ), ))) } @@ -184,6 +205,6 @@ impl ExprPlanner for FieldAccessPlanner { } } -fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool { - agg_func.func.name() == "array_agg" +fn is_array_agg(func: &Arc) -> bool { + func.name() == "array_agg" } diff --git a/datafusion/functions-nested/src/position.rs b/datafusion/functions-nested/src/position.rs index d5c9944709b3..b186b65407c3 100644 --- a/datafusion/functions-nested/src/position.rs +++ b/datafusion/functions-nested/src/position.rs @@ -17,8 +17,11 @@ //! [`ScalarUDFImpl`] definitions for array_position and array_positions functions. -use arrow_schema::DataType::{LargeList, List, UInt64}; -use arrow_schema::{DataType, Field}; +use arrow::datatypes::DataType; +use arrow::datatypes::{ + DataType::{LargeList, List, UInt64}, + Field, +}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -34,7 +37,7 @@ use arrow::array::{ use datafusion_common::cast::{ as_generic_list_array, as_int64_array, as_large_list_array, as_list_array, }; -use datafusion_common::{exec_err, internal_err, Result}; +use datafusion_common::{exec_err, internal_err, utils::take_function_args, Result}; use itertools::Itertools; use crate::utils::{compare_element_to_list, make_scalar_function}; @@ -117,12 +120,11 @@ impl ScalarUDFImpl for ArrayPosition { Ok(UInt64) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_position_inner)(args) + make_scalar_function(array_position_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -271,12 +273,11 @@ impl ScalarUDFImpl for ArrayPositions { Ok(List(Arc::new(Field::new_list_field(UInt64, true)))) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_positions_inner)(args) + make_scalar_function(array_positions_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -290,20 +291,16 @@ impl ScalarUDFImpl for ArrayPositions { /// Array_positions SQL function pub fn array_positions_inner(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_positions expects two arguments"); - } + let [array, element] = take_function_args("array_positions", args)?; - let element = &args[1]; - - match &args[0].data_type() { + match &array.data_type() { List(_) => { - let arr = as_list_array(&args[0])?; + let arr = as_list_array(&array)?; crate::utils::check_datatypes("array_positions", &[arr.values(), element])?; general_positions::(arr, element) } LargeList(_) => { - let arr = as_large_list_array(&args[0])?; + let arr = as_large_list_array(&array)?; crate::utils::check_datatypes("array_positions", &[arr.values(), element])?; general_positions::(arr, element) } diff --git a/datafusion/functions-nested/src/range.rs b/datafusion/functions-nested/src/range.rs index c3f52cef3366..637a78d158ab 100644 --- a/datafusion/functions-nested/src/range.rs +++ b/datafusion/functions-nested/src/range.rs @@ -27,15 +27,15 @@ use arrow::array::{ TimestampNanosecondArray, }; use arrow::buffer::OffsetBuffer; -use arrow::datatypes::{DataType, Field}; -use arrow_schema::DataType::*; -use arrow_schema::IntervalUnit::MonthDayNano; -use arrow_schema::TimeUnit::Nanosecond; +use arrow::datatypes::{ + DataType, DataType::*, Field, IntervalUnit::MonthDayNano, TimeUnit::Nanosecond, +}; use datafusion_common::cast::{ as_date32_array, as_int64_array, as_interval_mdn_array, as_timestamp_nanosecond_array, }; use datafusion_common::{ - exec_datafusion_err, exec_err, internal_err, not_impl_datafusion_err, Result, + exec_datafusion_err, exec_err, internal_err, not_impl_datafusion_err, + utils::take_function_args, Result, }; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, @@ -155,11 +155,12 @@ impl ScalarUDFImpl for Range { } } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { + let args = &args.args; + if args.iter().any(|arg| arg.data_type().is_null()) { return Ok(ColumnarValue::Array(Arc::new(NullArray::new(1)))); } @@ -278,11 +279,12 @@ impl ScalarUDFImpl for GenSeries { } } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { + let args = &args.args; + if args.iter().any(|arg| arg.data_type().is_null()) { return Ok(ColumnarValue::Array(Arc::new(NullArray::new(1)))); } @@ -435,13 +437,12 @@ fn gen_range_iter( } fn gen_range_date(args: &[ArrayRef], include_upper_bound: bool) -> Result { - if args.len() != 3 { - return exec_err!("arguments length does not match"); - } + let [start, stop, step] = take_function_args("range", args)?; + let (start_array, stop_array, step_array) = ( - Some(as_date32_array(&args[0])?), - as_date32_array(&args[1])?, - Some(as_interval_mdn_array(&args[2])?), + Some(as_date32_array(start)?), + as_date32_array(stop)?, + Some(as_interval_mdn_array(step)?), ); // values are date32s @@ -508,21 +509,17 @@ fn gen_range_date(args: &[ArrayRef], include_upper_bound: bool) -> Result Result { - if args.len() != 3 { - return exec_err!( - "Arguments length must be 3 for {}", - if include_upper_bound { - "GENERATE_SERIES" - } else { - "RANGE" - } - ); - } + let func_name = if include_upper_bound { + "GENERATE_SERIES" + } else { + "RANGE" + }; + let [start, stop, step] = take_function_args(func_name, args)?; // coerce_types fn should coerce all types to Timestamp(Nanosecond, tz) - let (start_arr, start_tz_opt) = cast_timestamp_arg(&args[0], include_upper_bound)?; - let (stop_arr, stop_tz_opt) = cast_timestamp_arg(&args[1], include_upper_bound)?; - let step_arr = as_interval_mdn_array(&args[2])?; + let (start_arr, start_tz_opt) = cast_timestamp_arg(start, include_upper_bound)?; + let (stop_arr, stop_tz_opt) = cast_timestamp_arg(stop, include_upper_bound)?; + let step_arr = as_interval_mdn_array(step)?; let start_tz = parse_tz(start_tz_opt)?; let stop_tz = parse_tz(stop_tz_opt)?; diff --git a/datafusion/functions-nested/src/remove.rs b/datafusion/functions-nested/src/remove.rs index 099cc7e1131d..7f5baa18e769 100644 --- a/datafusion/functions-nested/src/remove.rs +++ b/datafusion/functions-nested/src/remove.rs @@ -24,9 +24,9 @@ use arrow::array::{ OffsetSizeTrait, }; use arrow::buffer::OffsetBuffer; -use arrow_schema::{DataType, Field}; +use arrow::datatypes::{DataType, Field}; use datafusion_common::cast::as_int64_array; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{exec_err, utils::take_function_args, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -101,12 +101,11 @@ impl ScalarUDFImpl for ArrayRemove { Ok(arg_types[0].clone()) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_remove_inner)(args) + make_scalar_function(array_remove_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -180,12 +179,11 @@ impl ScalarUDFImpl for ArrayRemoveN { Ok(arg_types[0].clone()) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_remove_n_inner)(args) + make_scalar_function(array_remove_n_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -258,12 +256,11 @@ impl ScalarUDFImpl for ArrayRemoveAll { Ok(arg_types[0].clone()) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_remove_all_inner)(args) + make_scalar_function(array_remove_all_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -277,32 +274,26 @@ impl ScalarUDFImpl for ArrayRemoveAll { /// Array_remove SQL function pub fn array_remove_inner(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_remove expects two arguments"); - } + let [array, element] = take_function_args("array_remove", args)?; - let arr_n = vec![1; args[0].len()]; - array_remove_internal(&args[0], &args[1], arr_n) + let arr_n = vec![1; array.len()]; + array_remove_internal(array, element, arr_n) } /// Array_remove_n SQL function pub fn array_remove_n_inner(args: &[ArrayRef]) -> Result { - if args.len() != 3 { - return exec_err!("array_remove_n expects three arguments"); - } + let [array, element, max] = take_function_args("array_remove_n", args)?; - let arr_n = as_int64_array(&args[2])?.values().to_vec(); - array_remove_internal(&args[0], &args[1], arr_n) + let arr_n = as_int64_array(max)?.values().to_vec(); + array_remove_internal(array, element, arr_n) } /// Array_remove_all SQL function pub fn array_remove_all_inner(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_remove_all expects two arguments"); - } + let [array, element] = take_function_args("array_remove_all", args)?; - let arr_n = vec![i64::MAX; args[0].len()]; - array_remove_internal(&args[0], &args[1], arr_n) + let arr_n = vec![i64::MAX; array.len()]; + array_remove_internal(array, element, arr_n) } fn array_remove_internal( diff --git a/datafusion/functions-nested/src/repeat.rs b/datafusion/functions-nested/src/repeat.rs index 4772da9a4bf4..26d67ad3113f 100644 --- a/datafusion/functions-nested/src/repeat.rs +++ b/datafusion/functions-nested/src/repeat.rs @@ -25,10 +25,13 @@ use arrow::array::{ use arrow::buffer::OffsetBuffer; use arrow::compute; use arrow::compute::cast; -use arrow_schema::DataType::{LargeList, List}; -use arrow_schema::{DataType, Field}; +use arrow::datatypes::DataType; +use arrow::datatypes::{ + DataType::{LargeList, List}, + Field, +}; use datafusion_common::cast::{as_large_list_array, as_list_array, as_uint64_array}; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{exec_err, utils::take_function_args, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -112,12 +115,11 @@ impl ScalarUDFImpl for ArrayRepeat { )))) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_repeat_inner)(args) + make_scalar_function(array_repeat_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -125,17 +127,10 @@ impl ScalarUDFImpl for ArrayRepeat { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() != 2 { - return exec_err!("array_repeat expects two arguments"); - } - - let element_type = &arg_types[0]; - let first = element_type.clone(); - - let count_type = &arg_types[1]; + let [first_type, second_type] = take_function_args(self.name(), arg_types)?; // Coerce the second argument to Int64/UInt64 if it's a numeric type - let second = match count_type { + let second = match second_type { DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { DataType::Int64 } @@ -145,7 +140,7 @@ impl ScalarUDFImpl for ArrayRepeat { _ => return exec_err!("count must be an integer type"), }; - Ok(vec![first, second]) + Ok(vec![first_type.clone(), second]) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions-nested/src/replace.rs b/datafusion/functions-nested/src/replace.rs index 939fce6fdf3f..71bfedb72d1c 100644 --- a/datafusion/functions-nested/src/replace.rs +++ b/datafusion/functions-nested/src/replace.rs @@ -21,14 +21,15 @@ use arrow::array::{ Array, ArrayRef, AsArray, Capacities, GenericListArray, MutableArrayData, NullBufferBuilder, OffsetSizeTrait, }; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::buffer::OffsetBuffer; -use arrow_schema::Field; use datafusion_common::cast::as_int64_array; -use datafusion_common::{exec_err, Result}; +use datafusion_common::utils::ListCoercion; +use datafusion_common::{exec_err, utils::take_function_args, Result}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, + ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_macros::user_doc; @@ -92,7 +93,19 @@ impl Default for ArrayReplace { impl ArrayReplace { pub fn new() -> Self { Self { - signature: Signature::any(3, Volatility::Immutable), + signature: Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Element, + ArrayFunctionArgument::Element, + ], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }, + ), + volatility: Volatility::Immutable, + }, aliases: vec![String::from("list_replace")], } } @@ -115,12 +128,11 @@ impl ScalarUDFImpl for ArrayReplace { Ok(args[0].clone()) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_replace_inner)(args) + make_scalar_function(array_replace_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -161,7 +173,20 @@ pub(super) struct ArrayReplaceN { impl ArrayReplaceN { pub fn new() -> Self { Self { - signature: Signature::any(4, Volatility::Immutable), + signature: Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Element, + ArrayFunctionArgument::Element, + ArrayFunctionArgument::Index, + ], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }, + ), + volatility: Volatility::Immutable, + }, aliases: vec![String::from("list_replace_n")], } } @@ -184,12 +209,11 @@ impl ScalarUDFImpl for ArrayReplaceN { Ok(args[0].clone()) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_replace_n_inner)(args) + make_scalar_function(array_replace_n_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -229,7 +253,19 @@ pub(super) struct ArrayReplaceAll { impl ArrayReplaceAll { pub fn new() -> Self { Self { - signature: Signature::any(3, Volatility::Immutable), + signature: Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Element, + ArrayFunctionArgument::Element, + ], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }, + ), + volatility: Volatility::Immutable, + }, aliases: vec![String::from("list_replace_all")], } } @@ -252,12 +288,11 @@ impl ScalarUDFImpl for ArrayReplaceAll { Ok(args[0].clone()) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_replace_all_inner)(args) + make_scalar_function(array_replace_all_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -381,42 +416,36 @@ fn general_replace( } pub(crate) fn array_replace_inner(args: &[ArrayRef]) -> Result { - if args.len() != 3 { - return exec_err!("array_replace expects three arguments"); - } + let [array, from, to] = take_function_args("array_replace", args)?; // replace at most one occurrence for each element - let arr_n = vec![1; args[0].len()]; - let array = &args[0]; + let arr_n = vec![1; array.len()]; match array.data_type() { DataType::List(_) => { let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) + general_replace::(list_array, from, to, arr_n) } DataType::LargeList(_) => { let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) + general_replace::(list_array, from, to, arr_n) } array_type => exec_err!("array_replace does not support type '{array_type:?}'."), } } pub(crate) fn array_replace_n_inner(args: &[ArrayRef]) -> Result { - if args.len() != 4 { - return exec_err!("array_replace_n expects four arguments"); - } + let [array, from, to, max] = take_function_args("array_replace_n", args)?; // replace the specified number of occurrences - let arr_n = as_int64_array(&args[3])?.values().to_vec(); - let array = &args[0]; + let arr_n = as_int64_array(max)?.values().to_vec(); match array.data_type() { DataType::List(_) => { let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) + general_replace::(list_array, from, to, arr_n) } DataType::LargeList(_) => { let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) + general_replace::(list_array, from, to, arr_n) } array_type => { exec_err!("array_replace_n does not support type '{array_type:?}'.") @@ -425,21 +454,18 @@ pub(crate) fn array_replace_n_inner(args: &[ArrayRef]) -> Result { } pub(crate) fn array_replace_all_inner(args: &[ArrayRef]) -> Result { - if args.len() != 3 { - return exec_err!("array_replace_all expects three arguments"); - } + let [array, from, to] = take_function_args("array_replace_all", args)?; // replace all occurrences (up to "i64::MAX") - let arr_n = vec![i64::MAX; args[0].len()]; - let array = &args[0]; + let arr_n = vec![i64::MAX; array.len()]; match array.data_type() { DataType::List(_) => { let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) + general_replace::(list_array, from, to, arr_n) } DataType::LargeList(_) => { let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) + general_replace::(list_array, from, to, arr_n) } array_type => { exec_err!("array_replace_all does not support type '{array_type:?}'.") diff --git a/datafusion/functions-nested/src/resize.rs b/datafusion/functions-nested/src/resize.rs index 3cd7bb5dac81..6c0b91a678e7 100644 --- a/datafusion/functions-nested/src/resize.rs +++ b/datafusion/functions-nested/src/resize.rs @@ -24,8 +24,11 @@ use arrow::array::{ }; use arrow::buffer::OffsetBuffer; use arrow::datatypes::ArrowNativeType; -use arrow_schema::DataType::{FixedSizeList, LargeList, List}; -use arrow_schema::{DataType, FieldRef}; +use arrow::datatypes::DataType; +use arrow::datatypes::{ + DataType::{FixedSizeList, LargeList, List}, + FieldRef, +}; use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array}; use datafusion_common::{exec_err, internal_datafusion_err, Result, ScalarValue}; use datafusion_expr::{ @@ -109,12 +112,11 @@ impl ScalarUDFImpl for ArrayResize { } } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_resize_inner)(args) + make_scalar_function(array_resize_inner)(&args.args) } fn aliases(&self) -> &[String] { diff --git a/datafusion/functions-nested/src/reverse.rs b/datafusion/functions-nested/src/reverse.rs index a60f84cb0320..140cd19aeff9 100644 --- a/datafusion/functions-nested/src/reverse.rs +++ b/datafusion/functions-nested/src/reverse.rs @@ -22,10 +22,10 @@ use arrow::array::{ Array, ArrayRef, Capacities, GenericListArray, MutableArrayData, OffsetSizeTrait, }; use arrow::buffer::OffsetBuffer; -use arrow_schema::DataType::{LargeList, List, Null}; -use arrow_schema::{DataType, FieldRef}; +use arrow::datatypes::DataType::{LargeList, List, Null}; +use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::cast::{as_large_list_array, as_list_array}; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{exec_err, utils::take_function_args, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -96,12 +96,11 @@ impl ScalarUDFImpl for ArrayReverse { Ok(arg_types[0].clone()) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_reverse_inner)(args) + make_scalar_function(array_reverse_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -115,20 +114,18 @@ impl ScalarUDFImpl for ArrayReverse { /// array_reverse SQL function pub fn array_reverse_inner(arg: &[ArrayRef]) -> Result { - if arg.len() != 1 { - return exec_err!("array_reverse needs one argument"); - } + let [input_array] = take_function_args("array_reverse", arg)?; - match &arg[0].data_type() { + match &input_array.data_type() { List(field) => { - let array = as_list_array(&arg[0])?; + let array = as_list_array(input_array)?; general_array_reverse::(array, field) } LargeList(field) => { - let array = as_large_list_array(&arg[0])?; + let array = as_large_list_array(input_array)?; general_array_reverse::(array, field) } - Null => Ok(Arc::clone(&arg[0])), + Null => Ok(Arc::clone(input_array)), array_type => exec_err!("array_reverse does not support type '{array_type:?}'."), } } diff --git a/datafusion/functions-nested/src/set_ops.rs b/datafusion/functions-nested/src/set_ops.rs index 079e0e3ed214..a67945b1f1e1 100644 --- a/datafusion/functions-nested/src/set_ops.rs +++ b/datafusion/functions-nested/src/set_ops.rs @@ -22,11 +22,11 @@ use crate::utils::make_scalar_function; use arrow::array::{new_empty_array, Array, ArrayRef, GenericListArray, OffsetSizeTrait}; use arrow::buffer::OffsetBuffer; use arrow::compute; +use arrow::datatypes::DataType::{FixedSizeList, LargeList, List, Null}; use arrow::datatypes::{DataType, Field, FieldRef}; use arrow::row::{RowConverter, SortField}; -use arrow_schema::DataType::{FixedSizeList, LargeList, List, Null}; use datafusion_common::cast::{as_large_list_array, as_list_array}; -use datafusion_common::{exec_err, internal_err, Result}; +use datafusion_common::{exec_err, internal_err, utils::take_function_args, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -131,12 +131,11 @@ impl ScalarUDFImpl for ArrayUnion { } } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_union_inner)(args) + make_scalar_function(array_union_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -211,12 +210,11 @@ impl ScalarUDFImpl for ArrayIntersect { } } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_intersect_inner)(args) + make_scalar_function(array_intersect_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -288,12 +286,11 @@ impl ScalarUDFImpl for ArrayDistinct { } } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_distinct_inner)(args) + make_scalar_function(array_distinct_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -308,23 +305,21 @@ impl ScalarUDFImpl for ArrayDistinct { /// array_distinct SQL function /// example: from list [1, 3, 2, 3, 1, 2, 4] to [1, 2, 3, 4] fn array_distinct_inner(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("array_distinct needs one argument"); - } + let [input_array] = take_function_args("array_distinct", args)?; // handle null - if args[0].data_type() == &Null { - return Ok(Arc::clone(&args[0])); + if input_array.data_type() == &Null { + return Ok(Arc::clone(input_array)); } // handle for list & largelist - match args[0].data_type() { + match input_array.data_type() { List(field) => { - let array = as_list_array(&args[0])?; + let array = as_list_array(&input_array)?; general_array_distinct(array, field) } LargeList(field) => { - let array = as_large_list_array(&args[0])?; + let array = as_large_list_array(&input_array)?; general_array_distinct(array, field) } array_type => exec_err!("array_distinct does not support type '{array_type:?}'"), @@ -488,24 +483,13 @@ fn general_set_op( /// Array_union SQL function fn array_union_inner(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_union needs two arguments"); - } - let array1 = &args[0]; - let array2 = &args[1]; - + let [array1, array2] = take_function_args("array_union", args)?; general_set_op(array1, array2, SetOp::Union) } /// array_intersect SQL function fn array_intersect_inner(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_intersect needs two arguments"); - } - - let array1 = &args[0]; - let array2 = &args[1]; - + let [array1, array2] = take_function_args("array_intersect", args)?; general_set_op(array1, array2, SetOp::Intersect) } diff --git a/datafusion/functions-nested/src/sort.rs b/datafusion/functions-nested/src/sort.rs index e4dcc02286f3..7dbf9f2b211e 100644 --- a/datafusion/functions-nested/src/sort.rs +++ b/datafusion/functions-nested/src/sort.rs @@ -20,9 +20,9 @@ use crate::utils::make_scalar_function; use arrow::array::{Array, ArrayRef, ListArray, NullBufferBuilder}; use arrow::buffer::OffsetBuffer; -use arrow::compute; -use arrow_schema::DataType::{FixedSizeList, LargeList, List}; -use arrow_schema::{DataType, Field, SortOptions}; +use arrow::datatypes::DataType::{FixedSizeList, LargeList, List}; +use arrow::datatypes::{DataType, Field}; +use arrow::{compute, compute::SortOptions}; use datafusion_common::cast::{as_list_array, as_string_array}; use datafusion_common::{exec_err, Result}; use datafusion_expr::{ @@ -121,12 +121,11 @@ impl ScalarUDFImpl for ArraySort { } } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_sort_inner)(args) + make_scalar_function(array_sort_inner)(&args.args) } fn aliases(&self) -> &[String] { diff --git a/datafusion/functions-nested/src/string.rs b/datafusion/functions-nested/src/string.rs index 1a0676aa39d5..99af3e95c804 100644 --- a/datafusion/functions-nested/src/string.rs +++ b/datafusion/functions-nested/src/string.rs @@ -39,7 +39,7 @@ use arrow::array::{ GenericStringArray, StringArrayType, StringViewArray, }; use arrow::compute::cast; -use arrow_schema::DataType::{ +use arrow::datatypes::DataType::{ Dictionary, FixedSizeList, LargeList, LargeUtf8, List, Null, Utf8, Utf8View, }; use datafusion_common::cast::{as_large_list_array, as_list_array}; @@ -192,12 +192,11 @@ impl ScalarUDFImpl for ArrayToString { }) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_to_string_inner)(args) + make_scalar_function(array_to_string_inner)(&args.args) } fn aliases(&self) -> &[String] { @@ -286,11 +285,11 @@ impl ScalarUDFImpl for StringToArray { }) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { + let args = &args.args; match args[0].data_type() { Utf8 | Utf8View => make_scalar_function(string_to_array_inner::)(args), LargeUtf8 => make_scalar_function(string_to_array_inner::)(args), diff --git a/datafusion/functions-nested/src/utils.rs b/datafusion/functions-nested/src/utils.rs index 5dd812a23b9a..74b21a3ceb47 100644 --- a/datafusion/functions-nested/src/utils.rs +++ b/datafusion/functions-nested/src/utils.rs @@ -19,14 +19,13 @@ use std::sync::Arc; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field, Fields}; use arrow::array::{ Array, ArrayRef, BooleanArray, GenericListArray, ListArray, OffsetSizeTrait, Scalar, UInt32Array, }; use arrow::buffer::OffsetBuffer; -use arrow_schema::{Field, Fields}; use datafusion_common::cast::{as_large_list_array, as_list_array}; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, plan_err, Result, ScalarValue, diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index db3e6838f6a5..c00997853bb3 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -56,7 +56,7 @@ regex_expressions = ["regex"] # enable string functions string_expressions = ["uuid"] # enable unicode functions -unicode_expressions = ["hashbrown", "unicode-segmentation"] +unicode_expressions = ["unicode-segmentation"] [lib] name = "datafusion_functions" @@ -69,7 +69,7 @@ arrow = { workspace = true } arrow-buffer = { workspace = true } base64 = { version = "0.22", optional = true } blake2 = { version = "^0.10.2", optional = true } -blake3 = { version = "1.0", optional = true } +blake3 = { version = "1.6", optional = true } chrono = { workspace = true } datafusion-common = { workspace = true } datafusion-doc = { workspace = true } @@ -77,7 +77,6 @@ datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-expr-common = { workspace = true } datafusion-macros = { workspace = true } -hashbrown = { workspace = true, optional = true } hex = { version = "0.4", optional = true } itertools = { workspace = true } log = { workspace = true } @@ -86,7 +85,7 @@ rand = { workspace = true } regex = { workspace = true, optional = true } sha2 = { version = "^0.10.1", optional = true } unicode-segmentation = { version = "^1.7.1", optional = true } -uuid = { version = "1.7", features = ["v4"], optional = true } +uuid = { version = "1.13", features = ["v4"], optional = true } [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } @@ -109,6 +108,21 @@ harness = false name = "encoding" required-features = ["encoding_expressions"] +[[bench]] +harness = false +name = "chr" +required-features = ["string_expressions"] + +[[bench]] +harness = false +name = "uuid" +required-features = ["string_expressions"] + +[[bench]] +harness = false +name = "to_hex" +required-features = ["string_expressions"] + [[bench]] harness = false name = "regx" @@ -134,6 +148,11 @@ harness = false name = "date_bin" required-features = ["datetime_expressions"] +[[bench]] +harness = false +name = "date_trunc" +required-features = ["datetime_expressions"] + [[bench]] harness = false name = "to_char" diff --git a/datafusion/functions/benches/chr.rs b/datafusion/functions/benches/chr.rs new file mode 100644 index 000000000000..58c5ee3d68f6 --- /dev/null +++ b/datafusion/functions/benches/chr.rs @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::{array::PrimitiveArray, datatypes::Int64Type, util::test_util::seedable_rng}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::string::chr; +use rand::Rng; + +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let cot_fn = chr(); + let size = 1024; + let input: PrimitiveArray = { + let null_density = 0.2; + let mut rng = seedable_rng(); + (0..size) + .map(|_| { + if rng.gen::() < null_density { + None + } else { + Some(rng.gen_range::(1i64..10_000)) + } + }) + .collect() + }; + let input = Arc::new(input); + let args = vec![ColumnarValue::Array(input)]; + c.bench_function("chr", |b| { + b.iter(|| black_box(cot_fn.invoke_batch(&args, size).unwrap())) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/concat.rs b/datafusion/functions/benches/concat.rs index 0f287ab36dad..45ca076e754f 100644 --- a/datafusion/functions/benches/concat.rs +++ b/datafusion/functions/benches/concat.rs @@ -16,10 +16,11 @@ // under the License. use arrow::array::ArrayRef; +use arrow::datatypes::DataType; use arrow::util::bench_util::create_string_array_with_len; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use datafusion_common::ScalarValue; -use datafusion_expr::ColumnarValue; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string::concat; use std::sync::Arc; @@ -39,8 +40,16 @@ fn criterion_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("concat function"); group.bench_function(BenchmarkId::new("concat", size), |b| { b.iter(|| { - // TODO use invoke_with_args - criterion::black_box(concat().invoke_batch(&args, size).unwrap()) + let args_cloned = args.clone(); + criterion::black_box( + concat() + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: size, + return_type: &DataType::Utf8, + }) + .unwrap(), + ) }) }); group.finish(); diff --git a/datafusion/functions/benches/date_trunc.rs b/datafusion/functions/benches/date_trunc.rs new file mode 100644 index 000000000000..d420b8f6ac70 --- /dev/null +++ b/datafusion/functions/benches/date_trunc.rs @@ -0,0 +1,60 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, TimestampSecondArray}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_common::ScalarValue; +use rand::rngs::ThreadRng; +use rand::Rng; + +use datafusion_expr::ColumnarValue; +use datafusion_functions::datetime::date_trunc; + +fn timestamps(rng: &mut ThreadRng) -> TimestampSecondArray { + let mut seconds = vec![]; + for _ in 0..1000 { + seconds.push(rng.gen_range(0..1_000_000)); + } + + TimestampSecondArray::from(seconds) +} + +fn criterion_benchmark(c: &mut Criterion) { + c.bench_function("date_trunc_minute_1000", |b| { + let mut rng = rand::thread_rng(); + let timestamps_array = Arc::new(timestamps(&mut rng)) as ArrayRef; + let batch_len = timestamps_array.len(); + let precision = + ColumnarValue::Scalar(ScalarValue::Utf8(Some("minute".to_string()))); + let timestamps = ColumnarValue::Array(timestamps_array); + let udf = date_trunc(); + + b.iter(|| { + black_box( + udf.invoke_batch(&[precision.clone(), timestamps.clone()], batch_len) + .expect("date_trunc should work on valid values"), + ) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/lower.rs b/datafusion/functions/benches/lower.rs index 114ac4a16fe5..534e5739225d 100644 --- a/datafusion/functions/benches/lower.rs +++ b/datafusion/functions/benches/lower.rs @@ -18,11 +18,12 @@ extern crate criterion; use arrow::array::{ArrayRef, StringArray, StringViewBuilder}; +use arrow::datatypes::DataType; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string; use std::sync::Arc; @@ -125,8 +126,12 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args1(size, 32); c.bench_function(&format!("lower_all_values_are_ascii: {}", size), |b| { b.iter(|| { - // TODO use invoke_with_args - black_box(lower.invoke_batch(&args, size)) + let args_cloned = args.clone(); + black_box(lower.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: size, + return_type: &DataType::Utf8, + })) }) }); @@ -135,8 +140,12 @@ fn criterion_benchmark(c: &mut Criterion) { &format!("lower_the_first_value_is_nonascii: {}", size), |b| { b.iter(|| { - // TODO use invoke_with_args - black_box(lower.invoke_batch(&args, size)) + let args_cloned = args.clone(); + black_box(lower.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: size, + return_type: &DataType::Utf8, + })) }) }, ); @@ -146,8 +155,12 @@ fn criterion_benchmark(c: &mut Criterion) { &format!("lower_the_middle_value_is_nonascii: {}", size), |b| { b.iter(|| { - // TODO use invoke_with_args - black_box(lower.invoke_batch(&args, size)) + let args_cloned = args.clone(); + black_box(lower.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: size, + return_type: &DataType::Utf8, + })) }) }, ); @@ -167,8 +180,12 @@ fn criterion_benchmark(c: &mut Criterion) { &format!("lower_all_values_are_ascii_string_views: size: {}, str_len: {}, null_density: {}, mixed: {}", size, str_len, null_density, mixed), |b| b.iter(|| { - // TODO use invoke_with_args - black_box(lower.invoke_batch(&args, size)) + let args_cloned = args.clone(); + black_box(lower.invoke_with_args(ScalarFunctionArgs{ + args: args_cloned, + number_rows: size, + return_type: &DataType::Utf8, + })) }), ); @@ -177,8 +194,12 @@ fn criterion_benchmark(c: &mut Criterion) { &format!("lower_all_values_are_ascii_string_views: size: {}, str_len: {}, null_density: {}, mixed: {}", size, str_len, null_density, mixed), |b| b.iter(|| { - // TODO use invoke_with_args - black_box(lower.invoke_batch(&args, size)) + let args_cloned = args.clone(); + black_box(lower.invoke_with_args(ScalarFunctionArgs{ + args: args_cloned, + number_rows: size, + return_type: &DataType::Utf8, + })) }), ); @@ -187,8 +208,12 @@ fn criterion_benchmark(c: &mut Criterion) { &format!("lower_some_values_are_nonascii_string_views: size: {}, str_len: {}, non_ascii_density: {}, null_density: {}, mixed: {}", size, str_len, 0.1, null_density, mixed), |b| b.iter(|| { - // TODO use invoke_with_args - black_box(lower.invoke_batch(&args, size)) + let args_cloned = args.clone(); + black_box(lower.invoke_with_args(ScalarFunctionArgs{ + args: args_cloned, + number_rows: size, + return_type: &DataType::Utf8, + })) }), ); } diff --git a/datafusion/functions/benches/ltrim.rs b/datafusion/functions/benches/ltrim.rs index fed455eeac91..457fb499f5a1 100644 --- a/datafusion/functions/benches/ltrim.rs +++ b/datafusion/functions/benches/ltrim.rs @@ -18,12 +18,13 @@ extern crate criterion; use arrow::array::{ArrayRef, LargeStringArray, StringArray, StringViewArray}; +use arrow::datatypes::DataType; use criterion::{ black_box, criterion_group, criterion_main, measurement::Measurement, BenchmarkGroup, Criterion, SamplingMode, }; use datafusion_common::ScalarValue; -use datafusion_expr::{ColumnarValue, ScalarUDF}; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDF}; use datafusion_functions::string; use rand::{distributions::Alphanumeric, rngs::StdRng, Rng, SeedableRng}; use std::{fmt, sync::Arc}; @@ -141,8 +142,12 @@ fn run_with_string_type( ), |b| { b.iter(|| { - // TODO use invoke_with_args - black_box(ltrim.invoke_batch(&args, size)) + let args_cloned = args.clone(); + black_box(ltrim.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: size, + return_type: &DataType::Utf8, + })) }) }, ); diff --git a/datafusion/functions/benches/repeat.rs b/datafusion/functions/benches/repeat.rs index e7e3c634ea82..5cc6a177d9d9 100644 --- a/datafusion/functions/benches/repeat.rs +++ b/datafusion/functions/benches/repeat.rs @@ -18,11 +18,12 @@ extern crate criterion; use arrow::array::{ArrayRef, Int64Array, OffsetSizeTrait}; +use arrow::datatypes::DataType; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; use criterion::{black_box, criterion_group, criterion_main, Criterion, SamplingMode}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string; use std::sync::Arc; use std::time::Duration; @@ -73,8 +74,12 @@ fn criterion_benchmark(c: &mut Criterion) { ), |b| { b.iter(|| { - // TODO use invoke_with_args - black_box(repeat.invoke_batch(&args, repeat_times as usize)) + let args_cloned = args.clone(); + black_box(repeat.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: repeat_times as usize, + return_type: &DataType::Utf8, + })) }) }, ); @@ -87,8 +92,12 @@ fn criterion_benchmark(c: &mut Criterion) { ), |b| { b.iter(|| { - // TODO use invoke_with_args - black_box(repeat.invoke_batch(&args, repeat_times as usize)) + let args_cloned = args.clone(); + black_box(repeat.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: repeat_times as usize, + return_type: &DataType::Utf8, + })) }) }, ); @@ -101,8 +110,12 @@ fn criterion_benchmark(c: &mut Criterion) { ), |b| { b.iter(|| { - // TODO use invoke_with_args - black_box(repeat.invoke_batch(&args, repeat_times as usize)) + let args_cloned = args.clone(); + black_box(repeat.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: repeat_times as usize, + return_type: &DataType::Utf8, + })) }) }, ); @@ -124,8 +137,12 @@ fn criterion_benchmark(c: &mut Criterion) { ), |b| { b.iter(|| { - // TODO use invoke_with_args - black_box(repeat.invoke_batch(&args, repeat_times as usize)) + let args_cloned = args.clone(); + black_box(repeat.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: repeat_times as usize, + return_type: &DataType::Utf8, + })) }) }, ); @@ -138,8 +155,12 @@ fn criterion_benchmark(c: &mut Criterion) { ), |b| { b.iter(|| { - // TODO use invoke_with_args - black_box(repeat.invoke_batch(&args, size)) + let args_cloned = args.clone(); + black_box(repeat.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: repeat_times as usize, + return_type: &DataType::Utf8, + })) }) }, ); @@ -152,8 +173,39 @@ fn criterion_benchmark(c: &mut Criterion) { ), |b| { b.iter(|| { - // TODO use invoke_with_args - black_box(repeat.invoke_batch(&args, repeat_times as usize)) + let args_cloned = args.clone(); + black_box(repeat.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: repeat_times as usize, + return_type: &DataType::Utf8, + })) + }) + }, + ); + + group.finish(); + + // REPEAT overflow + let repeat_times = 1073741824; + let mut group = c.benchmark_group(format!("repeat {} times", repeat_times)); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + let args = create_args::(size, 2, repeat_times, false); + group.bench_function( + format!( + "repeat_string overflow [size={}, repeat_times={}]", + size, repeat_times + ), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(repeat.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: repeat_times as usize, + return_type: &DataType::Utf8, + })) }) }, ); diff --git a/datafusion/functions/benches/to_hex.rs b/datafusion/functions/benches/to_hex.rs new file mode 100644 index 000000000000..a45d936c0a52 --- /dev/null +++ b/datafusion/functions/benches/to_hex.rs @@ -0,0 +1,65 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::datatypes::{DataType, Int32Type, Int64Type}; +use arrow::util::bench_util::create_primitive_array; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::string; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let hex = string::to_hex(); + let size = 1024; + let i32_array = Arc::new(create_primitive_array::(size, 0.2)); + let batch_len = i32_array.len(); + let i32_args = vec![ColumnarValue::Array(i32_array)]; + c.bench_function(&format!("to_hex i32 array: {}", size), |b| { + b.iter(|| { + let args_cloned = i32_args.clone(); + black_box( + hex.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: batch_len, + return_type: &DataType::Utf8, + }) + .unwrap(), + ) + }) + }); + let i64_array = Arc::new(create_primitive_array::(size, 0.2)); + let batch_len = i64_array.len(); + let i64_args = vec![ColumnarValue::Array(i64_array)]; + c.bench_function(&format!("to_hex i64 array: {}", size), |b| { + b.iter(|| { + let args_cloned = i64_args.clone(); + black_box( + hex.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: batch_len, + return_type: &DataType::Utf8, + }) + .unwrap(), + ) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/upper.rs b/datafusion/functions/benches/upper.rs index 9b41a15b11c7..f0bee89c7d37 100644 --- a/datafusion/functions/benches/upper.rs +++ b/datafusion/functions/benches/upper.rs @@ -17,9 +17,10 @@ extern crate criterion; +use arrow::datatypes::DataType; use arrow::util::bench_util::create_string_array_with_len; use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string; use std::sync::Arc; @@ -38,8 +39,12 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args(size, 32); c.bench_function("upper_all_values_are_ascii", |b| { b.iter(|| { - // TODO use invoke_with_args - black_box(upper.invoke_batch(&args, size)) + let args_cloned = args.clone(); + black_box(upper.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: size, + return_type: &DataType::Utf8, + })) }) }); } diff --git a/datafusion/functions/benches/uuid.rs b/datafusion/functions/benches/uuid.rs new file mode 100644 index 000000000000..7b8d156fec21 --- /dev/null +++ b/datafusion/functions/benches/uuid.rs @@ -0,0 +1,39 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::datatypes::DataType; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ScalarFunctionArgs; +use datafusion_functions::string; + +fn criterion_benchmark(c: &mut Criterion) { + let uuid = string::uuid(); + c.bench_function("uuid", |b| { + b.iter(|| { + black_box(uuid.invoke_with_args(ScalarFunctionArgs { + args: vec![], + number_rows: 1024, + return_type: &DataType::Utf8, + })) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 0f9f11b4eff0..2686dbf8be3c 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -22,14 +22,15 @@ use arrow::error::ArrowError; use datafusion_common::{ arrow_datafusion_err, exec_err, internal_err, Result, ScalarValue, }; -use datafusion_common::{exec_datafusion_err, DataFusionError}; +use datafusion_common::{ + exec_datafusion_err, utils::take_function_args, DataFusionError, +}; use std::any::Any; -use crate::utils::take_function_args; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl, - Signature, Volatility, + ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; @@ -137,11 +138,7 @@ impl ScalarUDFImpl for ArrowCastFunc { ) } - fn invoke_batch( - &self, - _args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { internal_err!("arrow_cast should have been simplified to cast") } @@ -178,13 +175,12 @@ impl ScalarUDFImpl for ArrowCastFunc { /// Returns the requested type from the arguments fn data_type_from_args(args: &[Expr]) -> Result { - if args.len() != 2 { - return exec_err!("arrow_cast needs 2 arguments, {} provided", args.len()); - } - let Expr::Literal(ScalarValue::Utf8(Some(val))) = &args[1] else { + let [_, type_arg] = take_function_args("arrow_cast", args)?; + + let Expr::Literal(ScalarValue::Utf8(Some(val))) = type_arg else { return exec_err!( "arrow_cast requires its second argument to be a constant string, got {:?}", - &args[1] + type_arg ); }; diff --git a/datafusion/functions/src/core/arrowtypeof.rs b/datafusion/functions/src/core/arrowtypeof.rs index 3c672384ffa1..2509ed246ac7 100644 --- a/datafusion/functions/src/core/arrowtypeof.rs +++ b/datafusion/functions/src/core/arrowtypeof.rs @@ -15,10 +15,9 @@ // specific language governing permissions and limitations // under the License. -use crate::utils::take_function_args; use arrow::datatypes::DataType; -use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, Documentation}; +use datafusion_common::{utils::take_function_args, Result, ScalarValue}; +use datafusion_expr::{ColumnarValue, Documentation, ScalarFunctionArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; use std::any::Any; @@ -76,12 +75,8 @@ impl ScalarUDFImpl for ArrowTypeOfFunc { Ok(DataType::Utf8) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - let [arg] = take_function_args(self.name(), args)?; + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [arg] = take_function_args(self.name(), args.args)?; let input_data_type = arg.data_type(); Ok(ColumnarValue::Scalar(ScalarValue::from(format!( "{input_data_type}" diff --git a/datafusion/functions/src/core/coalesce.rs b/datafusion/functions/src/core/coalesce.rs index 602fe0fd9585..ba20c23828eb 100644 --- a/datafusion/functions/src/core/coalesce.rs +++ b/datafusion/functions/src/core/coalesce.rs @@ -21,7 +21,9 @@ use arrow::compute::{and, is_not_null, is_null}; use arrow::datatypes::DataType; use datafusion_common::{exec_err, internal_err, Result}; use datafusion_expr::binary::try_type_union_resolution; -use datafusion_expr::{ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs}; +use datafusion_expr::{ + ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, +}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; use itertools::Itertools; @@ -93,11 +95,8 @@ impl ScalarUDFImpl for CoalesceFunc { } /// coalesce evaluates to the first value which is not NULL - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = args.args; // do not accept 0 arguments. if args.is_empty() { return exec_err!( diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 8533b3123d51..3ac26b98359b 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -15,16 +15,21 @@ // specific language governing permissions and limitations // under the License. -use crate::utils::take_function_args; use arrow::array::{ - make_array, Array, Capacities, MutableArrayData, Scalar, StringArray, + make_array, make_comparator, Array, BooleanArray, Capacities, MutableArrayData, + Scalar, }; +use arrow::compute::SortOptions; use arrow::datatypes::DataType; +use arrow_buffer::NullBuffer; use datafusion_common::cast::{as_map_array, as_struct_array}; use datafusion_common::{ - exec_err, internal_err, plan_datafusion_err, Result, ScalarValue, + exec_err, internal_err, plan_datafusion_err, utils::take_function_args, Result, + ScalarValue, +}; +use datafusion_expr::{ + ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, }; -use datafusion_expr::{ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; use std::any::Any; @@ -104,34 +109,20 @@ impl ScalarUDFImpl for GetFieldFunc { let name = match field_name { Expr::Literal(name) => name, - _ => { - return exec_err!( - "get_field function requires the argument field_name to be a string" - ); - } + other => &ScalarValue::Utf8(Some(other.schema_name().to_string())), }; Ok(format!("{base}[{name}]")) } fn schema_name(&self, args: &[Expr]) -> Result { - if args.len() != 2 { - return exec_err!( - "get_field function requires 2 arguments, got {}", - args.len() - ); - } - - let name = match &args[1] { + let [base, field_name] = take_function_args(self.name(), args)?; + let name = match field_name { Expr::Literal(name) => name, - _ => { - return exec_err!( - "get_field function requires the argument field_name to be a string" - ); - } + other => &ScalarValue::Utf8(Some(other.schema_name().to_string())), }; - Ok(format!("{}[{}]", args[0].schema_name(), name)) + Ok(format!("{}[{}]", base.schema_name(), name)) } fn signature(&self) -> &Signature { @@ -175,26 +166,17 @@ impl ScalarUDFImpl for GetFieldFunc { } } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - if args.len() != 2 { - return exec_err!( - "get_field function requires 2 arguments, got {}", - args.len() - ); - } + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [base, field_name] = take_function_args(self.name(), args.args)?; - if args[0].data_type().is_null() { + if base.data_type().is_null() { return Ok(ColumnarValue::Scalar(ScalarValue::Null)); } - let arrays = ColumnarValue::values_to_arrays(args)?; + let arrays = + ColumnarValue::values_to_arrays(&[base.clone(), field_name.clone()])?; let array = Arc::clone(&arrays[0]); - - let name = match &args[1] { + let name = match field_name { ColumnarValue::Scalar(name) => name, _ => { return exec_err!( @@ -203,42 +185,74 @@ impl ScalarUDFImpl for GetFieldFunc { } }; - match (array.data_type(), name) { - (DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => { - let map_array = as_map_array(array.as_ref())?; - let key_scalar: Scalar>> = Scalar::new(StringArray::from(vec![k.clone()])); - let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?; + fn process_map_array( + array: Arc, + key_array: Arc, + ) -> Result { + let map_array = as_map_array(array.as_ref())?; + let keys = if key_array.data_type().is_nested() { + let comparator = make_comparator( + map_array.keys().as_ref(), + key_array.as_ref(), + SortOptions::default(), + )?; + let len = map_array.keys().len().min(key_array.len()); + let values = (0..len).map(|i| comparator(i, i).is_eq()).collect(); + let nulls = + NullBuffer::union(map_array.keys().nulls(), key_array.nulls()); + BooleanArray::new(values, nulls) + } else { + let be_compared = Scalar::new(key_array); + arrow::compute::kernels::cmp::eq(&be_compared, map_array.keys())? + }; + + let original_data = map_array.entries().column(1).to_data(); + let capacity = Capacities::Array(original_data.len()); + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], true, capacity); + + for entry in 0..map_array.len() { + let start = map_array.value_offsets()[entry] as usize; + let end = map_array.value_offsets()[entry + 1] as usize; - // note that this array has more entries than the expected output/input size - // because map_array is flattened - let original_data = map_array.entries().column(1).to_data(); - let capacity = Capacities::Array(original_data.len()); - let mut mutable = - MutableArrayData::with_capacities(vec![&original_data], true, - capacity); + let maybe_matched = keys + .slice(start, end - start) + .iter() + .enumerate() + .find(|(_, t)| t.unwrap()); - for entry in 0..map_array.len(){ - let start = map_array.value_offsets()[entry] as usize; - let end = map_array.value_offsets()[entry + 1] as usize; + if maybe_matched.is_none() { + mutable.extend_nulls(1); + continue; + } + let (match_offset, _) = maybe_matched.unwrap(); + mutable.extend(0, start + match_offset, start + match_offset + 1); + } + + let data = mutable.freeze(); + let data = make_array(data); + Ok(ColumnarValue::Array(data)) + } - let maybe_matched = - keys.slice(start, end-start). - iter().enumerate(). - find(|(_, t)| t.unwrap()); - if maybe_matched.is_none() { - mutable.extend_nulls(1); - continue - } - let (match_offset,_) = maybe_matched.unwrap(); - mutable.extend(0, start + match_offset, start + match_offset + 1); + match (array.data_type(), name) { + (DataType::Map(_, _), ScalarValue::List(arr)) => { + let key_array: Arc = arr; + process_map_array(array, key_array) + } + (DataType::Map(_, _), ScalarValue::Struct(arr)) => { + process_map_array(array, arr as Arc) + } + (DataType::Map(_, _), other) => { + let data_type = other.data_type(); + if data_type.is_nested() { + exec_err!("unsupported type {:?} for map access", data_type) + } else { + process_map_array(array, other.to_array()?) } - let data = mutable.freeze(); - let data = make_array(data); - Ok(ColumnarValue::Array(data)) } (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { let as_struct_array = as_struct_array(&array)?; - match as_struct_array.column_by_name(k) { + match as_struct_array.column_by_name(&k) { None => exec_err!("get indexed field {k} not found in struct"), Some(col) => Ok(ColumnarValue::Array(Arc::clone(col))), } diff --git a/datafusion/functions/src/core/greatest.rs b/datafusion/functions/src/core/greatest.rs index 6864da2d5c06..2d7ad2be3986 100644 --- a/datafusion/functions/src/core/greatest.rs +++ b/datafusion/functions/src/core/greatest.rs @@ -23,7 +23,7 @@ use arrow::compute::SortOptions; use arrow::datatypes::DataType; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_doc::Documentation; -use datafusion_expr::ColumnarValue; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; use std::any::Any; @@ -143,8 +143,8 @@ impl ScalarUDFImpl for GreatestFunc { Ok(arg_types[0].clone()) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { - super::greatest_least_utils::execute_conditional::(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + super::greatest_least_utils::execute_conditional::(&args.args) } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { diff --git a/datafusion/functions/src/core/least.rs b/datafusion/functions/src/core/least.rs index a26b14babf2c..662dac3e699f 100644 --- a/datafusion/functions/src/core/least.rs +++ b/datafusion/functions/src/core/least.rs @@ -23,7 +23,7 @@ use arrow::compute::SortOptions; use arrow::datatypes::DataType; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_doc::Documentation; -use datafusion_expr::ColumnarValue; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; use std::any::Any; @@ -156,8 +156,8 @@ impl ScalarUDFImpl for LeastFunc { Ok(arg_types[0].clone()) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { - super::greatest_least_utils::execute_conditional::(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + super::greatest_least_utils::execute_conditional::(&args.args) } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index 76fb4bbe5b47..425ce78decbe 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -34,6 +34,7 @@ pub mod nvl; pub mod nvl2; pub mod planner; pub mod r#struct; +pub mod union_extract; pub mod version; // create UDFs @@ -48,6 +49,7 @@ make_udf_function!(getfield::GetFieldFunc, get_field); make_udf_function!(coalesce::CoalesceFunc, coalesce); make_udf_function!(greatest::GreatestFunc, greatest); make_udf_function!(least::LeastFunc, least); +make_udf_function!(union_extract::UnionExtractFun, union_extract); make_udf_function!(version::VersionFunc, version); pub mod expr_fn { @@ -99,6 +101,11 @@ pub mod expr_fn { pub fn get_field(arg1: Expr, arg2: impl Literal) -> Expr { super::get_field().call(vec![arg1, arg2.lit()]) } + + #[doc = "Returns the value of the field with the given name from the union when it's selected, or NULL otherwise"] + pub fn union_extract(arg1: Expr, arg2: impl Literal) -> Expr { + super::union_extract().call(vec![arg1, arg2.lit()]) + } } /// Returns all DataFusion functions defined in this package @@ -121,6 +128,7 @@ pub fn functions() -> Vec> { coalesce(), greatest(), least(), + union_extract(), version(), r#struct(), ] diff --git a/datafusion/functions/src/core/nullif.rs b/datafusion/functions/src/core/nullif.rs index a0f3c8b8a452..ee29714da16b 100644 --- a/datafusion/functions/src/core/nullif.rs +++ b/datafusion/functions/src/core/nullif.rs @@ -16,13 +16,11 @@ // under the License. use arrow::datatypes::DataType; -use datafusion_common::Result; -use datafusion_expr::{ColumnarValue, Documentation}; +use datafusion_expr::{ColumnarValue, Documentation, ScalarFunctionArgs}; -use crate::utils::take_function_args; use arrow::compute::kernels::cmp::eq; use arrow::compute::kernels::nullif::nullif; -use datafusion_common::ScalarValue; +use datafusion_common::{utils::take_function_args, Result, ScalarValue}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; use std::any::Any; @@ -103,12 +101,8 @@ impl ScalarUDFImpl for NullIfFunc { Ok(arg_types[0].to_owned()) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - nullif_func(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + nullif_func(&args.args) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/core/nvl.rs b/datafusion/functions/src/core/nvl.rs index 5b306c8093cb..82d367072a25 100644 --- a/datafusion/functions/src/core/nvl.rs +++ b/datafusion/functions/src/core/nvl.rs @@ -15,14 +15,14 @@ // specific language governing permissions and limitations // under the License. -use crate::utils::take_function_args; use arrow::array::Array; use arrow::compute::is_not_null; use arrow::compute::kernels::zip::zip; use arrow::datatypes::DataType; -use datafusion_common::Result; +use datafusion_common::{utils::take_function_args, Result}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; use std::sync::Arc; @@ -117,12 +117,8 @@ impl ScalarUDFImpl for NVLFunc { Ok(arg_types[0].clone()) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - nvl_func(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + nvl_func(&args.args) } fn aliases(&self) -> &[String] { diff --git a/datafusion/functions/src/core/nvl2.rs b/datafusion/functions/src/core/nvl2.rs index b1f8e4e5c213..d20b01e29fba 100644 --- a/datafusion/functions/src/core/nvl2.rs +++ b/datafusion/functions/src/core/nvl2.rs @@ -15,15 +15,14 @@ // specific language governing permissions and limitations // under the License. -use crate::utils::take_function_args; use arrow::array::Array; use arrow::compute::is_not_null; use arrow::compute::kernels::zip::zip; use arrow::datatypes::DataType; -use datafusion_common::{internal_err, Result}; +use datafusion_common::{internal_err, utils::take_function_args, Result}; use datafusion_expr::{ type_coercion::binary::comparison_coercion, ColumnarValue, Documentation, - ScalarUDFImpl, Signature, Volatility, + ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; use std::sync::Arc; @@ -96,12 +95,8 @@ impl ScalarUDFImpl for NVL2Func { Ok(arg_types[1].clone()) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - nvl2_func(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + nvl2_func(&args.args) } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs new file mode 100644 index 000000000000..95814197d8df --- /dev/null +++ b/datafusion/functions/src/core/union_extract.rs @@ -0,0 +1,246 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::Array; +use arrow::datatypes::{DataType, FieldRef, UnionFields}; +use datafusion_common::cast::as_union_array; +use datafusion_common::utils::take_function_args; +use datafusion_common::{ + exec_datafusion_err, exec_err, internal_err, Result, ScalarValue, +}; +use datafusion_doc::Documentation; +use datafusion_expr::{ColumnarValue, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs}; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_macros::user_doc; + +#[user_doc( + doc_section(label = "Union Functions"), + description = "Returns the value of the given field in the union when selected, or NULL otherwise.", + syntax_example = "union_extract(union, field_name)", + sql_example = r#"```sql +❯ select union_column, union_extract(union_column, 'a'), union_extract(union_column, 'b') from table_with_union; ++--------------+----------------------------------+----------------------------------+ +| union_column | union_extract(union_column, 'a') | union_extract(union_column, 'b') | ++--------------+----------------------------------+----------------------------------+ +| {a=1} | 1 | | +| {b=3.0} | | 3.0 | +| {a=4} | 4 | | +| {b=} | | | +| {a=} | | | ++--------------+----------------------------------+----------------------------------+ +```"#, + standard_argument(name = "union", prefix = "Union"), + argument( + name = "field_name", + description = "String expression to operate on. Must be a constant." + ) +)] +#[derive(Debug)] +pub struct UnionExtractFun { + signature: Signature, +} + +impl Default for UnionExtractFun { + fn default() -> Self { + Self::new() + } +} + +impl UnionExtractFun { + pub fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for UnionExtractFun { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "union_extract" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> Result { + // should be using return_type_from_exprs and not calling the default implementation + internal_err!("union_extract should return type from exprs") + } + + fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + if args.arg_types.len() != 2 { + return exec_err!( + "union_extract expects 2 arguments, got {} instead", + args.arg_types.len() + ); + } + + let DataType::Union(fields, _) = &args.arg_types[0] else { + return exec_err!( + "union_extract first argument must be a union, got {} instead", + args.arg_types[0] + ); + }; + + let Some(ScalarValue::Utf8(Some(field_name))) = &args.scalar_arguments[1] else { + return exec_err!( + "union_extract second argument must be a non-null string literal, got {} instead", + args.arg_types[1] + ); + }; + + let field = find_field(fields, field_name)?.1; + + Ok(ReturnInfo::new_nullable(field.data_type().clone())) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [array, target_name] = take_function_args("union_extract", args.args)?; + + let target_name = match target_name { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(target_name))) => Ok(target_name), + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => exec_err!("union_extract second argument must be a non-null string literal, got a null instead"), + _ => exec_err!("union_extract second argument must be a non-null string literal, got {} instead", target_name.data_type()), + }?; + + match array { + ColumnarValue::Array(array) => { + let union_array = as_union_array(&array).map_err(|_| { + exec_datafusion_err!( + "union_extract first argument must be a union, got {} instead", + array.data_type() + ) + })?; + + Ok(ColumnarValue::Array( + arrow::compute::kernels::union_extract::union_extract( + union_array, + &target_name, + )?, + )) + } + ColumnarValue::Scalar(ScalarValue::Union(value, fields, _)) => { + let (target_type_id, target) = find_field(&fields, &target_name)?; + + let result = match value { + Some((type_id, value)) if target_type_id == type_id => *value, + _ => ScalarValue::try_new_null(target.data_type())?, + }; + + Ok(ColumnarValue::Scalar(result)) + } + other => exec_err!( + "union_extract first argument must be a union, got {} instead", + other.data_type() + ), + } + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn find_field<'a>(fields: &'a UnionFields, name: &str) -> Result<(i8, &'a FieldRef)> { + fields + .iter() + .find(|field| field.1.name() == name) + .ok_or_else(|| exec_datafusion_err!("field {name} not found on union")) +} + +#[cfg(test)] +mod tests { + + use arrow::datatypes::{DataType, Field, UnionFields, UnionMode}; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; + + use super::UnionExtractFun; + + // when it becomes possible to construct union scalars in SQL, this should go to sqllogictests + #[test] + fn test_scalar_value() -> Result<()> { + let fun = UnionExtractFun::new(); + + let fields = UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ); + + let result = fun.invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Union( + None, + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ], + number_rows: 1, + return_type: &DataType::Utf8, + })?; + + assert_scalar(result, ScalarValue::Utf8(None)); + + let result = fun.invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Union( + Some((3, Box::new(ScalarValue::Int32(Some(42))))), + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ], + number_rows: 1, + return_type: &DataType::Utf8, + })?; + + assert_scalar(result, ScalarValue::Utf8(None)); + + let result = fun.invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Union( + Some((1, Box::new(ScalarValue::new_utf8("42")))), + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ], + number_rows: 1, + return_type: &DataType::Utf8, + })?; + + assert_scalar(result, ScalarValue::new_utf8("42")); + + Ok(()) + } + + fn assert_scalar(value: ColumnarValue, expected: ScalarValue) { + match value { + ColumnarValue::Array(array) => panic!("expected scalar got {array:?}"), + ColumnarValue::Scalar(scalar) => assert_eq!(scalar, expected), + } + } +} diff --git a/datafusion/functions/src/core/version.rs b/datafusion/functions/src/core/version.rs index 139763af7b38..34038022f2dc 100644 --- a/datafusion/functions/src/core/version.rs +++ b/datafusion/functions/src/core/version.rs @@ -17,11 +17,11 @@ //! [`VersionFunc`]: Implementation of the `version` function. -use crate::utils::take_function_args; use arrow::datatypes::DataType; -use datafusion_common::{Result, ScalarValue}; +use datafusion_common::{utils::take_function_args, Result, ScalarValue}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; use std::any::Any; @@ -76,12 +76,8 @@ impl ScalarUDFImpl for VersionFunc { Ok(DataType::Utf8) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - let [] = take_function_args(self.name(), args)?; + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [] = take_function_args(self.name(), args.args)?; // TODO it would be great to add rust version and arrow version, // but that requires a `build.rs` script and/or adding a version const to arrow-rs let version = format!( @@ -106,8 +102,13 @@ mod test { #[tokio::test] async fn test_version_udf() { let version_udf = ScalarUDF::from(VersionFunc::new()); - #[allow(deprecated)] // TODO: migrate to invoke_with_args - let version = version_udf.invoke_batch(&[], 1).unwrap(); + let version = version_udf + .invoke_with_args(ScalarFunctionArgs { + args: vec![], + number_rows: 0, + return_type: &DataType::Utf8, + }) + .unwrap(); if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(version))) = version { assert!(version.starts_with("Apache DataFusion")); diff --git a/datafusion/functions/src/crypto/basic.rs b/datafusion/functions/src/crypto/basic.rs index a15b9b57cff6..191154b8f8ff 100644 --- a/datafusion/functions/src/crypto/basic.rs +++ b/datafusion/functions/src/crypto/basic.rs @@ -24,12 +24,10 @@ use blake2::{Blake2b512, Blake2s256, Digest}; use blake3::Hasher as Blake3; use datafusion_common::cast::as_binary_array; -use crate::utils::take_function_args; use arrow::compute::StringArrayType; -use datafusion_common::plan_err; use datafusion_common::{ - cast::as_generic_binary_array, exec_err, internal_err, DataFusionError, Result, - ScalarValue, + cast::as_generic_binary_array, exec_err, internal_err, plan_err, + utils::take_function_args, DataFusionError, Result, ScalarValue, }; use datafusion_expr::ColumnarValue; use md5::Md5; diff --git a/datafusion/functions/src/crypto/digest.rs b/datafusion/functions/src/crypto/digest.rs index cc52f32614fd..4f9d4605fe07 100644 --- a/datafusion/functions/src/crypto/digest.rs +++ b/datafusion/functions/src/crypto/digest.rs @@ -20,7 +20,8 @@ use super::basic::{digest, utf8_or_binary_to_binary_type}; use arrow::datatypes::DataType; use datafusion_common::Result; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature::*, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignature::*, Volatility, }; use datafusion_macros::user_doc; use std::any::Any; @@ -94,12 +95,8 @@ impl ScalarUDFImpl for DigestFunc { fn return_type(&self, arg_types: &[DataType]) -> Result { utf8_or_binary_to_binary_type(&arg_types[0], self.name()) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - digest(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + digest(&args.args) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/crypto/md5.rs b/datafusion/functions/src/crypto/md5.rs index 636ca65735c9..18ad0d6a7ded 100644 --- a/datafusion/functions/src/crypto/md5.rs +++ b/datafusion/functions/src/crypto/md5.rs @@ -20,7 +20,8 @@ use crate::crypto::basic::md5; use arrow::datatypes::DataType; use datafusion_common::{plan_err, Result}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; use std::any::Any; @@ -98,12 +99,8 @@ impl ScalarUDFImpl for Md5Func { } }) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - md5(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + md5(&args.args) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/crypto/sha224.rs b/datafusion/functions/src/crypto/sha224.rs index 341b3495f9c6..24fe5e119df3 100644 --- a/datafusion/functions/src/crypto/sha224.rs +++ b/datafusion/functions/src/crypto/sha224.rs @@ -20,7 +20,8 @@ use super::basic::{sha224, utf8_or_binary_to_binary_type}; use arrow::datatypes::DataType; use datafusion_common::Result; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; use std::any::Any; @@ -80,12 +81,8 @@ impl ScalarUDFImpl for SHA224Func { utf8_or_binary_to_binary_type(&arg_types[0], self.name()) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - sha224(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + sha224(&args.args) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/crypto/sha256.rs b/datafusion/functions/src/crypto/sha256.rs index f40dd99c59fe..c48dda19cbc5 100644 --- a/datafusion/functions/src/crypto/sha256.rs +++ b/datafusion/functions/src/crypto/sha256.rs @@ -20,7 +20,8 @@ use super::basic::{sha256, utf8_or_binary_to_binary_type}; use arrow::datatypes::DataType; use datafusion_common::Result; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; use std::any::Any; @@ -78,12 +79,8 @@ impl ScalarUDFImpl for SHA256Func { utf8_or_binary_to_binary_type(&arg_types[0], self.name()) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - sha256(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + sha256(&args.args) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/crypto/sha384.rs b/datafusion/functions/src/crypto/sha384.rs index e38a755826f8..11d1d130e929 100644 --- a/datafusion/functions/src/crypto/sha384.rs +++ b/datafusion/functions/src/crypto/sha384.rs @@ -20,7 +20,8 @@ use super::basic::{sha384, utf8_or_binary_to_binary_type}; use arrow::datatypes::DataType; use datafusion_common::Result; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; use std::any::Any; @@ -78,12 +79,8 @@ impl ScalarUDFImpl for SHA384Func { utf8_or_binary_to_binary_type(&arg_types[0], self.name()) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - sha384(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + sha384(&args.args) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/crypto/sha512.rs b/datafusion/functions/src/crypto/sha512.rs index 7fe2a26ebbce..26fa85a5da3a 100644 --- a/datafusion/functions/src/crypto/sha512.rs +++ b/datafusion/functions/src/crypto/sha512.rs @@ -20,7 +20,8 @@ use super::basic::{sha512, utf8_or_binary_to_binary_type}; use arrow::datatypes::DataType; use datafusion_common::Result; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; use std::any::Any; @@ -78,12 +79,8 @@ impl ScalarUDFImpl for SHA512Func { utf8_or_binary_to_binary_type(&arg_types[0], self.name()) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - sha512(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + sha512(&args.args) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index c7dbf089e530..49b7a4ec462a 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -27,9 +27,8 @@ use arrow::datatypes::DataType::{ }; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; use arrow::datatypes::{DataType, TimeUnit}; +use datafusion_common::types::{logical_date, NativeType}; -use crate::utils::take_function_args; -use datafusion_common::not_impl_err; use datafusion_common::{ cast::{ as_date32_array, as_date64_array, as_int32_array, as_time32_millisecond_array, @@ -37,15 +36,16 @@ use datafusion_common::{ as_timestamp_microsecond_array, as_timestamp_millisecond_array, as_timestamp_nanosecond_array, as_timestamp_second_array, }, - exec_err, internal_err, + exec_err, internal_err, not_impl_err, types::logical_string, + utils::take_function_args, Result, ScalarValue, }; use datafusion_expr::{ ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; -use datafusion_expr_common::signature::TypeSignatureClass; +use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; use datafusion_macros::user_doc; #[user_doc( @@ -96,24 +96,29 @@ impl DatePartFunc { signature: Signature::one_of( vec![ TypeSignature::Coercible(vec![ - TypeSignatureClass::Native(logical_string()), - TypeSignatureClass::Timestamp, + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_implicit( + TypeSignatureClass::Timestamp, + // Not consistent with Postgres and DuckDB but to avoid regression we implicit cast string to timestamp + vec![TypeSignatureClass::Native(logical_string())], + NativeType::Timestamp(Nanosecond, None), + ), ]), TypeSignature::Coercible(vec![ - TypeSignatureClass::Native(logical_string()), - TypeSignatureClass::Date, + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Native(logical_date())), ]), TypeSignature::Coercible(vec![ - TypeSignatureClass::Native(logical_string()), - TypeSignatureClass::Time, + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Time), ]), TypeSignature::Coercible(vec![ - TypeSignatureClass::Native(logical_string()), - TypeSignatureClass::Interval, + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Interval), ]), TypeSignature::Coercible(vec![ - TypeSignatureClass::Native(logical_string()), - TypeSignatureClass::Duration, + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Duration), ]), ], Volatility::Immutable, @@ -167,10 +172,7 @@ impl ScalarUDFImpl for DatePartFunc { args: &[ColumnarValue], _number_rows: usize, ) -> Result { - if args.len() != 2 { - return exec_err!("Expected two arguments in DATE_PART"); - } - let (part, array) = (&args[0], &args[1]); + let [part, array] = take_function_args(self.name(), args)?; let part = if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = part { v diff --git a/datafusion/functions/src/datetime/date_trunc.rs b/datafusion/functions/src/datetime/date_trunc.rs index 4780f5f5b818..7c10cdd0029d 100644 --- a/datafusion/functions/src/datetime/date_trunc.rs +++ b/datafusion/functions/src/datetime/date_trunc.rs @@ -185,10 +185,10 @@ impl ScalarUDFImpl for DateTruncFunc { ) -> Result { let parsed_tz = parse_tz(tz_opt)?; let array = as_primitive_array::(array)?; - let array = array - .iter() - .map(|x| general_date_trunc(T::UNIT, &x, parsed_tz, granularity.as_str())) - .collect::>>()? + let array: PrimitiveArray = array + .try_unary(|x| { + general_date_trunc(T::UNIT, x, parsed_tz, granularity.as_str()) + })? .with_timezone_opt(tz_opt.clone()); Ok(ColumnarValue::Array(Arc::new(array))) } @@ -199,7 +199,16 @@ impl ScalarUDFImpl for DateTruncFunc { tz_opt: &Option>, ) -> Result { let parsed_tz = parse_tz(tz_opt)?; - let value = general_date_trunc(T::UNIT, v, parsed_tz, granularity.as_str())?; + let value = if let Some(v) = v { + Some(general_date_trunc( + T::UNIT, + *v, + parsed_tz, + granularity.as_str(), + )?) + } else { + None + }; let value = ScalarValue::new_timestamp::(value, tz_opt.clone()); Ok(ColumnarValue::Scalar(value)) } @@ -417,10 +426,10 @@ fn date_trunc_coarse(granularity: &str, value: i64, tz: Option) -> Result, + value: i64, tz: Option, granularity: &str, -) -> Result, DataFusionError> { +) -> Result { let scale = match tu { Second => 1_000_000_000, Millisecond => 1_000_000, @@ -428,35 +437,31 @@ fn general_date_trunc( Nanosecond => 1, }; - let Some(value) = value else { - return Ok(None); - }; - // convert to nanoseconds let nano = date_trunc_coarse(granularity, scale * value, tz)?; let result = match tu { Second => match granularity { - "minute" => Some(nano / 1_000_000_000 / 60 * 60), - _ => Some(nano / 1_000_000_000), + "minute" => nano / 1_000_000_000 / 60 * 60, + _ => nano / 1_000_000_000, }, Millisecond => match granularity { - "minute" => Some(nano / 1_000_000 / 1_000 / 60 * 1_000 * 60), - "second" => Some(nano / 1_000_000 / 1_000 * 1_000), - _ => Some(nano / 1_000_000), + "minute" => nano / 1_000_000 / 1_000 / 60 * 1_000 * 60, + "second" => nano / 1_000_000 / 1_000 * 1_000, + _ => nano / 1_000_000, }, Microsecond => match granularity { - "minute" => Some(nano / 1_000 / 1_000_000 / 60 * 60 * 1_000_000), - "second" => Some(nano / 1_000 / 1_000_000 * 1_000_000), - "millisecond" => Some(nano / 1_000 / 1_000 * 1_000), - _ => Some(nano / 1_000), + "minute" => nano / 1_000 / 1_000_000 / 60 * 60 * 1_000_000, + "second" => nano / 1_000 / 1_000_000 * 1_000_000, + "millisecond" => nano / 1_000 / 1_000 * 1_000, + _ => nano / 1_000, }, _ => match granularity { - "minute" => Some(nano / 1_000_000_000 / 60 * 1_000_000_000 * 60), - "second" => Some(nano / 1_000_000_000 * 1_000_000_000), - "millisecond" => Some(nano / 1_000_000 * 1_000_000), - "microsecond" => Some(nano / 1_000 * 1_000), - _ => Some(nano), + "minute" => nano / 1_000_000_000 / 60 * 1_000_000_000 * 60, + "second" => nano / 1_000_000_000 * 1_000_000_000, + "millisecond" => nano / 1_000_000 * 1_000_000, + "microsecond" => nano / 1_000 * 1_000, + _ => nano, }, }; Ok(result) diff --git a/datafusion/functions/src/datetime/make_date.rs b/datafusion/functions/src/datetime/make_date.rs index 2d4db56cc788..f081dfd11ecf 100644 --- a/datafusion/functions/src/datetime/make_date.rs +++ b/datafusion/functions/src/datetime/make_date.rs @@ -26,8 +26,7 @@ use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Date32, Int32, Int64, UInt32, UInt64, Utf8, Utf8View}; use chrono::prelude::*; -use crate::utils::take_function_args; -use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_common::{exec_err, utils::take_function_args, Result, ScalarValue}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; diff --git a/datafusion/functions/src/datetime/mod.rs b/datafusion/functions/src/datetime/mod.rs index eec5e3cef624..dee40215c9ea 100644 --- a/datafusion/functions/src/datetime/mod.rs +++ b/datafusion/functions/src/datetime/mod.rs @@ -137,7 +137,7 @@ pub mod expr_fn { /// # use datafusion_common::ScalarValue::TimestampNanosecond; /// # use std::sync::Arc; /// # use arrow::array::{Date32Array, RecordBatch, StringArray}; - /// # use arrow_schema::{DataType, Field, Schema}; + /// # use arrow::datatypes::{DataType, Field, Schema}; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let schema = Arc::new(Schema::new(vec![ diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index 485fdc7a3384..b049ca01ac97 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -28,8 +28,7 @@ use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; use arrow::error::ArrowError; use arrow::util::display::{ArrayFormatter, DurationFormat, FormatOptions}; -use crate::utils::take_function_args; -use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_common::{exec_err, utils::take_function_args, Result, ScalarValue}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index b350819a55ec..0e235735e29f 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -27,10 +27,12 @@ use arrow::datatypes::{ ArrowTimestampType, DataType, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; - use chrono::{DateTime, MappedLocalTime, Offset, TimeDelta, TimeZone, Utc}; + use datafusion_common::cast::as_primitive_array; -use datafusion_common::{exec_err, plan_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + exec_err, plan_err, utils::take_function_args, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -113,14 +115,8 @@ impl ToLocalTimeFunc { } fn to_local_time(&self, args: &[ColumnarValue]) -> Result { - if args.len() != 1 { - return exec_err!( - "to_local_time function requires 1 argument, got {}", - args.len() - ); - } + let [time_value] = take_function_args(self.name(), args)?; - let time_value = &args[0]; let arg_type = time_value.data_type(); match arg_type { Timestamp(_, None) => { @@ -360,17 +356,12 @@ impl ScalarUDFImpl for ToLocalTimeFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types.len() != 1 { - return exec_err!( - "to_local_time function requires 1 argument, got {:?}", - arg_types.len() - ); - } + let [time_value] = take_function_args(self.name(), arg_types)?; - match &arg_types[0] { + match time_value { Timestamp(timeunit, _) => Ok(Timestamp(*timeunit, None)), _ => exec_err!( - "The to_local_time function can only accept timestamp as the arg, got {:?}", arg_types[0] + "The to_local_time function can only accept timestamp as the arg, got {:?}", time_value ) } } @@ -380,14 +371,9 @@ impl ScalarUDFImpl for ToLocalTimeFunc { args: &[ColumnarValue], _number_rows: usize, ) -> Result { - if args.len() != 1 { - return exec_err!( - "to_local_time function requires 1 argument, got {:?}", - args.len() - ); - } + let [time_value] = take_function_args(self.name(), args)?; - self.to_local_time(args) + self.to_local_time(&[time_value.clone()]) } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { diff --git a/datafusion/functions/src/encoding/inner.rs b/datafusion/functions/src/encoding/inner.rs index a5338ff76592..51e8c6968866 100644 --- a/datafusion/functions/src/encoding/inner.rs +++ b/datafusion/functions/src/encoding/inner.rs @@ -28,6 +28,7 @@ use base64::{engine::general_purpose, Engine as _}; use datafusion_common::{ cast::{as_generic_binary_array, as_generic_string_array}, not_impl_err, plan_err, + utils::take_function_args, }; use datafusion_common::{exec_err, ScalarValue}; use datafusion_common::{DataFusionError, Result}; @@ -102,28 +103,21 @@ impl ScalarUDFImpl for EncodeFunc { }) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - encode(args) + encode(&args.args) } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() != 2 { - return plan_err!( - "{} expects to get 2 arguments, but got {}", - self.name(), - arg_types.len() - ); - } + let [expression, format] = take_function_args(self.name(), arg_types)?; - if arg_types[1] != DataType::Utf8 { + if format != &DataType::Utf8 { return Err(DataFusionError::Plan("2nd argument should be Utf8".into())); } - match arg_types[0] { + match expression { DataType::Utf8 | DataType::Utf8View | DataType::Null => { Ok(vec![DataType::Utf8; 2]) } @@ -188,12 +182,11 @@ impl ScalarUDFImpl for DecodeFunc { Ok(arg_types[0].to_owned()) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - decode(args) + decode(&args.args) } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { @@ -539,13 +532,9 @@ impl FromStr for Encoding { /// Second argument is the encoding to use. /// Standard encodings are base64 and hex. fn encode(args: &[ColumnarValue]) -> Result { - if args.len() != 2 { - return exec_err!( - "{:?} args were supplied but encode takes exactly two arguments", - args.len() - ); - } - let encoding = match &args[1] { + let [expression, format] = take_function_args("encode", args)?; + + let encoding = match format { ColumnarValue::Scalar(scalar) => match scalar.try_as_str() { Some(Some(method)) => method.parse::(), _ => not_impl_err!( @@ -556,20 +545,16 @@ fn encode(args: &[ColumnarValue]) -> Result { "Second argument to encode must be a constant: Encode using dynamically decided method is not yet supported" ), }?; - encode_process(&args[0], encoding) + encode_process(expression, encoding) } /// Decodes the given data, accepts Binary, LargeBinary, Utf8, Utf8View or LargeUtf8 and returns a [`ColumnarValue`]. /// Second argument is the encoding to use. /// Standard encodings are base64 and hex. fn decode(args: &[ColumnarValue]) -> Result { - if args.len() != 2 { - return exec_err!( - "{:?} args were supplied but decode takes exactly two arguments", - args.len() - ); - } - let encoding = match &args[1] { + let [expression, format] = take_function_args("decode", args)?; + + let encoding = match format { ColumnarValue::Scalar(scalar) => match scalar.try_as_str() { Some(Some(method))=> method.parse::(), _ => not_impl_err!( @@ -580,5 +565,5 @@ fn decode(args: &[ColumnarValue]) -> Result { "Second argument to decode must be a utf8 constant: Decode using dynamically decided method is not yet supported" ), }?; - decode_process(&args[0], encoding) + decode_process(expression, encoding) } diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index 48eff4fcd423..d2849c3abba0 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -164,7 +164,8 @@ macro_rules! make_math_unary_udf { use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, + Signature, Volatility, }; #[derive(Debug)] @@ -218,12 +219,11 @@ macro_rules! make_math_unary_udf { $EVALUATE_BOUNDS(inputs) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: ScalarFunctionArgs, ) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; + let args = ColumnarValue::values_to_arrays(&args.args)?; let arr: ArrayRef = match args[0].data_type() { DataType::Float64 => Arc::new( args[0] @@ -278,7 +278,8 @@ macro_rules! make_math_binary_udf { use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, + Signature, Volatility, }; #[derive(Debug)] @@ -330,12 +331,11 @@ macro_rules! make_math_binary_udf { $OUTPUT_ORDERING(input) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: ScalarFunctionArgs, ) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; + let args = ColumnarValue::values_to_arrays(&args.args)?; let arr: ArrayRef = match args[0].data_type() { DataType::Float64 => { let y = args[0].as_primitive::(); diff --git a/datafusion/functions/src/math/abs.rs b/datafusion/functions/src/math/abs.rs index a375af2ad29e..0c686a59016a 100644 --- a/datafusion/functions/src/math/abs.rs +++ b/datafusion/functions/src/math/abs.rs @@ -20,18 +20,20 @@ use std::any::Any; use std::sync::Arc; -use crate::utils::take_function_args; use arrow::array::{ ArrayRef, Decimal128Array, Decimal256Array, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, }; use arrow::datatypes::DataType; use arrow::error::ArrowError; -use datafusion_common::{internal_datafusion_err, not_impl_err, Result}; +use datafusion_common::{ + internal_datafusion_err, not_impl_err, utils::take_function_args, Result, +}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -167,12 +169,8 @@ impl ScalarUDFImpl for AbsFunc { } } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = ColumnarValue::values_to_arrays(&args.args)?; let [input] = take_function_args(self.name(), args)?; let input_data_type = input.data_type(); diff --git a/datafusion/functions/src/math/cot.rs b/datafusion/functions/src/math/cot.rs index 8b4f9317fe5f..4e56212ddbee 100644 --- a/datafusion/functions/src/math/cot.rs +++ b/datafusion/functions/src/math/cot.rs @@ -24,7 +24,7 @@ use arrow::datatypes::{DataType, Float32Type, Float64Type}; use crate::utils::make_scalar_function; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, Documentation}; +use datafusion_expr::{ColumnarValue, Documentation, ScalarFunctionArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -87,12 +87,8 @@ impl ScalarUDFImpl for CotFunc { self.doc() } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - make_scalar_function(cot, vec![])(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(cot, vec![])(&args.args) } } diff --git a/datafusion/functions/src/math/factorial.rs b/datafusion/functions/src/math/factorial.rs index 18f10863a01b..c2ac21b78f21 100644 --- a/datafusion/functions/src/math/factorial.rs +++ b/datafusion/functions/src/math/factorial.rs @@ -30,7 +30,8 @@ use datafusion_common::{ arrow_datafusion_err, exec_err, internal_datafusion_err, DataFusionError, Result, }; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -76,12 +77,8 @@ impl ScalarUDFImpl for FactorialFunc { Ok(Int64) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - make_scalar_function(factorial, vec![])(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(factorial, vec![])(&args.args) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/math/gcd.rs b/datafusion/functions/src/math/gcd.rs index 14503701f661..911e00308ab7 100644 --- a/datafusion/functions/src/math/gcd.rs +++ b/datafusion/functions/src/math/gcd.rs @@ -29,7 +29,8 @@ use datafusion_common::{ arrow_datafusion_err, exec_err, internal_datafusion_err, DataFusionError, Result, }; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -77,12 +78,8 @@ impl ScalarUDFImpl for GcdFunc { Ok(Int64) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - make_scalar_function(gcd, vec![])(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(gcd, vec![])(&args.args) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/math/iszero.rs b/datafusion/functions/src/math/iszero.rs index 8e72ee285518..bc12dfb7898e 100644 --- a/datafusion/functions/src/math/iszero.rs +++ b/datafusion/functions/src/math/iszero.rs @@ -25,7 +25,8 @@ use arrow::datatypes::{DataType, Float32Type, Float64Type}; use datafusion_common::{exec_err, Result}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -77,12 +78,8 @@ impl ScalarUDFImpl for IsZeroFunc { Ok(Boolean) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - make_scalar_function(iszero, vec![])(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(iszero, vec![])(&args.args) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/math/lcm.rs b/datafusion/functions/src/math/lcm.rs index c2c72c89841d..fc6bf9461f28 100644 --- a/datafusion/functions/src/math/lcm.rs +++ b/datafusion/functions/src/math/lcm.rs @@ -27,7 +27,8 @@ use datafusion_common::{ arrow_datafusion_err, exec_err, internal_datafusion_err, DataFusionError, Result, }; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -78,12 +79,8 @@ impl ScalarUDFImpl for LcmFunc { Ok(Int64) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - make_scalar_function(lcm, vec![])(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(lcm, vec![])(&args.args) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 88a624806874..fd135f4c5ec0 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -31,7 +31,8 @@ use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ - lit, ColumnarValue, Documentation, Expr, ScalarUDF, TypeSignature::*, + lit, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF, + TypeSignature::*, }; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -117,12 +118,8 @@ impl ScalarUDFImpl for LogFunc { } // Support overloaded log(base, x) and log(x) which defaults to log(10, x) - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = ColumnarValue::values_to_arrays(&args.args)?; let mut base = ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))); @@ -267,34 +264,44 @@ mod tests { #[test] #[should_panic] fn test_log_invalid_base_type() { - let args = [ - ColumnarValue::Array(Arc::new(Float64Array::from(vec![ - 10.0, 100.0, 1000.0, 10000.0, - ]))), // num - ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 20]))), - ]; - #[allow(deprecated)] // TODO: migrate to invoke_with_args - let _ = LogFunc::new().invoke_batch(&args, 4); + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(Float64Array::from(vec![ + 10.0, 100.0, 1000.0, 10000.0, + ]))), // num + ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 20]))), + ], + number_rows: 4, + return_type: &DataType::Float64, + }; + let _ = LogFunc::new().invoke_with_args(args); } #[test] fn test_log_invalid_value() { - let args = [ - ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num - ]; - #[allow(deprecated)] // TODO: migrate to invoke_with_args - let result = LogFunc::new().invoke_batch(&args, 1); + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num + ], + number_rows: 1, + return_type: &DataType::Float64, + }; + + let result = LogFunc::new().invoke_with_args(args); result.expect_err("expected error"); } #[test] fn test_log_scalar_f32_unary() { - let args = [ - ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))), // num - ]; - #[allow(deprecated)] // TODO: migrate to invoke_with_args + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))), // num + ], + number_rows: 1, + return_type: &DataType::Float32, + }; let result = LogFunc::new() - .invoke_batch(&args, 1) + .invoke_with_args(args) .expect("failed to initialize function log"); match result { @@ -313,12 +320,15 @@ mod tests { #[test] fn test_log_scalar_f64_unary() { - let args = [ - ColumnarValue::Scalar(ScalarValue::Float64(Some(10.0))), // num - ]; - #[allow(deprecated)] // TODO: migrate to invoke_with_args + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Float64(Some(10.0))), // num + ], + number_rows: 1, + return_type: &DataType::Float64, + }; let result = LogFunc::new() - .invoke_batch(&args, 1) + .invoke_with_args(args) .expect("failed to initialize function log"); match result { @@ -337,13 +347,16 @@ mod tests { #[test] fn test_log_scalar_f32() { - let args = [ - ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // num - ColumnarValue::Scalar(ScalarValue::Float32(Some(32.0))), // num - ]; - #[allow(deprecated)] // TODO: migrate to invoke_with_args + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // num + ColumnarValue::Scalar(ScalarValue::Float32(Some(32.0))), // num + ], + number_rows: 1, + return_type: &DataType::Float32, + }; let result = LogFunc::new() - .invoke_batch(&args, 1) + .invoke_with_args(args) .expect("failed to initialize function log"); match result { @@ -362,13 +375,16 @@ mod tests { #[test] fn test_log_scalar_f64() { - let args = [ - ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // num - ColumnarValue::Scalar(ScalarValue::Float64(Some(64.0))), // num - ]; - #[allow(deprecated)] // TODO: migrate to invoke_with_args + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // num + ColumnarValue::Scalar(ScalarValue::Float64(Some(64.0))), // num + ], + number_rows: 1, + return_type: &DataType::Float64, + }; let result = LogFunc::new() - .invoke_batch(&args, 1) + .invoke_with_args(args) .expect("failed to initialize function log"); match result { @@ -387,14 +403,17 @@ mod tests { #[test] fn test_log_f64_unary() { - let args = [ - ColumnarValue::Array(Arc::new(Float64Array::from(vec![ - 10.0, 100.0, 1000.0, 10000.0, - ]))), // num - ]; - #[allow(deprecated)] // TODO: migrate to invoke_with_args + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(Float64Array::from(vec![ + 10.0, 100.0, 1000.0, 10000.0, + ]))), // num + ], + number_rows: 4, + return_type: &DataType::Float64, + }; let result = LogFunc::new() - .invoke_batch(&args, 4) + .invoke_with_args(args) .expect("failed to initialize function log"); match result { @@ -416,14 +435,17 @@ mod tests { #[test] fn test_log_f32_unary() { - let args = [ - ColumnarValue::Array(Arc::new(Float32Array::from(vec![ - 10.0, 100.0, 1000.0, 10000.0, - ]))), // num - ]; - #[allow(deprecated)] // TODO: migrate to invoke_with_args + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(Float32Array::from(vec![ + 10.0, 100.0, 1000.0, 10000.0, + ]))), // num + ], + number_rows: 4, + return_type: &DataType::Float32, + }; let result = LogFunc::new() - .invoke_batch(&args, 4) + .invoke_with_args(args) .expect("failed to initialize function log"); match result { @@ -445,15 +467,20 @@ mod tests { #[test] fn test_log_f64() { - let args = [ - ColumnarValue::Array(Arc::new(Float64Array::from(vec![2.0, 2.0, 3.0, 5.0]))), // base - ColumnarValue::Array(Arc::new(Float64Array::from(vec![ - 8.0, 4.0, 81.0, 625.0, - ]))), // num - ]; - #[allow(deprecated)] // TODO: migrate to invoke_with_args + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(Float64Array::from(vec![ + 2.0, 2.0, 3.0, 5.0, + ]))), // base + ColumnarValue::Array(Arc::new(Float64Array::from(vec![ + 8.0, 4.0, 81.0, 625.0, + ]))), // num + ], + number_rows: 4, + return_type: &DataType::Float64, + }; let result = LogFunc::new() - .invoke_batch(&args, 4) + .invoke_with_args(args) .expect("failed to initialize function log"); match result { @@ -475,15 +502,20 @@ mod tests { #[test] fn test_log_f32() { - let args = [ - ColumnarValue::Array(Arc::new(Float32Array::from(vec![2.0, 2.0, 3.0, 5.0]))), // base - ColumnarValue::Array(Arc::new(Float32Array::from(vec![ - 8.0, 4.0, 81.0, 625.0, - ]))), // num - ]; - #[allow(deprecated)] // TODO: migrate to invoke_with_args + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(Float32Array::from(vec![ + 2.0, 2.0, 3.0, 5.0, + ]))), // base + ColumnarValue::Array(Arc::new(Float32Array::from(vec![ + 8.0, 4.0, 81.0, 625.0, + ]))), // num + ], + number_rows: 4, + return_type: &DataType::Float32, + }; let result = LogFunc::new() - .invoke_batch(&args, 4) + .invoke_with_args(args) .expect("failed to initialize function log"); match result { diff --git a/datafusion/functions/src/math/nans.rs b/datafusion/functions/src/math/nans.rs index 30c920c29a21..34a5c2a1c16b 100644 --- a/datafusion/functions/src/math/nans.rs +++ b/datafusion/functions/src/math/nans.rs @@ -19,7 +19,7 @@ use arrow::datatypes::{DataType, Float32Type, Float64Type}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, TypeSignature}; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, TypeSignature}; use arrow::array::{ArrayRef, AsArray, BooleanArray}; use datafusion_expr::{Documentation, ScalarUDFImpl, Signature, Volatility}; @@ -75,12 +75,8 @@ impl ScalarUDFImpl for IsNanFunc { Ok(DataType::Boolean) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = ColumnarValue::values_to_arrays(&args.args)?; let arr: ArrayRef = match args[0].data_type() { DataType::Float64 => Arc::new(BooleanArray::from_unary( diff --git a/datafusion/functions/src/math/nanvl.rs b/datafusion/functions/src/math/nanvl.rs index 33823acce751..9effb82896ee 100644 --- a/datafusion/functions/src/math/nanvl.rs +++ b/datafusion/functions/src/math/nanvl.rs @@ -26,7 +26,8 @@ use arrow::datatypes::{DataType, Float32Type, Float64Type}; use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -87,12 +88,8 @@ impl ScalarUDFImpl for NanvlFunc { } } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - make_scalar_function(nanvl, vec![])(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(nanvl, vec![])(&args.args) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/math/pi.rs b/datafusion/functions/src/math/pi.rs index 06f7a01544f8..5339a9b14a28 100644 --- a/datafusion/functions/src/math/pi.rs +++ b/datafusion/functions/src/math/pi.rs @@ -22,7 +22,8 @@ use arrow::datatypes::DataType::Float64; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -67,12 +68,8 @@ impl ScalarUDFImpl for PiFunc { Ok(Float64) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - if !args.is_empty() { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + if !args.args.is_empty() { return internal_err!("{} function does not accept arguments", self.name()); } Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some( diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 7fab858d34a0..028ec2fef793 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -29,7 +29,9 @@ use datafusion_common::{ }; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{ColumnarValue, Documentation, Expr, ScalarUDF, TypeSignature}; +use datafusion_expr::{ + ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF, TypeSignature, +}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -91,12 +93,8 @@ impl ScalarUDFImpl for PowerFunc { &self.aliases } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = ColumnarValue::values_to_arrays(&args.args)?; let arr: ArrayRef = match args[0].data_type() { DataType::Float64 => { @@ -195,13 +193,20 @@ mod tests { #[test] fn test_power_f64() { - let args = [ - ColumnarValue::Array(Arc::new(Float64Array::from(vec![2.0, 2.0, 3.0, 5.0]))), // base - ColumnarValue::Array(Arc::new(Float64Array::from(vec![3.0, 2.0, 4.0, 4.0]))), // exponent - ]; - #[allow(deprecated)] // TODO: migrate to invoke_with_args + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(Float64Array::from(vec![ + 2.0, 2.0, 3.0, 5.0, + ]))), // base + ColumnarValue::Array(Arc::new(Float64Array::from(vec![ + 3.0, 2.0, 4.0, 4.0, + ]))), // exponent + ], + number_rows: 4, + return_type: &DataType::Float64, + }; let result = PowerFunc::new() - .invoke_batch(&args, 4) + .invoke_with_args(args) .expect("failed to initialize function power"); match result { @@ -222,13 +227,16 @@ mod tests { #[test] fn test_power_i64() { - let args = [ - ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), // base - ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), // exponent - ]; - #[allow(deprecated)] // TODO: migrate to invoke_with_args + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), // base + ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), // exponent + ], + number_rows: 4, + return_type: &DataType::Int64, + }; let result = PowerFunc::new() - .invoke_batch(&args, 4) + .invoke_with_args(args) .expect("failed to initialize function power"); match result { diff --git a/datafusion/functions/src/math/random.rs b/datafusion/functions/src/math/random.rs index 197d065ea408..607f9fb09f2a 100644 --- a/datafusion/functions/src/math/random.rs +++ b/datafusion/functions/src/math/random.rs @@ -24,7 +24,7 @@ use arrow::datatypes::DataType::Float64; use rand::{thread_rng, Rng}; use datafusion_common::{internal_err, Result}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_expr::{Documentation, ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -70,16 +70,12 @@ impl ScalarUDFImpl for RandomFunc { Ok(Float64) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - num_rows: usize, - ) -> Result { - if !args.is_empty() { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + if !args.args.is_empty() { return internal_err!("{} function does not accept arguments", self.name()); } let mut rng = thread_rng(); - let mut values = vec![0.0; num_rows]; + let mut values = vec![0.0; args.number_rows]; // Equivalent to set each element with rng.gen_range(0.0..1.0), but more efficient rng.fill(&mut values[..]); let array = Float64Array::from(values); diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index b3442c321c99..fc87b7e63a62 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -28,7 +28,8 @@ use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -90,12 +91,8 @@ impl ScalarUDFImpl for RoundFunc { } } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - make_scalar_function(round, vec![])(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(round, vec![])(&args.args) } fn output_ordering(&self, input: &[ExprProperties]) -> Result { diff --git a/datafusion/functions/src/math/signum.rs b/datafusion/functions/src/math/signum.rs index f68834db375e..ba5422afa768 100644 --- a/datafusion/functions/src/math/signum.rs +++ b/datafusion/functions/src/math/signum.rs @@ -25,7 +25,8 @@ use arrow::datatypes::{DataType, Float32Type, Float64Type}; use datafusion_common::{exec_err, Result}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -88,12 +89,8 @@ impl ScalarUDFImpl for SignumFunc { Ok(input[0].sort_properties) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - make_scalar_function(signum, vec![])(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(signum, vec![])(&args.args) } fn documentation(&self) -> Option<&Documentation> { @@ -140,10 +137,10 @@ pub fn signum(args: &[ArrayRef]) -> Result { mod test { use std::sync::Arc; - use arrow::array::{Float32Array, Float64Array}; - + use arrow::array::{ArrayRef, Float32Array, Float64Array}; + use arrow::datatypes::DataType; use datafusion_common::cast::{as_float32_array, as_float64_array}; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; use crate::math::signum::SignumFunc; @@ -160,10 +157,13 @@ mod test { f32::INFINITY, f32::NEG_INFINITY, ])); - let batch_size = array.len(); - #[allow(deprecated)] // TODO: migrate to invoke_with_args + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)], + number_rows: array.len(), + return_type: &DataType::Float32, + }; let result = SignumFunc::new() - .invoke_batch(&[ColumnarValue::Array(array)], batch_size) + .invoke_with_args(args) .expect("failed to initialize function signum"); match result { @@ -201,10 +201,13 @@ mod test { f64::INFINITY, f64::NEG_INFINITY, ])); - let batch_size = array.len(); - #[allow(deprecated)] // TODO: migrate to invoke_with_args + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)], + number_rows: array.len(), + return_type: &DataType::Float64, + }; let result = SignumFunc::new() - .invoke_batch(&[ColumnarValue::Array(array)], batch_size) + .invoke_with_args(args) .expect("failed to initialize function signum"); match result { diff --git a/datafusion/functions/src/math/trunc.rs b/datafusion/functions/src/math/trunc.rs index 8d791370d7f8..2ac291204a0b 100644 --- a/datafusion/functions/src/math/trunc.rs +++ b/datafusion/functions/src/math/trunc.rs @@ -28,7 +28,8 @@ use datafusion_common::{exec_err, Result}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -99,12 +100,8 @@ impl ScalarUDFImpl for TruncFunc { } } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - make_scalar_function(trunc, vec![])(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(trunc, vec![])(&args.args) } fn output_ordering(&self, input: &[ExprProperties]) -> Result { diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index a81133713360..8cb1a4ff3d60 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -108,11 +108,12 @@ impl ScalarUDFImpl for RegexpCountFunc { Ok(Int64) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { + let args = &args.args; + let len = args .iter() .fold(Option::::None, |acc, arg| match arg { @@ -618,6 +619,7 @@ fn count_matches( mod tests { use super::*; use arrow::array::{GenericStringArray, StringViewArray}; + use datafusion_expr::ScalarFunctionArgs; #[test] fn test_regexp_count() { @@ -655,11 +657,11 @@ mod tests { let v_sv = ScalarValue::Utf8(Some(v.to_string())); let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); let expected = expected.get(pos).cloned(); - #[allow(deprecated)] // TODO: migrate to invoke_with_args - let re = RegexpCountFunc::new().invoke_batch( - &[ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], - 1, - ); + let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], + number_rows: 2, + return_type: &Int64, + }); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -670,11 +672,11 @@ mod tests { // largeutf8 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); - #[allow(deprecated)] // TODO: migrate to invoke_with_args - let re = RegexpCountFunc::new().invoke_batch( - &[ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], - 1, - ); + let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], + number_rows: 2, + return_type: &Int64, + }); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -685,11 +687,11 @@ mod tests { // utf8view let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); - #[allow(deprecated)] // TODO: migrate to invoke_with_args - let re = RegexpCountFunc::new().invoke_batch( - &[ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], - 1, - ); + let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], + number_rows: 2, + return_type: &Int64, + }); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -711,15 +713,15 @@ mod tests { let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); let start_sv = ScalarValue::Int64(Some(start)); let expected = expected.get(pos).cloned(); - #[allow(deprecated)] // TODO: migrate to invoke_with_args - let re = RegexpCountFunc::new().invoke_batch( - &[ + let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { + args: vec![ ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv), ColumnarValue::Scalar(start_sv.clone()), ], - 1, - ); + number_rows: 3, + return_type: &Int64, + }); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -730,15 +732,15 @@ mod tests { // largeutf8 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); - #[allow(deprecated)] // TODO: migrate to invoke_with_args - let re = RegexpCountFunc::new().invoke_batch( - &[ + let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { + args: vec![ ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv), ColumnarValue::Scalar(start_sv.clone()), ], - 1, - ); + number_rows: 3, + return_type: &Int64, + }); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -749,15 +751,15 @@ mod tests { // utf8view let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); - #[allow(deprecated)] // TODO: migrate to invoke_with_args - let re = RegexpCountFunc::new().invoke_batch( - &[ + let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { + args: vec![ ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv), + ColumnarValue::Scalar(start_sv.clone()), ], - 1, - ); + number_rows: 3, + return_type: &Int64, + }); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -781,16 +783,16 @@ mod tests { let start_sv = ScalarValue::Int64(Some(start)); let flags_sv = ScalarValue::Utf8(Some(flags.to_string())); let expected = expected.get(pos).cloned(); - #[allow(deprecated)] // TODO: migrate to invoke_with_args - let re = RegexpCountFunc::new().invoke_batch( - &[ + let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { + args: vec![ ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv), ColumnarValue::Scalar(start_sv.clone()), ColumnarValue::Scalar(flags_sv.clone()), ], - 1, - ); + number_rows: 4, + return_type: &Int64, + }); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -802,16 +804,16 @@ mod tests { let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); let flags_sv = ScalarValue::LargeUtf8(Some(flags.to_string())); - #[allow(deprecated)] // TODO: migrate to invoke_with_args - let re = RegexpCountFunc::new().invoke_batch( - &[ + let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { + args: vec![ ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv), ColumnarValue::Scalar(start_sv.clone()), ColumnarValue::Scalar(flags_sv.clone()), ], - 1, - ); + number_rows: 4, + return_type: &Int64, + }); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -823,16 +825,16 @@ mod tests { let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); let flags_sv = ScalarValue::Utf8View(Some(flags.to_string())); - #[allow(deprecated)] // TODO: migrate to invoke_with_args - let re = RegexpCountFunc::new().invoke_batch( - &[ + let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { + args: vec![ ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv), + ColumnarValue::Scalar(start_sv.clone()), ColumnarValue::Scalar(flags_sv.clone()), ], - 1, - ); + number_rows: 4, + return_type: &Int64, + }); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -905,16 +907,16 @@ mod tests { let start_sv = ScalarValue::Int64(Some(start)); let flags_sv = ScalarValue::Utf8(flags.get(pos).map(|f| f.to_string())); let expected = expected.get(pos).cloned(); - #[allow(deprecated)] // TODO: migrate to invoke_with_args - let re = RegexpCountFunc::new().invoke_batch( - &[ + let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { + args: vec![ ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv), ColumnarValue::Scalar(start_sv.clone()), ColumnarValue::Scalar(flags_sv.clone()), ], - 1, - ); + number_rows: 4, + return_type: &Int64, + }); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -926,16 +928,16 @@ mod tests { let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(regex.get(pos).map(|s| s.to_string())); let flags_sv = ScalarValue::LargeUtf8(flags.get(pos).map(|f| f.to_string())); - #[allow(deprecated)] // TODO: migrate to invoke_with_args - let re = RegexpCountFunc::new().invoke_batch( - &[ + let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { + args: vec![ ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv), ColumnarValue::Scalar(start_sv.clone()), ColumnarValue::Scalar(flags_sv.clone()), ], - 1, - ); + number_rows: 4, + return_type: &Int64, + }); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -947,16 +949,16 @@ mod tests { let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(regex.get(pos).map(|s| s.to_string())); let flags_sv = ScalarValue::Utf8View(flags.get(pos).map(|f| f.to_string())); - #[allow(deprecated)] // TODO: migrate to invoke_with_args - let re = RegexpCountFunc::new().invoke_batch( - &[ + let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { + args: vec![ ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv), + ColumnarValue::Scalar(start_sv.clone()), ColumnarValue::Scalar(flags_sv.clone()), ], - 1, - ); + number_rows: 4, + return_type: &Int64, + }); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); diff --git a/datafusion/functions/src/regex/regexplike.rs b/datafusion/functions/src/regex/regexplike.rs index 296ec339a623..6006309306d5 100644 --- a/datafusion/functions/src/regex/regexplike.rs +++ b/datafusion/functions/src/regex/regexplike.rs @@ -110,11 +110,12 @@ impl ScalarUDFImpl for RegexpLikeFunc { }) } - fn invoke_batch( + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { + let args = &args.args; + let len = args .iter() .fold(Option::::None, |acc, arg| match arg { diff --git a/datafusion/functions/src/regex/regexpmatch.rs b/datafusion/functions/src/regex/regexpmatch.rs index 57207ecfdacd..1119e66398d1 100644 --- a/datafusion/functions/src/regex/regexpmatch.rs +++ b/datafusion/functions/src/regex/regexpmatch.rs @@ -118,11 +118,12 @@ impl ScalarUDFImpl for RegexpMatchFunc { other => DataType::List(Arc::new(Field::new_list_field(other.clone(), true))), }) } - fn invoke_batch( + + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { + let args = &args.args; let len = args .iter() .fold(Option::::None, |acc, arg| match arg { diff --git a/datafusion/functions/src/regex/regexpreplace.rs b/datafusion/functions/src/regex/regexpreplace.rs index 130c66caeecd..3a83564ff11f 100644 --- a/datafusion/functions/src/regex/regexpreplace.rs +++ b/datafusion/functions/src/regex/regexpreplace.rs @@ -147,11 +147,13 @@ impl ScalarUDFImpl for RegexpReplaceFunc { } }) } - fn invoke_batch( + + fn invoke_with_args( &self, - args: &[ColumnarValue], - _number_rows: usize, + args: datafusion_expr::ScalarFunctionArgs, ) -> Result { + let args = &args.args; + let len = args .iter() .fold(Option::::None, |acc, arg| match arg { diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index 858eddc7c8f8..006492a0e07a 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -19,9 +19,11 @@ use crate::utils::make_scalar_function; use arrow::array::{ArrayAccessor, ArrayIter, ArrayRef, AsArray, Int32Array}; use arrow::datatypes::DataType; use arrow::error::ArrowError; +use datafusion_common::types::logical_string; use datafusion_common::{internal_err, Result}; -use datafusion_expr::{ColumnarValue, Documentation}; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ColumnarValue, Documentation, TypeSignatureClass}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr_common::signature::Coercion; use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; @@ -61,7 +63,12 @@ impl Default for AsciiFunc { impl AsciiFunc { pub fn new() -> Self { Self { - signature: Signature::string(1, Volatility::Immutable), + signature: Signature::coercible( + vec![Coercion::new_exact(TypeSignatureClass::Native( + logical_string(), + ))], + Volatility::Immutable, + ), } } } @@ -85,12 +92,8 @@ impl ScalarUDFImpl for AsciiFunc { Ok(Int32) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - make_scalar_function(ascii, vec![])(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(ascii, vec![])(&args.args) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/string/bit_length.rs b/datafusion/functions/src/string/bit_length.rs index 623fb2ba03f0..2a782c59963e 100644 --- a/datafusion/functions/src/string/bit_length.rs +++ b/datafusion/functions/src/string/bit_length.rs @@ -20,9 +20,9 @@ use arrow::datatypes::DataType; use std::any::Any; use crate::utils::utf8_to_int_type; -use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_common::{utils::take_function_args, Result, ScalarValue}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; #[user_doc( @@ -77,19 +77,10 @@ impl ScalarUDFImpl for BitLengthFunc { utf8_to_int_type(&arg_types[0], "bit_length") } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - if args.len() != 1 { - return exec_err!( - "bit_length function requires 1 argument, got {}", - args.len() - ); - } + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [array] = take_function_args(self.name(), &args.args)?; - match &args[0] { + match array { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(bit_length(v.as_ref())?)), ColumnarValue::Scalar(v) => match v { ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( diff --git a/datafusion/functions/src/string/btrim.rs b/datafusion/functions/src/string/btrim.rs index 05a2f646e969..89bffa25698e 100644 --- a/datafusion/functions/src/string/btrim.rs +++ b/datafusion/functions/src/string/btrim.rs @@ -22,7 +22,8 @@ use arrow::datatypes::DataType; use datafusion_common::{exec_err, Result}; use datafusion_expr::function::Hint; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignature, Volatility, }; use datafusion_macros::user_doc; use std::any::Any; @@ -101,20 +102,16 @@ impl ScalarUDFImpl for BTrimFunc { } } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - match args[0].data_type() { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + match args.args[0].data_type() { DataType::Utf8 | DataType::Utf8View => make_scalar_function( btrim::, vec![Hint::Pad, Hint::AcceptsSingular], - )(args), + )(&args.args), DataType::LargeUtf8 => make_scalar_function( btrim::, vec![Hint::Pad, Hint::AcceptsSingular], - )(args), + )(&args.args), other => exec_err!( "Unsupported data type {other:?} for function btrim,\ expected Utf8, LargeUtf8 or Utf8View." diff --git a/datafusion/functions/src/string/chr.rs b/datafusion/functions/src/string/chr.rs index 3530e3f22c0f..a811de7fccf0 100644 --- a/datafusion/functions/src/string/chr.rs +++ b/datafusion/functions/src/string/chr.rs @@ -19,7 +19,7 @@ use std::any::Any; use std::sync::Arc; use arrow::array::ArrayRef; -use arrow::array::StringArray; +use arrow::array::GenericStringBuilder; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Int64; use arrow::datatypes::DataType::Utf8; @@ -28,7 +28,7 @@ use crate::utils::make_scalar_function; use datafusion_common::cast::as_int64_array; use datafusion_common::{exec_err, Result}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; /// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character. @@ -36,26 +36,39 @@ use datafusion_macros::user_doc; pub fn chr(args: &[ArrayRef]) -> Result { let integer_array = as_int64_array(&args[0])?; - // first map is the iterator, second is for the `Option<_>` - let result = integer_array - .iter() - .map(|integer: Option| { - integer - .map(|integer| { - if integer == 0 { - exec_err!("null character not permitted.") - } else { - match core::char::from_u32(integer as u32) { - Some(integer) => Ok(integer.to_string()), - None => { - exec_err!("requested character too large for encoding.") - } + let mut builder = GenericStringBuilder::::with_capacity( + integer_array.len(), + // 1 byte per character, assuming that is the common case + integer_array.len(), + ); + + let mut buf = [0u8; 4]; + + for integer in integer_array { + match integer { + Some(integer) => { + if integer == 0 { + return exec_err!("null character not permitted."); + } else { + match core::char::from_u32(integer as u32) { + Some(c) => { + builder.append_value(c.encode_utf8(&mut buf)); + } + None => { + return exec_err!( + "requested character too large for encoding." + ); } } - }) - .transpose() - }) - .collect::>()?; + } + } + None => { + builder.append_null(); + } + } + } + + let result = builder.finish(); Ok(Arc::new(result) as ArrayRef) } @@ -70,7 +83,7 @@ pub fn chr(args: &[ArrayRef]) -> Result { | chr(Int64(128640)) | +--------------------+ | 🚀 | -+--------------------+ ++--------------------+ ```"#, standard_argument(name = "expression", prefix = "String"), related_udf(name = "ascii") @@ -111,12 +124,8 @@ impl ScalarUDFImpl for ChrFunc { Ok(Utf8) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - make_scalar_function(chr, vec![])(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(chr, vec![])(&args.args) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 9ce732efa0c7..c47d08d579e4 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -30,7 +30,7 @@ use datafusion_common::{internal_err, plan_err, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{lit, ColumnarValue, Documentation, Expr, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; #[user_doc( @@ -105,11 +105,9 @@ impl ScalarUDFImpl for ConcatFunc { /// Concatenates the text representations of all the arguments. NULL arguments are ignored. /// concat('abcde', 2, NULL, 22) = 'abcde222' - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { args, .. } = args; + let mut return_datatype = DataType::Utf8; args.iter().for_each(|col| { if col.data_type() == DataType::Utf8View { @@ -169,7 +167,7 @@ impl ScalarUDFImpl for ConcatFunc { let mut data_size = 0; let mut columns = Vec::with_capacity(args.len()); - for arg in args { + for arg in &args { match arg { ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) | ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value)) @@ -470,10 +468,14 @@ mod tests { None, Some("b"), ]))); - let args = &[c0, c1, c2, c3, c4]; - #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch - let result = ConcatFunc::new().invoke_batch(args, 3)?; + let args = ScalarFunctionArgs { + args: vec![c0, c1, c2, c3, c4], + number_rows: 3, + return_type: &Utf8, + }; + + let result = ConcatFunc::new().invoke_with_args(args)?; let expected = Arc::new(StringViewArray::from(vec!["foo,x,a", "bar,,", "baz,z,b"])) as ArrayRef; diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 026d167cccd5..c2bad206db15 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -30,7 +30,7 @@ use datafusion_common::{exec_err, internal_err, plan_err, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{lit, ColumnarValue, Documentation, Expr, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; #[user_doc( @@ -102,11 +102,9 @@ impl ScalarUDFImpl for ConcatWsFunc { /// Concatenates all but the first argument, with separators. The first argument is used as the separator string, and should not be NULL. Other NULL arguments are ignored. /// concat_ws(',', 'abcde', 2, NULL, 22) = 'abcde,2,22' - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { args, .. } = args; + // do not accept 0 arguments. if args.len() < 2 { return exec_err!( @@ -411,7 +409,7 @@ mod tests { use crate::string::concat_ws::ConcatWsFunc; use datafusion_common::Result; use datafusion_common::ScalarValue; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; use crate::utils::test::test_function; @@ -482,10 +480,14 @@ mod tests { None, Some("z"), ]))); - let args = &[c0, c1, c2]; - #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch - let result = ConcatWsFunc::new().invoke_batch(args, 3)?; + let args = ScalarFunctionArgs { + args: vec![c0, c1, c2], + number_rows: 3, + return_type: &Utf8, + }; + + let result = ConcatWsFunc::new().invoke_with_args(args)?; let expected = Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef; match &result { @@ -508,10 +510,14 @@ mod tests { Some("y"), Some("z"), ]))); - let args = &[c0, c1, c2]; - #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch - let result = ConcatWsFunc::new().invoke_batch(args, 3)?; + let args = ScalarFunctionArgs { + args: vec![c0, c1, c2], + number_rows: 3, + return_type: &Utf8, + }; + + let result = ConcatWsFunc::new().invoke_with_args(args)?; let expected = Arc::new(StringArray::from(vec![Some("foo,x"), None, Some("baz+z")])) as ArrayRef; diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index 36871f0c3282..77774cdb5e1d 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -24,7 +24,8 @@ use datafusion_common::exec_err; use datafusion_common::DataFusionError; use datafusion_common::Result; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; use std::any::Any; @@ -81,12 +82,8 @@ impl ScalarUDFImpl for ContainsFunc { Ok(Boolean) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - make_scalar_function(contains, vec![])(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(contains, vec![])(&args.args) } fn documentation(&self) -> Option<&Documentation> { @@ -125,8 +122,9 @@ pub fn contains(args: &[ArrayRef]) -> Result { mod test { use super::ContainsFunc; use arrow::array::{BooleanArray, StringArray}; + use arrow::datatypes::DataType; use datafusion_common::ScalarValue; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; use std::sync::Arc; #[test] @@ -137,8 +135,14 @@ mod test { Some("yyy?()"), ]))); let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("x?(".to_string()))); - #[allow(deprecated)] // TODO migrate UDF to invoke - let actual = udf.invoke_batch(&[array, scalar], 2).unwrap(); + + let args = ScalarFunctionArgs { + args: vec![array, scalar], + number_rows: 2, + return_type: &DataType::Boolean, + }; + + let actual = udf.invoke_with_args(args).unwrap(); let expect = ColumnarValue::Array(Arc::new(BooleanArray::from(vec![ Some(true), Some(false), diff --git a/datafusion/functions/src/string/ends_with.rs b/datafusion/functions/src/string/ends_with.rs index 0a77ec9ebd2c..5cca79de14ff 100644 --- a/datafusion/functions/src/string/ends_with.rs +++ b/datafusion/functions/src/string/ends_with.rs @@ -24,7 +24,7 @@ use arrow::datatypes::DataType; use crate::utils::make_scalar_function; use datafusion_common::{internal_err, Result}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; #[user_doc( @@ -84,14 +84,10 @@ impl ScalarUDFImpl for EndsWithFunc { Ok(DataType::Boolean) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - match args[0].data_type() { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + match args.args[0].data_type() { DataType::Utf8View | DataType::Utf8 | DataType::LargeUtf8 => { - make_scalar_function(ends_with, vec![])(args) + make_scalar_function(ends_with, vec![])(&args.args) } other => { internal_err!("Unsupported data type {other:?} for function ends_with. Expected Utf8, LargeUtf8 or Utf8View")? diff --git a/datafusion/functions/src/string/levenshtein.rs b/datafusion/functions/src/string/levenshtein.rs index 57392c114d79..a19fcc5b476c 100644 --- a/datafusion/functions/src/string/levenshtein.rs +++ b/datafusion/functions/src/string/levenshtein.rs @@ -24,9 +24,9 @@ use arrow::datatypes::DataType; use crate::utils::{make_scalar_function, utf8_to_int_type}; use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::utils::datafusion_strsim; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{exec_err, utils::take_function_args, Result}; use datafusion_expr::{ColumnarValue, Documentation}; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; #[user_doc( @@ -86,16 +86,14 @@ impl ScalarUDFImpl for LevenshteinFunc { utf8_to_int_type(&arg_types[0], "levenshtein") } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - match args[0].data_type() { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + match args.args[0].data_type() { DataType::Utf8View | DataType::Utf8 => { - make_scalar_function(levenshtein::, vec![])(args) + make_scalar_function(levenshtein::, vec![])(&args.args) + } + DataType::LargeUtf8 => { + make_scalar_function(levenshtein::, vec![])(&args.args) } - DataType::LargeUtf8 => make_scalar_function(levenshtein::, vec![])(args), other => { exec_err!("Unsupported data type {other:?} for function levenshtein") } @@ -110,17 +108,12 @@ impl ScalarUDFImpl for LevenshteinFunc { ///Returns the Levenshtein distance between the two given strings. /// LEVENSHTEIN('kitten', 'sitting') = 3 pub fn levenshtein(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!( - "levenshtein function requires two arguments, got {}", - args.len() - ); - } + let [str1, str2] = take_function_args("levenshtein", args)?; - match args[0].data_type() { + match str1.data_type() { DataType::Utf8View => { - let str1_array = as_string_view_array(&args[0])?; - let str2_array = as_string_view_array(&args[1])?; + let str1_array = as_string_view_array(&str1)?; + let str2_array = as_string_view_array(&str2)?; let result = str1_array .iter() .zip(str2_array.iter()) @@ -134,8 +127,8 @@ pub fn levenshtein(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } DataType::Utf8 => { - let str1_array = as_generic_string_array::(&args[0])?; - let str2_array = as_generic_string_array::(&args[1])?; + let str1_array = as_generic_string_array::(&str1)?; + let str2_array = as_generic_string_array::(&str2)?; let result = str1_array .iter() .zip(str2_array.iter()) @@ -149,8 +142,8 @@ pub fn levenshtein(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } DataType::LargeUtf8 => { - let str1_array = as_generic_string_array::(&args[0])?; - let str2_array = as_generic_string_array::(&args[1])?; + let str1_array = as_generic_string_array::(&str1)?; + let str2_array = as_generic_string_array::(&str2)?; let result = str1_array .iter() .zip(str2_array.iter()) diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index e90c3804b1ee..375717e23d6d 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -22,7 +22,7 @@ use crate::string::common::to_lower; use crate::utils::utf8_to_str_type; use datafusion_common::Result; use datafusion_expr::{ColumnarValue, Documentation}; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; #[user_doc( @@ -77,12 +77,8 @@ impl ScalarUDFImpl for LowerFunc { utf8_to_str_type(&arg_types[0], "lower") } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - to_lower(args, "lower") + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + to_lower(&args.args, "lower") } fn documentation(&self) -> Option<&Documentation> { @@ -98,10 +94,14 @@ mod tests { fn to_lower(input: ArrayRef, expected: ArrayRef) -> Result<()> { let func = LowerFunc::new(); - let batch_len = input.len(); - let args = vec![ColumnarValue::Array(input)]; - #[allow(deprecated)] // TODO migrate UDF to invoke - let result = match func.invoke_batch(&args, batch_len)? { + + let args = ScalarFunctionArgs { + number_rows: input.len(), + args: vec![ColumnarValue::Array(input)], + return_type: &DataType::Utf8, + }; + + let result = match func.invoke_with_args(args)? { ColumnarValue::Array(result) => result, _ => unreachable!("lower"), }; diff --git a/datafusion/functions/src/string/ltrim.rs b/datafusion/functions/src/string/ltrim.rs index 0bc62ee5000d..75c4ff25b7df 100644 --- a/datafusion/functions/src/string/ltrim.rs +++ b/datafusion/functions/src/string/ltrim.rs @@ -24,7 +24,7 @@ use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::{exec_err, Result}; use datafusion_expr::function::Hint; use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; /// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed. @@ -104,20 +104,16 @@ impl ScalarUDFImpl for LtrimFunc { } } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - match args[0].data_type() { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + match args.args[0].data_type() { DataType::Utf8 | DataType::Utf8View => make_scalar_function( ltrim::, vec![Hint::Pad, Hint::AcceptsSingular], - )(args), + )(&args.args), DataType::LargeUtf8 => make_scalar_function( ltrim::, vec![Hint::Pad, Hint::AcceptsSingular], - )(args), + )(&args.args), other => exec_err!( "Unsupported data type {other:?} for function ltrim,\ expected Utf8, LargeUtf8 or Utf8View." diff --git a/datafusion/functions/src/string/octet_length.rs b/datafusion/functions/src/string/octet_length.rs index f443571112e7..46175c96cdc6 100644 --- a/datafusion/functions/src/string/octet_length.rs +++ b/datafusion/functions/src/string/octet_length.rs @@ -20,9 +20,9 @@ use arrow::datatypes::DataType; use std::any::Any; use crate::utils::utf8_to_int_type; -use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_common::{utils::take_function_args, Result, ScalarValue}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; #[user_doc( @@ -77,19 +77,10 @@ impl ScalarUDFImpl for OctetLengthFunc { utf8_to_int_type(&arg_types[0], "octet_length") } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - if args.len() != 1 { - return exec_err!( - "octet_length function requires 1 argument, got {}", - args.len() - ); - } + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [array] = take_function_args(self.name(), &args.args)?; - match &args[0] { + match array { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), ColumnarValue::Scalar(v) => match v { ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( diff --git a/datafusion/functions/src/string/overlay.rs b/datafusion/functions/src/string/overlay.rs index 3389da0968f7..0ea5359e9621 100644 --- a/datafusion/functions/src/string/overlay.rs +++ b/datafusion/functions/src/string/overlay.rs @@ -27,7 +27,7 @@ use datafusion_common::cast::{ }; use datafusion_common::{exec_err, Result}; use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; #[user_doc( @@ -100,16 +100,14 @@ impl ScalarUDFImpl for OverlayFunc { utf8_to_str_type(&arg_types[0], "overlay") } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - match args[0].data_type() { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + match args.args[0].data_type() { DataType::Utf8View | DataType::Utf8 => { - make_scalar_function(overlay::, vec![])(args) + make_scalar_function(overlay::, vec![])(&args.args) + } + DataType::LargeUtf8 => { + make_scalar_function(overlay::, vec![])(&args.args) } - DataType::LargeUtf8 => make_scalar_function(overlay::, vec![])(args), other => exec_err!("Unsupported data type {other:?} for function overlay"), } } diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index eea9af2ba749..2d36cb8356a0 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -26,11 +26,11 @@ use arrow::array::{ use arrow::datatypes::DataType; use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View}; use datafusion_common::cast::as_int64_array; -use datafusion_common::types::{logical_int64, logical_string}; -use datafusion_common::{exec_err, Result}; +use datafusion_common::types::{logical_int64, logical_string, NativeType}; +use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; -use datafusion_expr_common::signature::TypeSignatureClass; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; +use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; use datafusion_macros::user_doc; #[user_doc( @@ -67,8 +67,13 @@ impl RepeatFunc { Self { signature: Signature::coercible( vec![ - TypeSignatureClass::Native(logical_string()), - TypeSignatureClass::Native(logical_int64()), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + // Accept all integer types but cast them to i64 + Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![TypeSignatureClass::Integer], + NativeType::Int64, + ), ], Volatility::Immutable, ), @@ -93,12 +98,8 @@ impl ScalarUDFImpl for RepeatFunc { utf8_to_str_type(&arg_types[0], "repeat") } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - make_scalar_function(repeat, vec![])(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(repeat, vec![])(&args.args) } fn documentation(&self) -> Option<&Documentation> { @@ -113,15 +114,27 @@ fn repeat(args: &[ArrayRef]) -> Result { match args[0].data_type() { Utf8View => { let string_view_array = args[0].as_string_view(); - repeat_impl::(string_view_array, number_array) + repeat_impl::( + string_view_array, + number_array, + i32::MAX as usize, + ) } Utf8 => { let string_array = args[0].as_string::(); - repeat_impl::>(string_array, number_array) + repeat_impl::>( + string_array, + number_array, + i32::MAX as usize, + ) } LargeUtf8 => { let string_array = args[0].as_string::(); - repeat_impl::>(string_array, number_array) + repeat_impl::>( + string_array, + number_array, + i64::MAX as usize, + ) } other => exec_err!( "Unsupported data type {other:?} for function repeat. \ @@ -130,22 +143,51 @@ fn repeat(args: &[ArrayRef]) -> Result { } } -fn repeat_impl<'a, T, S>(string_array: S, number_array: &Int64Array) -> Result +fn repeat_impl<'a, T, S>( + string_array: S, + number_array: &Int64Array, + max_str_len: usize, +) -> Result where T: OffsetSizeTrait, S: StringArrayType<'a>, { - let mut builder: GenericStringBuilder = GenericStringBuilder::new(); - string_array - .iter() - .zip(number_array.iter()) - .for_each(|(string, number)| match (string, number) { - (Some(string), Some(number)) if number >= 0 => { - builder.append_value(string.repeat(number as usize)) + let mut total_capacity = 0; + string_array.iter().zip(number_array.iter()).try_for_each( + |(string, number)| -> Result<(), DataFusionError> { + match (string, number) { + (Some(string), Some(number)) if number >= 0 => { + let item_capacity = string.len() * number as usize; + if item_capacity > max_str_len { + return exec_err!( + "string size overflow on repeat, max size is {}, but got {}", + max_str_len, + number as usize * string.len() + ); + } + total_capacity += item_capacity; + } + _ => (), } - (Some(_), Some(_)) => builder.append_value(""), - _ => builder.append_null(), - }); + Ok(()) + }, + )?; + + let mut builder = + GenericStringBuilder::::with_capacity(string_array.len(), total_capacity); + + string_array.iter().zip(number_array.iter()).try_for_each( + |(string, number)| -> Result<(), DataFusionError> { + match (string, number) { + (Some(string), Some(number)) if number >= 0 => { + builder.append_value(string.repeat(number as usize)); + } + (Some(_), Some(_)) => builder.append_value(""), + _ => builder.append_null(), + } + Ok(()) + }, + )?; let array = builder.finish(); Ok(Arc::new(array) as ArrayRef) @@ -156,8 +198,8 @@ mod tests { use arrow::array::{Array, StringArray}; use arrow::datatypes::DataType::Utf8; - use datafusion_common::Result; use datafusion_common::ScalarValue; + use datafusion_common::{exec_err, Result}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; use crate::string::repeat::RepeatFunc; @@ -232,6 +274,21 @@ mod tests { Utf8, StringArray ); + test_function!( + RepeatFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(1073741824))), + ], + exec_err!( + "string size overflow on repeat, max size is {}, but got {}", + i32::MAX, + 2usize * 1073741824 + ), + &str, + Utf8, + StringArray + ); Ok(()) } diff --git a/datafusion/functions/src/string/replace.rs b/datafusion/functions/src/string/replace.rs index 9b6afc546994..a3488b561fd2 100644 --- a/datafusion/functions/src/string/replace.rs +++ b/datafusion/functions/src/string/replace.rs @@ -25,7 +25,7 @@ use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::{exec_err, Result}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; #[user_doc( doc_section(label = "String Functions"), @@ -82,15 +82,13 @@ impl ScalarUDFImpl for ReplaceFunc { utf8_to_str_type(&arg_types[0], "replace") } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - match args[0].data_type() { - DataType::Utf8 => make_scalar_function(replace::, vec![])(args), - DataType::LargeUtf8 => make_scalar_function(replace::, vec![])(args), - DataType::Utf8View => make_scalar_function(replace_view, vec![])(args), + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + match args.args[0].data_type() { + DataType::Utf8 => make_scalar_function(replace::, vec![])(&args.args), + DataType::LargeUtf8 => { + make_scalar_function(replace::, vec![])(&args.args) + } + DataType::Utf8View => make_scalar_function(replace_view, vec![])(&args.args), other => { exec_err!("Unsupported data type {other:?} for function replace") } diff --git a/datafusion/functions/src/string/rtrim.rs b/datafusion/functions/src/string/rtrim.rs index 3fb208bb7198..71c4286150e5 100644 --- a/datafusion/functions/src/string/rtrim.rs +++ b/datafusion/functions/src/string/rtrim.rs @@ -24,7 +24,7 @@ use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::{exec_err, Result}; use datafusion_expr::function::Hint; use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; /// Returns the longest string with trailing characters removed. If the characters are not specified, whitespace is removed. @@ -104,20 +104,16 @@ impl ScalarUDFImpl for RtrimFunc { } } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - match args[0].data_type() { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + match args.args[0].data_type() { DataType::Utf8 | DataType::Utf8View => make_scalar_function( rtrim::, vec![Hint::Pad, Hint::AcceptsSingular], - )(args), + )(&args.args), DataType::LargeUtf8 => make_scalar_function( rtrim::, vec![Hint::Pad, Hint::AcceptsSingular], - )(args), + )(&args.args), other => exec_err!( "Unsupported data type {other:?} for function rtrim,\ expected Utf8, LargeUtf8 or Utf8View." diff --git a/datafusion/functions/src/string/split_part.rs b/datafusion/functions/src/string/split_part.rs index a597e1be5d02..724d9c278cca 100644 --- a/datafusion/functions/src/string/split_part.rs +++ b/datafusion/functions/src/string/split_part.rs @@ -26,7 +26,7 @@ use datafusion_common::cast::as_int64_array; use datafusion_common::ScalarValue; use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; @@ -97,11 +97,9 @@ impl ScalarUDFImpl for SplitPartFunc { utf8_to_str_type(&arg_types[0], "split_part") } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { args, .. } = args; + // First, determine if any of the arguments is an Array let len = args.iter().find_map(|arg| match arg { ColumnarValue::Array(a) => Some(a.len()), diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index 74d0fbdc4033..f1344780eb4c 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -25,7 +25,7 @@ use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::utils::make_scalar_function; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::{ColumnarValue, Documentation, Expr, Like}; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; /// Returns true if string starts with prefix. @@ -86,14 +86,10 @@ impl ScalarUDFImpl for StartsWithFunc { Ok(DataType::Boolean) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - match args[0].data_type() { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + match args.args[0].data_type() { DataType::Utf8View | DataType::Utf8 | DataType::LargeUtf8 => { - make_scalar_function(starts_with, vec![])(args) + make_scalar_function(starts_with, vec![])(&args.args) } _ => internal_err!("Unsupported data types for starts_with. Expected Utf8, LargeUtf8 or Utf8View")?, } diff --git a/datafusion/functions/src/string/to_hex.rs b/datafusion/functions/src/string/to_hex.rs index 64654ef6ef10..a3a1acfcf1f0 100644 --- a/datafusion/functions/src/string/to_hex.rs +++ b/datafusion/functions/src/string/to_hex.rs @@ -16,9 +16,10 @@ // under the License. use std::any::Any; +use std::fmt::Write; use std::sync::Arc; -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::array::{ArrayRef, GenericStringBuilder, OffsetSizeTrait}; use arrow::datatypes::{ ArrowNativeType, ArrowPrimitiveType, DataType, Int32Type, Int64Type, }; @@ -29,7 +30,7 @@ use datafusion_common::Result; use datafusion_common::{exec_err, plan_err}; use datafusion_expr::{ColumnarValue, Documentation}; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; /// Converts the number to its equivalent hexadecimal representation. @@ -40,22 +41,30 @@ where { let integer_array = as_primitive_array::(&args[0])?; - let result = integer_array - .iter() - .map(|integer| { - if let Some(value) = integer { - if let Some(value_usize) = value.to_usize() { - Ok(Some(format!("{value_usize:x}"))) - } else if let Some(value_isize) = value.to_isize() { - Ok(Some(format!("{value_isize:x}"))) - } else { - exec_err!("Unsupported data type {integer:?} for function to_hex") - } + let mut result = GenericStringBuilder::::with_capacity( + integer_array.len(), + // * 8 to convert to bits, / 4 bits per hex char + integer_array.len() * (T::Native::get_byte_width() * 8 / 4), + ); + + for integer in integer_array { + if let Some(value) = integer { + if let Some(value_usize) = value.to_usize() { + write!(result, "{value_usize:x}")?; + } else if let Some(value_isize) = value.to_isize() { + write!(result, "{value_isize:x}")?; } else { - Ok(None) + return exec_err!( + "Unsupported data type {integer:?} for function to_hex" + ); } - }) - .collect::>>()?; + result.append_value(""); + } else { + result.append_null(); + } + } + + let result = result.finish(); Ok(Arc::new(result) as ArrayRef) } @@ -118,14 +127,14 @@ impl ScalarUDFImpl for ToHexFunc { }) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - match args[0].data_type() { - DataType::Int32 => make_scalar_function(to_hex::, vec![])(args), - DataType::Int64 => make_scalar_function(to_hex::, vec![])(args), + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + match args.args[0].data_type() { + DataType::Int32 => { + make_scalar_function(to_hex::, vec![])(&args.args) + } + DataType::Int64 => { + make_scalar_function(to_hex::, vec![])(&args.args) + } other => exec_err!("Unsupported data type {other:?} for function to_hex"), } } diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index 7bab33e68a4d..d27b54d29bc6 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -20,7 +20,7 @@ use crate::utils::utf8_to_str_type; use arrow::datatypes::DataType; use datafusion_common::Result; use datafusion_expr::{ColumnarValue, Documentation}; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; use std::any::Any; @@ -76,12 +76,8 @@ impl ScalarUDFImpl for UpperFunc { utf8_to_str_type(&arg_types[0], "upper") } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - to_upper(args, "upper") + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + to_upper(&args.args, "upper") } fn documentation(&self) -> Option<&Documentation> { @@ -97,10 +93,14 @@ mod tests { fn to_upper(input: ArrayRef, expected: ArrayRef) -> Result<()> { let func = UpperFunc::new(); - let batch_len = input.len(); - let args = vec![ColumnarValue::Array(input)]; - #[allow(deprecated)] // TODO migrate UDF to invoke - let result = match func.invoke_batch(&args, batch_len)? { + + let args = ScalarFunctionArgs { + number_rows: input.len(), + args: vec![ColumnarValue::Array(input)], + return_type: &DataType::Utf8, + }; + + let result = match func.invoke_with_args(args)? { ColumnarValue::Array(result) => result, _ => unreachable!("upper"), }; diff --git a/datafusion/functions/src/string/uuid.rs b/datafusion/functions/src/string/uuid.rs index f6d6a941068d..d1f43d548066 100644 --- a/datafusion/functions/src/string/uuid.rs +++ b/datafusion/functions/src/string/uuid.rs @@ -18,14 +18,15 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::GenericStringArray; +use arrow::array::GenericStringBuilder; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Utf8; +use rand::Rng; use uuid::Uuid; use datafusion_common::{internal_err, Result}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; #[user_doc( @@ -79,17 +80,31 @@ impl ScalarUDFImpl for UuidFunc { /// Prints random (v4) uuid values per row /// uuid() = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11' - fn invoke_batch( - &self, - args: &[ColumnarValue], - num_rows: usize, - ) -> Result { - if !args.is_empty() { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + if !args.args.is_empty() { return internal_err!("{} function does not accept arguments", self.name()); } - let values = std::iter::repeat_with(|| Uuid::new_v4().to_string()).take(num_rows); - let array = GenericStringArray::::from_iter_values(values); - Ok(ColumnarValue::Array(Arc::new(array))) + + // Generate random u128 values + let mut rng = rand::thread_rng(); + let mut randoms = vec![0u128; args.number_rows]; + rng.fill(&mut randoms[..]); + + let mut builder = GenericStringBuilder::::with_capacity( + args.number_rows, + args.number_rows * 36, + ); + + let mut buffer = [0u8; 36]; + for x in &mut randoms { + // From Uuid::new_v4(): Mask out the version and variant bits + *x = *x & 0xFFFFFFFFFFFF4FFFBFFFFFFFFFFFFFFF | 0x40008000000000000000; + let uuid = Uuid::from_u128(*x); + let fmt = uuid::fmt::Hyphenated::from_uuid(uuid); + builder.append_value(fmt.encode_lower(&mut buffer)); + } + + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs index 12f213a827cf..c4a9f067e9f4 100644 --- a/datafusion/functions/src/unicode/find_in_set.rs +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -25,7 +25,9 @@ use arrow::array::{ use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; use crate::utils::utf8_to_int_type; -use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; +use datafusion_common::{ + exec_err, internal_err, utils::take_function_args, Result, ScalarValue, +}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, @@ -96,17 +98,9 @@ impl ScalarUDFImpl for FindInSetFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let ScalarFunctionArgs { mut args, .. } = args; - - if args.len() != 2 { - return exec_err!( - "find_in_set was called with {} arguments. It requires 2.", - args.len() - ); - } + let ScalarFunctionArgs { args, .. } = args; - let str_list = args.pop().unwrap(); - let string = args.pop().unwrap(); + let [string, str_list] = take_function_args(self.name(), args)?; match (string, str_list) { // both inputs are scalars diff --git a/datafusion/functions/src/unicode/substrindex.rs b/datafusion/functions/src/unicode/substrindex.rs index 60ccd2204788..20ad33b3cfe3 100644 --- a/datafusion/functions/src/unicode/substrindex.rs +++ b/datafusion/functions/src/unicode/substrindex.rs @@ -25,7 +25,7 @@ use arrow::array::{ use arrow::datatypes::{DataType, Int32Type, Int64Type}; use crate::utils::{make_scalar_function, utf8_to_str_type}; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{exec_err, utils::take_function_args, Result}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, @@ -131,18 +131,13 @@ impl ScalarUDFImpl for SubstrIndexFunc { /// SUBSTRING_INDEX('www.apache.org', '.', -2) = apache.org /// SUBSTRING_INDEX('www.apache.org', '.', -1) = org fn substr_index(args: &[ArrayRef]) -> Result { - if args.len() != 3 { - return exec_err!( - "substr_index was called with {} arguments. It requires 3.", - args.len() - ); - } + let [str, delim, count] = take_function_args("substr_index", args)?; - match args[0].data_type() { + match str.data_type() { DataType::Utf8 => { - let string_array = args[0].as_string::(); - let delimiter_array = args[1].as_string::(); - let count_array: &PrimitiveArray = args[2].as_primitive(); + let string_array = str.as_string::(); + let delimiter_array = delim.as_string::(); + let count_array: &PrimitiveArray = count.as_primitive(); substr_index_general::( string_array, delimiter_array, @@ -150,9 +145,9 @@ fn substr_index(args: &[ArrayRef]) -> Result { ) } DataType::LargeUtf8 => { - let string_array = args[0].as_string::(); - let delimiter_array = args[1].as_string::(); - let count_array: &PrimitiveArray = args[2].as_primitive(); + let string_array = str.as_string::(); + let delimiter_array = delim.as_string::(); + let count_array: &PrimitiveArray = count.as_primitive(); substr_index_general::( string_array, delimiter_array, @@ -160,9 +155,9 @@ fn substr_index(args: &[ArrayRef]) -> Result { ) } DataType::Utf8View => { - let string_array = args[0].as_string_view(); - let delimiter_array = args[1].as_string_view(); - let count_array: &PrimitiveArray = args[2].as_primitive(); + let string_array = str.as_string_view(); + let delimiter_array = delim.as_string_view(); + let count_array: &PrimitiveArray = count.as_primitive(); substr_index_general::( string_array, delimiter_array, diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index 966fd8209a04..39d8aeeda460 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -18,51 +18,10 @@ use arrow::array::ArrayRef; use arrow::datatypes::DataType; -use datafusion_common::{exec_datafusion_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::function::Hint; use datafusion_expr::ColumnarValue; -/// Converts a collection of function arguments into an fixed-size array of length N -/// producing a reasonable error message in case of unexpected number of arguments. -/// -/// # Example -/// ``` -/// # use datafusion_common::ScalarValue; -/// # use datafusion_common::Result; -/// # use datafusion_expr_common::columnar_value::ColumnarValue; -/// # use datafusion_functions::utils::take_function_args; -/// fn my_function(args: &[ColumnarValue]) -> Result<()> { -/// // function expects 2 args, so create a 2-element array -/// let [arg1, arg2] = take_function_args("my_function", args)?; -/// // ... do stuff.. -/// Ok(()) -/// } -/// -/// // Calling the function with 1 argument produces an error: -/// let ten = ColumnarValue::from(ScalarValue::from(10i32)); -/// let twenty = ColumnarValue::from(ScalarValue::from(20i32)); -/// let args = vec![ten.clone()]; -/// let err = my_function(&args).unwrap_err(); -/// assert_eq!(err.to_string(), "Execution error: my_function function requires 2 arguments, got 1"); -/// // Calling the function with 2 arguments works great -/// let args = vec![ten, twenty]; -/// my_function(&args).unwrap(); -/// ``` -pub fn take_function_args( - function_name: &str, - args: impl IntoIterator, -) -> Result<[T; N]> { - let args = args.into_iter().collect::>(); - args.try_into().map_err(|v: Vec| { - exec_datafusion_err!( - "{} function requires {} arguments, got {}", - function_name, - N, - v.len() - ) - }) -} - /// Creates a function to identify the optimal return type of a string function given /// the type of its first argument. /// diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 454afa24b628..f517761b1e33 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -21,11 +21,11 @@ use crate::utils::NamePreserver; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; -use datafusion_expr::expr::{AggregateFunction, WindowFunction}; +use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams, WindowFunction}; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition}; -/// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`. +/// Rewrite `Count(Expr::Wildcard)` to `Count(Expr::Literal)`. /// /// Resolves issue: #[derive(Default, Debug)] @@ -55,13 +55,12 @@ fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool { matches!(aggregate_function, AggregateFunction { func, - args, - .. + params: AggregateFunctionParams { args, .. }, } if func.name() == "count" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty())) } fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool { - let args = &window_function.args; + let args = &window_function.params.args; matches!(window_function.fun, WindowFunctionDefinition::AggregateUDF(ref udaf) if udaf.name() == "count" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty())) @@ -75,13 +74,13 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { Expr::WindowFunction(mut window_function) if is_count_star_window_aggregate(&window_function) => { - window_function.args = vec![lit(COUNT_STAR_EXPANSION)]; + window_function.params.args = vec![lit(COUNT_STAR_EXPANSION)]; Ok(Transformed::yes(Expr::WindowFunction(window_function))) } Expr::AggregateFunction(mut aggregate_function) if is_count_star_aggregate(&aggregate_function) => { - aggregate_function.args = vec![lit(COUNT_STAR_EXPANSION)]; + aggregate_function.params.args = vec![lit(COUNT_STAR_EXPANSION)]; Ok(Transformed::yes(Expr::AggregateFunction( aggregate_function, ))) diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs index 16ebb8cd3972..f8a818563609 100644 --- a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs +++ b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs @@ -163,6 +163,7 @@ fn validate_args( group_by_expr: &HashMap<&Expr, usize>, ) -> Result<()> { let expr_not_in_group_by = function + .params .args .iter() .find(|expr| !group_by_expr.contains_key(expr)); @@ -183,7 +184,7 @@ fn grouping_function_on_id( is_grouping_set: bool, ) -> Result { validate_args(function, group_by_expr)?; - let args = &function.args; + let args = &function.params.args; // Postgres allows grouping function for group by without grouping sets, the result is then // always 0 diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 85fc9b31bcdd..d1d491cc7a64 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -33,8 +33,8 @@ use datafusion_common::{ DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference, }; use datafusion_expr::expr::{ - self, Alias, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, - ScalarFunction, Sort, WindowFunction, + self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, InList, + InSubquery, Like, ScalarFunction, Sort, WindowFunction, }; use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; use datafusion_expr::expr_schema::cast_subquery; @@ -506,11 +506,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { } Expr::AggregateFunction(expr::AggregateFunction { func, - args, - distinct, - filter, - order_by, - null_treatment, + params: + AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment, + }, }) => { let new_expr = coerce_arguments_for_signature_with_aggregate_udf( args, @@ -530,11 +533,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { } Expr::WindowFunction(WindowFunction { fun, - args, - partition_by, - order_by, - window_frame, - null_treatment, + params: + expr::WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment, + }, }) => { let window_frame = coerce_window_frame(window_frame, self.schema, &order_by)?; @@ -1047,8 +1053,8 @@ mod test { use datafusion_expr::{ cast, col, create_udaf, is_true, lit, AccumulatorFactoryFunction, AggregateUDF, BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Filter, LogicalPlan, - Operator, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, Subquery, - Volatility, + Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, + SimpleAggregateUDF, Subquery, Volatility, }; use datafusion_functions_aggregate::average::AvgAccumulator; @@ -1266,11 +1272,7 @@ mod test { Ok(Utf8) } - fn invoke_batch( - &self, - _args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { Ok(ColumnarValue::Scalar(ScalarValue::from("a"))) } } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 4b9a83fd3e4c..bfa53a5ce852 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -1703,13 +1703,5 @@ mod test { fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Float64) } - - fn invoke_batch( - &self, - _args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - unimplemented!() - } } } diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 7cb0e7c2f1f7..c38dd35abd36 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1003,7 +1003,8 @@ impl OptimizerRule for PushDownFilter { // Therefore, we need to ensure that any potential partition key returned is used in // ALL window functions. Otherwise, filters cannot be pushed by through that column. let extract_partition_keys = |func: &WindowFunction| { - func.partition_by + func.params + .partition_by .iter() .map(|c| Column::from_qualified_name(c.schema_name().to_string())) .collect::>() @@ -1386,8 +1387,9 @@ mod tests { use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{ col, in_list, in_subquery, lit, ColumnarValue, ExprFunctionExt, Extension, - LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, TableSource, TableType, - UserDefinedLogicalNodeCore, Volatility, WindowFunctionDefinition, + LogicalPlanBuilder, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, + TableSource, TableType, UserDefinedLogicalNodeCore, Volatility, + WindowFunctionDefinition, }; use crate::optimizer::Optimizer; @@ -3615,11 +3617,7 @@ Projection: a, b Ok(DataType::Int32) } - fn invoke_batch( - &self, - _args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { Ok(ColumnarValue::Scalar(ScalarValue::from(1))) } } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 29f3d7cbda39..e43e2e704080 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -44,6 +44,8 @@ use datafusion_expr::{ }; use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; +use super::inlist_simplifier::ShortenInListSimplifier; +use super::utils::*; use crate::analyzer::type_coercion::TypeCoercionRewriter; use crate::simplify_expressions::guarantees::GuaranteeRewriter; use crate::simplify_expressions::regex::simplify_regex_expr; @@ -51,9 +53,6 @@ use crate::simplify_expressions::SimplifyInfo; use indexmap::IndexSet; use regex::Regex; -use super::inlist_simplifier::ShortenInListSimplifier; -use super::utils::*; - /// This structure handles API for expression simplification /// /// Provides simplification information based on DFSchema and @@ -515,30 +514,27 @@ impl TreeNodeRewriter for ConstEvaluator<'_> { // NB: do not short circuit recursion even if we find a non // evaluatable node (so we can fold other children, args to - // functions, etc) + // functions, etc.) Ok(Transformed::no(expr)) } fn f_up(&mut self, expr: Expr) -> Result> { match self.can_evaluate.pop() { - // Certain expressions such as `CASE` and `COALESCE` are short circuiting - // and may not evaluate all their sub expressions. Thus if - // if any error is countered during simplification, return the original + // Certain expressions such as `CASE` and `COALESCE` are short-circuiting + // and may not evaluate all their sub expressions. Thus, if + // any error is countered during simplification, return the original // so that normal evaluation can occur - Some(true) => { - let result = self.evaluate_to_scalar(expr); - match result { - ConstSimplifyResult::Simplified(s) => { - Ok(Transformed::yes(Expr::Literal(s))) - } - ConstSimplifyResult::NotSimplified(s) => { - Ok(Transformed::no(Expr::Literal(s))) - } - ConstSimplifyResult::SimplifyRuntimeError(_, expr) => { - Ok(Transformed::yes(expr)) - } + Some(true) => match self.evaluate_to_scalar(expr) { + ConstSimplifyResult::Simplified(s) => { + Ok(Transformed::yes(Expr::Literal(s))) } - } + ConstSimplifyResult::NotSimplified(s) => { + Ok(Transformed::no(Expr::Literal(s))) + } + ConstSimplifyResult::SimplifyRuntimeError(_, expr) => { + Ok(Transformed::yes(expr)) + } + }, Some(false) => Ok(Transformed::no(expr)), _ => internal_err!("Failed to pop can_evaluate"), } @@ -586,9 +582,7 @@ impl<'a> ConstEvaluator<'a> { // added they can be checked for their ability to be evaluated // at plan time match expr { - // Has no runtime cost, but needed during planning - Expr::Alias(..) - | Expr::AggregateFunction { .. } + Expr::AggregateFunction { .. } | Expr::ScalarVariable(_, _) | Expr::Column(_) | Expr::OuterReferenceColumn(_, _) @@ -603,6 +597,7 @@ impl<'a> ConstEvaluator<'a> { Self::volatility_ok(func.signature().volatility) } Expr::Literal(_) + | Expr::Alias(..) | Expr::Unnest(_) | Expr::BinaryExpr { .. } | Expr::Not(_) diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index c8f3a4bc7859..191377fc2759 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -26,6 +26,7 @@ use datafusion_common::{ internal_err, tree_node::Transformed, DataFusionError, HashSet, Result, }; use datafusion_expr::builder::project; +use datafusion_expr::expr::AggregateFunctionParams; use datafusion_expr::{ col, expr::AggregateFunction, @@ -68,11 +69,14 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { for expr in aggr_expr { if let Expr::AggregateFunction(AggregateFunction { func, - distinct, - args, - filter, - order_by, - null_treatment: _, + params: + AggregateFunctionParams { + distinct, + args, + filter, + order_by, + null_treatment: _, + }, }) = expr { if filter.is_some() || order_by.is_some() { @@ -179,9 +183,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { .map(|aggr_expr| match aggr_expr { Expr::AggregateFunction(AggregateFunction { func, - mut args, - distinct, - .. + params: AggregateFunctionParams { mut args, distinct, .. } }) => { if distinct { if args.len() != 1 { diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 33983676472b..38e8b44791ab 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -38,7 +38,6 @@ path = "src/lib.rs" [dependencies] ahash = { workspace = true } arrow = { workspace = true } -arrow-schema = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } datafusion-expr-common = { workspace = true } diff --git a/datafusion/physical-expr/benches/case_when.rs b/datafusion/physical-expr/benches/case_when.rs index 480b1043fbf5..5a88604716d2 100644 --- a/datafusion/physical-expr/benches/case_when.rs +++ b/datafusion/physical-expr/benches/case_when.rs @@ -16,9 +16,8 @@ // under the License. use arrow::array::builder::{Int32Builder, StringBuilder}; -use arrow::datatypes::{Field, Schema}; +use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; -use arrow_schema::DataType; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_common::ScalarValue; use datafusion_expr::Operator; diff --git a/datafusion/physical-expr/benches/is_null.rs b/datafusion/physical-expr/benches/is_null.rs index ed393b8900f1..ce6ad6eac2c7 100644 --- a/datafusion/physical-expr/benches/is_null.rs +++ b/datafusion/physical-expr/benches/is_null.rs @@ -16,8 +16,7 @@ // under the License. use arrow::array::{builder::Int32Builder, RecordBatch}; -use arrow::datatypes::{Field, Schema}; -use arrow_schema::DataType; +use arrow::datatypes::{DataType, Field, Schema}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_physical_expr::expressions::{Column, IsNotNullExpr, IsNullExpr}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; diff --git a/datafusion/physical-expr/src/aggregate.rs b/datafusion/physical-expr/src/aggregate.rs index 84406f50051f..07a98340dbe7 100644 --- a/datafusion/physical-expr/src/aggregate.rs +++ b/datafusion/physical-expr/src/aggregate.rs @@ -40,8 +40,8 @@ use std::sync::Arc; use crate::expressions::Column; +use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use arrow_schema::SortOptions; use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue}; use datafusion_expr::{AggregateUDF, ReversedUDAF, SetMonotonicity}; use datafusion_expr_common::accumulator::Accumulator; diff --git a/datafusion/physical-expr/src/analysis.rs b/datafusion/physical-expr/src/analysis.rs index ceec21c71171..5abd50f6d1b4 100644 --- a/datafusion/physical-expr/src/analysis.rs +++ b/datafusion/physical-expr/src/analysis.rs @@ -301,7 +301,7 @@ fn calculate_selectivity( mod tests { use std::sync::Arc; - use arrow_schema::{DataType, Field, Schema}; + use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{assert_contains, DFSchema}; use datafusion_expr::{ col, execution_props::ExecutionProps, interval_arithmetic::Interval, lit, Expr, diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 5c749a1a5a6e..187aa8a39eb0 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -584,12 +584,18 @@ impl EquivalenceGroup { .collect::>(); (new_class.len() > 1).then_some(EquivalenceClass::new(new_class)) }); + // the key is the source expression and the value is the EquivalenceClass that contains the target expression of the source expression. let mut new_classes: IndexMap, EquivalenceClass> = IndexMap::new(); mapping.iter().for_each(|(source, target)| { + // We need to find equivalent projected expressions. + // e.g. table with columns [a,b,c] and a == b, projection: [a+c, b+c]. + // To conclude that a + c == b + c we firsty normalize all source expressions + // in the mapping, then merge all equivalent expressions into the classes. + let normalized_expr = self.normalize_expr(Arc::clone(source)); new_classes - .entry(Arc::clone(source)) + .entry(normalized_expr) .or_insert_with(EquivalenceClass::new_empty) .push(Arc::clone(target)); }); @@ -749,10 +755,10 @@ impl Display for EquivalenceGroup { #[cfg(test)] mod tests { - use super::*; use crate::equivalence::tests::create_test_params; - use crate::expressions::{lit, BinaryExpr, Literal}; + use crate::expressions::{binary, col, lit, BinaryExpr, Literal}; + use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Operator; @@ -1038,4 +1044,57 @@ mod tests { Ok(()) } + + #[test] + fn test_project_classes() -> Result<()> { + // - columns: [a, b, c]. + // - "a" and "b" in the same equivalence class. + // - then after a+c, b+c projection col(0) and col(1) must be + // in the same class too. + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ])); + let mut group = EquivalenceGroup::empty(); + group.add_equal_conditions(&col("a", &schema)?, &col("b", &schema)?); + + let projected_schema = Arc::new(Schema::new(vec![ + Field::new("a+c", DataType::Int32, false), + Field::new("b+c", DataType::Int32, false), + ])); + + let mapping = ProjectionMapping { + map: vec![ + ( + binary( + col("a", &schema)?, + Operator::Plus, + col("c", &schema)?, + &schema, + )?, + col("a+c", &projected_schema)?, + ), + ( + binary( + col("b", &schema)?, + Operator::Plus, + col("c", &schema)?, + &schema, + )?, + col("b+c", &projected_schema)?, + ), + ], + }; + + let projected = group.project(&mapping); + + assert!(!projected.is_empty()); + let first_normalized = projected.normalize_expr(col("a+c", &projected_schema)?); + let second_normalized = projected.normalize_expr(col("b+c", &projected_schema)?); + + assert!(first_normalized.eq(&second_normalized)); + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs index a5b85064e625..fcc1c564d8c8 100644 --- a/datafusion/physical-expr/src/equivalence/mod.rs +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -72,8 +72,8 @@ mod tests { use crate::expressions::col; use crate::PhysicalSortExpr; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow_schema::{SchemaRef, SortOptions}; + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::{plan_datafusion_err, Result}; use datafusion_physical_expr_common::sort_expr::{ LexOrdering, PhysicalSortRequirement, diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index a72759b5d49a..0f9743aecce3 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -22,7 +22,7 @@ use std::vec::IntoIter; use crate::equivalence::add_offset_to_expr; use crate::{LexOrdering, PhysicalExpr}; -use arrow_schema::SortOptions; +use arrow::compute::SortOptions; /// An `OrderingEquivalenceClass` object keeps track of different alternative /// orderings than can describe a schema. For example, consider the following table: @@ -279,8 +279,8 @@ mod tests { ScalarFunctionExpr, }; + use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; - use arrow_schema::SortOptions; use datafusion_common::Result; use datafusion_expr::{Operator, ScalarUDF}; use datafusion_physical_expr_common::sort_expr::LexOrdering; diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs index d1e7625525ae..035678fbf1f3 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -146,8 +146,8 @@ mod tests { use crate::utils::tests::TestScalarUDF; use crate::{PhysicalExprRef, ScalarFunctionExpr}; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow_schema::{SortOptions, TimeUnit}; + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion_expr::{Operator, ScalarUDF}; #[test] diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index a6417044a061..96208cc5e32c 100755 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -32,7 +32,8 @@ use crate::{ PhysicalExprRef, PhysicalSortExpr, PhysicalSortRequirement, }; -use arrow_schema::{SchemaRef, SortOptions}; +use arrow::compute::SortOptions; +use arrow::datatypes::SchemaRef; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{ internal_err, plan_err, Constraint, Constraints, HashMap, JoinSide, JoinType, Result, @@ -99,7 +100,7 @@ use itertools::Itertools; /// # Code Example /// ``` /// # use std::sync::Arc; -/// # use arrow_schema::{Schema, Field, DataType, SchemaRef}; +/// # use arrow::datatypes::{Schema, Field, DataType, SchemaRef}; /// # use datafusion_physical_expr::{ConstExpr, EquivalenceProperties}; /// # use datafusion_physical_expr::expressions::col; /// use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; @@ -2403,8 +2404,7 @@ mod tests { use crate::expressions::{col, BinaryExpr, Column}; use crate::ScalarFunctionExpr; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow_schema::{Fields, TimeUnit}; + use arrow::datatypes::{DataType, Field, Fields, Schema, TimeUnit}; use datafusion_common::{Constraint, ScalarValue}; use datafusion_expr::Operator; diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 1713842f410e..052054bad6c1 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -30,7 +30,7 @@ use arrow::compute::kernels::comparison::{regexp_is_match, regexp_is_match_scala use arrow::compute::kernels::concat_elements::concat_elements_utf8; use arrow::compute::{cast, ilike, like, nilike, nlike}; use arrow::datatypes::*; -use arrow_schema::ArrowError; +use arrow::error::ArrowError; use datafusion_common::cast::as_boolean_array; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::binary::BinaryTypeCoercer; diff --git a/datafusion/physical-expr/src/expressions/binary/kernels.rs b/datafusion/physical-expr/src/expressions/binary/kernels.rs index c0685c6decde..ae26f3e84241 100644 --- a/datafusion/physical-expr/src/expressions/binary/kernels.rs +++ b/datafusion/physical-expr/src/expressions/binary/kernels.rs @@ -27,7 +27,7 @@ use arrow::datatypes::DataType; use datafusion_common::plan_err; use datafusion_common::{Result, ScalarValue}; -use arrow_schema::ArrowError; +use arrow::error::ArrowError; use std::sync::Arc; /// Downcasts $LEFT and $RIGHT to $ARRAY_TYPE and then calls $KERNEL($LEFT, $RIGHT) diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index 0649cbd65d34..0ec985887c3f 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -23,10 +23,9 @@ use std::sync::Arc; use crate::physical_expr::PhysicalExpr; use arrow::{ - datatypes::{DataType, Schema}, + datatypes::{DataType, Schema, SchemaRef}, record_batch::RecordBatch, }; -use arrow_schema::SchemaRef; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{internal_err, plan_err, Result}; use datafusion_expr::ColumnarValue; diff --git a/datafusion/physical-expr/src/expressions/like.rs b/datafusion/physical-expr/src/expressions/like.rs index d61cd63c35b1..b26927b77f1f 100644 --- a/datafusion/physical-expr/src/expressions/like.rs +++ b/datafusion/physical-expr/src/expressions/like.rs @@ -19,8 +19,8 @@ use std::hash::Hash; use std::{any::Any, sync::Arc}; use crate::PhysicalExpr; +use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use arrow_schema::{DataType, Schema}; use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; use datafusion_physical_expr_common::datum::apply_cmp; @@ -183,7 +183,7 @@ mod test { use super::*; use crate::expressions::col; use arrow::array::*; - use arrow_schema::Field; + use arrow::datatypes::Field; use datafusion_common::cast::as_boolean_array; macro_rules! test_like { diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index 03f2111aca33..dc863ccff511 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -178,8 +178,8 @@ mod tests { use crate::expressions::{col, Column}; use arrow::array::*; + use arrow::datatypes::DataType::{Float32, Float64, Int16, Int32, Int64, Int8}; use arrow::datatypes::*; - use arrow_schema::DataType::{Float32, Float64, Int16, Int32, Int64, Int8}; use datafusion_common::cast::as_primitive_array; use datafusion_common::DataFusionError; diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index 166d2564fdf3..cb29109684fe 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -29,7 +29,7 @@ use crate::expressions::Literal; use crate::utils::{build_dag, ExprTreeNode}; use crate::PhysicalExpr; -use arrow_schema::{DataType, Schema}; +use arrow::datatypes::{DataType, Schema}; use datafusion_common::{internal_err, Result}; use datafusion_expr::interval_arithmetic::{apply_operator, satisfy_greater, Interval}; use datafusion_expr::Operator; @@ -723,8 +723,7 @@ mod tests { use crate::intervals::test_utils::gen_conjunctive_numerical_expr; use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; - use arrow::datatypes::TimeUnit; - use arrow_schema::Field; + use arrow::datatypes::{Field, TimeUnit}; use datafusion_common::ScalarValue; use itertools::Itertools; diff --git a/datafusion/physical-expr/src/intervals/test_utils.rs b/datafusion/physical-expr/src/intervals/test_utils.rs index fbd018fb9e80..c3d38a974ab0 100644 --- a/datafusion/physical-expr/src/intervals/test_utils.rs +++ b/datafusion/physical-expr/src/intervals/test_utils.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use crate::expressions::{binary, BinaryExpr, Literal}; use crate::PhysicalExpr; -use arrow_schema::Schema; +use arrow::datatypes::Schema; use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::Operator; diff --git a/datafusion/physical-expr/src/intervals/utils.rs b/datafusion/physical-expr/src/intervals/utils.rs index 56af8238c04e..910631ef4a43 100644 --- a/datafusion/physical-expr/src/intervals/utils.rs +++ b/datafusion/physical-expr/src/intervals/utils.rs @@ -25,7 +25,7 @@ use crate::{ }; use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; -use arrow_schema::{DataType, SchemaRef}; +use arrow::datatypes::{DataType, SchemaRef}; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::Operator; diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 8504705f2a09..fac83dfc4524 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -393,7 +393,7 @@ pub fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { #[cfg(test)] mod tests { use arrow::array::{ArrayRef, BooleanArray, RecordBatch, StringArray}; - use arrow_schema::{DataType, Field}; + use arrow::datatypes::{DataType, Field}; use datafusion_expr::{col, lit}; diff --git a/datafusion/physical-expr/src/utils/guarantee.rs b/datafusion/physical-expr/src/utils/guarantee.rs index 7afb78b8bf2e..8092dc3c1a61 100644 --- a/datafusion/physical-expr/src/utils/guarantee.rs +++ b/datafusion/physical-expr/src/utils/guarantee.rs @@ -417,7 +417,7 @@ mod test { use super::*; use crate::planner::logical2physical; - use arrow_schema::{DataType, Field, Schema, SchemaRef}; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_expr::expr_fn::*; use datafusion_expr::{lit, Expr}; diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 25769bef7200..7e4c7f0e10ba 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -259,10 +259,12 @@ pub(crate) mod tests { use crate::expressions::{binary, cast, col, in_list, lit, Literal}; use arrow::array::{ArrayRef, Float32Array, Float64Array}; - use arrow_schema::{DataType, Field, Schema}; + use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{exec_err, DataFusionError, ScalarValue}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, + }; use petgraph::visit::Bfs; @@ -309,12 +311,8 @@ pub(crate) mod tests { Ok(input[0].sort_properties) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = ColumnarValue::values_to_arrays(&args.args)?; let arr: ArrayRef = match args[0].data_type() { DataType::Float64 => Arc::new({ diff --git a/datafusion/physical-expr/src/window/standard_window_function_expr.rs b/datafusion/physical-expr/src/window/standard_window_function_expr.rs index d308812a0e35..624b747d93f9 100644 --- a/datafusion/physical-expr/src/window/standard_window_function_expr.rs +++ b/datafusion/physical-expr/src/window/standard_window_function_expr.rs @@ -18,9 +18,8 @@ use crate::{PhysicalExpr, PhysicalSortExpr}; use arrow::array::ArrayRef; -use arrow::datatypes::Field; +use arrow::datatypes::{Field, SchemaRef}; use arrow::record_batch::RecordBatch; -use arrow_schema::SchemaRef; use datafusion_common::Result; use datafusion_expr::PartitionEvaluator; diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index be7d080b683c..793f2e5ee586 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -553,7 +553,7 @@ mod tests { use crate::window::window_expr::is_row_ahead; use arrow::array::{ArrayRef, Float64Array}; - use arrow_schema::SortOptions; + use arrow::compute::SortOptions; use datafusion_common::Result; #[test] diff --git a/datafusion/physical-optimizer/Cargo.toml b/datafusion/physical-optimizer/Cargo.toml index 4dc9ac22f173..c9c86e9c8d5c 100644 --- a/datafusion/physical-optimizer/Cargo.toml +++ b/datafusion/physical-optimizer/Cargo.toml @@ -36,7 +36,6 @@ recursive_protection = ["dep:recursive"] [dependencies] arrow = { workspace = true } -arrow-schema = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } @@ -44,14 +43,10 @@ datafusion-expr-common = { workspace = true, default-features = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } datafusion-physical-plan = { workspace = true } -futures = { workspace = true } itertools = { workspace = true } log = { workspace = true } recursive = { workspace = true, optional = true } -url = { workspace = true } [dev-dependencies] datafusion-expr = { workspace = true } datafusion-functions-nested = { workspace = true } -rstest = { workspace = true } -tokio = { workspace = true } diff --git a/datafusion/physical-optimizer/src/enforce_sorting/mod.rs b/datafusion/physical-optimizer/src/enforce_sorting/mod.rs index 3a73e5f91c2c..c268aa22c767 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/mod.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/mod.rs @@ -73,8 +73,8 @@ use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties, InputOrde use itertools::izip; -/// This rule inspects [`SortExec`]'s in the given physical plan and removes the -/// ones it can prove unnecessary. +/// This rule inspects [`SortExec`]'s in the given physical plan in order to +/// remove unnecessary sorts, and optimize sort performance across the plan. #[derive(Default, Debug)] pub struct EnforceSorting {} @@ -85,33 +85,43 @@ impl EnforceSorting { } } -/// This object is used within the [`EnforceSorting`] rule to track the closest +/// This context object is used within the [`EnforceSorting`] rule to track the closest /// [`SortExec`] descendant(s) for every child of a plan. The data attribute /// stores whether the plan is a `SortExec` or is connected to a `SortExec` /// via its children. pub type PlanWithCorrespondingSort = PlanContext; -fn update_sort_ctx_children( - mut node: PlanWithCorrespondingSort, +/// For a given node, update the [`PlanContext.data`] attribute. +/// +/// If the node is a `SortExec`, or any of the node's children are a `SortExec`, +/// then set the attribute to true. +/// +/// This requires a bottom-up traversal was previously performed, updating the +/// children previously. +fn update_sort_ctx_children_data( + mut node_and_ctx: PlanWithCorrespondingSort, data: bool, ) -> Result { - for child_node in node.children.iter_mut() { - let plan = &child_node.plan; - child_node.data = if is_sort(plan) { - // Initiate connection: + // Update `child.data` for all children. + for child_node in node_and_ctx.children.iter_mut() { + let child_plan = &child_node.plan; + child_node.data = if is_sort(child_plan) { + // child is sort true - } else if is_limit(plan) { + } else if is_limit(child_plan) { // There is no sort linkage for this path, it starts at a limit. false } else { - let is_spm = is_sort_preserving_merge(plan); - let required_orderings = plan.required_input_ordering(); - let flags = plan.maintains_input_order(); + // If a descendent is a sort, and the child maintains the sort. + let is_spm = is_sort_preserving_merge(child_plan); + let required_orderings = child_plan.required_input_ordering(); + let flags = child_plan.maintains_input_order(); // Add parent node to the tree if there is at least one child with // a sort connection: izip!(flags, required_orderings).any(|(maintains, required_ordering)| { let propagates_ordering = (maintains && required_ordering.is_none()) || is_spm; + // `connected_to_sort` only returns the correct answer with bottom-up traversal let connected_to_sort = child_node.children.iter().any(|child| child.data); propagates_ordering && connected_to_sort @@ -119,8 +129,10 @@ fn update_sort_ctx_children( } } - node.data = data; - node.update_plan_from_children() + // set data attribute on current node + node_and_ctx.data = data; + + Ok(node_and_ctx) } /// This object is used within the [`EnforceSorting`] rule to track the closest @@ -152,11 +164,15 @@ fn update_coalesce_ctx_children( }; } -/// The boolean flag `repartition_sorts` defined in the config indicates -/// whether we elect to transform [`CoalescePartitionsExec`] + [`SortExec`] cascades -/// into [`SortExec`] + [`SortPreservingMergeExec`] cascades, which enables us to -/// perform sorting in parallel. +/// Performs optimizations based upon a series of subrules. /// +/// Refer to each subrule for detailed descriptions of the optimizations performed: +/// [`ensure_sorting`], [`parallelize_sorts`], [`replace_with_order_preserving_variants()`], +/// and [`pushdown_sorts`]. +/// +/// Subrule application is ordering dependent. +/// +/// The subrule `parallelize_sorts` is only applied if `repartition_sorts` is enabled. /// Optimizer consists of 5 main parts which work sequentially /// 1. `ensure_sorting` Responsible for removing unnecessary [`SortExec`]s, [`SortPreservingMergeExec`]s /// adjusting window operators, etc. @@ -262,17 +278,65 @@ fn replace_with_partial_sort( Ok(plan) } -/// This function turns plans of the form +/// Transform [`CoalescePartitionsExec`] + [`SortExec`] into +/// [`SortExec`] + [`SortPreservingMergeExec`] as illustrated below: +/// +/// The [`CoalescePartitionsExec`] + [`SortExec`] cascades +/// combine the partitions first, and then sort: +/// ```text +/// ┌ ─ ─ ─ ─ ─ ┐ +/// ┌─┬─┬─┐ +/// ││B│A│D│... ├──┐ +/// └─┴─┴─┘ │ +/// └ ─ ─ ─ ─ ─ ┘ │ ┌────────────────────────┐ ┌ ─ ─ ─ ─ ─ ─ ┐ ┌────────┐ ┌ ─ ─ ─ ─ ─ ─ ─ ┐ +/// Partition 1 │ │ Coalesce │ ┌─┬─┬─┬─┬─┐ │ │ ┌─┬─┬─┬─┬─┐ +/// ├──▶(no ordering guarantees)│──▶││B│E│A│D│C│...───▶ Sort ├───▶││A│B│C│D│E│... │ +/// │ │ │ └─┴─┴─┴─┴─┘ │ │ └─┴─┴─┴─┴─┘ +/// ┌ ─ ─ ─ ─ ─ ┐ │ └────────────────────────┘ └ ─ ─ ─ ─ ─ ─ ┘ └────────┘ └ ─ ─ ─ ─ ─ ─ ─ ┘ +/// ┌─┬─┐ │ Partition Partition +/// ││E│C│ ... ├──┘ +/// └─┴─┘ +/// └ ─ ─ ─ ─ ─ ┘ +/// Partition 2 +/// ``` +/// +/// +/// The [`SortExec`] + [`SortPreservingMergeExec`] cascades +/// sorts each partition first, then merge partitions while retaining the sort: +/// ```text +/// ┌ ─ ─ ─ ─ ─ ┐ ┌────────┐ ┌ ─ ─ ─ ─ ─ ┐ +/// ┌─┬─┬─┐ │ │ ┌─┬─┬─┐ +/// ││B│A│D│... │──▶│ Sort │──▶││A│B│D│... │──┐ +/// └─┴─┴─┘ │ │ └─┴─┴─┘ │ +/// └ ─ ─ ─ ─ ─ ┘ └────────┘ └ ─ ─ ─ ─ ─ ┘ │ ┌─────────────────────┐ ┌ ─ ─ ─ ─ ─ ─ ─ ┐ +/// Partition 1 Partition 1 │ │ │ ┌─┬─┬─┬─┬─┐ +/// ├──▶ SortPreservingMerge ├───▶││A│B│C│D│E│... │ +/// │ │ │ └─┴─┴─┴─┴─┘ +/// ┌ ─ ─ ─ ─ ─ ┐ ┌────────┐ ┌ ─ ─ ─ ─ ─ ┐ │ └─────────────────────┘ └ ─ ─ ─ ─ ─ ─ ─ ┘ +/// ┌─┬─┐ │ │ ┌─┬─┐ │ Partition +/// ││E│C│ ... │──▶│ Sort ├──▶││C│E│ ... │──┘ +/// └─┴─┘ │ │ └─┴─┘ +/// └ ─ ─ ─ ─ ─ ┘ └────────┘ └ ─ ─ ─ ─ ─ ┘ +/// Partition 2 Partition 2 +/// ``` +/// +/// The latter [`SortExec`] + [`SortPreservingMergeExec`] cascade performs the +/// sort first on a per-partition basis, thereby parallelizing the sort. +/// +/// +/// The outcome is that plans of the form /// ```text /// "SortExec: expr=\[a@0 ASC\]", -/// " CoalescePartitionsExec", -/// " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", +/// " ...nodes..." +/// " CoalescePartitionsExec", +/// " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", /// ``` -/// to +/// are transformed into /// ```text /// "SortPreservingMergeExec: \[a@0 ASC\]", -/// " SortExec: expr=\[a@0 ASC\]", -/// " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", +/// " ...nodes..." +/// " SortExec: expr=\[a@0 ASC\]", +/// " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", /// ``` /// by following connections from [`CoalescePartitionsExec`]s to [`SortExec`]s. /// By performing sorting in parallel, we can increase performance in some scenarios. @@ -348,7 +412,7 @@ pub fn parallelize_sorts( } /// This function enforces sorting requirements and makes optimizations without -/// violating these requirements whenever possible. +/// violating these requirements whenever possible. Requires a bottom-up traversal. /// /// **Steps** /// 1. Analyze if there are any immediate removals of [`SortExec`]s if so, removes them (see `analyze_immediate_sort_removal`) @@ -366,7 +430,7 @@ pub fn parallelize_sorts( pub fn ensure_sorting( mut requirements: PlanWithCorrespondingSort, ) -> Result> { - requirements = update_sort_ctx_children(requirements, false)?; + requirements = update_sort_ctx_children_data(requirements, false)?; // Perform naive analysis at the beginning -- remove already-satisfied sorts: if requirements.children.is_empty() { @@ -397,7 +461,7 @@ pub fn ensure_sorting( child = update_child_to_remove_unnecessary_sort(idx, child, plan)?; } child = add_sort_above(child, required.lex_requirement().clone(), None); - child = update_sort_ctx_children(child, true)?; + child = update_sort_ctx_children_data(child, true)?; } } else if physical_ordering.is_none() || !plan.maintains_input_order()[idx] @@ -410,21 +474,26 @@ pub fn ensure_sorting( updated_children.push(child); } requirements.children = updated_children; + requirements = requirements.update_plan_from_children()?; // For window expressions, we can remove some sorts when we can // calculate the result in reverse: let child_node = &requirements.children[0]; - if is_window(plan) && child_node.data { + if is_window(&requirements.plan) && child_node.data { return adjust_window_sort_removal(requirements).map(Transformed::yes); - } else if is_sort_preserving_merge(plan) + } else if is_sort_preserving_merge(&requirements.plan) && child_node.plan.output_partitioning().partition_count() <= 1 { // This `SortPreservingMergeExec` is unnecessary, input already has a - // single partition. - let child_node = requirements.children.swap_remove(0); + // single partition and no fetch is required. + let mut child_node = requirements.children.swap_remove(0); + if let Some(fetch) = requirements.plan.fetch() { + // Add the limit exec if the original SPM had a fetch: + child_node.plan = + Arc::new(LocalLimitExec::new(Arc::clone(&child_node.plan), fetch)); + } return Ok(Transformed::yes(child_node)); } - - update_sort_ctx_children(requirements, false).map(Transformed::yes) + update_sort_ctx_children_data(requirements, false).map(Transformed::yes) } /// Analyzes if there are any immediate sort removals by checking the `SortExec`s @@ -648,8 +717,9 @@ fn remove_corresponding_sort_from_sub_plan( } }) .collect::>()?; + node = node.update_plan_from_children()?; if any_connection || node.children.is_empty() { - node = update_sort_ctx_children(node, false)?; + node = update_sort_ctx_children_data(node, false)?; } // Replace with variants that do not preserve order. @@ -682,7 +752,7 @@ fn remove_corresponding_sort_from_sub_plan( Arc::new(CoalescePartitionsExec::new(plan)) as _ }; node = PlanWithCorrespondingSort::new(plan, false, vec![node]); - node = update_sort_ctx_children(node, false)?; + node = update_sort_ctx_children_data(node, false)?; } Ok(node) } diff --git a/datafusion/physical-optimizer/src/enforce_sorting/replace_with_order_preserving_variants.rs b/datafusion/physical-optimizer/src/enforce_sorting/replace_with_order_preserving_variants.rs index 1ddb279e0dd7..b770a1b39afe 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/replace_with_order_preserving_variants.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/replace_with_order_preserving_variants.rs @@ -44,7 +44,7 @@ use itertools::izip; pub type OrderPreservationContext = PlanContext; /// Updates order-preservation data for all children of the given node. -pub fn update_children(opc: &mut OrderPreservationContext) { +pub fn update_order_preservation_ctx_children_data(opc: &mut OrderPreservationContext) { for PlanContext { plan, children, @@ -238,7 +238,7 @@ pub fn replace_with_order_preserving_variants( // should only be made to fix the pipeline (streaming). is_spm_better: bool, ) -> Result> { - update_children(&mut requirements); + update_order_preservation_ctx_children_data(&mut requirements); if !(is_sort(&requirements.plan) && requirements.children[0].data) { return Ok(Transformed::no(requirements)); } diff --git a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs index a08c2fb0bdd0..4d9104041181 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs @@ -22,7 +22,7 @@ use crate::utils::{ add_sort_above, is_sort, is_sort_preserving_merge, is_union, is_window, }; -use arrow_schema::SchemaRef; +use arrow::datatypes::SchemaRef; use datafusion_common::tree_node::{ ConcreteTreeNode, Transformed, TreeNode, TreeNodeRecursion, }; diff --git a/datafusion/physical-optimizer/src/pruning.rs b/datafusion/physical-optimizer/src/pruning.rs index 8bf0ffbd3c32..2004aeafb893 100644 --- a/datafusion/physical-optimizer/src/pruning.rs +++ b/datafusion/physical-optimizer/src/pruning.rs @@ -1590,6 +1590,7 @@ fn build_statistics_expr( )), )) } + Operator::NotLikeMatch => build_not_like_match(expr_builder)?, Operator::LikeMatch => build_like_match(expr_builder).ok_or_else(|| { plan_datafusion_err!( "LIKE expression with wildcard at the beginning is not supported" @@ -1638,6 +1639,19 @@ fn build_statistics_expr( Ok(statistics_expr) } +/// returns the string literal of the scalar value if it is a string +fn unpack_string(s: &ScalarValue) -> Option<&str> { + s.try_as_str().flatten() +} + +fn extract_string_literal(expr: &Arc) -> Option<&str> { + if let Some(lit) = expr.as_any().downcast_ref::() { + let s = unpack_string(lit.value())?; + return Some(s); + } + None +} + /// Convert `column LIKE literal` where P is a constant prefix of the literal /// to a range check on the column: `P <= column && column < P'`, where P' is the /// lowest string after all P* strings. @@ -1650,19 +1664,6 @@ fn build_like_match( // column LIKE '%foo%' => min <= '' && '' <= max => true // column LIKE 'foo' => min <= 'foo' && 'foo' <= max - /// returns the string literal of the scalar value if it is a string - fn unpack_string(s: &ScalarValue) -> Option<&str> { - s.try_as_str().flatten() - } - - fn extract_string_literal(expr: &Arc) -> Option<&str> { - if let Some(lit) = expr.as_any().downcast_ref::() { - let s = unpack_string(lit.value())?; - return Some(s); - } - None - } - // TODO Handle ILIKE perhaps by making the min lowercase and max uppercase // this may involve building the physical expressions that call lower() and upper() let min_column_expr = expr_builder.min_column_expr().ok()?; @@ -1710,6 +1711,80 @@ fn build_like_match( Some(combined) } +// For predicate `col NOT LIKE 'const_prefix%'`, we rewrite it as `(col_min NOT LIKE 'const_prefix%' OR col_max NOT LIKE 'const_prefix%')`. +// +// The intuition is that if both `col_min` and `col_max` begin with `const_prefix` that means +// **all** data in this row group begins with `const_prefix` as well (and therefore the predicate +// looking for rows that don't begin with `const_prefix` can never be true) +fn build_not_like_match( + expr_builder: &mut PruningExpressionBuilder<'_>, +) -> Result> { + // col NOT LIKE 'const_prefix%' -> !(col_min LIKE 'const_prefix%' && col_max LIKE 'const_prefix%') -> (col_min NOT LIKE 'const_prefix%' || col_max NOT LIKE 'const_prefix%') + + let min_column_expr = expr_builder.min_column_expr()?; + let max_column_expr = expr_builder.max_column_expr()?; + + let scalar_expr = expr_builder.scalar_expr(); + + let pattern = extract_string_literal(scalar_expr).ok_or_else(|| { + plan_datafusion_err!("cannot extract literal from NOT LIKE expression") + })?; + + let (const_prefix, remaining) = split_constant_prefix(pattern); + if const_prefix.is_empty() || remaining != "%" { + // we can not handle `%` at the beginning or in the middle of the pattern + // Example: For pattern "foo%bar", the row group might include values like + // ["foobar", "food", "foodbar"], making it unsafe to prune. + // Even if the min/max values in the group (e.g., "foobar" and "foodbar") + // match the pattern, intermediate values like "food" may not + // match the full pattern "foo%bar", making pruning unsafe. + // (truncate foo%bar to foo% have same problem) + + // we can not handle pattern containing `_` + // Example: For pattern "foo_", row groups might contain ["fooa", "fooaa", "foob"], + // which means not every row is guaranteed to match the pattern. + return Err(plan_datafusion_err!( + "NOT LIKE expressions only support constant_prefix+wildcard`%`" + )); + } + + let min_col_not_like_epxr = Arc::new(phys_expr::LikeExpr::new( + true, + false, + Arc::clone(&min_column_expr), + Arc::clone(scalar_expr), + )); + + let max_col_not_like_expr = Arc::new(phys_expr::LikeExpr::new( + true, + false, + Arc::clone(&max_column_expr), + Arc::clone(scalar_expr), + )); + + Ok(Arc::new(phys_expr::BinaryExpr::new( + min_col_not_like_epxr, + Operator::Or, + max_col_not_like_expr, + ))) +} + +/// Returns unescaped constant prefix of a LIKE pattern (possibly empty) and the remaining pattern (possibly empty) +fn split_constant_prefix(pattern: &str) -> (&str, &str) { + let char_indices = pattern.char_indices().collect::>(); + for i in 0..char_indices.len() { + let (idx, char) = char_indices[i]; + if char == '%' || char == '_' { + if i != 0 && char_indices[i - 1].1 == '\\' { + // ecsaped by `\` + continue; + } + return (&pattern[..idx], &pattern[idx..]); + } + } + (pattern, "") +} + /// Increment a UTF8 string by one, returning `None` if it can't be incremented. /// This makes it so that the returned string will always compare greater than the input string /// or any other string with the same prefix. @@ -4061,6 +4136,132 @@ mod tests { prune_with_expr(expr, &schema, &statistics, expected_ret); } + #[test] + fn prune_utf8_not_like_one() { + let (schema, statistics) = utf8_setup(); + + let expr = col("s1").not_like(lit("A\u{10ffff}_")); + #[rustfmt::skip] + let expected_ret = &[ + // s1 ["A", "Z"] ==> some rows could pass (must keep) + true, + // s1 ["A", "L"] ==> some rows could pass (must keep) + true, + // s1 ["N", "Z"] ==> some rows could pass (must keep) + true, + // s1 ["M", "M"] ==> some rows could pass (must keep) + true, + // s1 [NULL, NULL] ==> unknown (must keep) + true, + // s1 ["A", NULL] ==> some rows could pass (must keep) + true, + // s1 ["", "A"] ==> some rows could pass (must keep) + true, + // s1 ["", ""] ==> some rows could pass (must keep) + true, + // s1 ["AB", "A\u{10ffff}\u{10ffff}\u{10ffff}"] ==> some rows could pass (must keep) + true, + // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> no row match. (min, max) maybe truncate + // orignal (min, max) maybe ("A\u{10ffff}\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}\u{10ffff}\u{10ffff}") + true, + ]; + prune_with_expr(expr, &schema, &statistics, expected_ret); + } + + #[test] + fn prune_utf8_not_like_many() { + let (schema, statistics) = utf8_setup(); + + let expr = col("s1").not_like(lit("A\u{10ffff}%")); + #[rustfmt::skip] + let expected_ret = &[ + // s1 ["A", "Z"] ==> some rows could pass (must keep) + true, + // s1 ["A", "L"] ==> some rows could pass (must keep) + true, + // s1 ["N", "Z"] ==> some rows could pass (must keep) + true, + // s1 ["M", "M"] ==> some rows could pass (must keep) + true, + // s1 [NULL, NULL] ==> unknown (must keep) + true, + // s1 ["A", NULL] ==> some rows could pass (must keep) + true, + // s1 ["", "A"] ==> some rows could pass (must keep) + true, + // s1 ["", ""] ==> some rows could pass (must keep) + true, + // s1 ["AB", "A\u{10ffff}\u{10ffff}\u{10ffff}"] ==> some rows could pass (must keep) + true, + // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> no row match + false, + ]; + prune_with_expr(expr, &schema, &statistics, expected_ret); + + let expr = col("s1").not_like(lit("A\u{10ffff}%\u{10ffff}")); + #[rustfmt::skip] + let expected_ret = &[ + // s1 ["A", "Z"] ==> some rows could pass (must keep) + true, + // s1 ["A", "L"] ==> some rows could pass (must keep) + true, + // s1 ["N", "Z"] ==> some rows could pass (must keep) + true, + // s1 ["M", "M"] ==> some rows could pass (must keep) + true, + // s1 [NULL, NULL] ==> unknown (must keep) + true, + // s1 ["A", NULL] ==> some rows could pass (must keep) + true, + // s1 ["", "A"] ==> some rows could pass (must keep) + true, + // s1 ["", ""] ==> some rows could pass (must keep) + true, + // s1 ["AB", "A\u{10ffff}\u{10ffff}\u{10ffff}"] ==> some rows could pass (must keep) + true, + // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> some rows could pass (must keep) + true, + ]; + prune_with_expr(expr, &schema, &statistics, expected_ret); + + let expr = col("s1").not_like(lit("A\u{10ffff}%\u{10ffff}_")); + #[rustfmt::skip] + let expected_ret = &[ + // s1 ["A", "Z"] ==> some rows could pass (must keep) + true, + // s1 ["A", "L"] ==> some rows could pass (must keep) + true, + // s1 ["N", "Z"] ==> some rows could pass (must keep) + true, + // s1 ["M", "M"] ==> some rows could pass (must keep) + true, + // s1 [NULL, NULL] ==> unknown (must keep) + true, + // s1 ["A", NULL] ==> some rows could pass (must keep) + true, + // s1 ["", "A"] ==> some rows could pass (must keep) + true, + // s1 ["", ""] ==> some rows could pass (must keep) + true, + // s1 ["AB", "A\u{10ffff}\u{10ffff}\u{10ffff}"] ==> some rows could pass (must keep) + true, + // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> some rows could pass (must keep) + true, + ]; + prune_with_expr(expr, &schema, &statistics, expected_ret); + + let expr = col("s1").not_like(lit("A\\%%")); + let statistics = TestStatistics::new().with( + "s1", + ContainerStats::new_utf8( + vec![Some("A%a"), Some("A")], + vec![Some("A%c"), Some("A")], + ), + ); + let expected_ret = &[false, true]; + prune_with_expr(expr, &schema, &statistics, expected_ret); + } + #[test] fn test_rewrite_expr_to_prunable() { let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index b84243b1b56b..f0afdaa2de3d 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -78,3 +78,7 @@ tokio = { workspace = true, features = [ [[bench]] harness = false name = "spm" + +[[bench]] +harness = false +name = "partial_ordering" diff --git a/datafusion/physical-plan/benches/partial_ordering.rs b/datafusion/physical-plan/benches/partial_ordering.rs new file mode 100644 index 000000000000..422826abcc8b --- /dev/null +++ b/datafusion/physical-plan/benches/partial_ordering.rs @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int32Array}; +use arrow_schema::{DataType, Field, Schema, SortOptions}; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion_physical_expr::{expressions::col, LexOrdering, PhysicalSortExpr}; +use datafusion_physical_plan::aggregates::order::GroupOrderingPartial; + +const BATCH_SIZE: usize = 8192; + +fn create_test_arrays(num_columns: usize) -> Vec { + (0..num_columns) + .map(|i| { + Arc::new(Int32Array::from_iter_values( + (0..BATCH_SIZE as i32).map(|x| x * (i + 1) as i32), + )) as ArrayRef + }) + .collect() +} +fn bench_new_groups(c: &mut Criterion) { + let mut group = c.benchmark_group("group_ordering_partial"); + + // Test with 1, 2, 4, and 8 order indices + for num_columns in [1, 2, 4, 8] { + let fields: Vec = (0..num_columns) + .map(|i| Field::new(format!("col{}", i), DataType::Int32, false)) + .collect(); + let schema = Schema::new(fields); + + let order_indices: Vec = (0..num_columns).collect(); + let ordering = LexOrdering::new( + (0..num_columns) + .map(|i| { + PhysicalSortExpr::new( + col(&format!("col{}", i), &schema).unwrap(), + SortOptions::default(), + ) + }) + .collect(), + ); + + group.bench_function(format!("order_indices_{}", num_columns), |b| { + let batch_group_values = create_test_arrays(num_columns); + let group_indices: Vec = (0..BATCH_SIZE).collect(); + + b.iter(|| { + let mut ordering = + GroupOrderingPartial::try_new(&schema, &order_indices, &ordering) + .unwrap(); + ordering + .new_groups(&batch_group_values, &group_indices, BATCH_SIZE) + .unwrap(); + }); + }); + } + group.finish(); +} + +criterion_group!(benches, bench_new_groups); +criterion_main!(benches); diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index 4cbeed9951f9..ce56ca4f7dfd 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -23,8 +23,7 @@ use arrow::array::types::{ TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; use arrow::array::{downcast_primitive, ArrayRef, RecordBatch}; -use arrow_schema::TimeUnit; -use arrow_schema::{DataType, SchemaRef}; +use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use datafusion_common::Result; use datafusion_expr::EmitTo; diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs index 96885f03146c..ac96a98edfe1 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs @@ -32,13 +32,13 @@ use ahash::RandomState; use arrow::array::{Array, ArrayRef, RecordBatch}; use arrow::compute::cast; use arrow::datatypes::{ - BinaryViewType, Date32Type, Date64Type, Decimal128Type, Float32Type, Float64Type, - Int16Type, Int32Type, Int64Type, Int8Type, StringViewType, Time32MillisecondType, - Time32SecondType, Time64MicrosecondType, Time64NanosecondType, - TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, - TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, + BinaryViewType, DataType, Date32Type, Date64Type, Decimal128Type, Float32Type, + Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, Schema, SchemaRef, + StringViewType, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, + TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, + UInt8Type, }; -use arrow_schema::{DataType, Schema, SchemaRef, TimeUnit}; use datafusion_common::hash_utils::create_hashes; use datafusion_common::{not_impl_err, DataFusionError, Result}; use datafusion_execution::memory_pool::proxy::{HashTableAllocExt, VecAllocExt}; @@ -1236,8 +1236,8 @@ mod tests { use std::{collections::HashMap, sync::Arc}; use arrow::array::{ArrayRef, Int64Array, RecordBatch, StringArray, StringViewArray}; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::{compute::concat_batches, util::pretty::pretty_format_batches}; - use arrow_schema::{DataType, Field, Schema, SchemaRef}; use datafusion_common::utils::proxy::HashTableAllocExt; use datafusion_expr::EmitTo; diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs index c85245d05592..005dcc8da386 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs @@ -19,7 +19,7 @@ use crate::aggregates::group_values::multi_group_by::{nulls_equal_to, GroupColum use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder; use arrow::array::{cast::AsArray, Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray}; use arrow::buffer::ScalarBuffer; -use arrow_schema::DataType; +use arrow::datatypes::DataType; use datafusion_execution::memory_pool::proxy::VecAllocExt; use itertools::izip; use std::iter; @@ -212,8 +212,7 @@ mod tests { use crate::aggregates::group_values::multi_group_by::primitive::PrimitiveGroupValueBuilder; use arrow::array::{ArrayRef, Int64Array, NullBufferBuilder}; - use arrow::datatypes::Int64Type; - use arrow_schema::DataType; + use arrow::datatypes::{DataType, Int64Type}; use super::GroupColumn; diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index a0331bf3fa3d..63751d470313 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -19,8 +19,8 @@ use crate::aggregates::group_values::GroupValues; use ahash::RandomState; use arrow::array::{Array, ArrayRef, ListArray, RecordBatch, StructArray}; use arrow::compute::cast; +use arrow::datatypes::{DataType, SchemaRef}; use arrow::row::{RowConverter, Rows, SortField}; -use arrow_schema::{DataType, SchemaRef}; use datafusion_common::hash_utils::create_hashes; use datafusion_common::Result; use datafusion_execution::memory_pool::proxy::{HashTableAllocExt, VecAllocExt}; @@ -285,7 +285,7 @@ fn dictionary_encode_if_necessary( let list = array.as_any().downcast_ref::().unwrap(); Ok(Arc::new(ListArray::try_new( - Arc::::clone(expected_field), + Arc::::clone(expected_field), list.offsets().clone(), dictionary_encode_if_necessary( Arc::::clone(list.values()), diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs index 5a6235edb25a..d945d3ddcbf5 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs @@ -22,9 +22,8 @@ use arrow::array::{ cast::AsArray, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, NullBufferBuilder, PrimitiveArray, }; -use arrow::datatypes::i256; +use arrow::datatypes::{i256, DataType}; use arrow::record_batch::RecordBatch; -use arrow_schema::DataType; use datafusion_common::Result; use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_expr::EmitTo; diff --git a/datafusion/physical-plan/src/aggregates/order/mod.rs b/datafusion/physical-plan/src/aggregates/order/mod.rs index 61a0ab8b247d..0b742b3d20fd 100644 --- a/datafusion/physical-plan/src/aggregates/order/mod.rs +++ b/datafusion/physical-plan/src/aggregates/order/mod.rs @@ -16,7 +16,7 @@ // under the License. use arrow::array::ArrayRef; -use arrow_schema::Schema; +use arrow::datatypes::Schema; use datafusion_common::Result; use datafusion_expr::EmitTo; use datafusion_physical_expr_common::sort_expr::LexOrdering; diff --git a/datafusion/physical-plan/src/aggregates/order/partial.rs b/datafusion/physical-plan/src/aggregates/order/partial.rs index 30655cd0ad59..aff69277a4ce 100644 --- a/datafusion/physical-plan/src/aggregates/order/partial.rs +++ b/datafusion/physical-plan/src/aggregates/order/partial.rs @@ -16,12 +16,15 @@ // under the License. use arrow::array::ArrayRef; -use arrow::row::{OwnedRow, RowConverter, Rows, SortField}; -use arrow_schema::Schema; -use datafusion_common::Result; +use arrow::compute::SortOptions; +use arrow::datatypes::Schema; +use arrow_ord::partition::partition; +use datafusion_common::utils::{compare_rows, get_row_at_idx}; +use datafusion_common::{Result, ScalarValue}; use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_expr::EmitTo; use datafusion_physical_expr_common::sort_expr::LexOrdering; +use std::cmp::Ordering; use std::mem::size_of; use std::sync::Arc; @@ -69,13 +72,9 @@ pub struct GroupOrderingPartial { /// For example if grouping by `id, state` and ordered by `state` /// this would be `[1]`. order_indices: Vec, - - /// Converter for the sort key (used on the group columns - /// specified in `order_indexes`) - row_converter: RowConverter, } -#[derive(Debug, Default)] +#[derive(Debug, Default, PartialEq)] enum State { /// The ordering was temporarily taken. `Self::Taken` is left /// when state must be temporarily taken to satisfy the borrow @@ -93,7 +92,7 @@ enum State { /// Smallest group index with the sort_key current_sort: usize, /// The sort key of group_index `current_sort` - sort_key: OwnedRow, + sort_key: Vec, /// index of the current group for which values are being /// generated current: usize, @@ -103,47 +102,47 @@ enum State { Complete, } +impl State { + fn size(&self) -> usize { + match self { + State::Taken => 0, + State::Start => 0, + State::InProgress { sort_key, .. } => sort_key + .iter() + .map(|scalar_value| scalar_value.size()) + .sum(), + State::Complete => 0, + } + } +} + impl GroupOrderingPartial { + /// TODO: Remove unnecessary `input_schema` parameter. pub fn try_new( - input_schema: &Schema, + _input_schema: &Schema, order_indices: &[usize], ordering: &LexOrdering, ) -> Result { assert!(!order_indices.is_empty()); assert!(order_indices.len() <= ordering.len()); - // get only the section of ordering, that consist of group by expressions. - let fields = ordering[0..order_indices.len()] - .iter() - .map(|sort_expr| { - Ok(SortField::new_with_options( - sort_expr.expr.data_type(input_schema)?, - sort_expr.options, - )) - }) - .collect::>>()?; - Ok(Self { state: State::Start, order_indices: order_indices.to_vec(), - row_converter: RowConverter::new(fields)?, }) } - /// Creates sort keys from the group values + /// Select sort keys from the group values /// /// For example, if group_values had `A, B, C` but the input was /// only sorted on `B` and `C` this should return rows for (`B`, /// `C`) - fn compute_sort_keys(&mut self, group_values: &[ArrayRef]) -> Result { + fn compute_sort_keys(&mut self, group_values: &[ArrayRef]) -> Vec { // Take only the columns that are in the sort key - let sort_values: Vec<_> = self - .order_indices + self.order_indices .iter() .map(|&idx| Arc::clone(&group_values[idx])) - .collect(); - - Ok(self.row_converter.convert_columns(&sort_values)?) + .collect() } /// How many groups be emitted, or None if no data can be emitted @@ -194,6 +193,23 @@ impl GroupOrderingPartial { }; } + fn updated_sort_key( + current_sort: usize, + sort_key: Option>, + range_current_sort: usize, + range_sort_key: Vec, + ) -> Result<(usize, Vec)> { + if let Some(sort_key) = sort_key { + let sort_options = vec![SortOptions::new(false, false); sort_key.len()]; + let ordering = compare_rows(&sort_key, &range_sort_key, &sort_options)?; + if ordering == Ordering::Equal { + return Ok((current_sort, sort_key)); + } + } + + Ok((range_current_sort, range_sort_key)) + } + /// Called when new groups are added in a batch. See documentation /// on [`super::GroupOrdering::new_groups`] pub fn new_groups( @@ -207,37 +223,46 @@ impl GroupOrderingPartial { let max_group_index = total_num_groups - 1; - // compute the sort key values for each group - let sort_keys = self.compute_sort_keys(batch_group_values)?; - - let old_state = std::mem::take(&mut self.state); - let (mut current_sort, mut sort_key) = match &old_state { + let (current_sort, sort_key) = match std::mem::take(&mut self.state) { State::Taken => unreachable!("State previously taken"), - State::Start => (0, sort_keys.row(0)), + State::Start => (0, None), State::InProgress { current_sort, sort_key, .. - } => (*current_sort, sort_key.row()), + } => (current_sort, Some(sort_key)), State::Complete => { panic!("Saw new group after the end of input"); } }; - // Find latest sort key - let iter = group_indices.iter().zip(sort_keys.iter()); - for (&group_index, group_sort_key) in iter { - // Does this group have seen a new sort_key? - if sort_key != group_sort_key { - current_sort = group_index; - sort_key = group_sort_key; - } - } + // Select the sort key columns + let sort_keys = self.compute_sort_keys(batch_group_values); + + // Check if the sort keys indicate a boundary inside the batch + let ranges = partition(&sort_keys)?.ranges(); + let last_range = ranges.last().unwrap(); + + let range_current_sort = group_indices[last_range.start]; + let range_sort_key = get_row_at_idx(&sort_keys, last_range.start)?; + + let (current_sort, sort_key) = if last_range.start == 0 { + // There was no boundary in the batch. Compare with the previous sort_key (if present) + // to check if there was a boundary between the current batch and the previous one. + Self::updated_sort_key( + current_sort, + sort_key, + range_current_sort, + range_sort_key, + )? + } else { + (range_current_sort, range_sort_key) + }; self.state = State::InProgress { current_sort, - sort_key: sort_key.owned(), current: max_group_index, + sort_key, }; Ok(()) @@ -245,8 +270,104 @@ impl GroupOrderingPartial { /// Return the size of memory allocated by this structure pub(crate) fn size(&self) -> usize { - size_of::() - + self.order_indices.allocated_size() - + self.row_converter.size() + size_of::() + self.order_indices.allocated_size() + self.state.size() + } +} + +#[cfg(test)] +mod tests { + use arrow::array::Int32Array; + use arrow_schema::{DataType, Field}; + use datafusion_physical_expr::{expressions::col, PhysicalSortExpr}; + + use super::*; + + #[test] + fn test_group_ordering_partial() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + + // Ordered on column a + let order_indices = vec![0]; + + let ordering = LexOrdering::new(vec![PhysicalSortExpr::new( + col("a", &schema)?, + SortOptions::default(), + )]); + + let mut group_ordering = + GroupOrderingPartial::try_new(&schema, &order_indices, &ordering)?; + + let batch_group_values: Vec = vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(Int32Array::from(vec![2, 1, 3])), + ]; + + let group_indices = vec![0, 1, 2]; + let total_num_groups = 3; + + group_ordering.new_groups( + &batch_group_values, + &group_indices, + total_num_groups, + )?; + + assert_eq!( + group_ordering.state, + State::InProgress { + current_sort: 2, + sort_key: vec![ScalarValue::Int32(Some(3))], + current: 2 + } + ); + + // push without a boundary + let batch_group_values: Vec = vec![ + Arc::new(Int32Array::from(vec![3, 3, 3])), + Arc::new(Int32Array::from(vec![2, 1, 7])), + ]; + let group_indices = vec![3, 4, 5]; + let total_num_groups = 6; + + group_ordering.new_groups( + &batch_group_values, + &group_indices, + total_num_groups, + )?; + + assert_eq!( + group_ordering.state, + State::InProgress { + current_sort: 2, + sort_key: vec![ScalarValue::Int32(Some(3))], + current: 5 + } + ); + + // push with only a boundary to previous batch + let batch_group_values: Vec = vec![ + Arc::new(Int32Array::from(vec![4, 4, 4])), + Arc::new(Int32Array::from(vec![1, 1, 1])), + ]; + let group_indices = vec![6, 7, 8]; + let total_num_groups = 9; + + group_ordering.new_groups( + &batch_group_values, + &group_indices, + total_num_groups, + )?; + assert_eq!( + group_ordering.state, + State::InProgress { + current_sort: 6, + sort_key: vec![ScalarValue::Int32(Some(4))], + current: 8 + } + ); + + Ok(()) } } diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index cc95ce51c15b..05122d5a5403 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -36,8 +36,8 @@ use crate::{aggregates, metrics, ExecutionPlan, PhysicalExpr}; use crate::{RecordBatchStream, SendableRecordBatchStream}; use arrow::array::*; +use arrow::compute::SortOptions; use arrow::datatypes::SchemaRef; -use arrow_schema::SortOptions; use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::proxy::VecAllocExt; @@ -632,16 +632,6 @@ pub(crate) fn create_group_accumulator( } } -/// Extracts a successful Ok(_) or returns Poll::Ready(Some(Err(e))) with errors -macro_rules! extract_ok { - ($RES: expr) => {{ - match $RES { - Ok(v) => v, - Err(e) => return Poll::Ready(Some(Err(e))), - } - }}; -} - impl Stream for GroupedHashAggregateStream { type Item = Result; @@ -661,7 +651,7 @@ impl Stream for GroupedHashAggregateStream { let input_rows = batch.num_rows(); // Do the grouping - extract_ok!(self.group_aggregate_batch(batch)); + self.group_aggregate_batch(batch)?; self.update_skip_aggregation_probe(input_rows); @@ -673,16 +663,14 @@ impl Stream for GroupedHashAggregateStream { // emit all groups and switch to producing output if self.hit_soft_group_limit() { timer.done(); - extract_ok!(self.set_input_done_and_produce_output()); + self.set_input_done_and_produce_output()?; // make sure the exec_state just set is not overwritten below break 'reading_input; } if let Some(to_emit) = self.group_ordering.emit_to() { timer.done(); - if let Some(batch) = - extract_ok!(self.emit(to_emit, false)) - { + if let Some(batch) = self.emit(to_emit, false)? { self.exec_state = ExecutionState::ProducingOutput(batch); }; @@ -690,9 +678,9 @@ impl Stream for GroupedHashAggregateStream { break 'reading_input; } - extract_ok!(self.emit_early_if_necessary()); + self.emit_early_if_necessary()?; - extract_ok!(self.switch_to_skip_aggregation()); + self.switch_to_skip_aggregation()?; timer.done(); } @@ -703,10 +691,10 @@ impl Stream for GroupedHashAggregateStream { let timer = elapsed_compute.timer(); // Make sure we have enough capacity for `batch`, otherwise spill - extract_ok!(self.spill_previous_if_necessary(&batch)); + self.spill_previous_if_necessary(&batch)?; // Do the grouping - extract_ok!(self.group_aggregate_batch(batch)); + self.group_aggregate_batch(batch)?; // If we can begin emitting rows, do so, // otherwise keep consuming input @@ -716,16 +704,14 @@ impl Stream for GroupedHashAggregateStream { // emit all groups and switch to producing output if self.hit_soft_group_limit() { timer.done(); - extract_ok!(self.set_input_done_and_produce_output()); + self.set_input_done_and_produce_output()?; // make sure the exec_state just set is not overwritten below break 'reading_input; } if let Some(to_emit) = self.group_ordering.emit_to() { timer.done(); - if let Some(batch) = - extract_ok!(self.emit(to_emit, false)) - { + if let Some(batch) = self.emit(to_emit, false)? { self.exec_state = ExecutionState::ProducingOutput(batch); }; @@ -745,7 +731,7 @@ impl Stream for GroupedHashAggregateStream { // Found end from input stream None => { // inner is done, emit all rows and switch to producing output - extract_ok!(self.set_input_done_and_produce_output()); + self.set_input_done_and_produce_output()?; } } } diff --git a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs index 8c7ba21b37c0..c818b4608de7 100644 --- a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs +++ b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs @@ -25,8 +25,7 @@ use arrow::array::{ builder::PrimitiveBuilder, cast::AsArray, downcast_primitive, Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray, StringArray, }; -use arrow::datatypes::i256; -use arrow_schema::DataType; +use arrow::datatypes::{i256, DataType}; use datafusion_common::DataFusionError; use datafusion_common::Result; use half::f16; diff --git a/datafusion/physical-plan/src/aggregates/topk/heap.rs b/datafusion/physical-plan/src/aggregates/topk/heap.rs index 09dae3df0a96..b202f812c6e1 100644 --- a/datafusion/physical-plan/src/aggregates/topk/heap.rs +++ b/datafusion/physical-plan/src/aggregates/topk/heap.rs @@ -23,8 +23,7 @@ use arrow::array::{ }; use arrow::array::{downcast_primitive, ArrayRef, ArrowPrimitiveType, PrimitiveArray}; use arrow::buffer::ScalarBuffer; -use arrow::datatypes::i256; -use arrow_schema::DataType; +use arrow::datatypes::{i256, DataType}; use datafusion_common::DataFusionError; use datafusion_common::Result; diff --git a/datafusion/physical-plan/src/aggregates/topk/priority_map.rs b/datafusion/physical-plan/src/aggregates/topk/priority_map.rs index 3cb12f0af089..3b954c4c72d3 100644 --- a/datafusion/physical-plan/src/aggregates/topk/priority_map.rs +++ b/datafusion/physical-plan/src/aggregates/topk/priority_map.rs @@ -20,7 +20,7 @@ use crate::aggregates::topk::hash_table::{new_hash_table, ArrowHashTable}; use crate::aggregates::topk::heap::{new_heap, ArrowHeap}; use arrow::array::ArrayRef; -use arrow_schema::DataType; +use arrow::datatypes::DataType; use datafusion_common::Result; /// A `Map` / `PriorityQueue` combo that evicts the worst values after reaching `capacity` @@ -109,10 +109,8 @@ impl PriorityMap { mod tests { use super::*; use arrow::array::{Int64Array, RecordBatch, StringArray}; + use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::util::pretty::pretty_format_batches; - use arrow_schema::Field; - use arrow_schema::Schema; - use arrow_schema::SchemaRef; use std::sync::Arc; #[test] diff --git a/datafusion/physical-plan/src/aggregates/topk_stream.rs b/datafusion/physical-plan/src/aggregates/topk_stream.rs index 8a984fc0d27b..bf02692486cc 100644 --- a/datafusion/physical-plan/src/aggregates/topk_stream.rs +++ b/datafusion/physical-plan/src/aggregates/topk_stream.rs @@ -24,8 +24,8 @@ use crate::aggregates::{ }; use crate::{RecordBatchStream, SendableRecordBatchStream}; use arrow::array::{Array, ArrayRef, RecordBatch}; +use arrow::datatypes::SchemaRef; use arrow::util::pretty::print_batches; -use arrow_schema::SchemaRef; use datafusion_common::DataFusionError; use datafusion_common::Result; use datafusion_execution::TaskContext; diff --git a/datafusion/physical-plan/src/coalesce/mod.rs b/datafusion/physical-plan/src/coalesce/mod.rs index ed60a350e300..eb4a7d875c95 100644 --- a/datafusion/physical-plan/src/coalesce/mod.rs +++ b/datafusion/physical-plan/src/coalesce/mod.rs @@ -20,7 +20,7 @@ use arrow::array::{ RecordBatchOptions, }; use arrow::compute::concat_batches; -use arrow_schema::SchemaRef; +use arrow::datatypes::SchemaRef; use std::sync::Arc; /// Concatenate multiple [`RecordBatch`]es diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index 541f8bcae122..b83641acf2ce 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -253,7 +253,7 @@ impl IPCWriter { /// Checks if the given projection is valid for the given schema. pub fn can_project( - schema: &arrow_schema::SchemaRef, + schema: &arrow::datatypes::SchemaRef, projection: Option<&Vec>, ) -> Result<()> { match projection { @@ -263,7 +263,7 @@ pub fn can_project( .max() .is_some_and(|&i| i >= schema.fields().len()) { - Err(arrow_schema::ArrowError::SchemaError(format!( + Err(arrow::error::ArrowError::SchemaError(format!( "project index {} out of bounds, max field {}", columns.iter().max().unwrap(), schema.fields().len() diff --git a/datafusion/physical-plan/src/display.rs b/datafusion/physical-plan/src/display.rs index dbf82eee05eb..0cc1cb02438a 100644 --- a/datafusion/physical-plan/src/display.rs +++ b/datafusion/physical-plan/src/display.rs @@ -21,7 +21,7 @@ use std::fmt; use std::fmt::Formatter; -use arrow_schema::SchemaRef; +use arrow::datatypes::SchemaRef; use datafusion_common::display::{GraphvizBuilder, PlanType, StringifiedPlan}; use datafusion_expr::display_schema; @@ -44,7 +44,7 @@ pub enum DisplayFormatType { /// # Example /// ``` /// # use std::sync::Arc; -/// # use arrow_schema::{Field, Schema, DataType}; +/// # use arrow::datatypes::{Field, Schema, DataType}; /// # use datafusion_expr::Operator; /// # use datafusion_physical_expr::expressions::{binary, col, lit}; /// # use datafusion_physical_plan::{displayable, ExecutionPlan}; diff --git a/datafusion/physical-plan/src/execution_plan.rs b/datafusion/physical-plan/src/execution_plan.rs index 2c9dde65aeba..81d39e56e713 100644 --- a/datafusion/physical-plan/src/execution_plan.rs +++ b/datafusion/physical-plan/src/execution_plan.rs @@ -285,7 +285,7 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// ``` /// # use std::sync::Arc; /// # use arrow::array::RecordBatch; - /// # use arrow_schema::SchemaRef; + /// # use arrow::datatypes::SchemaRef; /// # use datafusion_common::Result; /// # use datafusion_execution::{SendableRecordBatchStream, TaskContext}; /// # use datafusion_physical_plan::memory::MemoryStream; @@ -315,7 +315,7 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// ``` /// # use std::sync::Arc; /// # use arrow::array::RecordBatch; - /// # use arrow_schema::SchemaRef; + /// # use arrow::datatypes::SchemaRef; /// # use datafusion_common::Result; /// # use datafusion_execution::{SendableRecordBatchStream, TaskContext}; /// # use datafusion_physical_plan::memory::MemoryStream; @@ -350,7 +350,7 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// ``` /// # use std::sync::Arc; /// # use arrow::array::RecordBatch; - /// # use arrow_schema::SchemaRef; + /// # use arrow::datatypes::SchemaRef; /// # use futures::TryStreamExt; /// # use datafusion_common::Result; /// # use datafusion_execution::{SendableRecordBatchStream, TaskContext}; @@ -1116,7 +1116,7 @@ impl RequiredInputOrdering { mod tests { use super::*; use arrow::array::{DictionaryArray, Int32Array, NullArray, RunArray}; - use arrow_schema::{DataType, Field, Schema, SchemaRef}; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use std::any::Any; use std::sync::Arc; diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 91c44a4139d2..5866f0938e41 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -625,8 +625,7 @@ mod tests { use crate::test; use crate::test::exec::StatisticsExec; - use arrow::datatypes::{Field, Schema}; - use arrow_schema::{UnionFields, UnionMode}; + use arrow::datatypes::{Field, Schema, UnionFields, UnionMode}; use datafusion_common::ScalarValue; #[tokio::test] diff --git a/datafusion/physical-plan/src/insert.rs b/datafusion/physical-plan/src/insert.rs index 151a4ef7a02a..719e9d9f2c07 100644 --- a/datafusion/physical-plan/src/insert.rs +++ b/datafusion/physical-plan/src/insert.rs @@ -32,8 +32,7 @@ use crate::stream::RecordBatchStreamAdapter; use crate::ExecutionPlanProperties; use arrow::array::{ArrayRef, RecordBatch, UInt64Array}; -use arrow::datatypes::SchemaRef; -use arrow_schema::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::{internal_err, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::{Distribution, EquivalenceProperties}; diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 6cdd91bb1721..2983478ada74 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -64,9 +64,9 @@ use arrow::array::{ use arrow::compute::kernels::cmp::{eq, not_distinct}; use arrow::compute::{and, concat_batches, take, FilterBuilder}; use arrow::datatypes::{Schema, SchemaRef}; +use arrow::error::ArrowError; use arrow::record_batch::RecordBatch; use arrow::util::bit_util; -use arrow_schema::ArrowError; use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::{ internal_datafusion_err, internal_err, plan_err, project_schema, DataFusionError, diff --git a/datafusion/physical-plan/src/joins/join_filter.rs b/datafusion/physical-plan/src/joins/join_filter.rs index cfc7ad2c10e0..0e46a971d90b 100644 --- a/datafusion/physical-plan/src/joins/join_filter.rs +++ b/datafusion/physical-plan/src/joins/join_filter.rs @@ -16,7 +16,7 @@ // under the License. use crate::joins::utils::ColumnIndex; -use arrow_schema::SchemaRef; +use arrow::datatypes::SchemaRef; use datafusion_common::JoinSide; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::sync::Arc; diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 07289d861bcf..6de6b3b4dff4 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -1037,8 +1037,8 @@ pub(crate) mod tests { }; use arrow::array::Int32Array; + use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field}; - use arrow_schema::SortOptions; use datafusion_common::{assert_batches_sorted_eq, assert_contains, ScalarValue}; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_expr::Operator; diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index a3b3a37aa7ef..61a71315846c 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -30,8 +30,7 @@ use arrow::array::{ ArrowPrimitiveType, BooleanBufferBuilder, NativeAdapter, PrimitiveArray, RecordBatch, }; use arrow::compute::concat_batches; -use arrow::datatypes::ArrowNativeType; -use arrow_schema::{Schema, SchemaRef}; +use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{ arrow_datafusion_err, DataFusionError, HashSet, JoinSide, Result, ScalarValue, diff --git a/datafusion/physical-plan/src/joins/test_utils.rs b/datafusion/physical-plan/src/joins/test_utils.rs index 9e34c27ee7f4..9932c647be0a 100644 --- a/datafusion/physical-plan/src/joins/test_utils.rs +++ b/datafusion/physical-plan/src/joins/test_utils.rs @@ -32,8 +32,8 @@ use arrow::array::{ types::IntervalDayTime, ArrayRef, Float64Array, Int32Array, IntervalDayTimeArray, RecordBatch, TimestampMillisecondArray, }; +use arrow::datatypes::{DataType, Schema}; use arrow::util::pretty::pretty_format_batches; -use arrow_schema::{DataType, Schema}; use datafusion_common::{Result, ScalarValue}; use datafusion_execution::TaskContext; use datafusion_expr::{JoinType, Operator}; diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index bccfd2a69383..00edda8fa3a8 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -624,7 +624,7 @@ pub fn build_join_schema( JoinType::LeftSemi | JoinType::LeftAnti => left_fields().unzip(), JoinType::LeftMark => { let right_field = once(( - Field::new("mark", arrow_schema::DataType::Boolean, false), + Field::new("mark", arrow::datatypes::DataType::Boolean, false), ColumnIndex { index: 0, side: JoinSide::None, @@ -1822,9 +1822,9 @@ mod tests { use std::pin::Pin; use arrow::array::Int32Array; + use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Fields}; use arrow::error::{ArrowError, Result as ArrowResult}; - use arrow_schema::SortOptions; use datafusion_common::stats::Precision::{Absent, Exact, Inexact}; use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue}; diff --git a/datafusion/physical-plan/src/limit.rs b/datafusion/physical-plan/src/limit.rs index 15f19f6456a5..f720294c7ad9 100644 --- a/datafusion/physical-plan/src/limit.rs +++ b/datafusion/physical-plan/src/limit.rs @@ -481,7 +481,7 @@ mod tests { use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use arrow::array::RecordBatchOptions; - use arrow_schema::Schema; + use arrow::datatypes::Schema; use datafusion_common::stats::Precision; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::PhysicalExpr; diff --git a/datafusion/physical-plan/src/memory.rs b/datafusion/physical-plan/src/memory.rs index 3d161a047853..0077804bdfc9 100644 --- a/datafusion/physical-plan/src/memory.rs +++ b/datafusion/physical-plan/src/memory.rs @@ -34,8 +34,7 @@ use crate::projection::{ use crate::source::{DataSource, DataSourceExec}; use arrow::array::{RecordBatch, RecordBatchOptions}; -use arrow::datatypes::SchemaRef; -use arrow_schema::Schema; +use arrow::datatypes::{Schema, SchemaRef}; use datafusion_common::{ internal_err, plan_err, project_schema, Constraints, Result, ScalarValue, }; @@ -971,7 +970,8 @@ mod memory_exec_tests { use crate::source::DataSourceExec; use crate::ExecutionPlan; - use arrow_schema::{DataType, Field, Schema, SortOptions}; + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema}; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::PhysicalSortExpr; use datafusion_physical_expr_common::sort_expr::LexOrdering; @@ -1144,7 +1144,7 @@ mod tests { use crate::expressions::lit; use crate::test::{self, make_partition}; - use arrow_schema::{DataType, Field}; + use arrow::datatypes::{DataType, Field}; use datafusion_common::assert_batches_eq; use datafusion_common::stats::{ColumnStatistics, Precision}; use futures::StreamExt; diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index 3ebfd8f8ca80..08c4d24f4c7f 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -1003,7 +1003,7 @@ mod tests { use crate::common::collect; use crate::test; - use arrow_schema::DataType; + use arrow::datatypes::DataType; use datafusion_common::ScalarValue; use datafusion_expr::Operator; diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index ffd1a5b520fa..25668fa67d5b 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -1600,7 +1600,8 @@ mod tests { #[cfg(test)] mod test { - use arrow_schema::{DataType, Field, Schema, SortOptions}; + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema}; use super::*; use crate::memory::MemorySourceConfig; diff --git a/datafusion/physical-plan/src/sorts/cursor.rs b/datafusion/physical-plan/src/sorts/cursor.rs index e6986b86046c..8ea7c43d2613 100644 --- a/datafusion/physical-plan/src/sorts/cursor.rs +++ b/datafusion/physical-plan/src/sorts/cursor.rs @@ -291,6 +291,10 @@ pub struct ArrayValues { // Otherwise, the first null index null_threshold: usize, options: SortOptions, + + /// Tracks the memory used by the values array, + /// freed on drop. + _reservation: MemoryReservation, } impl ArrayValues { @@ -298,7 +302,11 @@ impl ArrayValues { /// to `options`. /// /// Panics if the array is empty - pub fn new>(options: SortOptions, array: &A) -> Self { + pub fn new>( + options: SortOptions, + array: &A, + reservation: MemoryReservation, + ) -> Self { assert!(array.len() > 0, "Empty array passed to FieldCursor"); let null_threshold = match options.nulls_first { true => array.null_count(), @@ -309,6 +317,7 @@ impl ArrayValues { values: array.values(), null_threshold, options, + _reservation: reservation, } } @@ -360,6 +369,12 @@ impl CursorValues for ArrayValues { #[cfg(test)] mod tests { + use std::sync::Arc; + + use datafusion_execution::memory_pool::{ + GreedyMemoryPool, MemoryConsumer, MemoryPool, + }; + use super::*; fn new_primitive( @@ -372,10 +387,15 @@ mod tests { false => values.len() - null_count, }; + let memory_pool: Arc = Arc::new(GreedyMemoryPool::new(10000)); + let consumer = MemoryConsumer::new("test"); + let reservation = consumer.register(&memory_pool); + let values = ArrayValues { values: PrimitiveValues(values), null_threshold, options, + _reservation: reservation, }; Cursor::new(values) diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 6c538801d71a..649468260e56 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -24,7 +24,7 @@ use std::fmt; use std::fmt::{Debug, Formatter}; use std::sync::Arc; -use crate::common::spawn_buffered; +use crate::common::{spawn_buffered, IPCWriter}; use crate::execution_plan::{Boundedness, CardinalityEffect, EmissionType}; use crate::expressions::PhysicalSortExpr; use crate::limit::LimitStream; @@ -46,9 +46,8 @@ use crate::{ use arrow::array::{Array, RecordBatch, RecordBatchOptions, UInt32Array}; use arrow::compute::{concat_batches, lexsort_to_indices, take_arrays, SortColumn}; -use arrow::datatypes::SchemaRef; +use arrow::datatypes::{DataType, SchemaRef}; use arrow::row::{RowConverter, SortField}; -use arrow_schema::DataType; use datafusion_common::{internal_err, Result}; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; @@ -280,7 +279,7 @@ impl ExternalSorter { Self { schema, in_mem_batches: vec![], - in_mem_batches_sorted: true, + in_mem_batches_sorted: false, spills: vec![], expr: expr.into(), metrics, @@ -303,27 +302,13 @@ impl ExternalSorter { } self.reserve_memory_for_merge()?; - let size = get_record_batch_memory_size(&input); - + let size = get_reserved_byte_for_record_batch(&input); if self.reservation.try_grow(size).is_err() { - let before = self.reservation.size(); - self.in_mem_sort().await?; - - // Sorting may have freed memory, especially if fetch is `Some` - // - // As such we check again, and if the memory usage has dropped by - // a factor of 2, and we can allocate the necessary capacity, - // we don't spill - // - // The factor of 2 aims to avoid a degenerate case where the - // memory required for `fetch` is just under the memory available, - // causing repeated re-sorting of data - if self.reservation.size() > before / 2 - || self.reservation.try_grow(size).is_err() - { - self.spill().await?; - self.reservation.try_grow(size)? - } + self.sort_or_spill_in_mem_batches().await?; + // We've already freed more than half of reserved memory, + // so we can grow the reservation again. There's nothing we can do + // if this try_grow fails. + self.reservation.try_grow(size)?; } self.in_mem_batches.push(input); @@ -345,6 +330,11 @@ impl ExternalSorter { /// 2. A combined streaming merge incorporating both in-memory /// batches and data from spill files on disk. fn sort(&mut self) -> Result { + // Release the memory reserved for merge back to the pool so + // there is some left when `in_mem_sort_stream` requests an + // allocation. + self.merge_reservation.free(); + if self.spilled_before() { let mut streams = vec![]; if !self.in_mem_batches.is_empty() { @@ -370,7 +360,7 @@ impl ExternalSorter { .with_metrics(self.metrics.baseline.clone()) .with_batch_size(self.batch_size) .with_fetch(self.fetch) - .with_reservation(self.reservation.new_empty()) + .with_reservation(self.merge_reservation.new_empty()) .build() } else { self.in_mem_sort_stream(self.metrics.baseline.clone()) @@ -409,50 +399,102 @@ impl ExternalSorter { debug!("Spilling sort data of ExternalSorter to disk whilst inserting"); - self.in_mem_sort().await?; - let spill_file = self.runtime.disk_manager.create_tmp_file("Sorting")?; let batches = std::mem::take(&mut self.in_mem_batches); - let spilled_rows = spill_record_batches( + let (spilled_rows, spilled_bytes) = spill_record_batches( batches, spill_file.path().into(), Arc::clone(&self.schema), )?; let used = self.reservation.free(); self.metrics.spill_count.add(1); - self.metrics.spilled_bytes.add(used); + self.metrics.spilled_bytes.add(spilled_bytes); self.metrics.spilled_rows.add(spilled_rows); self.spills.push(spill_file); Ok(used) } /// Sorts the in_mem_batches in place - async fn in_mem_sort(&mut self) -> Result<()> { - if self.in_mem_batches_sorted { - return Ok(()); - } - + /// + /// Sorting may have freed memory, especially if fetch is `Some`. If + /// the memory usage has dropped by a factor of 2, then we don't have + /// to spill. Otherwise, we spill to free up memory for inserting + /// more batches. + /// + /// The factor of 2 aims to avoid a degenerate case where the + /// memory required for `fetch` is just under the memory available, + // causing repeated re-sorting of data + async fn sort_or_spill_in_mem_batches(&mut self) -> Result<()> { // Release the memory reserved for merge back to the pool so - // there is some left when `in_memo_sort_stream` requests an - // allocation. + // there is some left when `in_mem_sort_stream` requests an + // allocation. At the end of this function, memory will be + // reserved again for the next spill. self.merge_reservation.free(); - self.in_mem_batches = self - .in_mem_sort_stream(self.metrics.baseline.intermediate())? - .try_collect() - .await?; + let before = self.reservation.size(); + + let mut sorted_stream = + self.in_mem_sort_stream(self.metrics.baseline.intermediate())?; + + // `self.in_mem_batches` is already taken away by the sort_stream, now it is empty. + // We'll gradually collect the sorted stream into self.in_mem_batches, or directly + // write sorted batches to disk when the memory is insufficient. + let mut spill_writer: Option = None; + while let Some(batch) = sorted_stream.next().await { + let batch = batch?; + match &mut spill_writer { + None => { + let sorted_size = get_reserved_byte_for_record_batch(&batch); + if self.reservation.try_grow(sorted_size).is_err() { + // Directly write in_mem_batches as well as all the remaining batches in + // sorted_stream to disk. Further batches fetched from `sorted_stream` will + // be handled by the `Some(writer)` matching arm. + let spill_file = + self.runtime.disk_manager.create_tmp_file("Sorting")?; + let mut writer = IPCWriter::new(spill_file.path(), &self.schema)?; + // Flush everything in memory to the spill file + for batch in self.in_mem_batches.drain(..) { + writer.write(&batch)?; + } + // as well as the newly sorted batch + writer.write(&batch)?; + spill_writer = Some(writer); + self.reservation.free(); + self.spills.push(spill_file); + } else { + self.in_mem_batches.push(batch); + self.in_mem_batches_sorted = true; + } + } + Some(writer) => { + writer.write(&batch)?; + } + } + } + + // Drop early to free up memory reserved by the sorted stream, otherwise the + // upcoming `self.reserve_memory_for_merge()` may fail due to insufficient memory. + drop(sorted_stream); - let size: usize = self - .in_mem_batches - .iter() - .map(get_record_batch_memory_size) - .sum(); + if let Some(writer) = &mut spill_writer { + writer.finish()?; + self.metrics.spill_count.add(1); + self.metrics.spilled_rows.add(writer.num_rows); + self.metrics.spilled_bytes.add(writer.num_bytes); + } + + // Sorting may free up some memory especially when fetch is `Some`. If we have + // not freed more than 50% of the memory, then we have to spill to free up more + // memory for inserting more batches. + if spill_writer.is_none() && self.reservation.size() > before / 2 { + // We have not freed more than 50% of the memory, so we have to spill to + // free up more memory + self.spill().await?; + } // Reserve headroom for next sort/merge self.reserve_memory_for_merge()?; - self.reservation.try_resize(size)?; - self.in_mem_batches_sorted = true; Ok(()) } @@ -529,6 +571,12 @@ impl ExternalSorter { let elapsed_compute = metrics.elapsed_compute().clone(); let _timer = elapsed_compute.timer(); + // Please pay attention that any operation inside of `in_mem_sort_stream` will + // not perform any memory reservation. This is for avoiding the need of handling + // reservation failure and spilling in the middle of the sort/merge. The memory + // space for batches produced by the resulting stream will be reserved by the + // consumer of the stream. + if self.in_mem_batches.len() == 1 { let batch = self.in_mem_batches.swap_remove(0); let reservation = self.reservation.take(); @@ -541,7 +589,7 @@ impl ExternalSorter { let batch = concat_batches(&self.schema, &self.in_mem_batches)?; self.in_mem_batches.clear(); self.reservation - .try_resize(get_record_batch_memory_size(&batch))?; + .try_resize(get_reserved_byte_for_record_batch(&batch))?; let reservation = self.reservation.take(); return self.sort_batch_stream(batch, metrics, reservation); } @@ -550,8 +598,9 @@ impl ExternalSorter { .into_iter() .map(|batch| { let metrics = self.metrics.baseline.intermediate(); - let reservation = - self.reservation.split(get_record_batch_memory_size(&batch)); + let reservation = self + .reservation + .split(get_reserved_byte_for_record_batch(&batch)); let input = self.sort_batch_stream(batch, metrics, reservation)?; Ok(spawn_buffered(input, 1)) }) @@ -580,7 +629,10 @@ impl ExternalSorter { metrics: BaselineMetrics, reservation: MemoryReservation, ) -> Result { - assert_eq!(get_record_batch_memory_size(&batch), reservation.size()); + assert_eq!( + get_reserved_byte_for_record_batch(&batch), + reservation.size() + ); let schema = batch.schema(); let fetch = self.fetch; @@ -613,6 +665,20 @@ impl ExternalSorter { } } +/// Estimate how much memory is needed to sort a `RecordBatch`. +/// +/// This is used to pre-reserve memory for the sort/merge. The sort/merge process involves +/// creating sorted copies of sorted columns in record batches for speeding up comparison +/// in sorting and merging. The sorted copies are in either row format or array format. +/// Please refer to cursor.rs and stream.rs for more details. No matter what format the +/// sorted copies are, they will use more memory than the original record batch. +fn get_reserved_byte_for_record_batch(batch: &RecordBatch) -> usize { + // 2x may not be enough for some cases, but it's a good start. + // If 2x is not enough, user can set a larger value for `sort_spill_reservation_bytes` + // to compensate for the extra memory needed. + get_record_batch_memory_size(batch) * 2 +} + impl Debug for ExternalSorter { fn fmt(&self, f: &mut Formatter) -> fmt::Result { f.debug_struct("ExternalSorter") @@ -642,7 +708,15 @@ pub fn sort_batch( lexsort_to_indices(&sort_columns, fetch)? }; - let columns = take_arrays(batch.columns(), &indices, None)?; + let mut columns = take_arrays(batch.columns(), &indices, None)?; + + // The columns may be larger than the unsorted columns in `batch` especially for variable length + // data types due to exponential growth when building the sort columns. We shrink the columns + // to prevent memory reservation failures, as well as excessive memory allocation when running + // merges in `SortPreservingMergeStream`. + columns.iter_mut().for_each(|c| { + c.shrink_to_fit(); + }); let options = RecordBatchOptions::new().with_row_count(Some(indices.len())); Ok(RecordBatch::try_new_with_options( @@ -1247,6 +1321,9 @@ mod tests { .with_runtime(runtime), ); + // The input has 100 partitions, each partition has a batch containing 100 rows. + // Each row has a single Int32 column with values 0..100. The total size of the + // input is roughly 40000 bytes. let partitions = 100; let input = test::scan_partitioned(partitions); let schema = input.schema(); @@ -1272,9 +1349,16 @@ mod tests { assert_eq!(metrics.output_rows().unwrap(), 10000); assert!(metrics.elapsed_compute().unwrap() > 0); - assert_eq!(metrics.spill_count().unwrap(), 3); - assert_eq!(metrics.spilled_bytes().unwrap(), 36000); - assert_eq!(metrics.spilled_rows().unwrap(), 9000); + + let spill_count = metrics.spill_count().unwrap(); + let spilled_rows = metrics.spilled_rows().unwrap(); + let spilled_bytes = metrics.spilled_bytes().unwrap(); + // Processing 40000 bytes of data using 12288 bytes of memory requires 3 spills + // unless we do something really clever. It will spill roughly 9000+ rows and 36000 + // bytes. We leave a little wiggle room for the actual numbers. + assert!((3..=10).contains(&spill_count)); + assert!((9000..=10000).contains(&spilled_rows)); + assert!((36000..=40000).contains(&spilled_bytes)); let columns = result[0].columns(); @@ -1291,6 +1375,77 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_sort_spill_utf8_strings() -> Result<()> { + let session_config = SessionConfig::new() + .with_batch_size(100) + .with_sort_in_place_threshold_bytes(20 * 1024) + .with_sort_spill_reservation_bytes(100 * 1024); + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(500 * 1024, 1.0) + .build_arc()?; + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(runtime), + ); + + // The input has 200 partitions, each partition has a batch containing 100 rows. + // Each row has a single Utf8 column, the Utf8 string values are roughly 42 bytes. + // The total size of the input is roughly 8.4 KB. + let input = test::scan_partitioned_utf8(200); + let schema = input.schema(); + + let sort_exec = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr { + expr: col("i", &schema)?, + options: SortOptions::default(), + }]), + Arc::new(CoalescePartitionsExec::new(input)), + )); + + let result = collect( + Arc::clone(&sort_exec) as Arc, + Arc::clone(&task_ctx), + ) + .await?; + + let num_rows = result.iter().map(|batch| batch.num_rows()).sum::(); + assert_eq!(num_rows, 20000); + + // Now, validate metrics + let metrics = sort_exec.metrics().unwrap(); + + assert_eq!(metrics.output_rows().unwrap(), 20000); + assert!(metrics.elapsed_compute().unwrap() > 0); + + let spill_count = metrics.spill_count().unwrap(); + let spilled_rows = metrics.spilled_rows().unwrap(); + let spilled_bytes = metrics.spilled_bytes().unwrap(); + // Processing 840 KB of data using 400 KB of memory requires at least 2 spills + // It will spill roughly 18000 rows and 800 KBytes. + // We leave a little wiggle room for the actual numbers. + assert!((2..=10).contains(&spill_count)); + assert!((15000..=20000).contains(&spilled_rows)); + assert!((700000..=900000).contains(&spilled_bytes)); + + // Verify that the result is sorted + let concated_result = concat_batches(&schema, &result)?; + let columns = concated_result.columns(); + let string_array = as_string_array(&columns[0]); + for i in 0..string_array.len() - 1 { + assert!(string_array.value(i) <= string_array.value(i + 1)); + } + + assert_eq!( + task_ctx.runtime_env().memory_pool.reserved(), + 0, + "The sort should have returned all memory used back to the memory manager" + ); + + Ok(()) + } + #[tokio::test] async fn test_sort_fetch_memory_calculation() -> Result<()> { // This test mirrors down the size from the example above. diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index 03a40bace5fd..22b23e1ed619 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -399,8 +399,7 @@ mod tests { TimestampNanosecondArray, }; use arrow::compute::SortOptions; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow_schema::SchemaRef; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::{assert_batches_eq, assert_contains, DataFusionError}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::config::SessionConfig; diff --git a/datafusion/physical-plan/src/sorts/stream.rs b/datafusion/physical-plan/src/sorts/stream.rs index ab8054be59a8..e029c60b285b 100644 --- a/datafusion/physical-plan/src/sorts/stream.rs +++ b/datafusion/physical-plan/src/sorts/stream.rs @@ -159,6 +159,8 @@ pub struct FieldCursorStream { sort: PhysicalSortExpr, /// Input streams streams: FusedStreams, + /// Create new reservations for each array + reservation: MemoryReservation, phantom: PhantomData T>, } @@ -171,11 +173,16 @@ impl std::fmt::Debug for FieldCursorStream { } impl FieldCursorStream { - pub fn new(sort: PhysicalSortExpr, streams: Vec) -> Self { + pub fn new( + sort: PhysicalSortExpr, + streams: Vec, + reservation: MemoryReservation, + ) -> Self { let streams = streams.into_iter().map(|s| s.fuse()).collect(); Self { sort, streams: FusedStreams(streams), + reservation, phantom: Default::default(), } } @@ -183,8 +190,15 @@ impl FieldCursorStream { fn convert_batch(&mut self, batch: &RecordBatch) -> Result> { let value = self.sort.expr.evaluate(batch)?; let array = value.into_array(batch.num_rows())?; + let size_in_mem = array.get_buffer_memory_size(); let array = array.as_any().downcast_ref::().expect("field values"); - Ok(ArrayValues::new(self.sort.options, array)) + let mut array_reservation = self.reservation.new_empty(); + array_reservation.try_grow(size_in_mem)?; + Ok(ArrayValues::new( + self.sort.options, + array, + array_reservation, + )) } } diff --git a/datafusion/physical-plan/src/sorts/streaming_merge.rs b/datafusion/physical-plan/src/sorts/streaming_merge.rs index 909b5875c8c5..a541f79dc717 100644 --- a/datafusion/physical-plan/src/sorts/streaming_merge.rs +++ b/datafusion/physical-plan/src/sorts/streaming_merge.rs @@ -38,7 +38,8 @@ macro_rules! primitive_merge_helper { macro_rules! merge_helper { ($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident, $fetch:ident, $reservation:ident, $enable_round_robin_tie_breaker:ident) => {{ - let streams = FieldCursorStream::<$t>::new($sort, $streams); + let streams = + FieldCursorStream::<$t>::new($sort, $streams, $reservation.new_empty()); return Ok(Box::pin(SortPreservingMergeStream::new( Box::new(streams), $schema, diff --git a/datafusion/physical-plan/src/spill.rs b/datafusion/physical-plan/src/spill.rs index dbcc46baf8ca..b45353ae13f0 100644 --- a/datafusion/physical-plan/src/spill.rs +++ b/datafusion/physical-plan/src/spill.rs @@ -62,7 +62,7 @@ pub(crate) fn spill_record_batches( batches: Vec, path: PathBuf, schema: SchemaRef, -) -> Result { +) -> Result<(usize, usize)> { let mut writer = IPCWriter::new(path.as_ref(), schema.as_ref())?; for batch in batches { writer.write(&batch)?; @@ -74,7 +74,7 @@ pub(crate) fn spill_record_batches( writer.num_rows, human_readable_size(writer.num_bytes), ); - Ok(writer.num_rows) + Ok((writer.num_rows, writer.num_bytes)) } fn read_spill(sender: Sender>, path: &Path) -> Result<()> { @@ -213,12 +213,12 @@ mod tests { let spill_file = disk_manager.create_tmp_file("Test Spill")?; let schema = batch1.schema(); let num_rows = batch1.num_rows() + batch2.num_rows(); - let cnt = spill_record_batches( + let (spilled_rows, _) = spill_record_batches( vec![batch1, batch2], spill_file.path().into(), Arc::clone(&schema), - ); - assert_eq!(cnt.unwrap(), num_rows); + )?; + assert_eq!(spilled_rows, num_rows); let file = BufReader::new(File::open(spill_file.path())?); let reader = FileReader::try_new(file, None)?; diff --git a/datafusion/physical-plan/src/stream.rs b/datafusion/physical-plan/src/stream.rs index 5c941c76ae47..23cbb1ce49c1 100644 --- a/datafusion/physical-plan/src/stream.rs +++ b/datafusion/physical-plan/src/stream.rs @@ -512,7 +512,7 @@ mod test { assert_strong_count_converges_to_zero, BlockingExec, MockExec, PanicExec, }; - use arrow_schema::{DataType, Field, Schema}; + use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::exec_err; fn schema() -> SchemaRef { diff --git a/datafusion/physical-plan/src/streaming.rs b/datafusion/physical-plan/src/streaming.rs index 751af9921448..8bdfca2a8907 100644 --- a/datafusion/physical-plan/src/streaming.rs +++ b/datafusion/physical-plan/src/streaming.rs @@ -32,8 +32,7 @@ use crate::projection::{ use crate::stream::RecordBatchStreamAdapter; use crate::{ExecutionPlan, Partitioning, SendableRecordBatchStream}; -use arrow::datatypes::SchemaRef; -use arrow_schema::Schema; +use arrow::datatypes::{Schema, SchemaRef}; use datafusion_common::{internal_err, plan_err, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, LexOrdering, PhysicalSortExpr}; diff --git a/datafusion/physical-plan/src/test.rs b/datafusion/physical-plan/src/test.rs index 757e2df65831..ad0e43503b2b 100644 --- a/datafusion/physical-plan/src/test.rs +++ b/datafusion/physical-plan/src/test.rs @@ -21,7 +21,7 @@ use std::collections::HashMap; use std::pin::Pin; use std::sync::Arc; -use arrow::array::{ArrayRef, Int32Array, RecordBatch}; +use arrow::array::{Array, ArrayRef, Int32Array, RecordBatch}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use futures::{Future, FutureExt}; @@ -132,11 +132,30 @@ pub fn make_partition(sz: i32) -> RecordBatch { RecordBatch::try_new(schema, vec![arr]).unwrap() } +pub fn make_partition_utf8(sz: i32) -> RecordBatch { + let seq_start = 0; + let seq_end = sz; + let values = (seq_start..seq_end) + .map(|i| format!("test_long_string_that_is_roughly_42_bytes_{}", i)) + .collect::>(); + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Utf8, true)])); + let mut string_array = arrow::array::StringArray::from(values); + string_array.shrink_to_fit(); + let arr = Arc::new(string_array); + let arr = arr as ArrayRef; + + RecordBatch::try_new(schema, vec![arr]).unwrap() +} + /// Returns a `DataSourceExec` that scans `partitions` of 100 batches each pub fn scan_partitioned(partitions: usize) -> Arc { Arc::new(mem_exec(partitions)) } +pub fn scan_partitioned_utf8(partitions: usize) -> Arc { + Arc::new(mem_exec_utf8(partitions)) +} + /// Returns a `DataSourceExec` that scans `partitions` of 100 batches each pub fn mem_exec(partitions: usize) -> DataSourceExec { let data: Vec> = (0..partitions).map(|_| vec![make_partition(100)]).collect(); @@ -148,6 +167,18 @@ pub fn mem_exec(partitions: usize) -> DataSourceExec { )) } +pub fn mem_exec_utf8(partitions: usize) -> DataSourceExec { + let data: Vec> = (0..partitions) + .map(|_| vec![make_partition_utf8(100)]) + .collect(); + + let schema = data[0][0].schema(); + let projection = None; + DataSourceExec::new(Arc::new( + MemorySourceConfig::try_new(&data, schema, projection).unwrap(), + )) +} + // Construct a stream partition for test purposes #[derive(Debug)] pub struct TestPartitionStream { diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 4cc8fc8711de..85de1eefce2e 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -28,7 +28,7 @@ use super::metrics::{BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuil use crate::spill::get_record_batch_memory_size; use crate::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream}; use arrow::array::{Array, ArrayRef, RecordBatch}; -use arrow_schema::SchemaRef; +use arrow::datatypes::SchemaRef; use datafusion_common::HashMap; use datafusion_common::Result; use datafusion_execution::{ diff --git a/datafusion/physical-plan/src/tree_node.rs b/datafusion/physical-plan/src/tree_node.rs index 96bd0de3d37c..69b0a165315e 100644 --- a/datafusion/physical-plan/src/tree_node.rs +++ b/datafusion/physical-plan/src/tree_node.rs @@ -39,9 +39,17 @@ impl DynTreeNode for dyn ExecutionPlan { } } -/// A node object beneficial for writing optimizer rules, encapsulating an [`ExecutionPlan`] node with a payload. -/// Since there are two ways to access child plans—directly from the plan and through child nodes—it's recommended +/// A node context object beneficial for writing optimizer rules. +/// This context encapsulating an [`ExecutionPlan`] node with a payload. +/// +/// Since each wrapped node has it's children within both the [`PlanContext.plan.children()`], +/// as well as separately within the [`PlanContext.children`] (which are child nodes wrapped in the context), +/// it's important to keep these child plans in sync when performing mutations. +/// +/// Since there are two ways to access child plans directly -— it's recommended /// to perform mutable operations via [`Self::update_plan_from_children`]. +/// After mutating the `PlanContext.children`, or after creating the `PlanContext`, +/// call `update_plan_from_children` to sync. #[derive(Debug)] pub struct PlanContext { /// The execution plan associated with this context. @@ -61,6 +69,8 @@ impl PlanContext { } } + /// Update the [`PlanContext.plan.children()`] from the [`PlanContext.children`], + /// if the `PlanContext.children` have been changed. pub fn update_plan_from_children(mut self) -> Result { let children_plans = self.children.iter().map(|c| Arc::clone(&c.plan)).collect(); self.plan = with_new_children_if_necessary(self.plan, children_plans)?; diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 7e0f88784644..e1972d267b97 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -644,7 +644,8 @@ mod tests { use crate::test; use crate::source::DataSourceExec; - use arrow_schema::{DataType, SortOptions}; + use arrow::compute::SortOptions; + use arrow::datatypes::DataType; use datafusion_common::ScalarValue; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; diff --git a/datafusion/physical-plan/src/values.rs b/datafusion/physical-plan/src/values.rs index ea1086c0a3d6..6ab5cc84a21f 100644 --- a/datafusion/physical-plan/src/values.rs +++ b/datafusion/physical-plan/src/values.rs @@ -233,7 +233,7 @@ mod tests { use crate::expressions::lit; use crate::test::{self, make_partition}; - use arrow_schema::{DataType, Field}; + use arrow::datatypes::{DataType, Field}; use datafusion_common::stats::{ColumnStatistics, Precision}; #[tokio::test] diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 29efa7bf194b..ef880c8f4086 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -1204,7 +1204,8 @@ mod tests { builder::{Int64Builder, UInt64Builder}, RecordBatch, }; - use arrow_schema::{DataType, Field, Schema, SchemaRef, SortOptions}; + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::{ assert_batches_eq, exec_datafusion_err, Result, ScalarValue, }; diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index bec7b3dd8005..b930033643bd 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -26,8 +26,7 @@ use crate::{ InputOrderMode, PhysicalExpr, }; -use arrow::datatypes::Schema; -use arrow_schema::{DataType, Field, SchemaRef}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::{exec_err, Result}; use datafusion_expr::{ PartitionEvaluator, ReversedUDWF, WindowFrame, WindowFunctionDefinition, WindowUDF, diff --git a/datafusion/physical-plan/src/windows/utils.rs b/datafusion/physical-plan/src/windows/utils.rs index 13332ea82fa1..be38976b3551 100644 --- a/datafusion/physical-plan/src/windows/utils.rs +++ b/datafusion/physical-plan/src/windows/utils.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow_schema::{Schema, SchemaBuilder}; +use arrow::datatypes::{Schema, SchemaBuilder}; use datafusion_common::Result; use datafusion_physical_expr::window::WindowExpr; use std::sync::Arc; diff --git a/datafusion/proto-common/gen/Cargo.toml b/datafusion/proto-common/gen/Cargo.toml index 28e45fee9039..cfd3368b0c5e 100644 --- a/datafusion/proto-common/gen/Cargo.toml +++ b/datafusion/proto-common/gen/Cargo.toml @@ -35,4 +35,4 @@ workspace = true [dependencies] # Pin these dependencies so that the generated output is deterministic pbjson-build = "=0.7.0" -prost-build = "=0.13.4" +prost-build = "=0.13.5" diff --git a/datafusion/proto-common/proto/datafusion_common.proto b/datafusion/proto-common/proto/datafusion_common.proto index 1c2807f390bf..8e5d1283f838 100644 --- a/datafusion/proto-common/proto/datafusion_common.proto +++ b/datafusion/proto-common/proto/datafusion_common.proto @@ -108,8 +108,7 @@ message Field { // for complex data types like structs, unions repeated Field children = 4; map metadata = 5; - int64 dict_id = 6; - bool dict_ordered = 7; + bool dict_ordered = 6; } message Timestamp{ diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs index b022e52b6a6f..93547efeb51e 100644 --- a/datafusion/proto-common/src/from_proto/mod.rs +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -320,21 +320,8 @@ impl TryFrom<&protobuf::Field> for Field { type Error = Error; fn try_from(field: &protobuf::Field) -> Result { let datatype = field.arrow_type.as_deref().required("arrow_type")?; - let field = if field.dict_id != 0 { - // https://github.com/apache/datafusion/issues/14173 - #[allow(deprecated)] - Self::new_dict( - field.name.as_str(), - datatype, - field.nullable, - field.dict_id, - field.dict_ordered, - ) - .with_metadata(field.metadata.clone()) - } else { - Self::new(field.name.as_str(), datatype, field.nullable) - .with_metadata(field.metadata.clone()) - }; + let field = Self::new(field.name.as_str(), datatype, field.nullable) + .with_metadata(field.metadata.clone()); Ok(field) } } @@ -436,36 +423,18 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { let id = dict_batch.id(); - let fields_using_this_dictionary = { - // See https://github.com/apache/datafusion/issues/14173 - #[allow(deprecated)] - schema.fields_with_dict_id(id) - }; + let record_batch = read_record_batch( + &buffer, + dict_batch.data().unwrap(), + Arc::new(schema.clone()), + &Default::default(), + None, + &message.version(), + )?; - let first_field = fields_using_this_dictionary.first().ok_or_else(|| { - Error::General("dictionary id not found in schema while deserializing ScalarValue::List".to_string()) - })?; + let values: ArrayRef = Arc::clone(record_batch.column(0)); - let values: ArrayRef = match first_field.data_type() { - DataType::Dictionary(_, ref value_type) => { - // Make a fake schema for the dictionary batch. - let value = value_type.as_ref().clone(); - let schema = Schema::new(vec![Field::new("", value, true)]); - // Read a single column - let record_batch = read_record_batch( - &buffer, - dict_batch.data().unwrap(), - Arc::new(schema), - &Default::default(), - None, - &message.version(), - )?; - Ok(Arc::clone(record_batch.column(0))) - } - _ => Err(Error::General("dictionary id not found in schema while deserializing ScalarValue::List".to_string())), - }?; - - Ok((id,values)) + Ok((id, values)) }).collect::>>()?; let record_batch = read_record_batch( diff --git a/datafusion/proto-common/src/generated/pbjson.rs b/datafusion/proto-common/src/generated/pbjson.rs index 40687de098c1..8c0a9041ba2c 100644 --- a/datafusion/proto-common/src/generated/pbjson.rs +++ b/datafusion/proto-common/src/generated/pbjson.rs @@ -3107,9 +3107,6 @@ impl serde::Serialize for Field { if !self.metadata.is_empty() { len += 1; } - if self.dict_id != 0 { - len += 1; - } if self.dict_ordered { len += 1; } @@ -3129,11 +3126,6 @@ impl serde::Serialize for Field { if !self.metadata.is_empty() { struct_ser.serialize_field("metadata", &self.metadata)?; } - if self.dict_id != 0 { - #[allow(clippy::needless_borrow)] - #[allow(clippy::needless_borrows_for_generic_args)] - struct_ser.serialize_field("dictId", ToString::to_string(&self.dict_id).as_str())?; - } if self.dict_ordered { struct_ser.serialize_field("dictOrdered", &self.dict_ordered)?; } @@ -3141,7 +3133,6 @@ impl serde::Serialize for Field { } } impl<'de> serde::Deserialize<'de> for Field { - #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, @@ -3153,8 +3144,6 @@ impl<'de> serde::Deserialize<'de> for Field { "nullable", "children", "metadata", - "dict_id", - "dictId", "dict_ordered", "dictOrdered", ]; @@ -3166,7 +3155,6 @@ impl<'de> serde::Deserialize<'de> for Field { Nullable, Children, Metadata, - DictId, DictOrdered, } impl<'de> serde::Deserialize<'de> for GeneratedField { @@ -3194,7 +3182,6 @@ impl<'de> serde::Deserialize<'de> for Field { "nullable" => Ok(GeneratedField::Nullable), "children" => Ok(GeneratedField::Children), "metadata" => Ok(GeneratedField::Metadata), - "dictId" | "dict_id" => Ok(GeneratedField::DictId), "dictOrdered" | "dict_ordered" => Ok(GeneratedField::DictOrdered), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } @@ -3220,7 +3207,6 @@ impl<'de> serde::Deserialize<'de> for Field { let mut nullable__ = None; let mut children__ = None; let mut metadata__ = None; - let mut dict_id__ = None; let mut dict_ordered__ = None; while let Some(k) = map_.next_key()? { match k { @@ -3256,14 +3242,6 @@ impl<'de> serde::Deserialize<'de> for Field { map_.next_value::>()? ); } - GeneratedField::DictId => { - if dict_id__.is_some() { - return Err(serde::de::Error::duplicate_field("dictId")); - } - dict_id__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } GeneratedField::DictOrdered => { if dict_ordered__.is_some() { return Err(serde::de::Error::duplicate_field("dictOrdered")); @@ -3278,7 +3256,6 @@ impl<'de> serde::Deserialize<'de> for Field { nullable: nullable__.unwrap_or_default(), children: children__.unwrap_or_default(), metadata: metadata__.unwrap_or_default(), - dict_id: dict_id__.unwrap_or_default(), dict_ordered: dict_ordered__.unwrap_or_default(), }) } diff --git a/datafusion/proto-common/src/generated/prost.rs b/datafusion/proto-common/src/generated/prost.rs index 9e4a1ecb6b09..db46b47efc1c 100644 --- a/datafusion/proto-common/src/generated/prost.rs +++ b/datafusion/proto-common/src/generated/prost.rs @@ -106,9 +106,7 @@ pub struct Field { ::prost::alloc::string::String, ::prost::alloc::string::String, >, - #[prost(int64, tag = "6")] - pub dict_id: i64, - #[prost(bool, tag = "7")] + #[prost(bool, tag = "6")] pub dict_ordered: bool, } #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index ced1865795aa..83c8e98cba97 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -97,9 +97,6 @@ impl TryFrom<&Field> for protobuf::Field { nullable: field.is_nullable(), children: Vec::new(), metadata: field.metadata().clone(), - #[allow(deprecated)] - // See https://github.com/apache/datafusion/issues/14173 to remove deprecated dict_id - dict_id: field.dict_id().unwrap_or(0), dict_ordered: field.dict_is_ordered().unwrap_or(false), }) } diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index 7f714d425342..fb5d414dcec4 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -58,5 +58,5 @@ datafusion-functions = { workspace = true, default-features = true } datafusion-functions-aggregate = { workspace = true } datafusion-functions-window-common = { workspace = true } doc-comment = { workspace = true } -strum = { version = "0.26.1", features = ["derive"] } +strum = { version = "0.27.1", features = ["derive"] } tokio = { workspace = true, features = ["rt-multi-thread"] } diff --git a/datafusion/proto/gen/Cargo.toml b/datafusion/proto/gen/Cargo.toml index 33c815d14900..467a7f487dae 100644 --- a/datafusion/proto/gen/Cargo.toml +++ b/datafusion/proto/gen/Cargo.toml @@ -35,4 +35,4 @@ workspace = true [dependencies] # Pin these dependencies so that the generated output is deterministic pbjson-build = "=0.7.0" -prost-build = "=0.13.4" +prost-build = "=0.13.5" diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 3bc884257dab..1cdfe6d216e3 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -278,7 +278,7 @@ message DmlNode{ Type dml_type = 1; LogicalPlanNode input = 2; TableReference table_name = 3; - datafusion_common.DfSchema schema = 4; + LogicalPlanNode target = 5; } message UnnestNode { diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs b/datafusion/proto/src/generated/datafusion_proto_common.rs index 9e4a1ecb6b09..db46b47efc1c 100644 --- a/datafusion/proto/src/generated/datafusion_proto_common.rs +++ b/datafusion/proto/src/generated/datafusion_proto_common.rs @@ -106,9 +106,7 @@ pub struct Field { ::prost::alloc::string::String, ::prost::alloc::string::String, >, - #[prost(int64, tag = "6")] - pub dict_id: i64, - #[prost(bool, tag = "7")] + #[prost(bool, tag = "6")] pub dict_ordered: bool, } #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index add72e4f777e..6e09e9a797ea 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -4764,7 +4764,7 @@ impl serde::Serialize for DmlNode { if self.table_name.is_some() { len += 1; } - if self.schema.is_some() { + if self.target.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.DmlNode", len)?; @@ -4779,8 +4779,8 @@ impl serde::Serialize for DmlNode { if let Some(v) = self.table_name.as_ref() { struct_ser.serialize_field("tableName", v)?; } - if let Some(v) = self.schema.as_ref() { - struct_ser.serialize_field("schema", v)?; + if let Some(v) = self.target.as_ref() { + struct_ser.serialize_field("target", v)?; } struct_ser.end() } @@ -4797,7 +4797,7 @@ impl<'de> serde::Deserialize<'de> for DmlNode { "input", "table_name", "tableName", - "schema", + "target", ]; #[allow(clippy::enum_variant_names)] @@ -4805,7 +4805,7 @@ impl<'de> serde::Deserialize<'de> for DmlNode { DmlType, Input, TableName, - Schema, + Target, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -4830,7 +4830,7 @@ impl<'de> serde::Deserialize<'de> for DmlNode { "dmlType" | "dml_type" => Ok(GeneratedField::DmlType), "input" => Ok(GeneratedField::Input), "tableName" | "table_name" => Ok(GeneratedField::TableName), - "schema" => Ok(GeneratedField::Schema), + "target" => Ok(GeneratedField::Target), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -4853,7 +4853,7 @@ impl<'de> serde::Deserialize<'de> for DmlNode { let mut dml_type__ = None; let mut input__ = None; let mut table_name__ = None; - let mut schema__ = None; + let mut target__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::DmlType => { @@ -4874,11 +4874,11 @@ impl<'de> serde::Deserialize<'de> for DmlNode { } table_name__ = map_.next_value()?; } - GeneratedField::Schema => { - if schema__.is_some() { - return Err(serde::de::Error::duplicate_field("schema")); + GeneratedField::Target => { + if target__.is_some() { + return Err(serde::de::Error::duplicate_field("target")); } - schema__ = map_.next_value()?; + target__ = map_.next_value()?; } } } @@ -4886,7 +4886,7 @@ impl<'de> serde::Deserialize<'de> for DmlNode { dml_type: dml_type__.unwrap_or_default(), input: input__, table_name: table_name__, - schema: schema__, + target: target__, }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index df32c1a70d61..f5ec45da48f2 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -409,8 +409,8 @@ pub struct DmlNode { pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, optional, tag = "3")] pub table_name: ::core::option::Option, - #[prost(message, optional, tag = "4")] - pub schema: ::core::option::Option, + #[prost(message, optional, boxed, tag = "5")] + pub target: ::core::option::Option<::prost::alloc::boxed::Box>, } /// Nested message and enum types in `DmlNode`. pub mod dml_node { diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 53b683bac66a..641dfe7b5fb8 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -55,8 +55,8 @@ use datafusion::{ }; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ - context, internal_datafusion_err, internal_err, not_impl_err, DataFusionError, - Result, TableReference, + context, internal_datafusion_err, internal_err, not_impl_err, plan_err, + DataFusionError, Result, TableReference, ToDFSchema, }; use datafusion_expr::{ dml, @@ -71,7 +71,7 @@ use datafusion_expr::{ }; use datafusion_expr::{ AggregateUDF, ColumnUnnestList, DmlStatement, FetchType, RecursiveQuery, SkipType, - Unnest, + TableSource, Unnest, }; use self::to_proto::{serialize_expr, serialize_exprs}; @@ -236,6 +236,45 @@ fn from_table_reference( Ok(table_ref.clone().try_into()?) } +/// Converts [LogicalPlan::TableScan] to [TableSource] +/// method to be used to deserialize nodes +/// serialized by [from_table_source] +fn to_table_source( + node: &Option>, + ctx: &SessionContext, + extension_codec: &dyn LogicalExtensionCodec, +) -> Result> { + if let Some(node) = node { + match node.try_into_logical_plan(ctx, extension_codec)? { + LogicalPlan::TableScan(TableScan { source, .. }) => Ok(source), + _ => plan_err!("expected TableScan node"), + } + } else { + plan_err!("LogicalPlanNode should be provided") + } +} + +/// converts [TableSource] to [LogicalPlan::TableScan] +/// using [LogicalPlan::TableScan] was the best approach to +/// serialize [TableSource] to [LogicalPlan::TableScan] +fn from_table_source( + table_name: TableReference, + target: Arc, + extension_codec: &dyn LogicalExtensionCodec, +) -> Result { + let projected_schema = target.schema().to_dfschema_ref()?; + let r = LogicalPlan::TableScan(TableScan { + table_name, + source: target, + projection: None, + projected_schema, + filters: vec![], + fetch: None, + }); + + LogicalPlanNode::try_from_logical_plan(&r, extension_codec) +} + impl AsLogicalPlan for LogicalPlanNode { fn try_decode(buf: &[u8]) -> Result where @@ -454,7 +493,7 @@ impl AsLogicalPlan for LogicalPlanNode { )? .build() } - CustomScan(scan) => { + LogicalPlanType::CustomScan(scan) => { let schema: Schema = convert_required!(scan.schema)?; let schema = Arc::new(schema); let mut projection = None; @@ -942,7 +981,7 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlanType::Dml(dml_node) => Ok(LogicalPlan::Dml( datafusion::logical_expr::DmlStatement::new( from_table_reference(dml_node.table_name.as_ref(), "DML ")?, - Arc::new(convert_required!(dml_node.schema)?), + to_table_source(&dml_node.target, ctx, extension_codec)?, dml_node.dml_type().into(), Arc::new(into_logical_plan!(dml_node.input, ctx, extension_codec)?), ), @@ -1658,7 +1697,7 @@ impl AsLogicalPlan for LogicalPlanNode { )), LogicalPlan::Dml(DmlStatement { table_name, - table_schema, + target, op, input, .. @@ -1669,7 +1708,11 @@ impl AsLogicalPlan for LogicalPlanNode { Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Dml(Box::new(DmlNode { input: Some(Box::new(input)), - schema: Some(table_schema.try_into()?), + target: Some(Box::new(from_table_source( + table_name.clone(), + Arc::clone(target), + extension_codec, + )?)), table_name: Some(table_name.clone().into()), dml_type: dml_type.into(), }))), diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 6d1d4f30610c..5785bc0c4966 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -22,8 +22,8 @@ use datafusion_common::{TableReference, UnnestOptions}; use datafusion_expr::dml::InsertOp; use datafusion_expr::expr::{ - self, Alias, Between, BinaryExpr, Cast, GroupingSet, InList, Like, Placeholder, - ScalarFunction, Unnest, + self, AggregateFunctionParams, Alias, Between, BinaryExpr, Cast, GroupingSet, InList, + Like, Placeholder, ScalarFunction, Unnest, }; use datafusion_expr::WriteOp; use datafusion_expr::{ @@ -300,12 +300,15 @@ pub fn serialize_expr( } Expr::WindowFunction(expr::WindowFunction { ref fun, - ref args, - ref partition_by, - ref order_by, - ref window_frame, - // TODO: support null treatment in proto - null_treatment: _, + params: + expr::WindowFunctionParams { + ref args, + ref partition_by, + ref order_by, + ref window_frame, + // TODO: support null treatment in proto + null_treatment: _, + }, }) => { let (window_function, fun_definition) = match fun { WindowFunctionDefinition::AggregateUDF(aggr_udf) => { @@ -348,11 +351,14 @@ pub fn serialize_expr( } Expr::AggregateFunction(expr::AggregateFunction { ref func, - ref args, - ref distinct, - ref filter, - ref order_by, - null_treatment: _, + params: + AggregateFunctionParams { + ref args, + ref distinct, + ref filter, + ref order_by, + null_treatment: _, + }, }) => { let mut buf = Vec::new(); let _ = codec.try_encode_udaf(func, &mut buf); diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 06e581c0e6ad..bf91acfc4ae7 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -244,7 +244,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { )? .with_newlines_in_values(scan.newlines_in_values) .with_file_compression_type(FileCompressionType::UNCOMPRESSED); - Ok(conf.new_exec()) + Ok(conf.build()) } #[cfg_attr(not(feature = "parquet"), allow(unused_variables))] PhysicalPlanType::ParquetScan(scan) => { @@ -281,7 +281,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { extension_codec, Arc::new(source), )?; - Ok(base_config.new_exec()) + Ok(base_config.build()) } #[cfg(not(feature = "parquet"))] panic!("Unable to process a Parquet PhysicalPlan when `parquet` feature is not enabled") @@ -293,7 +293,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { extension_codec, Arc::new(AvroSource::new()), )?; - Ok(conf.new_exec()) + Ok(conf.build()) } PhysicalPlanType::CoalesceBatches(coalesce_batches) => { let input: Arc = into_physical_plan( diff --git a/datafusion/proto/tests/cases/mod.rs b/datafusion/proto/tests/cases/mod.rs index f36b7178313a..25efa2690268 100644 --- a/datafusion/proto/tests/cases/mod.rs +++ b/datafusion/proto/tests/cases/mod.rs @@ -22,8 +22,8 @@ use std::fmt::Debug; use datafusion_common::plan_err; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, ColumnarValue, PartitionEvaluator, ScalarUDFImpl, - Signature, Volatility, WindowUDFImpl, + Accumulator, AggregateUDFImpl, PartitionEvaluator, ScalarUDFImpl, Signature, + Volatility, WindowUDFImpl, }; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; @@ -69,13 +69,6 @@ impl ScalarUDFImpl for MyRegexUdf { plan_err!("regex_udf only accepts Utf8 arguments") } } - fn invoke_batch( - &self, - _args: &[ColumnarValue], - _number_rows: usize, - ) -> datafusion_common::Result { - unimplemented!() - } fn aliases(&self) -> &[String] { &self.aliases } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 9a60c4f3066d..9cc7514a0d33 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -1494,20 +1494,6 @@ fn round_trip_scalar_values_and_data_types() { Field::new("b", DataType::Boolean, false), ScalarValue::from(false), ) - .with_scalar( - Field::new( - "c", - DataType::Dictionary( - Box::new(DataType::UInt16), - Box::new(DataType::Utf8), - ), - false, - ), - ScalarValue::Dictionary( - Box::new(DataType::UInt16), - Box::new("value".into()), - ), - ) .build() .unwrap(), ScalarValue::try_from(&DataType::Struct(Fields::from(vec![ @@ -1518,25 +1504,6 @@ fn round_trip_scalar_values_and_data_types() { ScalarValue::try_from(&DataType::Struct(Fields::from(vec![ Field::new("a", DataType::Int32, true), Field::new("b", DataType::Boolean, false), - Field::new( - "c", - DataType::Dictionary( - Box::new(DataType::UInt16), - Box::new(DataType::Binary), - ), - false, - ), - Field::new( - "d", - DataType::new_list( - DataType::Dictionary( - Box::new(DataType::UInt16), - Box::new(DataType::Binary), - ), - false, - ), - false, - ), ]))) .unwrap(), ScalarValue::try_from(&DataType::Map( @@ -1815,45 +1782,6 @@ fn round_trip_datatype() { } } -// See https://github.com/apache/datafusion/issues/14173 to remove deprecated dict_id -#[allow(deprecated)] -#[test] -fn roundtrip_dict_id() -> Result<()> { - let dict_id = 42; - let field = Field::new( - "keys", - DataType::List(Arc::new(Field::new_dict( - "item", - DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)), - true, - dict_id, - false, - ))), - false, - ); - let schema = Arc::new(Schema::new(vec![field])); - - // encode - let mut buf: Vec = vec![]; - let schema_proto: protobuf::Schema = schema.try_into().unwrap(); - schema_proto.encode(&mut buf).unwrap(); - - // decode - let schema_proto = protobuf::Schema::decode(buf.as_slice()).unwrap(); - let decoded: Schema = (&schema_proto).try_into()?; - - // assert - let keys = decoded.fields().iter().last().unwrap(); - match keys.data_type() { - DataType::List(field) => { - assert_eq!(field.dict_id(), Some(dict_id), "dict_id should be retained"); - } - _ => panic!("Invalid type"), - } - - Ok(()) -} - #[test] fn roundtrip_null_scalar_values() { let test_types = vec![ diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 38a2d0ada9e7..f1fe698fbf66 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -742,7 +742,7 @@ fn roundtrip_parquet_exec_with_pruning_predicate() -> Result<()> { source, }; - roundtrip_test(scan_config.new_exec()) + roundtrip_test(scan_config.build()) } #[tokio::test] @@ -773,7 +773,7 @@ async fn roundtrip_parquet_exec_with_table_partition_cols() -> Result<()> { source, }; - roundtrip_test(scan_config.new_exec()) + roundtrip_test(scan_config.build()) } #[test] @@ -919,7 +919,7 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { } } - let exec_plan = scan_config.new_exec(); + let exec_plan = scan_config.build(); let ctx = SessionContext::new(); roundtrip_test_and_return(exec_plan, &ctx, &CustomPhysicalExtensionCodec {})?; diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index b53f3674d13a..c4a404975d29 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -43,7 +43,6 @@ recursive_protection = ["dep:recursive"] [dependencies] arrow = { workspace = true } -arrow-schema = { workspace = true } bigdecimal = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs index 7f1e6bf8f28c..2c0bb86cd808 100644 --- a/datafusion/sql/examples/sql.rs +++ b/datafusion/sql/examples/sql.rs @@ -17,7 +17,7 @@ use std::{collections::HashMap, sync::Arc}; -use arrow_schema::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; use datafusion_common::{plan_err, Result, TableReference}; diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index da1a4ba81f5a..1cf3dcb289a6 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -17,7 +17,7 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use arrow_schema::DataType; +use arrow::datatypes::DataType; use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema, Dependency, Result, diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index ab5c550691bd..7d358d0b6624 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow_schema::Field; +use arrow::datatypes::Field; use datafusion_common::{ internal_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema, DataFusionError, Result, Span, TableReference, diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index de753da895d3..fa2619111e7e 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -15,8 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow_schema::DataType; -use arrow_schema::TimeUnit; +use arrow::datatypes::{DataType, TimeUnit}; use datafusion_expr::planner::{ PlannerResult, RawBinaryExpr, RawDictionaryExpr, RawFieldAccessExpr, }; diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index 847163c6d3b3..e81bfa0dc55f 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -19,8 +19,9 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow::compute::kernels::cast_utils::{ parse_interval_month_day_nano_config, IntervalParseConfig, IntervalUnit, }; -use arrow::datatypes::{i256, DECIMAL128_MAX_PRECISION}; -use arrow_schema::{DataType, DECIMAL256_MAX_PRECISION}; +use arrow::datatypes::{ + i256, DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, +}; use bigdecimal::num_bigint::BigInt; use bigdecimal::{BigDecimal, Signed, ToPrimitive}; use datafusion_common::{ diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 85d428cae84f..5fb6ef913d8c 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -20,7 +20,8 @@ use std::collections::HashMap; use std::sync::Arc; use std::vec; -use arrow_schema::*; +use arrow::datatypes::*; +use datafusion_common::error::add_possible_columns_to_diag; use datafusion_common::{ field_not_found, internal_err, plan_datafusion_err, DFSchemaRef, Diagnostic, SchemaError, @@ -223,7 +224,24 @@ impl PlannerContext { } } -/// SQL query planner +/// SQL query planner and binder +/// +/// This struct is used to convert a SQL AST into a [`LogicalPlan`]. +/// +/// You can control the behavior of the planner by providing [`ParserOptions`]. +/// +/// It performs the following tasks: +/// +/// 1. Name and type resolution (called "binding" in other systems). This +/// phase looks up table and column names using the [`ContextProvider`]. +/// 2. Mechanical translation of the AST into a [`LogicalPlan`]. +/// +/// It does not perform type coercion, or perform optimization, which are done +/// by subsequent passes. +/// +/// Key interfaces are: +/// * [`Self::sql_statement_to_plan`]: Convert a statement (e.g. `SELECT ...`) into a [`LogicalPlan`] +/// * [`Self::sql_to_expr`]: Convert an expression (e.g. `1 + 2`) into an [`Expr`] pub struct SqlToRel<'a, S: ContextProvider> { pub(crate) context_provider: &'a S, pub(crate) options: ParserOptions, @@ -368,10 +386,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } .map_err(|err: DataFusionError| match &err { DataFusionError::SchemaError( - SchemaError::FieldNotFound { .. }, + SchemaError::FieldNotFound { + field, + valid_fields, + }, _, ) => { - let diagnostic = if let Some(relation) = &col.relation { + let mut diagnostic = if let Some(relation) = &col.relation { Diagnostic::new_error( format!( "column '{}' not found in '{}'", @@ -385,6 +406,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { col.spans().first(), ) }; + add_possible_columns_to_diag( + &mut diagnostic, + field, + valid_fields, + ); err.with_diagnostic(diagnostic) } _ => err, diff --git a/datafusion/sql/src/set_expr.rs b/datafusion/sql/src/set_expr.rs index 2579f2397228..a55b3b039087 100644 --- a/datafusion/sql/src/set_expr.rs +++ b/datafusion/sql/src/set_expr.rs @@ -54,15 +54,18 @@ impl SqlToRel<'_, S> { return Err(err); } }; - self.validate_set_expr_num_of_columns( - op, - left_span, - right_span, - &left_plan, - &right_plan, - set_expr_span, - )?; - + if !(set_quantifier == SetQuantifier::ByName + || set_quantifier == SetQuantifier::AllByName) + { + self.validate_set_expr_num_of_columns( + op, + left_span, + right_span, + &left_plan, + &right_plan, + set_expr_span, + )?; + } self.set_operation_to_plan(op, left_plan, right_plan, set_quantifier) } SetExpr::Query(q) => self.query_to_plan(*q, planner_context), @@ -72,17 +75,11 @@ impl SqlToRel<'_, S> { pub(super) fn is_union_all(set_quantifier: SetQuantifier) -> Result { match set_quantifier { - SetQuantifier::All => Ok(true), - SetQuantifier::Distinct | SetQuantifier::None => Ok(false), - SetQuantifier::ByName => { - not_impl_err!("UNION BY NAME not implemented") - } - SetQuantifier::AllByName => { - not_impl_err!("UNION ALL BY NAME not implemented") - } - SetQuantifier::DistinctByName => { - not_impl_err!("UNION DISTINCT BY NAME not implemented") - } + SetQuantifier::All | SetQuantifier::AllByName => Ok(true), + SetQuantifier::Distinct + | SetQuantifier::ByName + | SetQuantifier::DistinctByName + | SetQuantifier::None => Ok(false), } } @@ -127,28 +124,42 @@ impl SqlToRel<'_, S> { right_plan: LogicalPlan, set_quantifier: SetQuantifier, ) -> Result { - let all = Self::is_union_all(set_quantifier)?; - match (op, all) { - (SetOperator::Union, true) => LogicalPlanBuilder::from(left_plan) - .union(right_plan)? - .build(), - (SetOperator::Union, false) => LogicalPlanBuilder::from(left_plan) - .union_distinct(right_plan)? + match (op, set_quantifier) { + (SetOperator::Union, SetQuantifier::All) => { + LogicalPlanBuilder::from(left_plan) + .union(right_plan)? + .build() + } + (SetOperator::Union, SetQuantifier::AllByName) => { + LogicalPlanBuilder::from(left_plan) + .union_by_name(right_plan)? + .build() + } + (SetOperator::Union, SetQuantifier::Distinct | SetQuantifier::None) => { + LogicalPlanBuilder::from(left_plan) + .union_distinct(right_plan)? + .build() + } + ( + SetOperator::Union, + SetQuantifier::ByName | SetQuantifier::DistinctByName, + ) => LogicalPlanBuilder::from(left_plan) + .union_by_name_distinct(right_plan)? .build(), - (SetOperator::Intersect, true) => { + (SetOperator::Intersect, SetQuantifier::All) => { LogicalPlanBuilder::intersect(left_plan, right_plan, true) } - (SetOperator::Intersect, false) => { + (SetOperator::Intersect, SetQuantifier::Distinct | SetQuantifier::None) => { LogicalPlanBuilder::intersect(left_plan, right_plan, false) } - (SetOperator::Except, true) => { + (SetOperator::Except, SetQuantifier::All) => { LogicalPlanBuilder::except(left_plan, right_plan, true) } - (SetOperator::Except, false) => { + (SetOperator::Except, SetQuantifier::Distinct | SetQuantifier::None) => { LogicalPlanBuilder::except(left_plan, right_plan, false) } - (SetOperator::Minus, _) => { - not_impl_err!("MINUS Set Operator not implemented") + (op, quantifier) => { + not_impl_err!("{op} {quantifier} not implemented") } } } diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index d48cc93ee39e..74055d979145 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -29,7 +29,7 @@ use crate::planner::{ }; use crate::utils::normalize_ident; -use arrow_schema::{DataType, Fields}; +use arrow::datatypes::{DataType, Fields}; use datafusion_common::error::_plan_err; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ @@ -1709,14 +1709,10 @@ impl SqlToRel<'_, S> { // Do a table lookup to verify the table exists let table_ref = self.object_name_to_table_reference(table_name.clone())?; let table_source = self.context_provider.get_table_source(table_ref.clone())?; - let schema = (*table_source.schema()).clone(); - let schema = DFSchema::try_from(schema)?; - let scan = LogicalPlanBuilder::scan( - object_name_to_string(&table_name), - table_source, - None, - )? - .build()?; + let schema = table_source.schema().to_dfschema_ref()?; + let scan = + LogicalPlanBuilder::scan(table_ref.clone(), Arc::clone(&table_source), None)? + .build()?; let mut planner_context = PlannerContext::new(); let source = match predicate_expr { @@ -1724,7 +1720,7 @@ impl SqlToRel<'_, S> { Some(predicate_expr) => { let filter_expr = self.sql_to_expr(predicate_expr, &schema, &mut planner_context)?; - let schema = Arc::new(schema.clone()); + let schema = Arc::new(schema); let mut using_columns = HashSet::new(); expr_to_columns(&filter_expr, &mut using_columns)?; let filter_expr = normalize_col_with_schemas_and_ambiguity_check( @@ -1738,7 +1734,7 @@ impl SqlToRel<'_, S> { let plan = LogicalPlan::Dml(DmlStatement::new( table_ref, - schema.into(), + table_source, WriteOp::Delete, Arc::new(source), )); @@ -1851,7 +1847,7 @@ impl SqlToRel<'_, S> { let plan = LogicalPlan::Dml(DmlStatement::new( table_name, - table_schema, + table_source, WriteOp::Update, Arc::new(source), )); @@ -1980,7 +1976,7 @@ impl SqlToRel<'_, S> { let plan = LogicalPlan::Dml(DmlStatement::new( table_name, - Arc::new(table_schema), + Arc::clone(&table_source), WriteOp::Insert(insert_op), Arc::new(source), )); diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index adfb7a0d0cd2..399f0df0a699 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -18,7 +18,7 @@ use std::{collections::HashMap, sync::Arc}; use super::{utils::character_length_to_sql, utils::date_part_to_sql, Unparser}; -use arrow_schema::TimeUnit; +use arrow::datatypes::TimeUnit; use datafusion_common::Result; use datafusion_expr::Expr; use regex::Regex; diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 909533773435..7c56969d47cd 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion_expr::expr::Unnest; +use datafusion_expr::expr::{AggregateFunctionParams, Unnest, WindowFunctionParams}; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ self, Array, BinaryOperator, Expr as AstExpr, Function, Ident, Interval, ObjectName, @@ -34,9 +34,8 @@ use arrow::array::{ }, ArrayRef, Date32Array, Date64Array, PrimitiveArray, }; -use arrow::datatypes::{Decimal128Type, Decimal256Type, DecimalType}; +use arrow::datatypes::{DataType, Decimal128Type, Decimal256Type, DecimalType}; use arrow::util::display::array_value_to_string; -use arrow_schema::DataType; use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_err, Column, Result, ScalarValue, @@ -190,11 +189,14 @@ impl Unparser<'_> { Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql_inner(expr), Expr::WindowFunction(WindowFunction { fun, - args, - partition_by, - order_by, - window_frame, - null_treatment: _, + params: + WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + .. + }, }) => { let func_name = fun.name(); @@ -285,9 +287,15 @@ impl Unparser<'_> { }), Expr::AggregateFunction(agg) => { let func_name = agg.func.name(); + let AggregateFunctionParams { + distinct, + args, + filter, + .. + } = &agg.params; - let args = self.function_args_to_sql(&agg.args)?; - let filter = match &agg.filter { + let args = self.function_args_to_sql(args)?; + let filter = match filter { Some(filter) => Some(Box::new(self.expr_to_sql_inner(filter)?)), None => None, }; @@ -298,8 +306,7 @@ impl Unparser<'_> { span: Span::empty(), }]), args: ast::FunctionArguments::List(ast::FunctionArgumentList { - duplicate_treatment: agg - .distinct + duplicate_treatment: distinct .then_some(ast::DuplicateTreatment::Distinct), args, clauses: vec![], @@ -1649,16 +1656,15 @@ mod tests { use std::{any::Any, sync::Arc, vec}; use arrow::array::{LargeListArray, ListArray}; - use arrow::datatypes::{Field, Int32Type, Schema, TimeUnit}; - use arrow_schema::DataType::Int8; + use arrow::datatypes::{DataType::Int8, Field, Int32Type, Schema, TimeUnit}; use ast::ObjectName; use datafusion_common::{Spans, TableReference}; use datafusion_expr::expr::WildcardOptions; use datafusion_expr::{ case, cast, col, cube, exists, grouping_set, interval_datetime_lit, interval_year_month_lit, lit, not, not_exists, out_ref_col, placeholder, rollup, - table_scan, try_cast, when, wildcard, ColumnarValue, ScalarUDF, ScalarUDFImpl, - Signature, Volatility, WindowFrame, WindowFunctionDefinition, + table_scan, try_cast, when, wildcard, ScalarUDF, ScalarUDFImpl, Signature, + Volatility, WindowFrame, WindowFunctionDefinition, }; use datafusion_expr::{interval_month_day_nano_lit, ExprFunctionExt}; use datafusion_functions::expr_fn::{get_field, named_struct}; @@ -1707,14 +1713,6 @@ mod tests { fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Int32) } - - fn invoke_batch( - &self, - _args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - unimplemented!("DummyUDF::invoke") - } } // See sql::tests for E2E tests. @@ -1934,30 +1932,34 @@ mod tests { ( Expr::WindowFunction(WindowFunction { fun: WindowFunctionDefinition::WindowUDF(row_number_udwf()), - args: vec![col("col")], - partition_by: vec![], - order_by: vec![], - window_frame: WindowFrame::new(None), - null_treatment: None, + params: WindowFunctionParams { + args: vec![col("col")], + partition_by: vec![], + order_by: vec![], + window_frame: WindowFrame::new(None), + null_treatment: None, + }, }), r#"row_number(col) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)"#, ), ( Expr::WindowFunction(WindowFunction { fun: WindowFunctionDefinition::AggregateUDF(count_udaf()), - args: vec![wildcard()], - partition_by: vec![], - order_by: vec![Sort::new(col("a"), false, true)], - window_frame: WindowFrame::new_bounds( - datafusion_expr::WindowFrameUnits::Range, - datafusion_expr::WindowFrameBound::Preceding( - ScalarValue::UInt32(Some(6)), - ), - datafusion_expr::WindowFrameBound::Following( - ScalarValue::UInt32(Some(2)), + params: WindowFunctionParams { + args: vec![wildcard()], + partition_by: vec![], + order_by: vec![Sort::new(col("a"), false, true)], + window_frame: WindowFrame::new_bounds( + datafusion_expr::WindowFrameUnits::Range, + datafusion_expr::WindowFrameBound::Preceding( + ScalarValue::UInt32(Some(6)), + ), + datafusion_expr::WindowFrameBound::Following( + ScalarValue::UInt32(Some(2)), + ), ), - ), - null_treatment: None, + null_treatment: None, + }, }), r#"count(*) OVER (ORDER BY a DESC NULLS FIRST RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING)"#, ), @@ -2790,7 +2792,7 @@ mod tests { let unparser = Unparser::new(dialect.as_ref()); let func = WindowFunctionDefinition::WindowUDF(rank_udwf()); let mut window_func = WindowFunction::new(func, vec![]); - window_func.order_by = vec![Sort::new(col("a"), true, true)]; + window_func.params.order_by = vec![Sort::new(col("a"), true, true)]; let expr = Expr::WindowFunction(window_func); let ast = unparser.expr_to_sql(&expr)?; diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs index db9837483168..2e3c8e9e9484 100644 --- a/datafusion/sql/src/unparser/rewrite.rs +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -17,7 +17,7 @@ use std::{collections::HashSet, sync::Arc}; -use arrow_schema::Schema; +use arrow::datatypes::Schema; use datafusion_common::tree_node::TreeNodeContainer; use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index ab3e75a960a0..3f093afaf26a 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -19,7 +19,7 @@ use std::vec; -use arrow_schema::{ +use arrow::datatypes::{ DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE, }; use datafusion_common::tree_node::{ @@ -30,7 +30,9 @@ use datafusion_common::{ HashMap, Result, ScalarValue, }; use datafusion_expr::builder::get_struct_unnested_columns; -use datafusion_expr::expr::{Alias, GroupingSet, Unnest, WindowFunction}; +use datafusion_expr::expr::{ + Alias, GroupingSet, Unnest, WindowFunction, WindowFunctionParams, +}; use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs}; use datafusion_expr::{ col, expr_vec_fmt, ColumnUnnestList, Expr, ExprSchemable, LogicalPlan, @@ -240,11 +242,15 @@ pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr let all_partition_keys = window_exprs .iter() .map(|expr| match expr { - Expr::WindowFunction(WindowFunction { partition_by, .. }) => Ok(partition_by), + Expr::WindowFunction(WindowFunction { + params: WindowFunctionParams { partition_by, .. }, + .. + }) => Ok(partition_by), Expr::Alias(Alias { expr, .. }) => match expr.as_ref() { - Expr::WindowFunction(WindowFunction { partition_by, .. }) => { - Ok(partition_by) - } + Expr::WindowFunction(WindowFunction { + params: WindowFunctionParams { partition_by, .. }, + .. + }) => Ok(partition_by), expr => exec_err!("Impossibly got non-window expr {expr:?}"), }, expr => exec_err!("Impossibly got non-window expr {expr:?}"), @@ -650,8 +656,7 @@ pub(crate) fn rewrite_recursive_unnest_bottom_up( mod tests { use std::{ops::Add, sync::Arc}; - use arrow::datatypes::{DataType as ArrowDataType, Field, Schema}; - use arrow_schema::Fields; + use arrow::datatypes::{DataType as ArrowDataType, Field, Fields, Schema}; use datafusion_common::{Column, DFSchema, Result}; use datafusion_expr::{ col, lit, unnest, ColumnUnnestList, EmptyRelation, LogicalPlan, diff --git a/datafusion/sql/tests/cases/diagnostic.rs b/datafusion/sql/tests/cases/diagnostic.rs index 55d3a953a728..9dae2d0c3e93 100644 --- a/datafusion/sql/tests/cases/diagnostic.rs +++ b/datafusion/sql/tests/cases/diagnostic.rs @@ -35,6 +35,7 @@ fn do_query(sql: &'static str) -> Diagnostic { collect_spans: true, ..ParserOptions::default() }; + let state = MockSessionState::default(); let context = MockContextProvider { state }; let sql_to_rel = SqlToRel::new_with_options(&context, options); @@ -200,14 +201,8 @@ fn test_ambiguous_reference() -> Result<()> { let diag = do_query(query); assert_eq!(diag.message, "column 'first_name' is ambiguous"); assert_eq!(diag.span, Some(spans["a"])); - assert_eq!( - diag.notes[0].message, - "possible reference to 'first_name' in table 'a'" - ); - assert_eq!( - diag.notes[1].message, - "possible reference to 'first_name' in table 'b'" - ); + assert_eq!(diag.notes[0].message, "possible column a.first_name"); + assert_eq!(diag.notes[1].message, "possible column b.first_name"); Ok(()) } @@ -225,3 +220,57 @@ fn test_incompatible_types_binary_arithmetic() -> Result<()> { assert_eq!(diag.notes[1].span, Some(spans["right"])); Ok(()) } + +#[test] +fn test_field_not_found_suggestion() -> Result<()> { + let query = "SELECT /*whole*/first_na/*whole*/ FROM person"; + let spans = get_spans(query); + let diag = do_query(query); + assert_eq!(diag.message, "column 'first_na' not found"); + assert_eq!(diag.span, Some(spans["whole"])); + assert_eq!(diag.notes.len(), 1); + + let mut suggested_fields: Vec = diag + .notes + .iter() + .filter_map(|note| { + if note.message.starts_with("possible column") { + Some(note.message.replace("possible column ", "")) + } else { + None + } + }) + .collect(); + suggested_fields.sort(); + assert_eq!(suggested_fields[0], "person.first_name"); + Ok(()) +} + +#[test] +fn test_ambiguous_column_suggestion() -> Result<()> { + let query = "SELECT /*whole*/id/*whole*/ FROM test_decimal, person"; + let spans = get_spans(query); + let diag = do_query(query); + + assert_eq!(diag.message, "column 'id' is ambiguous"); + assert_eq!(diag.span, Some(spans["whole"])); + + assert_eq!(diag.notes.len(), 2); + + let mut suggested_fields: Vec = diag + .notes + .iter() + .filter_map(|note| { + if note.message.starts_with("possible column") { + Some(note.message.replace("possible column ", "")) + } else { + None + } + }) + .collect(); + + suggested_fields.sort(); + assert_eq!(suggested_fields, vec!["person.id", "test_decimal.id"]); + + Ok(()) +} diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 80af4d367b1a..5af93a01e6c9 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow_schema::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{assert_contains, DFSchema, DFSchemaRef, Result, TableReference}; use datafusion_expr::test::function_stub::{ count_udaf, max_udaf, min_udaf, sum, sum_udaf, diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index 337fc2ce4f67..ee1b761970de 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -21,7 +21,7 @@ use std::collections::HashMap; use std::fmt::{Debug, Display}; use std::{sync::Arc, vec}; -use arrow_schema::*; +use arrow::datatypes::*; use datafusion_common::config::ConfigOptions; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{plan_err, DFSchema, GetExt, Result, TableReference}; diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index c514458d4a27..1df18302687e 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -21,8 +21,7 @@ use std::collections::HashMap; use std::sync::Arc; use std::vec; -use arrow_schema::TimeUnit::Nanosecond; -use arrow_schema::*; +use arrow::datatypes::{TimeUnit::Nanosecond, *}; use common::MockContextProvider; use datafusion_common::{ assert_contains, DataFusionError, ParamValues, Result, ScalarValue, @@ -31,8 +30,8 @@ use datafusion_expr::{ col, logical_plan::{LogicalPlan, Prepare}, test::function_stub::sum_udaf, - ColumnarValue, CreateIndex, DdlStatement, ScalarUDF, ScalarUDFImpl, Signature, - Statement, Volatility, + CreateIndex, DdlStatement, ScalarUDF, ScalarUDFImpl, Signature, Statement, + Volatility, }; use datafusion_functions::{string, unicode}; use datafusion_sql::{ @@ -556,6 +555,19 @@ Dml: op=[Delete] table=[person] quick_test(sql, plan); } +#[test] +fn plan_delete_quoted_identifier_case_sensitive() { + let sql = + "DELETE FROM \"SomeCatalog\".\"SomeSchema\".\"UPPERCASE_test\" WHERE \"Id\" = 1"; + let plan = r#" +Dml: op=[Delete] table=[SomeCatalog.SomeSchema.UPPERCASE_test] + Filter: Id = Int64(1) + TableScan: SomeCatalog.SomeSchema.UPPERCASE_test + "# + .trim(); + quick_test(sql, plan); +} + #[test] fn select_column_does_not_exist() { let sql = "SELECT doesnotexist FROM person"; @@ -2101,6 +2113,33 @@ fn union() { quick_test(sql, expected); } +#[test] +fn union_by_name_different_columns() { + let sql = "SELECT order_id from orders UNION BY NAME SELECT order_id, 1 FROM orders"; + let expected = "\ + Distinct:\ + \n Union\ + \n Projection: NULL AS Int64(1), order_id\ + \n Projection: orders.order_id\ + \n TableScan: orders\ + \n Projection: orders.order_id, Int64(1)\ + \n TableScan: orders"; + quick_test(sql, expected); +} + +#[test] +fn union_by_name_same_column_names() { + let sql = "SELECT order_id from orders UNION SELECT order_id FROM orders"; + let expected = "\ + Distinct:\ + \n Union\ + \n Projection: orders.order_id\ + \n TableScan: orders\ + \n Projection: orders.order_id\ + \n TableScan: orders"; + quick_test(sql, expected); +} + #[test] fn union_all() { let sql = "SELECT order_id from orders UNION ALL SELECT order_id FROM orders"; @@ -2112,6 +2151,31 @@ fn union_all() { quick_test(sql, expected); } +#[test] +fn union_all_by_name_different_columns() { + let sql = + "SELECT order_id from orders UNION ALL BY NAME SELECT order_id, 1 FROM orders"; + let expected = "\ + Union\ + \n Projection: NULL AS Int64(1), order_id\ + \n Projection: orders.order_id\ + \n TableScan: orders\ + \n Projection: orders.order_id, Int64(1)\ + \n TableScan: orders"; + quick_test(sql, expected); +} + +#[test] +fn union_all_by_name_same_column_names() { + let sql = "SELECT order_id from orders UNION ALL BY NAME SELECT order_id FROM orders"; + let expected = "Union\ + \n Projection: orders.order_id\ + \n TableScan: orders\ + \n Projection: orders.order_id\ + \n TableScan: orders"; + quick_test(sql, expected); +} + #[test] fn empty_over() { let sql = "SELECT order_id, MAX(order_id) OVER () from orders"; @@ -2634,14 +2698,6 @@ impl ScalarUDFImpl for DummyUDF { fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(self.return_type.clone()) } - - fn invoke_batch( - &self, - _args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - unimplemented!("DummyUDF::invoke") - } } /// Create logical plan, write with formatter, compare to expected output diff --git a/datafusion/sqllogictest/Cargo.toml b/datafusion/sqllogictest/Cargo.toml index 41b96e341074..f1d37c7202d6 100644 --- a/datafusion/sqllogictest/Cargo.toml +++ b/datafusion/sqllogictest/Cargo.toml @@ -40,11 +40,8 @@ async-trait = { workspace = true } bigdecimal = { workspace = true } bytes = { workspace = true, optional = true } chrono = { workspace = true, optional = true } -clap = { version = "4.5.16", features = ["derive", "env"] } +clap = { version = "4.5.30", features = ["derive", "env"] } datafusion = { workspace = true, default-features = true, features = ["avro"] } -datafusion-catalog = { workspace = true, default-features = true } -datafusion-common = { workspace = true, default-features = true } -datafusion-common-runtime = { workspace = true, default-features = true } futures = { workspace = true } half = { workspace = true, default-features = true } indicatif = "0.17" @@ -54,7 +51,7 @@ object_store = { workspace = true } postgres-protocol = { version = "0.6.7", optional = true } postgres-types = { version = "0.2.8", features = ["derive", "with-chrono-0_4"], optional = true } rust_decimal = { version = "1.36.0", features = ["tokio-pg"] } -sqllogictest = "0.26.4" +sqllogictest = "0.27.1" sqlparser = { workspace = true } tempfile = { workspace = true } testcontainers = { version = "0.23", features = ["default"], optional = true } diff --git a/datafusion/sqllogictest/bin/postgres_container.rs b/datafusion/sqllogictest/bin/postgres_container.rs index 64905022914a..411562a7ccc7 100644 --- a/datafusion/sqllogictest/bin/postgres_container.rs +++ b/datafusion/sqllogictest/bin/postgres_container.rs @@ -16,7 +16,7 @@ // under the License. use crate::Options; -use datafusion_common::Result; +use datafusion::common::Result; use log::info; use std::env::set_var; use std::future::Future; diff --git a/datafusion/sqllogictest/bin/sqllogictests.rs b/datafusion/sqllogictest/bin/sqllogictests.rs index c30aaf38ec9c..bbb88819efe0 100644 --- a/datafusion/sqllogictest/bin/sqllogictests.rs +++ b/datafusion/sqllogictest/bin/sqllogictests.rs @@ -16,10 +16,9 @@ // under the License. use clap::Parser; -use datafusion_common::instant::Instant; -use datafusion_common::utils::get_available_parallelism; -use datafusion_common::{exec_err, DataFusionError, Result}; -use datafusion_common_runtime::SpawnedTask; +use datafusion::common::instant::Instant; +use datafusion::common::utils::get_available_parallelism; +use datafusion::common::{exec_err, DataFusionError, Result}; use datafusion_sqllogictest::{ df_value_validator, read_dir_recursive, setup_scratch_dir, value_normalizer, DataFusion, TestContext, @@ -40,6 +39,7 @@ use sqllogictest::{ use crate::postgres_container::{ initialize_postgres_container, terminate_postgres_container, }; +use datafusion::common::runtime::SpawnedTask; use std::ffi::OsStr; use std::path::{Path, PathBuf}; @@ -330,7 +330,7 @@ async fn run_test_file_with_postgres( _mp: MultiProgress, _mp_style: ProgressStyle, ) -> Result<()> { - use datafusion_common::plan_err; + use datafusion::common::plan_err; plan_err!("Can not run with postgres as postgres feature is not enabled") } @@ -446,7 +446,7 @@ async fn run_complete_file_with_postgres( _mp: MultiProgress, _mp_style: ProgressStyle, ) -> Result<()> { - use datafusion_common::plan_err; + use datafusion::common::plan_err; plan_err!("Can not run with postgres as postgres feature is not enabled") } diff --git a/datafusion/sqllogictest/src/engines/datafusion_engine/error.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/error.rs index ae56c0260564..a60ae1012f9c 100644 --- a/datafusion/sqllogictest/src/engines/datafusion_engine/error.rs +++ b/datafusion/sqllogictest/src/engines/datafusion_engine/error.rs @@ -16,7 +16,7 @@ // under the License. use arrow::error::ArrowError; -use datafusion_common::DataFusionError; +use datafusion::error::DataFusionError; use sqllogictest::TestError; use sqlparser::parser::ParserError; use thiserror::Error; diff --git a/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs index a9325e452ae8..eeb34186ea20 100644 --- a/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs +++ b/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs @@ -15,19 +15,18 @@ // specific language governing permissions and limitations // under the License. +use super::super::conversion::*; +use super::error::{DFSqlLogicTestError, Result}; use crate::engines::output::DFColumnType; use arrow::array::{Array, AsArray}; use arrow::datatypes::Fields; use arrow::util::display::ArrayFormatter; use arrow::{array, array::ArrayRef, datatypes::DataType, record_batch::RecordBatch}; -use datafusion_common::format::DEFAULT_CLI_FORMAT_OPTIONS; -use datafusion_common::DataFusionError; +use datafusion::common::format::DEFAULT_CLI_FORMAT_OPTIONS; +use datafusion::common::DataFusionError; use std::path::PathBuf; use std::sync::LazyLock; -use super::super::conversion::*; -use super::error::{DFSqlLogicTestError, Result}; - /// Converts `batches` to a result as expected by sqllogictest. pub fn convert_batches(batches: Vec) -> Result>> { if batches.is_empty() { diff --git a/datafusion/sqllogictest/src/engines/datafusion_engine/runner.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/runner.rs index e696058484a9..a3a29eda2ee9 100644 --- a/datafusion/sqllogictest/src/engines/datafusion_engine/runner.rs +++ b/datafusion/sqllogictest/src/engines/datafusion_engine/runner.rs @@ -111,6 +111,8 @@ impl sqllogictest::AsyncDB for DataFusion { async fn sleep(dur: Duration) { tokio::time::sleep(dur).await; } + + async fn shutdown(&mut self) {} } async fn run_query(ctx: &SessionContext, sql: impl Into) -> Result { diff --git a/datafusion/sqllogictest/src/engines/postgres_engine/mod.rs b/datafusion/sqllogictest/src/engines/postgres_engine/mod.rs index 21fdedd9a513..68816626bf67 100644 --- a/datafusion/sqllogictest/src/engines/postgres_engine/mod.rs +++ b/datafusion/sqllogictest/src/engines/postgres_engine/mod.rs @@ -17,7 +17,7 @@ use async_trait::async_trait; use bytes::Bytes; -use datafusion_common_runtime::SpawnedTask; +use datafusion::common::runtime::SpawnedTask; use futures::{SinkExt, StreamExt}; use log::{debug, info}; use sqllogictest::DBOutput; @@ -53,8 +53,9 @@ pub enum Error { pub type Result = std::result::Result; pub struct Postgres { - client: tokio_postgres::Client, - _spawned_task: SpawnedTask<()>, + // None means the connection has been shutdown + client: Option, + spawned_task: Option>, /// Relative test file path relative_path: PathBuf, pb: ProgressBar, @@ -90,7 +91,7 @@ impl Postgres { let (client, connection) = res?; - let _spawned_task = SpawnedTask::spawn(async move { + let spawned_task = SpawnedTask::spawn(async move { if let Err(e) = connection.await { log::error!("Postgres connection error: {:?}", e); } @@ -113,13 +114,17 @@ impl Postgres { .await?; Ok(Self { - client, - _spawned_task, + client: Some(client), + spawned_task: Some(spawned_task), relative_path, pb, }) } + fn get_client(&mut self) -> &mut tokio_postgres::Client { + self.client.as_mut().expect("client is shutdown") + } + /// Special COPY command support. "COPY 'filename'" requires the /// server to read the file which may not be possible (maybe it is /// remote or running in some other docker container). @@ -170,7 +175,7 @@ impl Postgres { debug!("Copying data from file {filename} using sql: {new_sql}"); // start the COPY command and get location to write data to - let tx = self.client.transaction().await?; + let tx = self.get_client().transaction().await?; let sink = tx.copy_in(&new_sql).await?; let mut sink = Box::pin(sink); @@ -257,12 +262,12 @@ impl sqllogictest::AsyncDB for Postgres { } if !is_query_sql { - self.client.execute(sql, &[]).await?; + self.get_client().execute(sql, &[]).await?; self.pb.inc(1); return Ok(DBOutput::StatementComplete(0)); } let start = Instant::now(); - let rows = self.client.query(sql, &[]).await?; + let rows = self.get_client().query(sql, &[]).await?; let duration = start.elapsed(); if duration.gt(&Duration::from_millis(500)) { @@ -272,7 +277,7 @@ impl sqllogictest::AsyncDB for Postgres { self.pb.inc(1); let types: Vec = if rows.is_empty() { - self.client + self.get_client() .prepare(sql) .await? .columns() @@ -300,6 +305,15 @@ impl sqllogictest::AsyncDB for Postgres { fn engine_name(&self) -> &str { "postgres" } + + async fn shutdown(&mut self) { + if let Some(client) = self.client.take() { + drop(client); + } + if let Some(spawned_task) = self.spawned_task.take() { + spawned_task.join().await.ok(); + } + } } fn convert_rows(rows: Vec) -> Vec> { diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index f7c9346a8983..ce819f186454 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -22,25 +22,26 @@ use std::path::Path; use std::sync::Arc; use arrow::array::{ - ArrayRef, BinaryArray, Float64Array, Int32Array, LargeBinaryArray, LargeStringArray, - StringArray, TimestampNanosecondArray, + Array, ArrayRef, BinaryArray, Float64Array, Int32Array, LargeBinaryArray, + LargeStringArray, StringArray, TimestampNanosecondArray, UnionArray, }; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; +use arrow::buffer::ScalarBuffer; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit, UnionFields}; use arrow::record_batch::RecordBatch; +use datafusion::catalog::{ + CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, Session, +}; +use datafusion::common::DataFusionError; use datafusion::logical_expr::{create_udf, ColumnarValue, Expr, ScalarUDF, Volatility}; use datafusion::physical_plan::ExecutionPlan; -use datafusion::prelude::SessionConfig; +use datafusion::prelude::*; use datafusion::{ datasource::{MemTable, TableProvider, TableType}, prelude::{CsvReadOptions, SessionContext}, }; -use datafusion_catalog::CatalogProvider; -use datafusion_catalog::{memory::MemoryCatalogProvider, memory::MemorySchemaProvider}; -use datafusion_common::cast::as_float64_array; -use datafusion_common::DataFusionError; use async_trait::async_trait; -use datafusion::catalog::Session; +use datafusion::common::cast::as_float64_array; use log::info; use tempfile::TempDir; @@ -113,6 +114,10 @@ impl TestContext { info!("Registering metadata table tables"); register_metadata_tables(test_ctx.session_ctx()).await; } + "union_function.slt" => { + info!("Registering table with union column"); + register_union_table(test_ctx.session_ctx()) + } _ => { info!("Using default SessionContext"); } @@ -402,3 +407,24 @@ fn create_example_udf() -> ScalarUDF { adder, ) } + +fn register_union_table(ctx: &SessionContext) { + let union = UnionArray::try_new( + UnionFields::new(vec![3], vec![Field::new("int", DataType::Int32, false)]), + ScalarBuffer::from(vec![3, 3]), + None, + vec![Arc::new(Int32Array::from(vec![1, 2]))], + ) + .unwrap(); + + let schema = Schema::new(vec![Field::new( + "union_column", + union.data_type().clone(), + false, + )]); + + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(union)]).unwrap(); + + ctx.register_batch("union_table", batch).unwrap(); +} diff --git a/datafusion/sqllogictest/src/util.rs b/datafusion/sqllogictest/src/util.rs index 1bdfdd03360f..5ae640cc98a9 100644 --- a/datafusion/sqllogictest/src/util.rs +++ b/datafusion/sqllogictest/src/util.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::{exec_datafusion_err, Result}; +use datafusion::common::{exec_datafusion_err, Result}; use itertools::Itertools; use log::Level::Warn; use log::{info, log_enabled, warn}; diff --git a/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt b/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt index a2e51cffacf7..3a4d641abf68 100644 --- a/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt +++ b/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt @@ -220,6 +220,7 @@ set datafusion.execution.batch_size = 4; # Inserting into nullable table with batch_size specified above # to prevent creation on single in-memory batch + statement ok CREATE TABLE aggregate_test_100_null ( c2 TINYINT NOT NULL, @@ -506,7 +507,7 @@ SELECT avg(c11) FILTER (WHERE c2 != 5) FROM aggregate_test_100 GROUP BY c1 ORDER BY c1; ---- -a 2.5 0.449071887467 +a 2.5 0.449071887467 b 2.642857142857 0.445486298629 c 2.421052631579 0.422882117723 d 2.125 0.518706191331 diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 8f23bfe5ea65..6b5b246aee51 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -2656,6 +2656,28 @@ select list_push_front(1, arrow_cast(make_array(2, 3, 4), 'LargeList(Int64)')), ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +# array_prepend scalar function #7 (element is fixed size list) +query ??? +select array_prepend(arrow_cast(make_array(1), 'FixedSizeList(1, Int64)'), make_array(arrow_cast(make_array(2), 'FixedSizeList(1, Int64)'), arrow_cast(make_array(3), 'FixedSizeList(1, Int64)'), arrow_cast(make_array(4), 'FixedSizeList(1, Int64)'))), + array_prepend(arrow_cast(make_array(1.0), 'FixedSizeList(1, Float64)'), make_array(arrow_cast([2.0], 'FixedSizeList(1, Float64)'), arrow_cast([3.0], 'FixedSizeList(1, Float64)'), arrow_cast([4.0], 'FixedSizeList(1, Float64)'))), + array_prepend(arrow_cast(make_array('h'), 'FixedSizeList(1, Utf8)'), make_array(arrow_cast(['e'], 'FixedSizeList(1, Utf8)'), arrow_cast(['l'], 'FixedSizeList(1, Utf8)'), arrow_cast(['l'], 'FixedSizeList(1, Utf8)'), arrow_cast(['o'], 'FixedSizeList(1, Utf8)'))); +---- +[[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] + +query ??? +select array_prepend(arrow_cast(make_array(1), 'FixedSizeList(1, Int64)'), arrow_cast(make_array(make_array(2), make_array(3), make_array(4)), 'LargeList(FixedSizeList(1, Int64))')), + array_prepend(arrow_cast(make_array(1.0), 'FixedSizeList(1, Float64)'), arrow_cast(make_array([2.0], [3.0], [4.0]), 'LargeList(FixedSizeList(1, Float64))')), + array_prepend(arrow_cast(make_array('h'), 'FixedSizeList(1, Utf8)'), arrow_cast(make_array(['e'], ['l'], ['l'], ['o']), 'LargeList(FixedSizeList(1, Utf8))')); +---- +[[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] + +query ??? +select array_prepend(arrow_cast([1], 'FixedSizeList(1, Int64)'), arrow_cast([[1], [2], [3]], 'FixedSizeList(3, FixedSizeList(1, Int64))')), + array_prepend(arrow_cast([1.0], 'FixedSizeList(1, Float64)'), arrow_cast([[2.0], [3.0], [4.0]], 'FixedSizeList(3, FixedSizeList(1, Float64))')), + array_prepend(arrow_cast(['h'], 'FixedSizeList(1, Utf8)'), arrow_cast([['e'], ['l'], ['l'], ['o']], 'FixedSizeList(4, FixedSizeList(1, Utf8))')); +---- +[[1], [1], [2], [3]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] + # array_prepend with columns #1 query ? select array_prepend(column2, column1) from arrays_values; @@ -3563,6 +3585,17 @@ select list_replace( ---- [1, 3, 3, 4] [1, 0, 4, 5, 4, 6, 7] [1, 2, 3] +# array_replace scalar function #4 (null input) +query ? +select array_replace(make_array(1, 2, 3, 4, 5), NULL, NULL); +---- +[1, 2, 3, 4, 5] + +query ? +select array_replace(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL, NULL); +---- +[1, 2, 3, 4, 5] + # array_replace scalar function with columns #1 query ? select array_replace(column1, column2, column3) from arrays_with_repeating_elements; @@ -3728,6 +3761,17 @@ select ---- [1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] +# array_replace_n scalar function #4 (null input) +query ? +select array_replace_n(make_array(1, 2, 3, 4, 5), NULL, NULL, NULL); +---- +[1, 2, 3, 4, 5] + +query ? +select array_replace_n(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL, NULL, NULL); +---- +[1, 2, 3, 4, 5] + # array_replace_n scalar function with columns #1 query ? select @@ -3904,6 +3948,17 @@ select ---- [1, 3, 3, 4] [1, 0, 0, 5, 0, 6, 7] [1, 2, 3] +# array_replace_all scalar function #4 (null input) +query ? +select array_replace_all(make_array(1, 2, 3, 4, 5), NULL, NULL); +---- +[1, 2, 3, 4, 5] + +query ? +select array_replace_all(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL, NULL); +---- +[1, 2, 3, 4, 5] + # array_replace_all scalar function with columns #1 query ? select diff --git a/datafusion/sqllogictest/test_files/coalesce.slt b/datafusion/sqllogictest/test_files/coalesce.slt index a624d1def980..5f2d2f0d1da9 100644 --- a/datafusion/sqllogictest/test_files/coalesce.slt +++ b/datafusion/sqllogictest/test_files/coalesce.slt @@ -438,3 +438,8 @@ Date32 statement ok drop table test + +query T +select coalesce(arrow_cast('', 'Utf8View'), arrow_cast('', 'Dictionary(UInt32, Utf8)')); +---- +(empty) \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/copy.slt b/datafusion/sqllogictest/test_files/copy.slt index cd0a38a5e007..7dd85b3ae2d8 100644 --- a/datafusion/sqllogictest/test_files/copy.slt +++ b/datafusion/sqllogictest/test_files/copy.slt @@ -558,7 +558,6 @@ select * from validate_arrow_file_dict; c foo d bar - # Copy from table to folder of json query I COPY source_table to 'test_files/scratch/copy/table_arrow' STORED AS ARROW; @@ -632,3 +631,4 @@ COPY source_table to '/tmp/table.parquet' (row_group_size 55 + 102); # Copy using execution.keep_partition_by_columns with an invalid value query error DataFusion error: Invalid or Unsupported Configuration: provided value for 'execution.keep_partition_by_columns' was not recognized: "invalid_value" COPY source_table to '/tmp/table.parquet' OPTIONS (execution.keep_partition_by_columns invalid_value); + diff --git a/datafusion/sqllogictest/test_files/ddl.slt b/datafusion/sqllogictest/test_files/ddl.slt index 5e229075273d..aefc2672b539 100644 --- a/datafusion/sqllogictest/test_files/ddl.slt +++ b/datafusion/sqllogictest/test_files/ddl.slt @@ -637,46 +637,6 @@ select * from table_without_values; statement ok set datafusion.catalog.information_schema = true; -statement ok -CREATE OR REPLACE TABLE TABLE_WITH_NORMALIZATION(FIELD1 BIGINT, FIELD2 BIGINT); - -# Check table name is in lowercase -query TTTT -show create table table_with_normalization ----- -datafusion public table_with_normalization NULL - -# Check column name is in uppercase -query TTT -describe table_with_normalization ----- -field1 Int64 YES -field2 Int64 YES - -# Disable ident normalization -statement ok -set datafusion.sql_parser.enable_ident_normalization = false; - -statement ok -CREATE TABLE TABLE_WITHOUT_NORMALIZATION(FIELD1 BIGINT, FIELD2 BIGINT) AS VALUES (1,2); - -# Check table name is in uppercase -query TTTT -show create table TABLE_WITHOUT_NORMALIZATION ----- -datafusion public TABLE_WITHOUT_NORMALIZATION NULL - -# Check column name is in uppercase -query TTT -describe TABLE_WITHOUT_NORMALIZATION ----- -FIELD1 Int64 YES -FIELD2 Int64 YES - -statement ok -set datafusion.sql_parser.enable_ident_normalization = true; - - statement ok create table foo(x int); @@ -840,3 +800,31 @@ DROP TABLE t1; statement ok DROP TABLE t2; + +# Test memory table fields with correct nullable +statement ok +CREATE or replace TABLE table_with_pk ( + sn INT PRIMARY KEY NOT NULL, + ts TIMESTAMP WITH TIME ZONE NOT NULL, + currency VARCHAR(3) NOT NULL, + amount FLOAT + ) as VALUES + (0, '2022-01-01 06:00:00Z'::timestamp, 'EUR', 30.0), + (1, '2022-01-01 08:00:00Z'::timestamp, 'EUR', 50.0), + (2, '2022-01-01 11:30:00Z'::timestamp, 'TRY', 75.0), + (3, '2022-01-02 12:00:00Z'::timestamp, 'EUR', 200.0); + +query TTTTTT +show columns FROM table_with_pk; +---- +datafusion public table_with_pk sn Int32 NO +datafusion public table_with_pk ts Timestamp(Nanosecond, Some("+00:00")) NO +datafusion public table_with_pk currency Utf8 NO +datafusion public table_with_pk amount Float32 YES + +statement ok +drop table table_with_pk; + +statement ok +set datafusion.catalog.information_schema = false; + diff --git a/datafusion/sqllogictest/test_files/decimal.slt b/datafusion/sqllogictest/test_files/decimal.slt index f082a79c5508..089910785ad9 100644 --- a/datafusion/sqllogictest/test_files/decimal.slt +++ b/datafusion/sqllogictest/test_files/decimal.slt @@ -740,3 +740,10 @@ query R SELECT CAST('0' AS decimal(38,0)); ---- 0 + +query RR +SELECT + cast(cast('0' as decimal(3,0)) as decimal(2,0)), + cast(cast('5.20' as decimal(4,2)) as decimal(3,2)) +---- +0 5.2 diff --git a/datafusion/sqllogictest/test_files/errors.slt b/datafusion/sqllogictest/test_files/errors.slt index 5a94ba9c0583..a35a4d6f28dc 100644 --- a/datafusion/sqllogictest/test_files/errors.slt +++ b/datafusion/sqllogictest/test_files/errors.slt @@ -169,4 +169,19 @@ statement ok INSERT INTO tab0 VALUES(83,0,38); query error DataFusion error: Arrow error: Divide by zero error -SELECT DISTINCT - 84 FROM tab0 AS cor0 WHERE NOT + 96 / + col1 <= NULL GROUP BY col1, col0; \ No newline at end of file +SELECT DISTINCT - 84 FROM tab0 AS cor0 WHERE NOT + 96 / + col1 <= NULL GROUP BY col1, col0; + +statement ok +create table a(timestamp int, birthday int, ts int, tokens int, amp int, staamp int); + +query error DataFusion error: Schema error: No field named timetamp\. Did you mean 'a\.timestamp'\?\. +select timetamp from a; + +query error DataFusion error: Schema error: No field named dadsada\. Valid fields are a\.timestamp, a\.birthday, a\.ts, a\.tokens, a\.amp, a\.staamp\. +select dadsada from a; + +query error DataFusion error: Schema error: No field named ammp\. Did you mean 'a\.amp'\?\. +select ammp from a; + +statement ok +drop table a; \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index a0264c43622f..7980b180ae68 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -324,6 +324,16 @@ SELECT ascii('x') ---- 120 +query I +SELECT ascii('222') +---- +50 + +query I +SELECT ascii('0xa') +---- +48 + query I SELECT ascii(NULL) ---- @@ -571,7 +581,7 @@ select repeat('-1.2', arrow_cast(3, 'Int32')); ---- -1.2-1.2-1.2 -query error DataFusion error: Error during planning: Internal error: Function 'repeat' expects TypeSignatureClass::Native\(LogicalType\(Native\(Int64\), Int64\)\) but received Float64 +query error DataFusion error: Error during planning: Internal error: Expect TypeSignatureClass::Native\(LogicalType\(Native\(Int64\), Int64\)\) but received NativeType::Float64, DataType: Float64 select repeat('-1.2', 3.2); query T diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index b9699dfd5c06..de1dbf74c29b 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -720,6 +720,14 @@ select count(distinct u) from uuid_table; ---- 2 +# must be valid uuidv4 format +query B +SELECT REGEXP_LIKE(uuid(), + '^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$') + AS is_valid; +---- +true + statement ok drop table uuid_table diff --git a/datafusion/sqllogictest/test_files/ident_normalization.slt b/datafusion/sqllogictest/test_files/ident_normalization.slt new file mode 100644 index 000000000000..996093c3ad9c --- /dev/null +++ b/datafusion/sqllogictest/test_files/ident_normalization.slt @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Enable information_schema, so we can execute show create table +statement ok +set datafusion.catalog.information_schema = true; + +# Check ident normalization is enabled by default + +statement ok +CREATE OR REPLACE TABLE TABLE_WITH_NORMALIZATION(FIELD1 BIGINT, FIELD2 BIGINT); + +# Check table name is in lowercase +query TTTT +show create table table_with_normalization +---- +datafusion public table_with_normalization NULL + +# Check column name is in uppercase +query TTT +describe table_with_normalization +---- +field1 Int64 YES +field2 Int64 YES + +# Disable ident normalization +statement ok +set datafusion.sql_parser.enable_ident_normalization = false; + +statement ok +CREATE TABLE TABLE_WITHOUT_NORMALIZATION(FIELD1 BIGINT, FIELD2 BIGINT) AS VALUES (1,2); + +# Check table name is in uppercase +query TTTT +show create table TABLE_WITHOUT_NORMALIZATION +---- +datafusion public TABLE_WITHOUT_NORMALIZATION NULL + +# Check column name is in uppercase +query TTT +describe TABLE_WITHOUT_NORMALIZATION +---- +FIELD1 Int64 YES +FIELD2 Int64 YES + +statement ok +DROP TABLE TABLE_WITHOUT_NORMALIZATION + +############ +## Column Name Normalization +############ + +# Table x (lowercase) with a column named "A" (uppercase) +statement ok +create table x as select 1 "A" + +query TTT +describe x +---- +A Int64 NO + +# Expect error as 'a' is not a column -- "A" is and the identifiers +# are not normalized +query error DataFusion error: Schema error: No field named a\. Valid fields are x\."A"\. +select a from x; + +# should work (note the uppercase 'A') +query I +select A from x; +---- +1 + +statement ok +drop table x; + +############ +## Table Name Normalization +############ + +# Table Y (uppercase) with a column named a (lower case) +statement ok +create table Y as select 1 a; + +query TTT +describe Y +---- +a Int64 NO + +# Expect error as y is not a a table -- "Y" is +query error DataFusion error: Error during planning: table 'datafusion\.public\.y' not found +select * from y; + +# should work (note the uppercase 'Y') +query I +select * from Y; +---- +1 + +statement ok +drop table Y; + +############ +## Function Name Normalization +############ + +## Check function names are still normalized even though column names are not +query I +SELECT length('str'); +---- +3 + +query I +SELECT LENGTH('str'); +---- +3 + +query T +SELECT CONCAT('Hello', 'World') +---- +HelloWorld diff --git a/datafusion/sqllogictest/test_files/identifiers.slt b/datafusion/sqllogictest/test_files/identifiers.slt index 755d617e7a2a..e5eec3bf7f2c 100644 --- a/datafusion/sqllogictest/test_files/identifiers.slt +++ b/datafusion/sqllogictest/test_files/identifiers.slt @@ -90,16 +90,16 @@ drop table case_insensitive_test statement ok CREATE TABLE test("Column1" string) AS VALUES ('content1'); -statement error DataFusion error: Schema error: No field named column1\. Valid fields are test\."Column1"\. +statement error DataFusion error: Schema error: No field named column1\. Did you mean 'test\.Column1'\?\. SELECT COLumn1 from test -statement error DataFusion error: Schema error: No field named column1\. Valid fields are test\."Column1"\. +statement error DataFusion error: Schema error: No field named column1\. Did you mean 'test\.Column1'\?\. SELECT Column1 from test -statement error DataFusion error: Schema error: No field named column1\. Valid fields are test\."Column1"\. +statement error DataFusion error: Schema error: No field named column1\. Did you mean 'test\.Column1'\?\. SELECT column1 from test -statement error DataFusion error: Schema error: No field named column1\. Valid fields are test\."Column1"\. +statement error DataFusion error: Schema error: No field named column1\. Did you mean 'test\.Column1'\?\. SELECT "column1" from test statement ok diff --git a/datafusion/sqllogictest/test_files/insert.slt b/datafusion/sqllogictest/test_files/insert.slt index cbc989841ab3..ee76ee1c5511 100644 --- a/datafusion/sqllogictest/test_files/insert.slt +++ b/datafusion/sqllogictest/test_files/insert.slt @@ -296,8 +296,11 @@ insert into table_without_values(field1) values(3); 1 # insert NULL values for the missing column (field1), but column is non-nullable -statement error Execution error: Invalid batch column at '0' has null but schema specifies non-nullable +statement error insert into table_without_values(field2) values(300); +---- +DataFusion error: Execution error: Invalid batch column at '0' has null but schema specifies non-nullable + statement error Invalid argument error: Column 'column1' is declared as non-nullable but contains null values insert into table_without_values values(NULL, 300); @@ -358,7 +361,7 @@ statement ok create table test_column_defaults( a int, b int not null default null, - c int default 100*2+300, + c int default 100*2+300, d text default lower('DEFAULT_TEXT'), e timestamp default now() ) @@ -368,8 +371,11 @@ insert into test_column_defaults values(1, 10, 100, 'ABC', now()) ---- 1 -statement error DataFusion error: Execution error: Invalid batch column at '1' has null but schema specifies non-nullable +statement error insert into test_column_defaults(a) values(2) +---- +DataFusion error: Execution error: Invalid batch column at '1' has null but schema specifies non-nullable + query I insert into test_column_defaults(b) values(20) @@ -412,7 +418,7 @@ statement ok create table test_column_defaults( a int, b int not null default null, - c int default 100*2+300, + c int default 100*2+300, d text default lower('DEFAULT_TEXT'), e timestamp default now() ) as values(1, 10, 100, 'ABC', now()) diff --git a/datafusion/sqllogictest/test_files/insert_to_external.slt b/datafusion/sqllogictest/test_files/insert_to_external.slt index c5fa2b4e1a51..ee1d67c5e26d 100644 --- a/datafusion/sqllogictest/test_files/insert_to_external.slt +++ b/datafusion/sqllogictest/test_files/insert_to_external.slt @@ -60,6 +60,7 @@ STORED AS parquet LOCATION 'test_files/scratch/insert_to_external/parquet_types_partitioned/' PARTITIONED BY (b); +#query error here because PARTITIONED BY (b) will make the b nullable to false query I insert into dictionary_encoded_parquet_partitioned select * from dictionary_encoded_values @@ -81,6 +82,7 @@ STORED AS arrow LOCATION 'test_files/scratch/insert_to_external/arrow_dict_partitioned/' PARTITIONED BY (b); +#query error here because PARTITIONED BY (b) will make the b nullable to false query I insert into dictionary_encoded_arrow_partitioned select * from dictionary_encoded_values @@ -543,8 +545,11 @@ insert into table_without_values(field1) values(3); 1 # insert NULL values for the missing column (field1), but column is non-nullable -statement error Execution error: Invalid batch column at '0' has null but schema specifies non-nullable +statement error insert into table_without_values(field2) values(300); +---- +DataFusion error: Execution error: Invalid batch column at '0' has null but schema specifies non-nullable + statement error Invalid argument error: Column 'column1' is declared as non-nullable but contains null values insert into table_without_values values(NULL, 300); @@ -581,8 +586,11 @@ insert into test_column_defaults values(1, 10, 100, 'ABC', now()) ---- 1 -statement error DataFusion error: Execution error: Invalid batch column at '1' has null but schema specifies non-nullable +statement error insert into test_column_defaults(a) values(2) +---- +DataFusion error: Execution error: Invalid batch column at '1' has null but schema specifies non-nullable + query I insert into test_column_defaults(b) values(20) diff --git a/datafusion/sqllogictest/test_files/join.slt.part b/datafusion/sqllogictest/test_files/join.slt.part index c88f419a9cb2..21126a747967 100644 --- a/datafusion/sqllogictest/test_files/join.slt.part +++ b/datafusion/sqllogictest/test_files/join.slt.part @@ -94,7 +94,7 @@ statement ok set datafusion.execution.batch_size = 4096; # left semi with wrong where clause -query error DataFusion error: Schema error: No field named t2\.t2_id\. Valid fields are t1\.t1_id, t1\.t1_name, t1\.t1_int\. +query error DataFusion error: Schema error: No field named t2\.t2_id\. Did you mean 't1\.t1_id'\?\. SELECT t1.t1_id, t1.t1_name, t1.t1_int FROM t1 LEFT SEMI JOIN t2 ON t1.t1_id = t2.t2_id @@ -1312,3 +1312,78 @@ SELECT a+b*2, statement ok drop table t1; + +# Test that equivalent classes are projected correctly. + +statement ok +create table pairs(x int, y int) as values (1,1), (2,2), (3,3); + +statement ok +create table f(a int) as values (1), (2), (3); + +statement ok +create table s(b int) as values (1), (2), (3); + +statement ok +set datafusion.optimizer.repartition_joins = true; + +statement ok +set datafusion.execution.target_partitions = 16; + +# After the filter applying (x = y) we can join by both x and y, +# partitioning only once. + +query TT +explain +SELECT * FROM +(SELECT x+1 AS col0, y+1 AS col1 FROM PAIRS WHERE x == y) +JOIN f +ON col0 = f.a +JOIN s +ON col1 = s.b +---- +logical_plan +01)Inner Join: col1 = CAST(s.b AS Int64) +02)--Inner Join: col0 = CAST(f.a AS Int64) +03)----Projection: CAST(pairs.x AS Int64) + Int64(1) AS col0, CAST(pairs.y AS Int64) + Int64(1) AS col1 +04)------Filter: pairs.y = pairs.x +05)--------TableScan: pairs projection=[x, y] +06)----TableScan: f projection=[a] +07)--TableScan: s projection=[b] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col1@1, CAST(s.b AS Int64)@1)], projection=[col0@0, col1@1, a@2, b@3] +03)----ProjectionExec: expr=[col0@1 as col0, col1@2 as col1, a@0 as a] +04)------CoalesceBatchesExec: target_batch_size=8192 +05)--------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(CAST(f.a AS Int64)@1, col0@0)], projection=[a@0, col0@2, col1@3] +06)----------CoalesceBatchesExec: target_batch_size=8192 +07)------------RepartitionExec: partitioning=Hash([CAST(f.a AS Int64)@1], 16), input_partitions=1 +08)--------------ProjectionExec: expr=[a@0 as a, CAST(a@0 AS Int64) as CAST(f.a AS Int64)] +09)----------------DataSourceExec: partitions=1, partition_sizes=[1] +10)----------CoalesceBatchesExec: target_batch_size=8192 +11)------------RepartitionExec: partitioning=Hash([col0@0], 16), input_partitions=16 +12)--------------ProjectionExec: expr=[CAST(x@0 AS Int64) + 1 as col0, CAST(y@1 AS Int64) + 1 as col1] +13)----------------RepartitionExec: partitioning=RoundRobinBatch(16), input_partitions=1 +14)------------------CoalesceBatchesExec: target_batch_size=8192 +15)--------------------FilterExec: y@1 = x@0 +16)----------------------DataSourceExec: partitions=1, partition_sizes=[1] +17)----CoalesceBatchesExec: target_batch_size=8192 +18)------RepartitionExec: partitioning=Hash([CAST(s.b AS Int64)@1], 16), input_partitions=1 +19)--------ProjectionExec: expr=[b@0 as b, CAST(b@0 AS Int64) as CAST(s.b AS Int64)] +20)----------DataSourceExec: partitions=1, partition_sizes=[1] + +statement ok +drop table pairs; + +statement ok +drop table f; + +statement ok +drop table s; + +# Reset the configs to old values. +statement ok +set datafusion.execution.target_partitions = 4; + +statement ok +set datafusion.optimizer.repartition_joins = false; diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index 29ef506aa070..42a4ba621801 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -188,7 +188,7 @@ SELECT MAP([[1,2], [3,4]], ['a', 'b']); query error SELECT MAP() -query error DataFusion error: Execution error: map requires exactly 2 arguments, got 1 instead +query error DataFusion error: Execution error: map function requires 2 arguments, got 1 SELECT MAP(['POST', 'HEAD']) query error DataFusion error: Execution error: Expected list, large_list or fixed_size_list, got Null @@ -592,6 +592,43 @@ select map_extract(column1, 1), map_extract(column1, 5), map_extract(column1, 7) [NULL] [NULL] [[1, NULL, 3]] [NULL] [NULL] [NULL] +query ? +select column1[1] from map_array_table_1; +---- +[1, NULL, 3] +NULL +NULL +NULL + +query ? +select column1[-1000 + 1001] from map_array_table_1; +---- +[1, NULL, 3] +NULL +NULL +NULL + +# test for negative scenario +query ? +SELECT column1[-1] FROM map_array_table_1; +---- +NULL +NULL +NULL +NULL + +query ? +SELECT column1[1000] FROM map_array_table_1; +---- +NULL +NULL +NULL +NULL + + +query error DataFusion error: Arrow error: Invalid argument error +SELECT column1[NULL] FROM map_array_table_1; + query ??? select map_extract(column1, column2), map_extract(column1, column3), map_extract(column1, column4) from map_array_table_1; ---- @@ -722,3 +759,28 @@ drop table map_array_table_1; statement ok drop table map_array_table_2; + + +statement ok +create table tt as values(MAP{[1,2,3]:1}, MAP {{'a':1, 'b':2}:2}, MAP{true: 3}); + +# accessing using an array +query I +select column1[make_array(1, 2, 3)] from tt; +---- +1 + +# accessing using a struct +query I +select column2[{a:1, b: 2}] from tt; +---- +2 + +# accessing using Bool +query I +select column3[true] from tt; +---- +3 + +statement ok +drop table tt; diff --git a/datafusion/sqllogictest/test_files/regexp.slt b/datafusion/sqllogictest/test_files/regexp.slt index ce39434e6827..44ba61e877d9 100644 --- a/datafusion/sqllogictest/test_files/regexp.slt +++ b/datafusion/sqllogictest/test_files/regexp.slt @@ -477,8 +477,8 @@ SELECT 'foo\nbar\nbaz' ~ 'bar'; true statement error -Error during planning: Cannot infer common argument type for regex operation List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata -: {} }) ~ List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) +Error during planning: Cannot infer common argument type for regex operation List(Field { name: "item", data_type: Int64, nullable: true, dict_is_ordered: false, metadata +: {} }) ~ List(Field { name: "item", data_type: Int64, nullable: true, dict_is_ordered: false, metadata: {} }) select [1,2] ~ [3]; query B diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index a6826a6ef108..66413775b393 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -1807,24 +1807,6 @@ SELECT acos(); statement error SELECT isnan(); -# turn off enable_ident_normalization -statement ok -set datafusion.sql_parser.enable_ident_normalization = false; - -query I -SELECT LENGTH('str'); ----- -3 - -query T -SELECT CONCAT('Hello', 'World') ----- -HelloWorld - -# turn on enable_ident_normalization -statement ok -set datafusion.sql_parser.enable_ident_normalization = true; - query I SELECT LENGTH('str'); ---- diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index b0c9ad93e155..264392fc1017 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -834,7 +834,7 @@ query TT explain SELECT t1_id, (SELECT count(*) as _cnt FROM t2 WHERE t2.t2_int = t1.t1_int) as cnt from t1 ---- logical_plan -01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) AS _cnt ELSE __scalar_sq_1._cnt END AS cnt +01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1._cnt END AS cnt 02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int 03)----TableScan: t1 projection=[t1_id, t1_int] 04)----SubqueryAlias: __scalar_sq_1 @@ -855,7 +855,7 @@ query TT explain SELECT t1_id, (SELECT count(*) + 2 as _cnt FROM t2 WHERE t2.t2_int = t1.t1_int) from t1 ---- logical_plan -01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) AS _cnt ELSE __scalar_sq_1._cnt END AS _cnt +01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) ELSE __scalar_sq_1._cnt END AS _cnt 02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int 03)----TableScan: t1 projection=[t1_id, t1_int] 04)----SubqueryAlias: __scalar_sq_1 @@ -922,7 +922,7 @@ query TT explain SELECT t1_id, (SELECT count(*) + 2 as cnt_plus_2 FROM t2 WHERE t2.t2_int = t1.t1_int having count(*) = 0) from t1 ---- logical_plan -01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) AS cnt_plus_2 WHEN __scalar_sq_1.count(*) != Int64(0) THEN NULL ELSE __scalar_sq_1.cnt_plus_2 END AS cnt_plus_2 +01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) WHEN __scalar_sq_1.count(*) != Int64(0) THEN NULL ELSE __scalar_sq_1.cnt_plus_2 END AS cnt_plus_2 02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int 03)----TableScan: t1 projection=[t1_id, t1_int] 04)----SubqueryAlias: __scalar_sq_1 diff --git a/datafusion/sqllogictest/test_files/union_by_name.slt b/datafusion/sqllogictest/test_files/union_by_name.slt new file mode 100644 index 000000000000..0ba4c32ee5be --- /dev/null +++ b/datafusion/sqllogictest/test_files/union_by_name.slt @@ -0,0 +1,288 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Portions of this file are derived from DuckDB and are licensed +# under the MIT License (see below). + +# Copyright 2018-2025 Stichting DuckDB Foundation + +# Permission is hereby granted, free of charge, to any person +# obtaining a copy of this software and associated documentation +# files (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: + +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +statement ok +CREATE TABLE t1 (x INT, y INT); + +statement ok +INSERT INTO t1 VALUES (3, 3), (3, 3), (1, 1); + +statement ok +CREATE TABLE t2 (y INT, z INT); + +statement ok +INSERT INTO t2 VALUES (2, 2), (4, 4); + + +# Test binding +query I +SELECT t1.x FROM t1 UNION BY NAME SELECT x FROM t1 ORDER BY t1.x; +---- +1 +3 + +query I +SELECT t1.x FROM t1 UNION ALL BY NAME SELECT x FROM t1 ORDER BY t1.x; +---- +1 +1 +3 +3 +3 +3 + +query I +SELECT x FROM t1 UNION BY NAME SELECT x FROM t1 ORDER BY t1.x; +---- +1 +3 + +query I +SELECT x FROM t1 UNION ALL BY NAME SELECT x FROM t1 ORDER BY t1.x; +---- +1 +1 +3 +3 +3 +3 + +query II +(SELECT x FROM t1 UNION ALL SELECT x FROM t1) UNION BY NAME SELECT 5 ORDER BY x; +---- +NULL 1 +NULL 3 +5 NULL + +# TODO: This should pass, but the sanity checker isn't allowing it. +# Commenting out the ordering check in the sanity checker produces the correct result. +query error +(SELECT x FROM t1 UNION ALL SELECT x FROM t1) UNION ALL BY NAME SELECT 5 ORDER BY x; +---- +DataFusion error: SanityCheckPlan +caused by +Error during planning: Plan: ["SortPreservingMergeExec: [x@1 ASC NULLS LAST]", " UnionExec", " SortExec: expr=[x@1 ASC NULLS LAST], preserve_partitioning=[true]", " ProjectionExec: expr=[NULL as Int64(5), x@0 as x]", " UnionExec", " DataSourceExec: partitions=1, partition_sizes=[1]", " DataSourceExec: partitions=1, partition_sizes=[1]", " ProjectionExec: expr=[5 as Int64(5), NULL as x]", " PlaceholderRowExec"] does not satisfy order requirements: [x@1 ASC NULLS LAST]. Child-0 order: [] + + +query II +(SELECT x FROM t1 UNION ALL SELECT y FROM t1) UNION BY NAME SELECT 5 ORDER BY x; +---- +NULL 1 +NULL 3 +5 NULL + +# TODO: This should pass, but the sanity checker isn't allowing it. +# Commenting out the ordering check in the sanity checker produces the correct result. +query error +(SELECT x FROM t1 UNION ALL SELECT y FROM t1) UNION ALL BY NAME SELECT 5 ORDER BY x; +---- +DataFusion error: SanityCheckPlan +caused by +Error during planning: Plan: ["SortPreservingMergeExec: [x@1 ASC NULLS LAST]", " UnionExec", " SortExec: expr=[x@1 ASC NULLS LAST], preserve_partitioning=[true]", " ProjectionExec: expr=[NULL as Int64(5), x@0 as x]", " UnionExec", " DataSourceExec: partitions=1, partition_sizes=[1]", " ProjectionExec: expr=[y@0 as x]", " DataSourceExec: partitions=1, partition_sizes=[1]", " ProjectionExec: expr=[5 as Int64(5), NULL as x]", " PlaceholderRowExec"] does not satisfy order requirements: [x@1 ASC NULLS LAST]. Child-0 order: [] + + + +# Ambiguous name + +statement error DataFusion error: Schema error: No field named t1.x. Valid fields are a, b. +SELECT x AS a FROM t1 UNION BY NAME SELECT x AS b FROM t1 ORDER BY t1.x; + +query II +(SELECT y FROM t1 UNION ALL SELECT x FROM t1) UNION BY NAME (SELECT z FROM t2 UNION ALL SELECT y FROM t2) ORDER BY y, z; +---- +1 NULL +3 NULL +NULL 2 +NULL 4 + +query II +(SELECT y FROM t1 UNION ALL SELECT x FROM t1) UNION ALL BY NAME (SELECT z FROM t2 UNION ALL SELECT y FROM t2) ORDER BY y, z; +---- +1 NULL +1 NULL +3 NULL +3 NULL +3 NULL +3 NULL +NULL 2 +NULL 2 +NULL 4 +NULL 4 + +# Limit + +query III +SELECT 1 UNION BY NAME SELECT * FROM unnest(range(2, 100)) UNION BY NAME SELECT 999 ORDER BY 3, 1 LIMIT 5; +---- +NULL NULL 2 +NULL NULL 3 +NULL NULL 4 +NULL NULL 5 +NULL NULL 6 + +# TODO: This should pass, but the sanity checker isn't allowing it. +# Commenting out the ordering check in the sanity checker produces the correct result. +query error +SELECT 1 UNION ALL BY NAME SELECT * FROM unnest(range(2, 100)) UNION ALL BY NAME SELECT 999 ORDER BY 3, 1 LIMIT 5; +---- +DataFusion error: SanityCheckPlan +caused by +Error during planning: Plan: ["SortPreservingMergeExec: [UNNEST(range(Int64(2),Int64(100)))@2 ASC NULLS LAST, Int64(1)@0 ASC NULLS LAST], fetch=5", " UnionExec", " SortExec: TopK(fetch=5), expr=[UNNEST(range(Int64(2),Int64(100)))@2 ASC NULLS LAST], preserve_partitioning=[true]", " ProjectionExec: expr=[Int64(1)@0 as Int64(1), NULL as Int64(999), UNNEST(range(Int64(2),Int64(100)))@1 as UNNEST(range(Int64(2),Int64(100)))]", " UnionExec", " ProjectionExec: expr=[1 as Int64(1), NULL as UNNEST(range(Int64(2),Int64(100)))]", " PlaceholderRowExec", " ProjectionExec: expr=[NULL as Int64(1), __unnest_placeholder(range(Int64(2),Int64(100)),depth=1)@0 as UNNEST(range(Int64(2),Int64(100)))]", " UnnestExec", " ProjectionExec: expr=[[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99] as __unnest_placeholder(range(Int64(2),Int64(100)))]", " PlaceholderRowExec", " ProjectionExec: expr=[NULL as Int64(1), 999 as Int64(999), NULL as UNNEST(range(Int64(2),Int64(100)))]", " PlaceholderRowExec"] does not satisfy order requirements: [UNNEST(range(Int64(2),Int64(100)))@2 ASC NULLS LAST, Int64(1)@0 ASC NULLS LAST]. Child-0 order: [] + + +# Order by + +query III +SELECT x, y FROM t1 UNION BY NAME SELECT y, z FROM t2 ORDER BY y; +---- +1 1 NULL +NULL 2 2 +3 3 NULL +NULL 4 4 + +query III +SELECT x, y FROM t1 UNION ALL BY NAME SELECT y, z FROM t2 ORDER BY y; +---- +1 1 NULL +NULL 2 2 +3 3 NULL +3 3 NULL +NULL 4 4 + +query III +SELECT x, y FROM t1 UNION BY NAME SELECT y, z FROM t2 ORDER BY 3, 1; +---- +NULL 2 2 +NULL 4 4 +1 1 NULL +3 3 NULL + +query III +SELECT x, y FROM t1 UNION ALL BY NAME SELECT y, z FROM t2 ORDER BY 3, 1; +---- +NULL 2 2 +NULL 4 4 +1 1 NULL +3 3 NULL +3 3 NULL + +statement error +SELECT x, y FROM t1 UNION BY NAME SELECT y, z FROM t2 ORDER BY 4; +---- +DataFusion error: Error during planning: Order by column out of bounds, specified: 4, max: 3 + + +statement error +SELECT x, y FROM t1 UNION ALL BY NAME SELECT y, z FROM t2 ORDER BY 4; +---- +DataFusion error: Error during planning: Order by column out of bounds, specified: 4, max: 3 + + +# Multi set operations + +query IIII rowsort +(SELECT 1 UNION BY NAME SELECT x, y FROM t1) UNION BY NAME SELECT y, z FROM t2; +---- +1 NULL NULL NULL +NULL 1 1 NULL +NULL 3 3 NULL +NULL NULL 2 2 +NULL NULL 4 4 + +query IIII rowsort +(SELECT 1 UNION ALL BY NAME SELECT x, y FROM t1) UNION ALL BY NAME SELECT y, z FROM t2; +---- +1 NULL NULL NULL +NULL 1 1 NULL +NULL 3 3 NULL +NULL 3 3 NULL +NULL NULL 2 2 +NULL NULL 4 4 + +query III +SELECT x, y FROM t1 UNION BY NAME (SELECT y, z FROM t2 INTERSECT SELECT 2, 2 as two FROM t1 ORDER BY 1) ORDER BY 1; +---- +1 1 NULL +3 3 NULL +NULL 2 2 + +query III +SELECT x, y FROM t1 UNION ALL BY NAME (SELECT y, z FROM t2 INTERSECT SELECT 2, 2 as two FROM t1 ORDER BY 1) ORDER BY 1; +---- +1 1 NULL +3 3 NULL +3 3 NULL +NULL 2 2 + +query III +(SELECT x, y FROM t1 UNION BY NAME SELECT y, z FROM t2 ORDER BY 1) EXCEPT SELECT NULL, 2, 2 as two FROM t1 ORDER BY 1; +---- +1 1 NULL +3 3 NULL +NULL 4 4 + +# Alias in select list + +query II +SELECT x as a FROM t1 UNION BY NAME SELECT x FROM t1 ORDER BY 1, 2; +---- +1 NULL +3 NULL +NULL 1 +NULL 3 + +query II +SELECT x as a FROM t1 UNION ALL BY NAME SELECT x FROM t1 ORDER BY 1, 2; +---- +1 NULL +3 NULL +3 NULL +NULL 1 +NULL 3 +NULL 3 + +# Different types + +query T rowsort +SELECT '0' as c UNION ALL BY NAME SELECT 0 as c; +---- +0 +0 diff --git a/datafusion/sqllogictest/test_files/union_function.slt b/datafusion/sqllogictest/test_files/union_function.slt new file mode 100644 index 000000000000..9c70b1011f58 --- /dev/null +++ b/datafusion/sqllogictest/test_files/union_function.slt @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## UNION DataType Tests +########## + +query ?I +select union_column, union_extract(union_column, 'int') from union_table; +---- +{int=1} 1 +{int=2} 2 + +query error DataFusion error: Execution error: field bool not found on union +select union_extract(union_column, 'bool') from union_table; + +query error DataFusion error: Error during planning: 'union_extract' does not support zero arguments +select union_extract() from union_table; + +query error DataFusion error: Error during planning: The function 'union_extract' expected 2 arguments but received 1 +select union_extract(union_column) from union_table; + +query error DataFusion error: Error during planning: The function 'union_extract' expected 2 arguments but received 1 +select union_extract('a') from union_table; + +query error DataFusion error: Execution error: union_extract first argument must be a union, got Utf8 instead +select union_extract('a', union_column) from union_table; + +query error DataFusion error: Execution error: union_extract second argument must be a non\-null string literal, got Int64 instead +select union_extract(union_column, 1) from union_table; + +query error DataFusion error: Error during planning: The function 'union_extract' expected 2 arguments but received 3 +select union_extract(union_column, 'a', 'b') from union_table; diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index f13d2b77a787..3e3ea7843ac9 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -41,6 +41,7 @@ pbjson-types = { workspace = true } prost = { workspace = true } substrait = { version = "0.53", features = ["serde"] } url = { workspace = true } +tokio = { workspace = true, features = ["fs"] } [dev-dependencies] datafusion = { workspace = true, features = ["nested_expressions"] } diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 89112e3fe84e..da8613781d69 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -27,7 +27,7 @@ use datafusion::common::{ substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef, TableReference, }; use datafusion::datasource::provider_as_source; -use datafusion::logical_expr::expr::{Exists, InSubquery, Sort}; +use datafusion::logical_expr::expr::{Exists, InSubquery, Sort, WindowFunctionParams}; use datafusion::logical_expr::{ Aggregate, BinaryExpr, Case, Cast, EmptyRelation, Expr, ExprSchemable, Extension, @@ -2223,12 +2223,19 @@ pub async fn from_window_function( Ok(Expr::WindowFunction(expr::WindowFunction { fun, - args: from_substrait_func_args(consumer, &window.arguments, input_schema).await?, - partition_by: from_substrait_rex_vec(consumer, &window.partitions, input_schema) + params: WindowFunctionParams { + args: from_substrait_func_args(consumer, &window.arguments, input_schema) + .await?, + partition_by: from_substrait_rex_vec( + consumer, + &window.partitions, + input_schema, + ) .await?, - order_by, - window_frame, - null_treatment: None, + order_by, + window_frame, + null_treatment: None, + }, })) } @@ -3361,7 +3368,7 @@ mod test { match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? { Expr::WindowFunction(window_function) => { - assert_eq!(window_function.order_by.len(), 1) + assert_eq!(window_function.params.order_by.len(), 1) } _ => panic!("expr was not a WindowFunction"), }; diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 42c226174932..36e89b8205ea 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -52,7 +52,8 @@ use datafusion::common::{ use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::SessionState; use datafusion::logical_expr::expr::{ - Alias, BinaryExpr, Case, Cast, GroupingSet, InList, InSubquery, WindowFunction, + AggregateFunctionParams, Alias, BinaryExpr, Case, Cast, GroupingSet, InList, + InSubquery, WindowFunction, WindowFunctionParams, }; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::Expr; @@ -368,7 +369,7 @@ pub trait SubstraitProducer: Send + Sync + Sized { } } -struct DefaultSubstraitProducer<'a> { +pub struct DefaultSubstraitProducer<'a> { extensions: Extensions, serializer_registry: &'a dyn SerializerRegistry, } @@ -1208,11 +1209,14 @@ pub fn from_aggregate_function( ) -> Result { let expr::AggregateFunction { func, - args, - distinct, - filter, - order_by, - null_treatment: _null_treatment, + params: + AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment: _null_treatment, + }, } = agg_fn; let sorts = if let Some(order_by) = order_by { order_by @@ -1612,11 +1616,14 @@ pub fn from_window_function( ) -> Result { let WindowFunction { fun, - args, - partition_by, - order_by, - window_frame, - null_treatment: _, + params: + WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment: _, + }, } = window_fn; // function reference let function_anchor = producer.register_function(fun.to_string()); diff --git a/datafusion/substrait/src/physical_plan/consumer.rs b/datafusion/substrait/src/physical_plan/consumer.rs index ce056ddac664..7bbdfc2a5d94 100644 --- a/datafusion/substrait/src/physical_plan/consumer.rs +++ b/datafusion/substrait/src/physical_plan/consumer.rs @@ -152,7 +152,7 @@ pub async fn from_substrait_rel( } } - Ok(base_config.new_exec() as Arc) + Ok(base_config.build() as Arc) } _ => not_impl_err!( "Only LocalFile reads are supported when parsing physical" diff --git a/datafusion/substrait/src/serializer.rs b/datafusion/substrait/src/serializer.rs index 4278671777fd..4a9e5d55ce05 100644 --- a/datafusion/substrait/src/serializer.rs +++ b/datafusion/substrait/src/serializer.rs @@ -22,42 +22,59 @@ use datafusion::error::Result; use datafusion::prelude::*; use prost::Message; +use std::path::Path; use substrait::proto::Plan; +use tokio::{ + fs::OpenOptions, + io::{AsyncReadExt, AsyncWriteExt}, +}; -use std::fs::OpenOptions; -use std::io::{Read, Write}; +/// Plans a sql and serializes the generated logical plan to bytes. +/// The bytes are then written into a file at `path`. +/// +/// Returns an error if the file already exists. +pub async fn serialize( + sql: &str, + ctx: &SessionContext, + path: impl AsRef, +) -> Result<()> { + let protobuf_out = serialize_bytes(sql, ctx).await?; -#[allow(clippy::suspicious_open_options)] -pub async fn serialize(sql: &str, ctx: &SessionContext, path: &str) -> Result<()> { - let protobuf_out = serialize_bytes(sql, ctx).await; - let mut file = OpenOptions::new().create(true).write(true).open(path)?; - file.write_all(&protobuf_out?)?; + let mut file = OpenOptions::new() + .write(true) + .create_new(true) + .open(path) + .await?; + file.write_all(&protobuf_out).await?; Ok(()) } +/// Plans a sql and serializes the generated logical plan to bytes. pub async fn serialize_bytes(sql: &str, ctx: &SessionContext) -> Result> { let df = ctx.sql(sql).await?; let plan = df.into_optimized_plan()?; let proto = producer::to_substrait_plan(&plan, &ctx.state())?; let mut protobuf_out = Vec::::new(); - proto.encode(&mut protobuf_out).map_err(|e| { - DataFusionError::Substrait(format!("Failed to encode substrait plan: {e}")) - })?; + proto + .encode(&mut protobuf_out) + .map_err(|e| DataFusionError::Substrait(format!("Failed to encode plan: {e}")))?; Ok(protobuf_out) } -pub async fn deserialize(path: &str) -> Result> { +/// Reads the file at `path` and deserializes a plan from the bytes. +pub async fn deserialize(path: impl AsRef) -> Result> { let mut protobuf_in = Vec::::new(); - let mut file = OpenOptions::new().read(true).open(path)?; + let mut file = OpenOptions::new().read(true).open(path).await?; + file.read_to_end(&mut protobuf_in).await?; - file.read_to_end(&mut protobuf_in)?; deserialize_bytes(protobuf_in).await } +/// Deserializes a plan from the bytes. pub async fn deserialize_bytes(proto_bytes: Vec) -> Result> { Ok(Box::new(Message::decode(&*proto_bytes).map_err(|e| { - DataFusionError::Substrait(format!("Failed to decode substrait plan: {e}")) + DataFusionError::Substrait(format!("Failed to decode plan: {e}")) })?)) } diff --git a/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs index 04c5e8ada758..f1284db2ad46 100644 --- a/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs @@ -49,7 +49,7 @@ async fn parquet_exec() -> Result<()> { 123, )], ]); - let parquet_exec: Arc = scan_config.new_exec(); + let parquet_exec: Arc = scan_config.build(); let mut extension_info: ( Vec, diff --git a/datafusion/substrait/tests/cases/serialize.rs b/datafusion/substrait/tests/cases/serialize.rs index e28c63312788..02089b9fa92d 100644 --- a/datafusion/substrait/tests/cases/serialize.rs +++ b/datafusion/substrait/tests/cases/serialize.rs @@ -17,6 +17,7 @@ #[cfg(test)] mod tests { + use datafusion::common::assert_contains; use datafusion::datasource::provider_as_source; use datafusion::logical_expr::LogicalPlanBuilder; use datafusion_substrait::logical_plan::consumer::from_substrait_plan; @@ -31,6 +32,25 @@ mod tests { use substrait::proto::rel_common::{Emit, EmitKind}; use substrait::proto::{rel, RelCommon}; + #[tokio::test] + async fn serialize_to_file() -> Result<()> { + let ctx = create_context().await?; + let path = "tests/serialize_to_file.bin"; + let sql = "SELECT a, b FROM data"; + + // Test case 1: serializing to a non-existing file should succeed. + serializer::serialize(sql, &ctx, path).await?; + serializer::deserialize(path).await?; + + // Test case 2: serializing to an existing file should fail. + let got = serializer::serialize(sql, &ctx, path).await.unwrap_err(); + assert_contains!(got.to_string(), "File exists"); + + fs::remove_file(path)?; + + Ok(()) + } + #[tokio::test] async fn serialize_simple_select() -> Result<()> { let ctx = create_context().await?; diff --git a/datafusion/wasmtest/Cargo.toml b/datafusion/wasmtest/Cargo.toml index aae66e6b9a97..7db051ad191f 100644 --- a/datafusion/wasmtest/Cargo.toml +++ b/datafusion/wasmtest/Cargo.toml @@ -43,25 +43,16 @@ chrono = { version = "0.4", features = ["wasmbind"] } # code size when deploying. console_error_panic_hook = { version = "0.1.1", optional = true } datafusion = { workspace = true } -datafusion-catalog = { workspace = true } datafusion-common = { workspace = true, default-features = true } -datafusion-common-runtime = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } -datafusion-expr-common = { workspace = true } -datafusion-functions = { workspace = true } -datafusion-functions-aggregate = { workspace = true } -datafusion-functions-aggregate-common = { workspace = true } -datafusion-functions-table = { workspace = true } datafusion-optimizer = { workspace = true, default-features = true } datafusion-physical-expr = { workspace = true, default-features = true } -datafusion-physical-expr-common = { workspace = true } datafusion-physical-plan = { workspace = true } datafusion-sql = { workspace = true } # getrandom must be compiled with js feature getrandom = { version = "0.2.8", features = ["js"] } -parquet = { workspace = true } wasm-bindgen = "0.2.99" wasm-bindgen-futures = "0.4.49" diff --git a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json index 37512e8278a7..65d8bdbb5e93 100644 --- a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json +++ b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json @@ -12,7 +12,7 @@ "datafusion-wasmtest": "../pkg" }, "devDependencies": { - "copy-webpack-plugin": "6.4.1", + "copy-webpack-plugin": "12.0.2", "webpack": "5.94.0", "webpack-cli": "5.1.4", "webpack-dev-server": "4.15.1" @@ -31,12 +31,6 @@ "node": ">=10.0.0" } }, - "node_modules/@gar/promisify": { - "version": "1.1.3", - "resolved": "https://registry.npmjs.org/@gar/promisify/-/promisify-1.1.3.tgz", - "integrity": "sha512-k2Ty1JcVojjJFwrg/ThKi2ujJ7XNLYaFGNB/bWT9wGR+oSMJHMa5w+CUq6p/pVrKeNNgA7pCqEcjSnHVoqJQFw==", - "dev": true - }, "node_modules/@jridgewell/gen-mapping": { "version": "0.3.5", "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.5.tgz", @@ -106,6 +100,7 @@ "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", "integrity": "sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==", "dev": true, + "license": "MIT", "dependencies": { "@nodelib/fs.stat": "2.0.5", "run-parallel": "^1.1.9" @@ -119,6 +114,7 @@ "resolved": "https://registry.npmjs.org/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz", "integrity": "sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A==", "dev": true, + "license": "MIT", "engines": { "node": ">= 8" } @@ -128,6 +124,7 @@ "resolved": "https://registry.npmjs.org/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz", "integrity": "sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg==", "dev": true, + "license": "MIT", "dependencies": { "@nodelib/fs.scandir": "2.1.5", "fastq": "^1.6.0" @@ -136,28 +133,17 @@ "node": ">= 8" } }, - "node_modules/@npmcli/fs": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/@npmcli/fs/-/fs-1.1.1.tgz", - "integrity": "sha512-8KG5RD0GVP4ydEzRn/I4BNDuxDtqVbOdm8675T49OIG/NGhaK0pjPX7ZcDlvKYbA+ulvVK3ztfcF4uBdOxuJbQ==", - "dev": true, - "dependencies": { - "@gar/promisify": "^1.0.1", - "semver": "^7.3.5" - } - }, - "node_modules/@npmcli/move-file": { - "version": "1.1.2", - "resolved": "https://registry.npmjs.org/@npmcli/move-file/-/move-file-1.1.2.tgz", - "integrity": "sha512-1SUf/Cg2GzGDyaf15aR9St9TWlb+XvbZXWpDx8YKs7MLzMH/BCeopv+y9vzrzgkfykCGuWOlSu3mZhj2+FQcrg==", - "deprecated": "This functionality has been moved to @npmcli/fs", + "node_modules/@sindresorhus/merge-streams": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/@sindresorhus/merge-streams/-/merge-streams-2.3.0.tgz", + "integrity": "sha512-LtoMMhxAlorcGhmFYI+LhPgbPZCkgP6ra1YL604EeF6U98pLlQ3iWIGMdWSC+vWmPBWBNgmDBAhnAobLROJmwg==", "dev": true, - "dependencies": { - "mkdirp": "^1.0.4", - "rimraf": "^3.0.2" - }, + "license": "MIT", "engines": { - "node": ">=10" + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" } }, "node_modules/@types/body-parser": { @@ -563,19 +549,6 @@ "acorn": "^8" } }, - "node_modules/aggregate-error": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/aggregate-error/-/aggregate-error-3.1.0.tgz", - "integrity": "sha512-4I7Td01quW/RpocfNayFdFVk1qSuoh0E7JrbRJ16nH01HhKFQ88INq9Sd+nd72zqRySlr9BmDA8xlEJ6vJMrYA==", - "dev": true, - "dependencies": { - "clean-stack": "^2.0.0", - "indent-string": "^4.0.0" - }, - "engines": { - "node": ">=8" - } - }, "node_modules/ajv": { "version": "6.12.6", "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", @@ -671,15 +644,6 @@ "integrity": "sha512-hNfzcOV8W4NdualtqBFPyVO+54DSJuZGY9qT4pRroB6S9e3iiido2ISIC5h9R2sPJ8H3FHCIiEnsv1lPXO3KtQ==", "dev": true }, - "node_modules/array-union": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/array-union/-/array-union-2.1.0.tgz", - "integrity": "sha512-HGyxoOTYUyCM6stUe6EJgnd4EoewAI7zMdfqO+kGjnlZmBDz/cR5pf8r/cR4Wq60sL/p0IkcjUEEPwS3GFrIyw==", - "dev": true, - "engines": { - "node": ">=8" - } - }, "node_modules/balanced-match": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", @@ -692,15 +656,6 @@ "integrity": "sha1-3DQxT05nkxgJP8dgJyUl+UvyXBY=", "dev": true }, - "node_modules/big.js": { - "version": "5.2.2", - "resolved": "https://registry.npmjs.org/big.js/-/big.js-5.2.2.tgz", - "integrity": "sha512-vyL2OymJxmarO8gxMr0mhChsO9QGwhynfuu4+MHTAW6czfq9humCB7rKpUjDd9YUiDPU4mzpyupFSvOClAwbmQ==", - "dev": true, - "engines": { - "node": "*" - } - }, "node_modules/binary-extensions": { "version": "2.2.0", "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.2.0.tgz", @@ -842,35 +797,6 @@ "node": ">= 0.8" } }, - "node_modules/cacache": { - "version": "15.3.0", - "resolved": "https://registry.npmjs.org/cacache/-/cacache-15.3.0.tgz", - "integrity": "sha512-VVdYzXEn+cnbXpFgWs5hTT7OScegHVmLhJIR8Ufqk3iFD6A6j5iSX1KuBTfNEv4tdJWE2PzA6IVFtcLC7fN9wQ==", - "dev": true, - "dependencies": { - "@npmcli/fs": "^1.0.0", - "@npmcli/move-file": "^1.0.1", - "chownr": "^2.0.0", - "fs-minipass": "^2.0.0", - "glob": "^7.1.4", - "infer-owner": "^1.0.4", - "lru-cache": "^6.0.0", - "minipass": "^3.1.1", - "minipass-collect": "^1.0.2", - "minipass-flush": "^1.0.5", - "minipass-pipeline": "^1.2.2", - "mkdirp": "^1.0.3", - "p-map": "^4.0.0", - "promise-inflight": "^1.0.1", - "rimraf": "^3.0.2", - "ssri": "^8.0.1", - "tar": "^6.0.2", - "unique-filename": "^1.1.1" - }, - "engines": { - "node": ">= 10" - } - }, "node_modules/call-bind": { "version": "1.0.7", "resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.7.tgz", @@ -937,15 +863,6 @@ "fsevents": "~2.3.2" } }, - "node_modules/chownr": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/chownr/-/chownr-2.0.0.tgz", - "integrity": "sha512-bIomtDF5KGpdogkLd9VspvFzk9KfpyyGlS8YFVZl7TGPBHL5snIOnxeshwVgPteQ9b4Eydl+pVbIyE1DcvCWgQ==", - "dev": true, - "engines": { - "node": ">=10" - } - }, "node_modules/chrome-trace-event": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/chrome-trace-event/-/chrome-trace-event-1.0.2.tgz", @@ -958,15 +875,6 @@ "node": ">=6.0" } }, - "node_modules/clean-stack": { - "version": "2.2.0", - "resolved": "https://registry.npmjs.org/clean-stack/-/clean-stack-2.2.0.tgz", - "integrity": "sha512-4diC9HaTE+KRAMWhDhrGOECgWZxoevMc5TlkObMqNSsVU62PYzXZ/SMTjzyGAFF1YusgxGcSWTEXBhp0CPwQ1A==", - "dev": true, - "engines": { - "node": ">=6" - } - }, "node_modules/clone-deep": { "version": "4.0.1", "resolved": "https://registry.npmjs.org/clone-deep/-/clone-deep-4.0.1.tgz", @@ -993,12 +901,6 @@ "integrity": "sha512-GpVkmM8vF2vQUkj2LvZmD35JxeJOLCwJ9cUkugyk2nuhbv3+mJvpLYYt+0+USMxE+oj+ey/lJEnhZw75x/OMcQ==", "dev": true }, - "node_modules/commondir": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/commondir/-/commondir-1.0.1.tgz", - "integrity": "sha512-W9pAhw0ja1Edb5GVdIF1mjZw/ASI0AlShXM83UUGe2DVr5TdAPEA1OA8m/g8zWp9x6On7gqufY+FatDbC3MDQg==", - "dev": true - }, "node_modules/compressible": { "version": "2.0.18", "resolved": "https://registry.npmjs.org/compressible/-/compressible-2.0.18.tgz", @@ -1110,32 +1012,98 @@ "dev": true }, "node_modules/copy-webpack-plugin": { - "version": "6.4.1", - "resolved": "https://registry.npmjs.org/copy-webpack-plugin/-/copy-webpack-plugin-6.4.1.tgz", - "integrity": "sha512-MXyPCjdPVx5iiWyl40Va3JGh27bKzOTNY3NjUTrosD2q7dR/cLD0013uqJ3BpFbUjyONINjb6qI7nDIJujrMbA==", + "version": "12.0.2", + "resolved": "https://registry.npmjs.org/copy-webpack-plugin/-/copy-webpack-plugin-12.0.2.tgz", + "integrity": "sha512-SNwdBeHyII+rWvee/bTnAYyO8vfVdcSTud4EIb6jcZ8inLeWucJE0DnxXQBjlQ5zlteuuvooGQy3LIyGxhvlOA==", "dev": true, + "license": "MIT", "dependencies": { - "cacache": "^15.0.5", - "fast-glob": "^3.2.4", - "find-cache-dir": "^3.3.1", - "glob-parent": "^5.1.1", - "globby": "^11.0.1", - "loader-utils": "^2.0.0", + "fast-glob": "^3.3.2", + "glob-parent": "^6.0.1", + "globby": "^14.0.0", "normalize-path": "^3.0.0", - "p-limit": "^3.0.2", - "schema-utils": "^3.0.0", - "serialize-javascript": "^5.0.1", - "webpack-sources": "^1.4.3" + "schema-utils": "^4.2.0", + "serialize-javascript": "^6.0.2" }, "engines": { - "node": ">= 10.13.0" + "node": ">= 18.12.0" }, "funding": { "type": "opencollective", "url": "https://opencollective.com/webpack" }, "peerDependencies": { - "webpack": "^4.37.0 || ^5.0.0" + "webpack": "^5.1.0" + } + }, + "node_modules/copy-webpack-plugin/node_modules/ajv": { + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz", + "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", + "dev": true, + "license": "MIT", + "dependencies": { + "fast-deep-equal": "^3.1.3", + "fast-uri": "^3.0.1", + "json-schema-traverse": "^1.0.0", + "require-from-string": "^2.0.2" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/epoberezkin" + } + }, + "node_modules/copy-webpack-plugin/node_modules/ajv-keywords": { + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-5.1.0.tgz", + "integrity": "sha512-YCS/JNFAUyr5vAuhk1DWm1CBxRHW9LbJ2ozWeemrIqpbsqKjHVxYPyi5GC0rjZIT5JxJ3virVTS8wk4i/Z+krw==", + "dev": true, + "license": "MIT", + "dependencies": { + "fast-deep-equal": "^3.1.3" + }, + "peerDependencies": { + "ajv": "^8.8.2" + } + }, + "node_modules/copy-webpack-plugin/node_modules/glob-parent": { + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", + "integrity": "sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A==", + "dev": true, + "license": "ISC", + "dependencies": { + "is-glob": "^4.0.3" + }, + "engines": { + "node": ">=10.13.0" + } + }, + "node_modules/copy-webpack-plugin/node_modules/json-schema-traverse": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", + "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", + "dev": true, + "license": "MIT" + }, + "node_modules/copy-webpack-plugin/node_modules/schema-utils": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.3.0.tgz", + "integrity": "sha512-Gf9qqc58SpCA/xdziiHz35F4GNIWYWZrEshUc/G/r5BnLph6xpKuLeoJoQuj5WfBIx/eQLf+hmVPYHaxJu7V2g==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/json-schema": "^7.0.9", + "ajv": "^8.9.0", + "ajv-formats": "^2.1.1", + "ajv-keywords": "^5.1.0" + }, + "engines": { + "node": ">= 10.13.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" } }, "node_modules/core-util-is": { @@ -1241,18 +1209,6 @@ "integrity": "sha512-ZIzRpLJrOj7jjP2miAtgqIfmzbxa4ZOr5jJc601zklsfEx9oTzmmj2nVpIPRpNlRTIh8lc1kyViIY7BWSGNmKw==", "dev": true }, - "node_modules/dir-glob": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/dir-glob/-/dir-glob-3.0.1.tgz", - "integrity": "sha512-WkrWp9GR4KXfKGYzOLmTuGVi1UWFfws377n9cc55/tb6DuqyF6pcQ5AbiHEshaDpY9v6oaSr2XCDidGmMwdzIA==", - "dev": true, - "dependencies": { - "path-type": "^4.0.0" - }, - "engines": { - "node": ">=8" - } - }, "node_modules/dns-equal": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/dns-equal/-/dns-equal-1.0.0.tgz", @@ -1283,15 +1239,6 @@ "integrity": "sha512-UdREXMXzLkREF4jA8t89FQjA8WHI6ssP38PMY4/4KhXFQbtImnghh4GkCgrtiZwLKUKVD2iTVXvDVQjfomEQuA==", "dev": true }, - "node_modules/emojis-list": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/emojis-list/-/emojis-list-3.0.0.tgz", - "integrity": "sha512-/kyM18EfinwXZbno9FyUGeFh87KC8HRQBQGildHZbEuRyWFOmv1U10o9BBp8XVZDVNNuQKyIGIu5ZYAAXJ0V2Q==", - "dev": true, - "engines": { - "node": ">= 4" - } - }, "node_modules/encodeurl": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-1.0.2.tgz", @@ -1569,16 +1516,17 @@ "dev": true }, "node_modules/fast-glob": { - "version": "3.3.1", - "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.1.tgz", - "integrity": "sha512-kNFPyjhh5cKjrUltxs+wFx+ZkbRaxxmZ+X0ZU31SOsxCEtP9VPgtq2teZw1DebupL5GmDaNQ6yKMMVcM41iqDg==", + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.3.tgz", + "integrity": "sha512-7MptL8U0cqcFdzIzwOTHoilX9x5BrNqye7Z/LuC7kCMRio1EMSyqRK3BEAUD7sXRq4iT4AzTVuZdhgQ2TCvYLg==", "dev": true, + "license": "MIT", "dependencies": { "@nodelib/fs.stat": "^2.0.2", "@nodelib/fs.walk": "^1.2.3", "glob-parent": "^5.1.2", "merge2": "^1.3.0", - "micromatch": "^4.0.4" + "micromatch": "^4.0.8" }, "engines": { "node": ">=8.6.0" @@ -1590,6 +1538,23 @@ "integrity": "sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw==", "dev": true }, + "node_modules/fast-uri": { + "version": "3.0.6", + "resolved": "https://registry.npmjs.org/fast-uri/-/fast-uri-3.0.6.tgz", + "integrity": "sha512-Atfo14OibSv5wAp4VWNsFYE1AchQRTv9cBGWET4pZWHzYshFSS9NQI6I57rdKn9croWVMbYFbLhJ+yJvmZIIHw==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/fastify" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/fastify" + } + ], + "license": "BSD-3-Clause" + }, "node_modules/fastest-levenshtein": { "version": "1.0.16", "resolved": "https://registry.npmjs.org/fastest-levenshtein/-/fastest-levenshtein-1.0.16.tgz", @@ -1600,10 +1565,11 @@ } }, "node_modules/fastq": { - "version": "1.15.0", - "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.15.0.tgz", - "integrity": "sha512-wBrocU2LCXXa+lWBt8RoIRD89Fi8OdABODa/kEnyeyjS5aZO5/GNvI5sEINADqP/h8M29UHTHUb53sUu5Ihqdw==", + "version": "1.19.0", + "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.19.0.tgz", + "integrity": "sha512-7SFSRCNjBQIZH/xZR3iy5iQYR8aGBE0h3VG6/cwlbrpdciNYBMotQav8c1XI3HjHH+NikUpP53nPdlZSdWmFzA==", "dev": true, + "license": "ISC", "dependencies": { "reusify": "^1.0.4" } @@ -1677,23 +1643,6 @@ "node": ">= 0.8" } }, - "node_modules/find-cache-dir": { - "version": "3.3.2", - "resolved": "https://registry.npmjs.org/find-cache-dir/-/find-cache-dir-3.3.2.tgz", - "integrity": "sha512-wXZV5emFEjrridIgED11OoUKLxiYjAcqot/NJdAkOhlJ+vGzwhOAfcG5OX1jP+S0PcjEn8bdMJv+g2jwQ3Onig==", - "dev": true, - "dependencies": { - "commondir": "^1.0.1", - "make-dir": "^3.0.2", - "pkg-dir": "^4.1.0" - }, - "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/avajs/find-cache-dir?sponsor=1" - } - }, "node_modules/find-up": { "version": "4.1.0", "resolved": "https://registry.npmjs.org/find-up/-/find-up-4.1.0.tgz", @@ -1745,18 +1694,6 @@ "node": ">= 0.6" } }, - "node_modules/fs-minipass": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/fs-minipass/-/fs-minipass-2.1.0.tgz", - "integrity": "sha512-V/JgOLFCS+R6Vcq0slCuaeWEdNC3ouDlJMNIsacH2VtALiu9mV4LPrHc5cDl8k5aw6J8jwgWWpiTo5RYhmIzvg==", - "dev": true, - "dependencies": { - "minipass": "^3.0.0" - }, - "engines": { - "node": ">= 8" - } - }, "node_modules/fs-monkey": { "version": "1.0.4", "resolved": "https://registry.npmjs.org/fs-monkey/-/fs-monkey-1.0.4.tgz", @@ -1862,20 +1799,21 @@ "dev": true }, "node_modules/globby": { - "version": "11.1.0", - "resolved": "https://registry.npmjs.org/globby/-/globby-11.1.0.tgz", - "integrity": "sha512-jhIXaOzy1sb8IyocaruWSn1TjmnBVs8Ayhcy83rmxNJ8q2uWKCAj3CnJY+KpGSXCueAPc0i05kVvVKtP1t9S3g==", + "version": "14.1.0", + "resolved": "https://registry.npmjs.org/globby/-/globby-14.1.0.tgz", + "integrity": "sha512-0Ia46fDOaT7k4og1PDW4YbodWWr3scS2vAr2lTbsplOt2WkKp0vQbkI9wKis/T5LV/dqPjO3bpS/z6GTJB82LA==", "dev": true, + "license": "MIT", "dependencies": { - "array-union": "^2.1.0", - "dir-glob": "^3.0.1", - "fast-glob": "^3.2.9", - "ignore": "^5.2.0", - "merge2": "^1.4.1", - "slash": "^3.0.0" + "@sindresorhus/merge-streams": "^2.1.0", + "fast-glob": "^3.3.3", + "ignore": "^7.0.3", + "path-type": "^6.0.0", + "slash": "^5.1.0", + "unicorn-magic": "^0.3.0" }, "engines": { - "node": ">=10" + "node": ">=18" }, "funding": { "url": "https://github.com/sponsors/sindresorhus" @@ -2114,10 +2052,11 @@ } }, "node_modules/ignore": { - "version": "5.2.4", - "resolved": "https://registry.npmjs.org/ignore/-/ignore-5.2.4.tgz", - "integrity": "sha512-MAb38BcSbH0eHNBxn7ql2NH/kX33OkB3lZ1BNdh7ENeRChHTYsTvWrMubiIAMNS2llXEEgZ1MUOBtXChP3kaFQ==", + "version": "7.0.3", + "resolved": "https://registry.npmjs.org/ignore/-/ignore-7.0.3.tgz", + "integrity": "sha512-bAH5jbK/F3T3Jls4I0SO1hmPR0dKU0a7+SY6n1yzRtG54FLO8d6w/nxLFX2Nb7dBu6cCWXPaAME6cYqFUMmuCA==", "dev": true, + "license": "MIT", "engines": { "node": ">= 4" } @@ -2141,30 +2080,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/imurmurhash": { - "version": "0.1.4", - "resolved": "https://registry.npmjs.org/imurmurhash/-/imurmurhash-0.1.4.tgz", - "integrity": "sha512-JmXMZ6wuvDmLiHEml9ykzqO6lwFbof0GG4IkcGaENdCRDDmMVnny7s5HsIgHCbaq0w2MyPhDqkhTUgS2LU2PHA==", - "dev": true, - "engines": { - "node": ">=0.8.19" - } - }, - "node_modules/indent-string": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/indent-string/-/indent-string-4.0.0.tgz", - "integrity": "sha512-EdDDZu4A2OyIK7Lr/2zG+w5jmbuk1DVBnEwREQvBzspBJkCEbRa8GxU1lghYcaGJCnRWibjDXlq779X1/y5xwg==", - "dev": true, - "engines": { - "node": ">=8" - } - }, - "node_modules/infer-owner": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/infer-owner/-/infer-owner-1.0.4.tgz", - "integrity": "sha512-IClj+Xz94+d7irH5qRyfJonOdfTzuDaifE6ZPWfx0N0+/ATZCbuTPq2prFl526urkQd90WyUKIh1DfBQ2hMz9A==", - "dev": true - }, "node_modules/inflight": { "version": "1.0.6", "resolved": "https://registry.npmjs.org/inflight/-/inflight-1.0.6.tgz", @@ -2363,18 +2278,6 @@ "integrity": "sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==", "dev": true }, - "node_modules/json5": { - "version": "2.2.3", - "resolved": "https://registry.npmjs.org/json5/-/json5-2.2.3.tgz", - "integrity": "sha512-XmOWe7eyHYH14cLdVPoyg+GOH3rYX++KpzrylJwSW98t3Nk+U8XOl8FWKOgwtzdb8lXGf6zYwDUzeHMWfxasyg==", - "dev": true, - "bin": { - "json5": "lib/cli.js" - }, - "engines": { - "node": ">=6" - } - }, "node_modules/kind-of": { "version": "6.0.3", "resolved": "https://registry.npmjs.org/kind-of/-/kind-of-6.0.3.tgz", @@ -2403,20 +2306,6 @@ "node": ">=6.11.5" } }, - "node_modules/loader-utils": { - "version": "2.0.4", - "resolved": "https://registry.npmjs.org/loader-utils/-/loader-utils-2.0.4.tgz", - "integrity": "sha512-xXqpXoINfFhgua9xiqD8fPFHgkoq1mmmpE92WlDbm9rNRd/EbRb+Gqf908T2DMfuHjjJlksiK2RbHVOdD/MqSw==", - "dev": true, - "dependencies": { - "big.js": "^5.2.2", - "emojis-list": "^3.0.0", - "json5": "^2.1.2" - }, - "engines": { - "node": ">=8.9.0" - } - }, "node_modules/locate-path": { "version": "5.0.0", "resolved": "https://registry.npmjs.org/locate-path/-/locate-path-5.0.0.tgz", @@ -2429,42 +2318,6 @@ "node": ">=8" } }, - "node_modules/lru-cache": { - "version": "6.0.0", - "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz", - "integrity": "sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==", - "dev": true, - "dependencies": { - "yallist": "^4.0.0" - }, - "engines": { - "node": ">=10" - } - }, - "node_modules/make-dir": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/make-dir/-/make-dir-3.1.0.tgz", - "integrity": "sha512-g3FeP20LNwhALb/6Cz6Dd4F2ngze0jz7tbzrD2wAV+o9FeNHe4rL+yK2md0J/fiSf1sa1ADhXqi5+oVwOM/eGw==", - "dev": true, - "dependencies": { - "semver": "^6.0.0" - }, - "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/make-dir/node_modules/semver": { - "version": "6.3.1", - "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz", - "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==", - "dev": true, - "bin": { - "semver": "bin/semver.js" - } - }, "node_modules/media-typer": { "version": "0.3.0", "resolved": "https://registry.npmjs.org/media-typer/-/media-typer-0.3.0.tgz", @@ -2506,6 +2359,7 @@ "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", "integrity": "sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg==", "dev": true, + "license": "MIT", "engines": { "node": ">= 8" } @@ -2520,12 +2374,13 @@ } }, "node_modules/micromatch": { - "version": "4.0.5", - "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.5.tgz", - "integrity": "sha512-DMy+ERcEW2q8Z2Po+WNXuw3c5YaUSFjAO5GsJqfEl7UjvtIuFKO6ZrKvcItdy98dwFI2N1tg3zNIdKaQT+aNdA==", + "version": "4.0.8", + "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.8.tgz", + "integrity": "sha512-PXwfBhYu0hBCPw8Dn0E+WDYb7af3dSLVWKi3HGv84IdF4TyFoC0ysxFd0Goxw7nSv4T/PzEJQxsYsEiFCKo2BA==", "dev": true, + "license": "MIT", "dependencies": { - "braces": "^3.0.2", + "braces": "^3.0.3", "picomatch": "^2.3.1" }, "engines": { @@ -2592,79 +2447,6 @@ "node": "*" } }, - "node_modules/minipass": { - "version": "3.3.6", - "resolved": "https://registry.npmjs.org/minipass/-/minipass-3.3.6.tgz", - "integrity": "sha512-DxiNidxSEK+tHG6zOIklvNOwm3hvCrbUrdtzY74U6HKTJxvIDfOUL5W5P2Ghd3DTkhhKPYGqeNUIh5qcM4YBfw==", - "dev": true, - "dependencies": { - "yallist": "^4.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/minipass-collect": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/minipass-collect/-/minipass-collect-1.0.2.tgz", - "integrity": "sha512-6T6lH0H8OG9kITm/Jm6tdooIbogG9e0tLgpY6mphXSm/A9u8Nq1ryBG+Qspiub9LjWlBPsPS3tWQ/Botq4FdxA==", - "dev": true, - "dependencies": { - "minipass": "^3.0.0" - }, - "engines": { - "node": ">= 8" - } - }, - "node_modules/minipass-flush": { - "version": "1.0.5", - "resolved": "https://registry.npmjs.org/minipass-flush/-/minipass-flush-1.0.5.tgz", - "integrity": "sha512-JmQSYYpPUqX5Jyn1mXaRwOda1uQ8HP5KAT/oDSLCzt1BYRhQU0/hDtsB1ufZfEEzMZ9aAVmsBw8+FWsIXlClWw==", - "dev": true, - "dependencies": { - "minipass": "^3.0.0" - }, - "engines": { - "node": ">= 8" - } - }, - "node_modules/minipass-pipeline": { - "version": "1.2.4", - "resolved": "https://registry.npmjs.org/minipass-pipeline/-/minipass-pipeline-1.2.4.tgz", - "integrity": "sha512-xuIq7cIOt09RPRJ19gdi4b+RiNvDFYe5JH+ggNvBqGqpQXcru3PcRmOZuHBKWK1Txf9+cQ+HMVN4d6z46LZP7A==", - "dev": true, - "dependencies": { - "minipass": "^3.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/minizlib": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/minizlib/-/minizlib-2.1.2.tgz", - "integrity": "sha512-bAxsR8BVfj60DWXHE3u30oHzfl4G7khkSuPW+qvpd7jFRHm7dLxOjUk1EHACJ/hxLY8phGJ0YhYHZo7jil7Qdg==", - "dev": true, - "dependencies": { - "minipass": "^3.0.0", - "yallist": "^4.0.0" - }, - "engines": { - "node": ">= 8" - } - }, - "node_modules/mkdirp": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/mkdirp/-/mkdirp-1.0.4.tgz", - "integrity": "sha512-vVqVZQyf3WLx2Shd0qJ9xuvqgAyKPLAiqITEtqW0oIUjzo3PePDd6fW9iFz30ef7Ysp/oiWqbhszeGWW2T6Gzw==", - "dev": true, - "bin": { - "mkdirp": "bin/cmd.js" - }, - "engines": { - "node": ">=10" - } - }, "node_modules/ms": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz", @@ -2815,21 +2597,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/p-limit": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-3.1.0.tgz", - "integrity": "sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ==", - "dev": true, - "dependencies": { - "yocto-queue": "^0.1.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/p-locate": { "version": "4.1.0", "resolved": "https://registry.npmjs.org/p-locate/-/p-locate-4.1.0.tgz", @@ -2857,21 +2624,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/p-map": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/p-map/-/p-map-4.0.0.tgz", - "integrity": "sha512-/bjOqmgETBYB5BoEeGVea8dmvHb2m9GLy1E9W43yeyfP6QQCZGFNa+XRceJEuDB6zqr+gKpIAmlLebMpykw/MQ==", - "dev": true, - "dependencies": { - "aggregate-error": "^3.0.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/p-retry": { "version": "4.6.2", "resolved": "https://registry.npmjs.org/p-retry/-/p-retry-4.6.2.tgz", @@ -2943,12 +2695,16 @@ "dev": true }, "node_modules/path-type": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/path-type/-/path-type-4.0.0.tgz", - "integrity": "sha512-gDKb8aZMDeD/tZWs9P6+q0J9Mwkdl6xMV8TjnGP3qJVJ06bdMgkbBlLU8IdfOsIsFz2BW1rNVT3XuNEl8zPAvw==", + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/path-type/-/path-type-6.0.0.tgz", + "integrity": "sha512-Vj7sf++t5pBD637NSfkxpHSMfWaeig5+DKWLhcqIYx6mWQz5hdJTGDVMQiJcw1ZYkhs7AazKDGpRVji1LJCZUQ==", "dev": true, + "license": "MIT", "engines": { - "node": ">=8" + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" } }, "node_modules/picocolors": { @@ -2987,12 +2743,6 @@ "integrity": "sha512-MtEC1TqN0EU5nephaJ4rAtThHtC86dNN9qCuEhtshvpVBkAW5ZO7BASN9REnF9eoXGcRub+pFuKEpOHE+HbEMw==", "dev": true }, - "node_modules/promise-inflight": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/promise-inflight/-/promise-inflight-1.0.1.tgz", - "integrity": "sha512-6zWPyEOFaQBJYcGMHBKTKJ3u6TBsnMFOIZSa6ce1e/ZrrsOlnHRHbabMjLiBYKp+n44X9eUI6VUPaukCXHuG4g==", - "dev": true - }, "node_modules/proxy-addr": { "version": "2.0.7", "resolved": "https://registry.npmjs.org/proxy-addr/-/proxy-addr-2.0.7.tgz", @@ -3057,7 +2807,8 @@ "type": "consulting", "url": "https://feross.org/support" } - ] + ], + "license": "MIT" }, "node_modules/randombytes": { "version": "2.1.0", @@ -3207,6 +2958,7 @@ "resolved": "https://registry.npmjs.org/reusify/-/reusify-1.0.4.tgz", "integrity": "sha512-U9nH88a3fc/ekCF1l0/UP1IosiuIjyTh7hBvXVMHYgVcfGvt897Xguj2UOLDeI5BG2m7/uwyaLVT6fbtCwTyzw==", "dev": true, + "license": "MIT", "engines": { "iojs": ">=1.0.0", "node": ">=0.10.0" @@ -3246,6 +2998,7 @@ "url": "https://feross.org/support" } ], + "license": "MIT", "dependencies": { "queue-microtask": "^1.2.2" } @@ -3298,21 +3051,6 @@ "node": ">=10" } }, - "node_modules/semver": { - "version": "7.5.4", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", - "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", - "dev": true, - "dependencies": { - "lru-cache": "^6.0.0" - }, - "bin": { - "semver": "bin/semver.js" - }, - "engines": { - "node": ">=10" - } - }, "node_modules/send": { "version": "0.19.0", "resolved": "https://registry.npmjs.org/send/-/send-0.19.0.tgz", @@ -3377,10 +3115,11 @@ } }, "node_modules/serialize-javascript": { - "version": "5.0.1", - "resolved": "https://registry.npmjs.org/serialize-javascript/-/serialize-javascript-5.0.1.tgz", - "integrity": "sha512-SaaNal9imEO737H2c05Og0/8LUXG7EnsZyMa8MzkmuHoELfT6txuj0cMqRj6zfPKnmQ1yasR4PCJc8x+M4JSPA==", + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/serialize-javascript/-/serialize-javascript-6.0.2.tgz", + "integrity": "sha512-Saa1xPByTTq2gdeFZYLLo+RFE35NHZkAbqZeWNd3BpzppeVisAqpDjcp8dyf6uIvEqJRd46jemmyA4iFIeVk8g==", "dev": true, + "license": "BSD-3-Clause", "dependencies": { "randombytes": "^2.1.0" } @@ -3547,12 +3286,16 @@ "dev": true }, "node_modules/slash": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/slash/-/slash-3.0.0.tgz", - "integrity": "sha512-g9Q1haeby36OSStwb4ntCGGGaKsaVSjQ68fBxoQcutl5fS1vuY18H3wSt3jFyFtrkx+Kz0V1G85A4MyAdDMi2Q==", + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/slash/-/slash-5.1.0.tgz", + "integrity": "sha512-ZA6oR3T/pEyuqwMgAKT0/hAv8oAXckzbkmR0UkUosQ+Mc4RxGoJkRmwHgHufaenlyAgE1Mxgpdcrf75y6XcnDg==", "dev": true, + "license": "MIT", "engines": { - "node": ">=8" + "node": ">=14.16" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" } }, "node_modules/sockjs": { @@ -3566,12 +3309,6 @@ "websocket-driver": "^0.7.4" } }, - "node_modules/source-list-map": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/source-list-map/-/source-list-map-2.0.1.tgz", - "integrity": "sha512-qnQ7gVMxGNxsiL4lEuJwe/To8UnK7fAnmbGEEH8RpLouuKbeEm0lhbQVFIrNSuB+G7tVrAlVsZgETT5nljf+Iw==", - "dev": true - }, "node_modules/source-map": { "version": "0.6.1", "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz", @@ -3635,18 +3372,6 @@ "node": ">= 6" } }, - "node_modules/ssri": { - "version": "8.0.1", - "resolved": "https://registry.npmjs.org/ssri/-/ssri-8.0.1.tgz", - "integrity": "sha512-97qShzy1AiyxvPNIkLWoGua7xoQzzPjQ0HAH4B0rWKo7SZ6USuPcrUiAFrws0UH8RrbWmgq3LMTObhPIHbbBeQ==", - "dev": true, - "dependencies": { - "minipass": "^3.1.1" - }, - "engines": { - "node": ">= 8" - } - }, "node_modules/statuses": { "version": "1.5.0", "resolved": "https://registry.npmjs.org/statuses/-/statuses-1.5.0.tgz", @@ -3704,36 +3429,10 @@ "node_modules/tapable": { "version": "2.2.1", "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.2.1.tgz", - "integrity": "sha512-GNzQvQTOIP6RyTfE2Qxb8ZVlNmw0n88vp1szwWRimP02mnTsx3Wtn5qRdqY9w2XduFNUgvOwhNnQsjwCp+kqaQ==", - "dev": true, - "engines": { - "node": ">=6" - } - }, - "node_modules/tar": { - "version": "6.2.0", - "resolved": "https://registry.npmjs.org/tar/-/tar-6.2.0.tgz", - "integrity": "sha512-/Wo7DcT0u5HUV486xg675HtjNd3BXZ6xDbzsCUZPt5iw8bTQ63bP0Raut3mvro9u+CUyq7YQd8Cx55fsZXxqLQ==", - "dev": true, - "dependencies": { - "chownr": "^2.0.0", - "fs-minipass": "^2.0.0", - "minipass": "^5.0.0", - "minizlib": "^2.1.1", - "mkdirp": "^1.0.3", - "yallist": "^4.0.0" - }, - "engines": { - "node": ">=10" - } - }, - "node_modules/tar/node_modules/minipass": { - "version": "5.0.0", - "resolved": "https://registry.npmjs.org/minipass/-/minipass-5.0.0.tgz", - "integrity": "sha512-3FnjYuehv9k6ovOEbyOswadCDPX1piCfhV8ncmYtHOjuPwylVWsghTLo7rabjC3Rx5xD4HDx8Wm1xnMF7S5qFQ==", + "integrity": "sha512-GNzQvQTOIP6RyTfE2Qxb8ZVlNmw0n88vp1szwWRimP02mnTsx3Wtn5qRdqY9w2XduFNUgvOwhNnQsjwCp+kqaQ==", "dev": true, "engines": { - "node": ">=8" + "node": ">=6" } }, "node_modules/terser": { @@ -3788,15 +3487,6 @@ } } }, - "node_modules/terser-webpack-plugin/node_modules/serialize-javascript": { - "version": "6.0.2", - "resolved": "https://registry.npmjs.org/serialize-javascript/-/serialize-javascript-6.0.2.tgz", - "integrity": "sha512-Saa1xPByTTq2gdeFZYLLo+RFE35NHZkAbqZeWNd3BpzppeVisAqpDjcp8dyf6uIvEqJRd46jemmyA4iFIeVk8g==", - "dev": true, - "dependencies": { - "randombytes": "^2.1.0" - } - }, "node_modules/thunky": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/thunky/-/thunky-1.1.0.tgz", @@ -3843,22 +3533,17 @@ "node": ">= 0.6" } }, - "node_modules/unique-filename": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/unique-filename/-/unique-filename-1.1.1.tgz", - "integrity": "sha512-Vmp0jIp2ln35UTXuryvjzkjGdRyf9b2lTXuSYUiPmzRcl3FDtYqAwOnTJkAngD9SWhnoJzDbTKwaOrZ+STtxNQ==", - "dev": true, - "dependencies": { - "unique-slug": "^2.0.0" - } - }, - "node_modules/unique-slug": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/unique-slug/-/unique-slug-2.0.2.tgz", - "integrity": "sha512-zoWr9ObaxALD3DOPfjPSqxt4fnZiWblxHIgeWqW8x7UqDzEtHEQLzji2cuJYQFCU6KmoJikOYAZlrTHHebjx2w==", + "node_modules/unicorn-magic": { + "version": "0.3.0", + "resolved": "https://registry.npmjs.org/unicorn-magic/-/unicorn-magic-0.3.0.tgz", + "integrity": "sha512-+QBBXBCvifc56fsbuxZQ6Sic3wqqc3WWaqxs58gvJrcOuN83HGTCwz3oS5phzU9LthRNE9VrJCFCLUgHeeFnfA==", "dev": true, - "dependencies": { - "imurmurhash": "^0.1.4" + "license": "MIT", + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" } }, "node_modules/unpipe": { @@ -4265,16 +3950,6 @@ "node": ">=10.0.0" } }, - "node_modules/webpack-sources": { - "version": "1.4.3", - "resolved": "https://registry.npmjs.org/webpack-sources/-/webpack-sources-1.4.3.tgz", - "integrity": "sha512-lgTS3Xhv1lCOKo7SA5TjKXMjpSM4sBjNV5+q2bqesbSPs5FjGmU6jjtBSkX9b4qW87vDIsCIlUPOEhbZrMdjeQ==", - "dev": true, - "dependencies": { - "source-list-map": "^2.0.0", - "source-map": "~0.6.1" - } - }, "node_modules/webpack/node_modules/webpack-sources": { "version": "3.2.3", "resolved": "https://registry.npmjs.org/webpack-sources/-/webpack-sources-3.2.3.tgz", @@ -4354,24 +4029,6 @@ "optional": true } } - }, - "node_modules/yallist": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz", - "integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==", - "dev": true - }, - "node_modules/yocto-queue": { - "version": "0.1.0", - "resolved": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-0.1.0.tgz", - "integrity": "sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==", - "dev": true, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } } }, "dependencies": { @@ -4381,12 +4038,6 @@ "integrity": "sha512-dBVuXR082gk3jsFp7Rd/JI4kytwGHecnCoTtXFb7DB6CNHp4rg5k1bhg0nWdLGLnOV71lmDzGQaLMy8iPLY0pw==", "dev": true }, - "@gar/promisify": { - "version": "1.1.3", - "resolved": "https://registry.npmjs.org/@gar/promisify/-/promisify-1.1.3.tgz", - "integrity": "sha512-k2Ty1JcVojjJFwrg/ThKi2ujJ7XNLYaFGNB/bWT9wGR+oSMJHMa5w+CUq6p/pVrKeNNgA7pCqEcjSnHVoqJQFw==", - "dev": true - }, "@jridgewell/gen-mapping": { "version": "0.3.5", "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.5.tgz", @@ -4468,25 +4119,11 @@ "fastq": "^1.6.0" } }, - "@npmcli/fs": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/@npmcli/fs/-/fs-1.1.1.tgz", - "integrity": "sha512-8KG5RD0GVP4ydEzRn/I4BNDuxDtqVbOdm8675T49OIG/NGhaK0pjPX7ZcDlvKYbA+ulvVK3ztfcF4uBdOxuJbQ==", - "dev": true, - "requires": { - "@gar/promisify": "^1.0.1", - "semver": "^7.3.5" - } - }, - "@npmcli/move-file": { - "version": "1.1.2", - "resolved": "https://registry.npmjs.org/@npmcli/move-file/-/move-file-1.1.2.tgz", - "integrity": "sha512-1SUf/Cg2GzGDyaf15aR9St9TWlb+XvbZXWpDx8YKs7MLzMH/BCeopv+y9vzrzgkfykCGuWOlSu3mZhj2+FQcrg==", - "dev": true, - "requires": { - "mkdirp": "^1.0.4", - "rimraf": "^3.0.2" - } + "@sindresorhus/merge-streams": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/@sindresorhus/merge-streams/-/merge-streams-2.3.0.tgz", + "integrity": "sha512-LtoMMhxAlorcGhmFYI+LhPgbPZCkgP6ra1YL604EeF6U98pLlQ3iWIGMdWSC+vWmPBWBNgmDBAhnAobLROJmwg==", + "dev": true }, "@types/body-parser": { "version": "1.19.3", @@ -4857,16 +4494,6 @@ "dev": true, "requires": {} }, - "aggregate-error": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/aggregate-error/-/aggregate-error-3.1.0.tgz", - "integrity": "sha512-4I7Td01quW/RpocfNayFdFVk1qSuoh0E7JrbRJ16nH01HhKFQ88INq9Sd+nd72zqRySlr9BmDA8xlEJ6vJMrYA==", - "dev": true, - "requires": { - "clean-stack": "^2.0.0", - "indent-string": "^4.0.0" - } - }, "ajv": { "version": "6.12.6", "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", @@ -4937,12 +4564,6 @@ "integrity": "sha512-hNfzcOV8W4NdualtqBFPyVO+54DSJuZGY9qT4pRroB6S9e3iiido2ISIC5h9R2sPJ8H3FHCIiEnsv1lPXO3KtQ==", "dev": true }, - "array-union": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/array-union/-/array-union-2.1.0.tgz", - "integrity": "sha512-HGyxoOTYUyCM6stUe6EJgnd4EoewAI7zMdfqO+kGjnlZmBDz/cR5pf8r/cR4Wq60sL/p0IkcjUEEPwS3GFrIyw==", - "dev": true - }, "balanced-match": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", @@ -4955,12 +4576,6 @@ "integrity": "sha1-3DQxT05nkxgJP8dgJyUl+UvyXBY=", "dev": true }, - "big.js": { - "version": "5.2.2", - "resolved": "https://registry.npmjs.org/big.js/-/big.js-5.2.2.tgz", - "integrity": "sha512-vyL2OymJxmarO8gxMr0mhChsO9QGwhynfuu4+MHTAW6czfq9humCB7rKpUjDd9YUiDPU4mzpyupFSvOClAwbmQ==", - "dev": true - }, "binary-extensions": { "version": "2.2.0", "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.2.0.tgz", @@ -5065,32 +4680,6 @@ "integrity": "sha1-0ygVQE1olpn4Wk6k+odV3ROpYEg=", "dev": true }, - "cacache": { - "version": "15.3.0", - "resolved": "https://registry.npmjs.org/cacache/-/cacache-15.3.0.tgz", - "integrity": "sha512-VVdYzXEn+cnbXpFgWs5hTT7OScegHVmLhJIR8Ufqk3iFD6A6j5iSX1KuBTfNEv4tdJWE2PzA6IVFtcLC7fN9wQ==", - "dev": true, - "requires": { - "@npmcli/fs": "^1.0.0", - "@npmcli/move-file": "^1.0.1", - "chownr": "^2.0.0", - "fs-minipass": "^2.0.0", - "glob": "^7.1.4", - "infer-owner": "^1.0.4", - "lru-cache": "^6.0.0", - "minipass": "^3.1.1", - "minipass-collect": "^1.0.2", - "minipass-flush": "^1.0.5", - "minipass-pipeline": "^1.2.2", - "mkdirp": "^1.0.3", - "p-map": "^4.0.0", - "promise-inflight": "^1.0.1", - "rimraf": "^3.0.2", - "ssri": "^8.0.1", - "tar": "^6.0.2", - "unique-filename": "^1.1.1" - } - }, "call-bind": { "version": "1.0.7", "resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.7.tgz", @@ -5126,12 +4715,6 @@ "readdirp": "~3.6.0" } }, - "chownr": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/chownr/-/chownr-2.0.0.tgz", - "integrity": "sha512-bIomtDF5KGpdogkLd9VspvFzk9KfpyyGlS8YFVZl7TGPBHL5snIOnxeshwVgPteQ9b4Eydl+pVbIyE1DcvCWgQ==", - "dev": true - }, "chrome-trace-event": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/chrome-trace-event/-/chrome-trace-event-1.0.2.tgz", @@ -5141,12 +4724,6 @@ "tslib": "^1.9.0" } }, - "clean-stack": { - "version": "2.2.0", - "resolved": "https://registry.npmjs.org/clean-stack/-/clean-stack-2.2.0.tgz", - "integrity": "sha512-4diC9HaTE+KRAMWhDhrGOECgWZxoevMc5TlkObMqNSsVU62PYzXZ/SMTjzyGAFF1YusgxGcSWTEXBhp0CPwQ1A==", - "dev": true - }, "clone-deep": { "version": "4.0.1", "resolved": "https://registry.npmjs.org/clone-deep/-/clone-deep-4.0.1.tgz", @@ -5170,12 +4747,6 @@ "integrity": "sha512-GpVkmM8vF2vQUkj2LvZmD35JxeJOLCwJ9cUkugyk2nuhbv3+mJvpLYYt+0+USMxE+oj+ey/lJEnhZw75x/OMcQ==", "dev": true }, - "commondir": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/commondir/-/commondir-1.0.1.tgz", - "integrity": "sha512-W9pAhw0ja1Edb5GVdIF1mjZw/ASI0AlShXM83UUGe2DVr5TdAPEA1OA8m/g8zWp9x6On7gqufY+FatDbC3MDQg==", - "dev": true - }, "compressible": { "version": "2.0.18", "resolved": "https://registry.npmjs.org/compressible/-/compressible-2.0.18.tgz", @@ -5259,22 +4830,67 @@ "dev": true }, "copy-webpack-plugin": { - "version": "6.4.1", - "resolved": "https://registry.npmjs.org/copy-webpack-plugin/-/copy-webpack-plugin-6.4.1.tgz", - "integrity": "sha512-MXyPCjdPVx5iiWyl40Va3JGh27bKzOTNY3NjUTrosD2q7dR/cLD0013uqJ3BpFbUjyONINjb6qI7nDIJujrMbA==", + "version": "12.0.2", + "resolved": "https://registry.npmjs.org/copy-webpack-plugin/-/copy-webpack-plugin-12.0.2.tgz", + "integrity": "sha512-SNwdBeHyII+rWvee/bTnAYyO8vfVdcSTud4EIb6jcZ8inLeWucJE0DnxXQBjlQ5zlteuuvooGQy3LIyGxhvlOA==", "dev": true, "requires": { - "cacache": "^15.0.5", - "fast-glob": "^3.2.4", - "find-cache-dir": "^3.3.1", - "glob-parent": "^5.1.1", - "globby": "^11.0.1", - "loader-utils": "^2.0.0", + "fast-glob": "^3.3.2", + "glob-parent": "^6.0.1", + "globby": "^14.0.0", "normalize-path": "^3.0.0", - "p-limit": "^3.0.2", - "schema-utils": "^3.0.0", - "serialize-javascript": "^5.0.1", - "webpack-sources": "^1.4.3" + "schema-utils": "^4.2.0", + "serialize-javascript": "^6.0.2" + }, + "dependencies": { + "ajv": { + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz", + "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", + "dev": true, + "requires": { + "fast-deep-equal": "^3.1.3", + "fast-uri": "^3.0.1", + "json-schema-traverse": "^1.0.0", + "require-from-string": "^2.0.2" + } + }, + "ajv-keywords": { + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-5.1.0.tgz", + "integrity": "sha512-YCS/JNFAUyr5vAuhk1DWm1CBxRHW9LbJ2ozWeemrIqpbsqKjHVxYPyi5GC0rjZIT5JxJ3virVTS8wk4i/Z+krw==", + "dev": true, + "requires": { + "fast-deep-equal": "^3.1.3" + } + }, + "glob-parent": { + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", + "integrity": "sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A==", + "dev": true, + "requires": { + "is-glob": "^4.0.3" + } + }, + "json-schema-traverse": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", + "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", + "dev": true + }, + "schema-utils": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.3.0.tgz", + "integrity": "sha512-Gf9qqc58SpCA/xdziiHz35F4GNIWYWZrEshUc/G/r5BnLph6xpKuLeoJoQuj5WfBIx/eQLf+hmVPYHaxJu7V2g==", + "dev": true, + "requires": { + "@types/json-schema": "^7.0.9", + "ajv": "^8.9.0", + "ajv-formats": "^2.1.1", + "ajv-keywords": "^5.1.0" + } + } } }, "core-util-is": { @@ -5358,15 +4974,6 @@ "integrity": "sha512-ZIzRpLJrOj7jjP2miAtgqIfmzbxa4ZOr5jJc601zklsfEx9oTzmmj2nVpIPRpNlRTIh8lc1kyViIY7BWSGNmKw==", "dev": true }, - "dir-glob": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/dir-glob/-/dir-glob-3.0.1.tgz", - "integrity": "sha512-WkrWp9GR4KXfKGYzOLmTuGVi1UWFfws377n9cc55/tb6DuqyF6pcQ5AbiHEshaDpY9v6oaSr2XCDidGmMwdzIA==", - "dev": true, - "requires": { - "path-type": "^4.0.0" - } - }, "dns-equal": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/dns-equal/-/dns-equal-1.0.0.tgz", @@ -5394,12 +5001,6 @@ "integrity": "sha512-UdREXMXzLkREF4jA8t89FQjA8WHI6ssP38PMY4/4KhXFQbtImnghh4GkCgrtiZwLKUKVD2iTVXvDVQjfomEQuA==", "dev": true }, - "emojis-list": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/emojis-list/-/emojis-list-3.0.0.tgz", - "integrity": "sha512-/kyM18EfinwXZbno9FyUGeFh87KC8HRQBQGildHZbEuRyWFOmv1U10o9BBp8XVZDVNNuQKyIGIu5ZYAAXJ0V2Q==", - "dev": true - }, "encodeurl": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-1.0.2.tgz", @@ -5610,16 +5211,16 @@ "dev": true }, "fast-glob": { - "version": "3.3.1", - "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.1.tgz", - "integrity": "sha512-kNFPyjhh5cKjrUltxs+wFx+ZkbRaxxmZ+X0ZU31SOsxCEtP9VPgtq2teZw1DebupL5GmDaNQ6yKMMVcM41iqDg==", + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.3.tgz", + "integrity": "sha512-7MptL8U0cqcFdzIzwOTHoilX9x5BrNqye7Z/LuC7kCMRio1EMSyqRK3BEAUD7sXRq4iT4AzTVuZdhgQ2TCvYLg==", "dev": true, "requires": { "@nodelib/fs.stat": "^2.0.2", "@nodelib/fs.walk": "^1.2.3", "glob-parent": "^5.1.2", "merge2": "^1.3.0", - "micromatch": "^4.0.4" + "micromatch": "^4.0.8" } }, "fast-json-stable-stringify": { @@ -5628,6 +5229,12 @@ "integrity": "sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw==", "dev": true }, + "fast-uri": { + "version": "3.0.6", + "resolved": "https://registry.npmjs.org/fast-uri/-/fast-uri-3.0.6.tgz", + "integrity": "sha512-Atfo14OibSv5wAp4VWNsFYE1AchQRTv9cBGWET4pZWHzYshFSS9NQI6I57rdKn9croWVMbYFbLhJ+yJvmZIIHw==", + "dev": true + }, "fastest-levenshtein": { "version": "1.0.16", "resolved": "https://registry.npmjs.org/fastest-levenshtein/-/fastest-levenshtein-1.0.16.tgz", @@ -5635,9 +5242,9 @@ "dev": true }, "fastq": { - "version": "1.15.0", - "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.15.0.tgz", - "integrity": "sha512-wBrocU2LCXXa+lWBt8RoIRD89Fi8OdABODa/kEnyeyjS5aZO5/GNvI5sEINADqP/h8M29UHTHUb53sUu5Ihqdw==", + "version": "1.19.0", + "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.19.0.tgz", + "integrity": "sha512-7SFSRCNjBQIZH/xZR3iy5iQYR8aGBE0h3VG6/cwlbrpdciNYBMotQav8c1XI3HjHH+NikUpP53nPdlZSdWmFzA==", "dev": true, "requires": { "reusify": "^1.0.4" @@ -5699,17 +5306,6 @@ } } }, - "find-cache-dir": { - "version": "3.3.2", - "resolved": "https://registry.npmjs.org/find-cache-dir/-/find-cache-dir-3.3.2.tgz", - "integrity": "sha512-wXZV5emFEjrridIgED11OoUKLxiYjAcqot/NJdAkOhlJ+vGzwhOAfcG5OX1jP+S0PcjEn8bdMJv+g2jwQ3Onig==", - "dev": true, - "requires": { - "commondir": "^1.0.1", - "make-dir": "^3.0.2", - "pkg-dir": "^4.1.0" - } - }, "find-up": { "version": "4.1.0", "resolved": "https://registry.npmjs.org/find-up/-/find-up-4.1.0.tgz", @@ -5738,15 +5334,6 @@ "integrity": "sha512-zJ2mQYM18rEFOudeV4GShTGIQ7RbzA7ozbU9I/XBpm7kqgMywgmylMwXHxZJmkVoYkna9d2pVXVXPdYTP9ej8Q==", "dev": true }, - "fs-minipass": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/fs-minipass/-/fs-minipass-2.1.0.tgz", - "integrity": "sha512-V/JgOLFCS+R6Vcq0slCuaeWEdNC3ouDlJMNIsacH2VtALiu9mV4LPrHc5cDl8k5aw6J8jwgWWpiTo5RYhmIzvg==", - "dev": true, - "requires": { - "minipass": "^3.0.0" - } - }, "fs-monkey": { "version": "1.0.4", "resolved": "https://registry.npmjs.org/fs-monkey/-/fs-monkey-1.0.4.tgz", @@ -5821,17 +5408,17 @@ "dev": true }, "globby": { - "version": "11.1.0", - "resolved": "https://registry.npmjs.org/globby/-/globby-11.1.0.tgz", - "integrity": "sha512-jhIXaOzy1sb8IyocaruWSn1TjmnBVs8Ayhcy83rmxNJ8q2uWKCAj3CnJY+KpGSXCueAPc0i05kVvVKtP1t9S3g==", + "version": "14.1.0", + "resolved": "https://registry.npmjs.org/globby/-/globby-14.1.0.tgz", + "integrity": "sha512-0Ia46fDOaT7k4og1PDW4YbodWWr3scS2vAr2lTbsplOt2WkKp0vQbkI9wKis/T5LV/dqPjO3bpS/z6GTJB82LA==", "dev": true, "requires": { - "array-union": "^2.1.0", - "dir-glob": "^3.0.1", - "fast-glob": "^3.2.9", - "ignore": "^5.2.0", - "merge2": "^1.4.1", - "slash": "^3.0.0" + "@sindresorhus/merge-streams": "^2.1.0", + "fast-glob": "^3.3.3", + "ignore": "^7.0.3", + "path-type": "^6.0.0", + "slash": "^5.1.0", + "unicorn-magic": "^0.3.0" } }, "gopd": { @@ -6003,9 +5590,9 @@ } }, "ignore": { - "version": "5.2.4", - "resolved": "https://registry.npmjs.org/ignore/-/ignore-5.2.4.tgz", - "integrity": "sha512-MAb38BcSbH0eHNBxn7ql2NH/kX33OkB3lZ1BNdh7ENeRChHTYsTvWrMubiIAMNS2llXEEgZ1MUOBtXChP3kaFQ==", + "version": "7.0.3", + "resolved": "https://registry.npmjs.org/ignore/-/ignore-7.0.3.tgz", + "integrity": "sha512-bAH5jbK/F3T3Jls4I0SO1hmPR0dKU0a7+SY6n1yzRtG54FLO8d6w/nxLFX2Nb7dBu6cCWXPaAME6cYqFUMmuCA==", "dev": true }, "import-local": { @@ -6018,24 +5605,6 @@ "resolve-cwd": "^3.0.0" } }, - "imurmurhash": { - "version": "0.1.4", - "resolved": "https://registry.npmjs.org/imurmurhash/-/imurmurhash-0.1.4.tgz", - "integrity": "sha512-JmXMZ6wuvDmLiHEml9ykzqO6lwFbof0GG4IkcGaENdCRDDmMVnny7s5HsIgHCbaq0w2MyPhDqkhTUgS2LU2PHA==", - "dev": true - }, - "indent-string": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/indent-string/-/indent-string-4.0.0.tgz", - "integrity": "sha512-EdDDZu4A2OyIK7Lr/2zG+w5jmbuk1DVBnEwREQvBzspBJkCEbRa8GxU1lghYcaGJCnRWibjDXlq779X1/y5xwg==", - "dev": true - }, - "infer-owner": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/infer-owner/-/infer-owner-1.0.4.tgz", - "integrity": "sha512-IClj+Xz94+d7irH5qRyfJonOdfTzuDaifE6ZPWfx0N0+/ATZCbuTPq2prFl526urkQd90WyUKIh1DfBQ2hMz9A==", - "dev": true - }, "inflight": { "version": "1.0.6", "resolved": "https://registry.npmjs.org/inflight/-/inflight-1.0.6.tgz", @@ -6180,12 +5749,6 @@ "integrity": "sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==", "dev": true }, - "json5": { - "version": "2.2.3", - "resolved": "https://registry.npmjs.org/json5/-/json5-2.2.3.tgz", - "integrity": "sha512-XmOWe7eyHYH14cLdVPoyg+GOH3rYX++KpzrylJwSW98t3Nk+U8XOl8FWKOgwtzdb8lXGf6zYwDUzeHMWfxasyg==", - "dev": true - }, "kind-of": { "version": "6.0.3", "resolved": "https://registry.npmjs.org/kind-of/-/kind-of-6.0.3.tgz", @@ -6208,17 +5771,6 @@ "integrity": "sha512-3R/1M+yS3j5ou80Me59j7F9IMs4PXs3VqRrm0TU3AbKPxlmpoY1TNscJV/oGJXo8qCatFGTfDbY6W6ipGOYXfg==", "dev": true }, - "loader-utils": { - "version": "2.0.4", - "resolved": "https://registry.npmjs.org/loader-utils/-/loader-utils-2.0.4.tgz", - "integrity": "sha512-xXqpXoINfFhgua9xiqD8fPFHgkoq1mmmpE92WlDbm9rNRd/EbRb+Gqf908T2DMfuHjjJlksiK2RbHVOdD/MqSw==", - "dev": true, - "requires": { - "big.js": "^5.2.2", - "emojis-list": "^3.0.0", - "json5": "^2.1.2" - } - }, "locate-path": { "version": "5.0.0", "resolved": "https://registry.npmjs.org/locate-path/-/locate-path-5.0.0.tgz", @@ -6228,32 +5780,6 @@ "p-locate": "^4.1.0" } }, - "lru-cache": { - "version": "6.0.0", - "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz", - "integrity": "sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==", - "dev": true, - "requires": { - "yallist": "^4.0.0" - } - }, - "make-dir": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/make-dir/-/make-dir-3.1.0.tgz", - "integrity": "sha512-g3FeP20LNwhALb/6Cz6Dd4F2ngze0jz7tbzrD2wAV+o9FeNHe4rL+yK2md0J/fiSf1sa1ADhXqi5+oVwOM/eGw==", - "dev": true, - "requires": { - "semver": "^6.0.0" - }, - "dependencies": { - "semver": { - "version": "6.3.1", - "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz", - "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==", - "dev": true - } - } - }, "media-typer": { "version": "0.3.0", "resolved": "https://registry.npmjs.org/media-typer/-/media-typer-0.3.0.tgz", @@ -6294,12 +5820,12 @@ "dev": true }, "micromatch": { - "version": "4.0.5", - "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.5.tgz", - "integrity": "sha512-DMy+ERcEW2q8Z2Po+WNXuw3c5YaUSFjAO5GsJqfEl7UjvtIuFKO6ZrKvcItdy98dwFI2N1tg3zNIdKaQT+aNdA==", + "version": "4.0.8", + "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.8.tgz", + "integrity": "sha512-PXwfBhYu0hBCPw8Dn0E+WDYb7af3dSLVWKi3HGv84IdF4TyFoC0ysxFd0Goxw7nSv4T/PzEJQxsYsEiFCKo2BA==", "dev": true, "requires": { - "braces": "^3.0.2", + "braces": "^3.0.3", "picomatch": "^2.3.1" } }, @@ -6345,58 +5871,6 @@ "brace-expansion": "^1.1.7" } }, - "minipass": { - "version": "3.3.6", - "resolved": "https://registry.npmjs.org/minipass/-/minipass-3.3.6.tgz", - "integrity": "sha512-DxiNidxSEK+tHG6zOIklvNOwm3hvCrbUrdtzY74U6HKTJxvIDfOUL5W5P2Ghd3DTkhhKPYGqeNUIh5qcM4YBfw==", - "dev": true, - "requires": { - "yallist": "^4.0.0" - } - }, - "minipass-collect": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/minipass-collect/-/minipass-collect-1.0.2.tgz", - "integrity": "sha512-6T6lH0H8OG9kITm/Jm6tdooIbogG9e0tLgpY6mphXSm/A9u8Nq1ryBG+Qspiub9LjWlBPsPS3tWQ/Botq4FdxA==", - "dev": true, - "requires": { - "minipass": "^3.0.0" - } - }, - "minipass-flush": { - "version": "1.0.5", - "resolved": "https://registry.npmjs.org/minipass-flush/-/minipass-flush-1.0.5.tgz", - "integrity": "sha512-JmQSYYpPUqX5Jyn1mXaRwOda1uQ8HP5KAT/oDSLCzt1BYRhQU0/hDtsB1ufZfEEzMZ9aAVmsBw8+FWsIXlClWw==", - "dev": true, - "requires": { - "minipass": "^3.0.0" - } - }, - "minipass-pipeline": { - "version": "1.2.4", - "resolved": "https://registry.npmjs.org/minipass-pipeline/-/minipass-pipeline-1.2.4.tgz", - "integrity": "sha512-xuIq7cIOt09RPRJ19gdi4b+RiNvDFYe5JH+ggNvBqGqpQXcru3PcRmOZuHBKWK1Txf9+cQ+HMVN4d6z46LZP7A==", - "dev": true, - "requires": { - "minipass": "^3.0.0" - } - }, - "minizlib": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/minizlib/-/minizlib-2.1.2.tgz", - "integrity": "sha512-bAxsR8BVfj60DWXHE3u30oHzfl4G7khkSuPW+qvpd7jFRHm7dLxOjUk1EHACJ/hxLY8phGJ0YhYHZo7jil7Qdg==", - "dev": true, - "requires": { - "minipass": "^3.0.0", - "yallist": "^4.0.0" - } - }, - "mkdirp": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/mkdirp/-/mkdirp-1.0.4.tgz", - "integrity": "sha512-vVqVZQyf3WLx2Shd0qJ9xuvqgAyKPLAiqITEtqW0oIUjzo3PePDd6fW9iFz30ef7Ysp/oiWqbhszeGWW2T6Gzw==", - "dev": true - }, "ms": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz", @@ -6508,15 +5982,6 @@ "is-wsl": "^2.2.0" } }, - "p-limit": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-3.1.0.tgz", - "integrity": "sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ==", - "dev": true, - "requires": { - "yocto-queue": "^0.1.0" - } - }, "p-locate": { "version": "4.1.0", "resolved": "https://registry.npmjs.org/p-locate/-/p-locate-4.1.0.tgz", @@ -6537,15 +6002,6 @@ } } }, - "p-map": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/p-map/-/p-map-4.0.0.tgz", - "integrity": "sha512-/bjOqmgETBYB5BoEeGVea8dmvHb2m9GLy1E9W43yeyfP6QQCZGFNa+XRceJEuDB6zqr+gKpIAmlLebMpykw/MQ==", - "dev": true, - "requires": { - "aggregate-error": "^3.0.0" - } - }, "p-retry": { "version": "4.6.2", "resolved": "https://registry.npmjs.org/p-retry/-/p-retry-4.6.2.tgz", @@ -6599,9 +6055,9 @@ "dev": true }, "path-type": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/path-type/-/path-type-4.0.0.tgz", - "integrity": "sha512-gDKb8aZMDeD/tZWs9P6+q0J9Mwkdl6xMV8TjnGP3qJVJ06bdMgkbBlLU8IdfOsIsFz2BW1rNVT3XuNEl8zPAvw==", + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/path-type/-/path-type-6.0.0.tgz", + "integrity": "sha512-Vj7sf++t5pBD637NSfkxpHSMfWaeig5+DKWLhcqIYx6mWQz5hdJTGDVMQiJcw1ZYkhs7AazKDGpRVji1LJCZUQ==", "dev": true }, "picocolors": { @@ -6631,12 +6087,6 @@ "integrity": "sha512-MtEC1TqN0EU5nephaJ4rAtThHtC86dNN9qCuEhtshvpVBkAW5ZO7BASN9REnF9eoXGcRub+pFuKEpOHE+HbEMw==", "dev": true }, - "promise-inflight": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/promise-inflight/-/promise-inflight-1.0.1.tgz", - "integrity": "sha512-6zWPyEOFaQBJYcGMHBKTKJ3u6TBsnMFOIZSa6ce1e/ZrrsOlnHRHbabMjLiBYKp+n44X9eUI6VUPaukCXHuG4g==", - "dev": true - }, "proxy-addr": { "version": "2.0.7", "resolved": "https://registry.npmjs.org/proxy-addr/-/proxy-addr-2.0.7.tgz", @@ -6850,15 +6300,6 @@ "node-forge": "^1" } }, - "semver": { - "version": "7.5.4", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", - "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", - "dev": true, - "requires": { - "lru-cache": "^6.0.0" - } - }, "send": { "version": "0.19.0", "resolved": "https://registry.npmjs.org/send/-/send-0.19.0.tgz", @@ -6918,9 +6359,9 @@ } }, "serialize-javascript": { - "version": "5.0.1", - "resolved": "https://registry.npmjs.org/serialize-javascript/-/serialize-javascript-5.0.1.tgz", - "integrity": "sha512-SaaNal9imEO737H2c05Og0/8LUXG7EnsZyMa8MzkmuHoELfT6txuj0cMqRj6zfPKnmQ1yasR4PCJc8x+M4JSPA==", + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/serialize-javascript/-/serialize-javascript-6.0.2.tgz", + "integrity": "sha512-Saa1xPByTTq2gdeFZYLLo+RFE35NHZkAbqZeWNd3BpzppeVisAqpDjcp8dyf6uIvEqJRd46jemmyA4iFIeVk8g==", "dev": true, "requires": { "randombytes": "^2.1.0" @@ -7059,9 +6500,9 @@ "dev": true }, "slash": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/slash/-/slash-3.0.0.tgz", - "integrity": "sha512-g9Q1haeby36OSStwb4ntCGGGaKsaVSjQ68fBxoQcutl5fS1vuY18H3wSt3jFyFtrkx+Kz0V1G85A4MyAdDMi2Q==", + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/slash/-/slash-5.1.0.tgz", + "integrity": "sha512-ZA6oR3T/pEyuqwMgAKT0/hAv8oAXckzbkmR0UkUosQ+Mc4RxGoJkRmwHgHufaenlyAgE1Mxgpdcrf75y6XcnDg==", "dev": true }, "sockjs": { @@ -7075,12 +6516,6 @@ "websocket-driver": "^0.7.4" } }, - "source-list-map": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/source-list-map/-/source-list-map-2.0.1.tgz", - "integrity": "sha512-qnQ7gVMxGNxsiL4lEuJwe/To8UnK7fAnmbGEEH8RpLouuKbeEm0lhbQVFIrNSuB+G7tVrAlVsZgETT5nljf+Iw==", - "dev": true - }, "source-map": { "version": "0.6.1", "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz", @@ -7137,15 +6572,6 @@ } } }, - "ssri": { - "version": "8.0.1", - "resolved": "https://registry.npmjs.org/ssri/-/ssri-8.0.1.tgz", - "integrity": "sha512-97qShzy1AiyxvPNIkLWoGua7xoQzzPjQ0HAH4B0rWKo7SZ6USuPcrUiAFrws0UH8RrbWmgq3LMTObhPIHbbBeQ==", - "dev": true, - "requires": { - "minipass": "^3.1.1" - } - }, "statuses": { "version": "1.5.0", "resolved": "https://registry.npmjs.org/statuses/-/statuses-1.5.0.tgz", @@ -7188,28 +6614,6 @@ "integrity": "sha512-GNzQvQTOIP6RyTfE2Qxb8ZVlNmw0n88vp1szwWRimP02mnTsx3Wtn5qRdqY9w2XduFNUgvOwhNnQsjwCp+kqaQ==", "dev": true }, - "tar": { - "version": "6.2.0", - "resolved": "https://registry.npmjs.org/tar/-/tar-6.2.0.tgz", - "integrity": "sha512-/Wo7DcT0u5HUV486xg675HtjNd3BXZ6xDbzsCUZPt5iw8bTQ63bP0Raut3mvro9u+CUyq7YQd8Cx55fsZXxqLQ==", - "dev": true, - "requires": { - "chownr": "^2.0.0", - "fs-minipass": "^2.0.0", - "minipass": "^5.0.0", - "minizlib": "^2.1.1", - "mkdirp": "^1.0.3", - "yallist": "^4.0.0" - }, - "dependencies": { - "minipass": { - "version": "5.0.0", - "resolved": "https://registry.npmjs.org/minipass/-/minipass-5.0.0.tgz", - "integrity": "sha512-3FnjYuehv9k6ovOEbyOswadCDPX1piCfhV8ncmYtHOjuPwylVWsghTLo7rabjC3Rx5xD4HDx8Wm1xnMF7S5qFQ==", - "dev": true - } - } - }, "terser": { "version": "5.31.6", "resolved": "https://registry.npmjs.org/terser/-/terser-5.31.6.tgz", @@ -7233,17 +6637,6 @@ "schema-utils": "^3.1.1", "serialize-javascript": "^6.0.1", "terser": "^5.26.0" - }, - "dependencies": { - "serialize-javascript": { - "version": "6.0.2", - "resolved": "https://registry.npmjs.org/serialize-javascript/-/serialize-javascript-6.0.2.tgz", - "integrity": "sha512-Saa1xPByTTq2gdeFZYLLo+RFE35NHZkAbqZeWNd3BpzppeVisAqpDjcp8dyf6uIvEqJRd46jemmyA4iFIeVk8g==", - "dev": true, - "requires": { - "randombytes": "^2.1.0" - } - } } }, "thunky": { @@ -7283,23 +6676,11 @@ "mime-types": "~2.1.24" } }, - "unique-filename": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/unique-filename/-/unique-filename-1.1.1.tgz", - "integrity": "sha512-Vmp0jIp2ln35UTXuryvjzkjGdRyf9b2lTXuSYUiPmzRcl3FDtYqAwOnTJkAngD9SWhnoJzDbTKwaOrZ+STtxNQ==", - "dev": true, - "requires": { - "unique-slug": "^2.0.0" - } - }, - "unique-slug": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/unique-slug/-/unique-slug-2.0.2.tgz", - "integrity": "sha512-zoWr9ObaxALD3DOPfjPSqxt4fnZiWblxHIgeWqW8x7UqDzEtHEQLzji2cuJYQFCU6KmoJikOYAZlrTHHebjx2w==", - "dev": true, - "requires": { - "imurmurhash": "^0.1.4" - } + "unicorn-magic": { + "version": "0.3.0", + "resolved": "https://registry.npmjs.org/unicorn-magic/-/unicorn-magic-0.3.0.tgz", + "integrity": "sha512-+QBBXBCvifc56fsbuxZQ6Sic3wqqc3WWaqxs58gvJrcOuN83HGTCwz3oS5phzU9LthRNE9VrJCFCLUgHeeFnfA==", + "dev": true }, "unpipe": { "version": "1.0.0", @@ -7580,16 +6961,6 @@ "wildcard": "^2.0.0" } }, - "webpack-sources": { - "version": "1.4.3", - "resolved": "https://registry.npmjs.org/webpack-sources/-/webpack-sources-1.4.3.tgz", - "integrity": "sha512-lgTS3Xhv1lCOKo7SA5TjKXMjpSM4sBjNV5+q2bqesbSPs5FjGmU6jjtBSkX9b4qW87vDIsCIlUPOEhbZrMdjeQ==", - "dev": true, - "requires": { - "source-list-map": "^2.0.0", - "source-map": "~0.6.1" - } - }, "websocket-driver": { "version": "0.7.4", "resolved": "https://registry.npmjs.org/websocket-driver/-/websocket-driver-0.7.4.tgz", @@ -7634,18 +7005,6 @@ "integrity": "sha512-6XQFvXTkbfUOZOKKILFG1PDK2NDQs4azKQl26T0YS5CxqWLgXajbPZ+h4gZekJyRqFU8pvnbAbbs/3TgRPy+GQ==", "dev": true, "requires": {} - }, - "yallist": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz", - "integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==", - "dev": true - }, - "yocto-queue": { - "version": "0.1.0", - "resolved": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-0.1.0.tgz", - "integrity": "sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==", - "dev": true } } } diff --git a/datafusion/wasmtest/datafusion-wasm-app/package.json b/datafusion/wasmtest/datafusion-wasm-app/package.json index 0860473276ea..5a2262400cfd 100644 --- a/datafusion/wasmtest/datafusion-wasm-app/package.json +++ b/datafusion/wasmtest/datafusion-wasm-app/package.json @@ -30,6 +30,6 @@ "webpack": "5.94.0", "webpack-cli": "5.1.4", "webpack-dev-server": "4.15.1", - "copy-webpack-plugin": "6.4.1" + "copy-webpack-plugin": "12.0.2" } } diff --git a/dev/update_function_docs.sh b/dev/update_function_docs.sh index 205ab41984a5..a9e87aacf5ad 100755 --- a/dev/update_function_docs.sh +++ b/dev/update_function_docs.sh @@ -236,7 +236,7 @@ WINDOW w AS (PARTITION BY depname ORDER BY salary DESC); The syntax for the OVER-clause is -``` +```sql function([expr]) OVER( [PARTITION BY expr[, …]] @@ -247,7 +247,7 @@ function([expr]) where **frame_clause** is one of: -``` +```sql { RANGE | ROWS | GROUPS } frame_start { RANGE | ROWS | GROUPS } BETWEEN frame_start AND frame_end ``` diff --git a/docs/build.sh b/docs/build.sh index 3fdcd0327024..14464fab40ea 100755 --- a/docs/build.sh +++ b/docs/build.sh @@ -25,4 +25,7 @@ mkdir temp cp -rf source/* temp/ # replace relative URLs with absolute URLs sed -i -e 's/\.\.\/\.\.\/\.\.\//https:\/\/github.com\/apache\/arrow-datafusion\/blob\/main\//g' temp/contributor-guide/index.md + +python rustdoc_trim.py + make SOURCEDIR=`pwd`/temp html diff --git a/docs/rustdoc_trim.py b/docs/rustdoc_trim.py new file mode 100644 index 000000000000..7ea96dbb44a5 --- /dev/null +++ b/docs/rustdoc_trim.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import re + +from pathlib import Path + +# Regex pattern to match Rust code blocks in Markdown +RUST_CODE_BLOCK_PATTERN = re.compile(r"```rust\s*(.*?)```", re.DOTALL) + + +def remove_hashtag_lines_in_rust_blocks(markdown_content): + """ + Removes lines starting with '# ' in Rust code blocks within a Markdown string. + """ + + def _process_code_block(match): + # Extract the code block content + code_block_content = match.group(1).strip() + + # Remove lines starting with '#' + modified_code_block = "\n".join( + line + for line in code_block_content.splitlines() + if (not line.lstrip().startswith("# ")) and line.strip() != "#" + ) + + # Return the modified code block wrapped in triple backticks + return f"```rust\n{modified_code_block}\n```" + + # Replace all Rust code blocks using the _process_code_block function + return RUST_CODE_BLOCK_PATTERN.sub(_process_code_block, markdown_content) + + +# Example usage +def process_markdown_file(file_path): + # Read the Markdown file + with open(file_path, "r", encoding="utf-8") as file: + markdown_content = file.read() + + # Remove lines starting with '#' in Rust code blocks + updated_markdown_content = remove_hashtag_lines_in_rust_blocks(markdown_content) + + # Write the updated content back to the Markdown file + with open(file_path, "w", encoding="utf-8") as file: + file.write(updated_markdown_content) + + print(f"Done processing file: {file_path}") + + +root_directory = Path("./temp/library-user-guide") +for file_path in root_directory.rglob("*.md"): + print(f"Processing file: {file_path}") + process_markdown_file(file_path) + +root_directory = Path("./temp/user-guide") +for file_path in root_directory.rglob("*.md"): + print(f"Processing file: {file_path}") + process_markdown_file(file_path) + +print("All Markdown files processed.") diff --git a/docs/source/contributor-guide/gsoc_application_guidelines.md b/docs/source/contributor-guide/gsoc_application_guidelines.md new file mode 100644 index 000000000000..fddd0b5e1805 --- /dev/null +++ b/docs/source/contributor-guide/gsoc_application_guidelines.md @@ -0,0 +1,105 @@ +# GSoC Application Guidelines + +## Introduction + +Welcome to the Apache DataFusion Google Summer of Code (GSoC) application guidelines. We are excited to support contributors who are passionate about open-source data processing technologies and eager to contribute to DataFusion. This document provides detailed instructions on how to apply, what we expect from applicants, and how you can increase your chances of selection. + +## Why Contribute to Apache DataFusion? + +Apache DataFusion is a high-performance, extensible query engine for data processing, written in Rust and designed for modern analytical workloads. GSoC offers a fantastic opportunity for students and early-career developers to work with experienced mentors, learn about open-source development, and make meaningful contributions. + +## Prerequisites + +Before applying, ensure you: + +- Have read and understood the [Apache DataFusion Contributor Guide](https://datafusion.apache.org/contributor-guide/index.html). +- Have basic familiarity with Rust programming and SQL-based data processing. +- Have explored DataFusion’s GitHub repository and tried running sample queries. +- Have introduced yourself on our mailing list or Discord to discuss project ideas with potential mentors. + +## Application Process + +To apply, follow these steps: + +1. **Choose a Project Idea** + - Review the list of proposed GSoC projects for Apache DataFusion. + - If you have your own project idea, discuss it with potential mentors before submitting your proposal. +2. **Engage with the Community** + - Join our [mailing list](mailto:dev@datafusion.apache.org) and [Discord](https://discord.gg/Q9eh6S2T) to introduce yourself and ask questions. + - Optional: Submit a small pull request (PR) for an issue marked with the **good first issue** tag to understand/test whether you enjoy working on Apache DataFusion, get comfortable with navigating the codebase and demonstrate your ability. +3. **Write a Clear Proposal** + - You can use the template below to structure your proposal. + - Ensure it is has sufficient details and is feasible. + - Seek feedback from mentors before submission. + +## Application Template + +``` +# Apache DataFusion GSoC Application + +## Personal Information + +- **Name:** +- **GitHub ID:** +- **Email:** +- **LinkedIn/Personal Website (if any):** +- **Time Zone & Available Hours Per Week:** + +## Project Proposal + +### Title + +Provide a concise and descriptive project title. + +### Synopsis + +Summarize the project in a few sentences. What problem does it solve? Why is it important? If you choose an idea proposed by us, this can simply be a summary of your research on the problem and/or your understanding of it. + +### Benefits to the Community + +Explain how this project will improve Apache DataFusion and its ecosystem. If you choose an idea proposed by us, this can simply be a summary of your understanding of potential benefits. + +### Deliverables & Milestones + +Consult with project mentors to come up with a rough roadmap for what you plan to accomplish, ensuring it aligns with GSoC’s timeline. + +### Technical Details + +Discuss the technologies, tools, and methodologies you plan to use. Mention any potential challenges and how you plan to address them. + +### Related Work & References + +List any relevant research, documentation, or prior work that informs your proposal. + +## Personal Experience + +### Relevant Skills & Background + +Describe your experience with Rust, databases, and open-source contributions. + +### Past Open-Source Contributions + +List any prior contributions (links to PRs, issues, repositories). + +### Learning Plan + +Explain how you will learn new skills required for this project. + +## Mentor & Communication + +- **Preferred Communication Channels:** (Email, Discord, etc.) +- **Weekly Progress Updates Plan:** Describe how you plan to remain in sync with your mentor(s). + +## Additional Information + +Add anything else you believe strengthens your application. + +``` + +## Final Steps + +- Review your proposal for clarity and completeness. +- Submit your proposal via the GSoC portal before the deadline. +- Stay active in the community and be ready to discuss your application with mentors. + +We look forward to your application and your contributions to Apache DataFusion! diff --git a/docs/source/contributor-guide/gsoc_project_ideas.md b/docs/source/contributor-guide/gsoc_project_ideas.md new file mode 100644 index 000000000000..3feaba559a48 --- /dev/null +++ b/docs/source/contributor-guide/gsoc_project_ideas.md @@ -0,0 +1,112 @@ +# GSoC Project Ideas + +## Introduction + +Welcome to the Apache DataFusion Google Summer of Code (GSoC) 2025 project ideas list. Below you can find information about the projects. Please refer to [this page](https://datafusion.apache.org/contributor-guide/gsoc_application_guidelines.html) for application guidelines. + +## Projects + +### [Implement Continuous Monitoring of DataFusion Performance](https://github.com/apache/datafusion/issues/5504) + +- **Description and Outcomes:** DataFusion lacks continuous monitoring of how performance evolves over time -- we do this somewhat manually today. Even though performance has been one of our top priorities for a while now, we didn't build a continuous monitoring system yet. This linked issue contains a summary of all the previous efforts that made us inch closer to having such a system, but a functioning system needs to built on top of that progress. A student successfully completing this project would gain experience in building an end-to-end monitoring system that integrates with GitHub, scheduling/running benchmarks on some sort of a cloud infrastructure, and building a versatile web UI to expose the results. The outcome of this project will benefit Apache DataFusion on an ongoing basis in its quest for ever-more performance. +- **Category:** Tooling +- **Difficulty:** Medium +- **Possible Mentor(s) and/or Helper(s):** [alamb](https://github.com/alamb) and [mertak-synnada](https://github.com/mertak-synnada) +- **Skills:** DevOps, Cloud Computing, Web Development, Integrations +- **Expected Project Size:** 175 to 350 hours\* + +### [Supporting Correlated Subqueries](https://github.com/apache/datafusion/issues/5483) + +- **Description and Outcomes:** Correlated subqueries are an important SQL feature that enables some users to express their business logic more intuitively without thinking about "joins". Even though DataFusion has decent join support, it doesn't fully support correlated subqueries. The linked epic contains bite-size pieces of the steps necessary to achieve full support. For students interested in internals of data systems and databases, this project is a good opportunity to apply and/or improve their computer science knowledge. The experience of adding such a feature to a widely-used foundational query engine can also serve as a good opportunity to kickstart a career in the area of databases and data systems. +- **Category:** Core +- **Difficulty:** Advanced +- **Possible Mentor(s) and/or Helper(s):** [jayzhan-synnada](https://github.com/jayzhan-synnada) and [xudong963](https://github.com/xudong963) +- **Skills:** Databases, Algorithms, Data Structures, Testing Techniques +- **Expected Project Size:** 350 hours + +### Improving DataFusion DX (e.g. [1](https://github.com/apache/datafusion/issues/9371) and [2](https://github.com/apache/datafusion/issues/14429)) + +- **Description and Outcomes:** While performance, extensibility and customizability is DataFusion's strong aspects, we have much work to do in terms of user-friendliness and ease of debug-ability. This project aims to make strides in these areas by improving terminal visualizations of query plans and increasing the "deployment" of the newly-added diagnostics framework. This project is a potential high-impact project with high output visibility, and reduce the barrier to entry to new users. +- **Category:** DX +- **Difficulty:** Medium +- **Possible Mentor(s) and/or Helper(s):** [eliaperantoni](https://github.com/eliaperantoni) and [mkarbo](https://github.com/mkarbo) +- **Skills:** Software Engineering, Terminal Visualizations +- **Expected Project Size:** 175 to 350 hours\* + +### [Robust WASM Support](https://github.com/apache/datafusion/issues/13815) + +- **Description and Outcomes:** DataFusion can be compiled today to WASM with some care. However, it is somewhat tricky and brittle. Having robust WASM support improves the _embeddability_ aspect of DataFusion, and can enable many practical use cases. A good conclusion of this project would be the addition of a live demo sub-page to the DataFusion homepage. +- **Category:** Build +- **Difficulty:** Medium +- **Possible Mentor(s) and/or Helper(s):** [alamb](https://github.com/alamb) and [waynexia](https://github.com/waynexia) +- **Skills:** WASM, Advanced Rust, Web Development, Software Engineering +- **Expected Project Size:** 175 to 350 hours\* + +### [High Performance Aggregations](https://github.com/apache/datafusion/issues/7000) + +- **Description and Outcomes:** An aggregation is one of the most fundamental operations within a query engine. Practical performance in many use cases, and results in many well-known benchmarks (e.g. [ClickBench](https://benchmark.clickhouse.com/)), depend heavily on aggregation performance. DataFusion community has been working on improving aggregation performance for a while now, but there is still work to do. A student working on this project will get the chance to hone their skills on high-performance, low(ish) level coding, intricacies of measuring performance, data structures and others. +- **Category:** Core +- **Difficulty:** Advanced +- **Possible Mentor(s) and/or Helper(s):** [jayzhan-synnada](https://github.com/jayzhan-synnada) and [Rachelint](https://github.com/Rachelint) +- **Skills:** Algorithms, Data Structures, Advanced Rust, Databases, Benchmarking Techniques +- **Expected Project Size:** 350 hours + +### [Improving Python Bindings](https://github.com/apache/datafusion-python) + +- **Description and Outcomes:** DataFusion offers Python bindings that enable users to build data systems using Python. However, the Python bindings are still relatively low-level, and do not expose all APIs libraries like [Pandas](https://pandas.pydata.org/) and [Polars](https://pola.rs/) with a end-user focus offer. This project aims to improve DataFusion's Python bindings to make progress towards moving it closer to such libraries in terms of built-in APIs and functionality. +- **Category:** Python Bindings +- **Difficulty:** Medium +- **Possible Mentor(s) and/or Helper(s):** [timsaucer](https://github.com/timsaucer) +- **Skills:** APIs, FFIs, DataFrame Libraries +- **Expected Project Size:** 175 to 350 hours\* + +### [Optimizing DataFusion Binary Size](https://github.com/apache/datafusion/issues/13816) + +- **Description and Outcomes:** DataFusion is a foundational library with a large feature set. Even though we try to avoid adding too many dependencies and implement many low-level functionalities inside the codebase, the fast moving nature of the project results in an accumulation of dependencies over time. This inflates DataFusion's binary size over time, which reduces portability and embeddability. This project involves a study of the codebase, using compiler tooling, to understand where code bloat comes from, simplifying/reducing the number of dependencies by efficient in-house implementations, and avoiding code duplications. +- **Category:** Core/Build +- **Difficulty:** Medium +- **Possible Mentor(s) and/or Helper(s):** [comphead](https://github.com/comphead) and [alamb](https://github.com/alamb) +- **Skills:** Software Engineering, Refactoring, Dependency Management, Compilers +- **Expected Project Size:** 175 to 350 hours\* + +### [Ergonomic SQL Features](https://github.com/apache/datafusion/issues/14514) + +- **Description and Outcomes:** [DuckDB](https://duckdb.org/) has many innovative features that significantly improve the SQL UX. Even though some of those features are already implemented in DataFusion, there are many others we can implement (and get inspiration from). [This page](https://duckdb.org/docs/sql/dialect/friendly_sql.html) contains a good summary of such features. Each such feature will serve as a bite-size, achievable milestone for a cool GSoC project that will have user-facing impact improving the UX on a broad basis. The project will start with a survey of what is already implemented, what is missing, and kick off with a prioritization proposal/implementation plan. +- **Category:** SQL FE +- **Difficulty:** Medium +- **Possible Mentor(s) and/or Helper(s):** [berkaysynnada](https://github.com/berkaysynnada) +- **Skills:** SQL, Planning, Parsing, Software Engineering +- **Expected Project Size:** 350 hours + +### [Advanced Interval Analysis](https://github.com/apache/datafusion/issues/14515) + +- **Description and Outcomes:** DataFusion implements interval arithmetic and utilizes it for range estimations, which enables use cases in data pruning, optimizations and statistics. However, the current implementation only works efficiently for forward evaluation; i.e. calculating the output range of an expression given input ranges (ranges of columns). When propagating constraints using the same graph, the current approach requires multiple bottom-up and top-down traversals to narrow column bounds fully. This project aims to fix this deficiency by utilizing a better algorithmic approach. Note that this is a _very advanced_ project for students with a deep interest in computational methods, expression graphs, and constraint solvers. +- **Category:** Core +- **Difficulty:** Advanced +- **Possible Mentor(s) and/or Helper(s):** [ozankabak](https://github.com/ozankabak) and [berkaysynnada](https://github.com/berkaysynnada) +- **Skills:** Algorithms, Data Structures, Applied Mathematics, Software Engineering +- **Expected Project Size:** 350 hours + +### [Spark-Compatible Functions Crate](https://github.com/apache/datafusion/issues/5600) + +- **Description and Outcomes:** In general, DataFusion aims to be compatible with PostgreSQL in terms of functions and behaviors. However, there are many users (and downstream projects, such as [DataFusion Comet](https://datafusion.apache.org/comet/)) that desire compatibility with [Apache Spark](https://spark.apache.org/). This project aims to collect Spark-compatible functions into a separate crate to help such users and/or projects. The project will be an exercise in creating the right APIs, explaining how to use them, and then telling the world about them (e.g. via creating a compatibility-tracking page cataloging such functions, writing blog posts etc.). +- **Category:** Extensions +- **Difficulty:** Medium +- **Possible Mentor(s) and/or Helper(s):** [alamb](https://github.com/alamb) and [andygrove](https://github.com/andygrove) +- **Skills:** SQL, Spark, Software Engineering +- **Expected Project Size:** 175 to 350 hours\* + +### [SQL Fuzzing Framework in Rust](https://github.com/apache/datafusion/issues/14535) + +- **Description and Outcomes:** Fuzz testing is a very important technique we utilize often in DataFusion. Having SQL-level fuzz testing enables us to battle-test DataFusion in an end-to-end fashion. Initial version of our fuzzing framework is Java-based, but the time has come to migrate to Rust-native solution. This will simplify the overall implementation (by avoiding things like JDBC), enable us to implement more advanced algorithms for query generation, and attract more contributors over time. This project is a good blend of software engineering, algorithms and testing techniques (i.e. fuzzing techniques). +- **Category:** Extensions +- **Difficulty:** Advanced +- **Possible Mentor(s) and/or Helper(s):** [2010YOUY01](https://github.com/2010YOUY01) +- **Skills:** SQL, Testing Techniques, Advanced Rust, Software Engineering +- **Expected Project Size:** 175 to 350 hours\* + +\*_There is enough material to make this a 350-hour project, but it is granular enough to make it a 175-hour project as well._ + +## Contact Us + +You can join our [mailing list](mailto:dev%40datafusion.apache.org) and [Discord](https://discord.gg/Q9eh6S2T) to introduce yourself and ask questions. diff --git a/docs/source/contributor-guide/howtos.md b/docs/source/contributor-guide/howtos.md index e406804caa44..556242751ff4 100644 --- a/docs/source/contributor-guide/howtos.md +++ b/docs/source/contributor-guide/howtos.md @@ -19,6 +19,12 @@ # HOWTOs +## How to update the version of Rust used in CI tests + +- Make a PR to update the [rust-toolchain] file in the root of the repository: + +[rust-toolchain]: https://github.com/apache/datafusion/blob/main/rust-toolchain.toml + ## How to add a new scalar function Below is a checklist of what you need to do to add a new scalar function to DataFusion: diff --git a/docs/source/index.rst b/docs/source/index.rst index 739166782ad6..d9b0c126ab12 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -103,6 +103,7 @@ To get started, see user-guide/introduction user-guide/example-usage + user-guide/features user-guide/concepts-readings-events user-guide/crate-configuration user-guide/cli/index @@ -148,6 +149,8 @@ To get started, see contributor-guide/governance contributor-guide/inviting contributor-guide/specification/index + contributor-guide/gsoc_application_guidelines + contributor-guide/gsoc_project_ideas .. _toc.subprojects: diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/adding-udfs.md index a9202976973b..a365ef6696a3 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/adding-udfs.md @@ -55,46 +55,62 @@ of arguments. This a lower level API with more functionality but is more complex, also documented in [`advanced_udf.rs`]. ```rust +use std::sync::Arc; use std::any::Any; +use std::sync::LazyLock; use arrow::datatypes::DataType; +use datafusion_common::cast::as_int64_array; use datafusion_common::{DataFusionError, plan_err, Result}; -use datafusion_expr::{col, ColumnarValue, Signature, Volatility}; +use datafusion_expr::{col, ColumnarValue, ScalarFunctionArgs, Signature, Volatility}; +use datafusion::arrow::array::{ArrayRef, Int64Array}; use datafusion_expr::{ScalarUDFImpl, ScalarUDF}; - +use datafusion_macros::user_doc; +use datafusion_doc::Documentation; + +/// This struct for a simple UDF that adds one to an int32 +#[user_doc( + doc_section(label = "Math Functions"), + description = "Add one udf", + syntax_example = "add_one(1)" +)] #[derive(Debug)] struct AddOne { - signature: Signature -}; + signature: Signature, +} impl AddOne { - fn new() -> Self { - Self { - signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable) - } - } + fn new() -> Self { + Self { + signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable), + } + } } /// Implement the ScalarUDFImpl trait for AddOne impl ScalarUDFImpl for AddOne { - fn as_any(&self) -> &dyn Any { self } - fn name(&self) -> &str { "add_one" } - fn signature(&self) -> &Signature { &self.signature } - fn return_type(&self, args: &[DataType]) -> Result { - if !matches!(args.get(0), Some(&DataType::Int32)) { - return plan_err!("add_one only accepts Int32 arguments"); - } - Ok(DataType::Int32) - } - // The actual implementation would add one to the argument - fn invoke_batch(&self, args: &[ColumnarValue], _number_rows: usize) -> Result { - let args = columnar_values_to_array(args)?; + fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { "add_one" } + fn signature(&self) -> &Signature { &self.signature } + fn return_type(&self, args: &[DataType]) -> Result { + if !matches!(args.get(0), Some(&DataType::Int32)) { + return plan_err!("add_one only accepts Int32 arguments"); + } + Ok(DataType::Int32) + } + // The actual implementation would add one to the argument + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = ColumnarValue::values_to_arrays(&args.args)?; let i64s = as_int64_array(&args[0])?; let new_array = i64s .iter() .map(|array_elem| array_elem.map(|value| value + 1)) .collect::(); - Ok(Arc::new(new_array)) + + Ok(ColumnarValue::from(Arc::new(new_array) as ArrayRef)) + } + fn documentation(&self) -> Option<&Documentation> { + self.doc() } } ``` @@ -102,15 +118,75 @@ impl ScalarUDFImpl for AddOne { We now need to register the function with DataFusion so that it can be used in the context of a query. ```rust +# use std::sync::Arc; +# use std::any::Any; +# use std::sync::LazyLock; +# use arrow::datatypes::DataType; +# use datafusion_common::cast::as_int64_array; +# use datafusion_common::{DataFusionError, plan_err, Result}; +# use datafusion_expr::{col, ColumnarValue, ScalarFunctionArgs, Signature, Volatility}; +# use datafusion::arrow::array::{ArrayRef, Int64Array}; +# use datafusion_expr::{ScalarUDFImpl, ScalarUDF}; +# use datafusion_macros::user_doc; +# use datafusion_doc::Documentation; +# +# /// This struct for a simple UDF that adds one to an int32 +# #[user_doc( +# doc_section(label = "Math Functions"), +# description = "Add one udf", +# syntax_example = "add_one(1)" +# )] +# #[derive(Debug)] +# struct AddOne { +# signature: Signature, +# } +# +# impl AddOne { +# fn new() -> Self { +# Self { +# signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable), +# } +# } +# } +# +# /// Implement the ScalarUDFImpl trait for AddOne +# impl ScalarUDFImpl for AddOne { +# fn as_any(&self) -> &dyn Any { self } +# fn name(&self) -> &str { "add_one" } +# fn signature(&self) -> &Signature { &self.signature } +# fn return_type(&self, args: &[DataType]) -> Result { +# if !matches!(args.get(0), Some(&DataType::Int32)) { +# return plan_err!("add_one only accepts Int32 arguments"); +# } +# Ok(DataType::Int32) +# } +# // The actual implementation would add one to the argument +# fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { +# let args = ColumnarValue::values_to_arrays(&args.args)?; +# let i64s = as_int64_array(&args[0])?; +# +# let new_array = i64s +# .iter() +# .map(|array_elem| array_elem.map(|value| value + 1)) +# .collect::(); +# +# Ok(ColumnarValue::from(Arc::new(new_array) as ArrayRef)) +# } +# fn documentation(&self) -> Option<&Documentation> { +# self.doc() +# } +# } +use datafusion::execution::context::SessionContext; + // Create a new ScalarUDF from the implementation let add_one = ScalarUDF::from(AddOne::new()); +// Call the function `add_one(col)` +let expr = add_one.call(vec![col("a")]); + // register the UDF with the context so it can be invoked by name and from SQL let mut ctx = SessionContext::new(); ctx.register_udf(add_one.clone()); - -// Call the function `add_one(col)` -let expr = add_one.call(vec![col("a")]); ``` ### Adding a Scalar UDF by [`create_udf`] @@ -121,7 +197,6 @@ There is a an older, more concise, but also more limited API [`create_udf`] avai ```rust use std::sync::Arc; - use datafusion::arrow::array::{ArrayRef, Int64Array}; use datafusion::common::cast::as_int64_array; use datafusion::common::Result; @@ -145,6 +220,24 @@ This "works" in isolation, i.e. if you have a slice of `ArrayRef`s, you can call `ArrayRef` with 1 added to each value. ```rust +# use std::sync::Arc; +# use datafusion::arrow::array::{ArrayRef, Int64Array}; +# use datafusion::common::cast::as_int64_array; +# use datafusion::common::Result; +# use datafusion::logical_expr::ColumnarValue; +# +# pub fn add_one(args: &[ColumnarValue]) -> Result { +# // Error handling omitted for brevity +# let args = ColumnarValue::values_to_arrays(args)?; +# let i64s = as_int64_array(&args[0])?; +# +# let new_array = i64s +# .iter() +# .map(|array_elem| array_elem.map(|value| value + 1)) +# .collect::(); +# +# Ok(ColumnarValue::from(Arc::new(new_array) as ArrayRef)) +# } let input = vec![Some(1), None, Some(3)]; let input = ColumnarValue::from(Arc::new(Int64Array::from(input)) as ArrayRef); @@ -165,9 +258,26 @@ with the `SessionContext`. DataFusion provides the [`create_udf`] and helper functions to make this easier. ```rust +# use std::sync::Arc; +# use datafusion::arrow::array::{ArrayRef, Int64Array}; +# use datafusion::common::cast::as_int64_array; +# use datafusion::common::Result; +# use datafusion::logical_expr::ColumnarValue; +# +# pub fn add_one(args: &[ColumnarValue]) -> Result { +# // Error handling omitted for brevity +# let args = ColumnarValue::values_to_arrays(args)?; +# let i64s = as_int64_array(&args[0])?; +# +# let new_array = i64s +# .iter() +# .map(|array_elem| array_elem.map(|value| value + 1)) +# .collect::(); +# +# Ok(ColumnarValue::from(Arc::new(new_array) as ArrayRef)) +# } use datafusion::logical_expr::{Volatility, create_udf}; use datafusion::arrow::datatypes::DataType; -use std::sync::Arc; let udf = create_udf( "add_one", @@ -178,12 +288,7 @@ let udf = create_udf( ); ``` -[`scalarudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.ScalarUDF.html -[`create_udf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udf.html -[`process_scalar_func_inputs`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.process_scalar_func_inputs.html -[`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs - -A few things to note: +A few things to note on `create_udf`: - The first argument is the name of the function. This is the name that will be used in SQL queries. - The second argument is a vector of `DataType`s. This is the list of argument types that the function accepts. I.e. in @@ -198,20 +303,51 @@ A few things to note: That gives us a `ScalarUDF` that we can register with the `SessionContext`: ```rust +# use std::sync::Arc; +# use datafusion::arrow::array::{ArrayRef, Int64Array}; +# use datafusion::common::cast::as_int64_array; +# use datafusion::common::Result; +# use datafusion::logical_expr::ColumnarValue; +# +# pub fn add_one(args: &[ColumnarValue]) -> Result { +# // Error handling omitted for brevity +# let args = ColumnarValue::values_to_arrays(args)?; +# let i64s = as_int64_array(&args[0])?; +# +# let new_array = i64s +# .iter() +# .map(|array_elem| array_elem.map(|value| value + 1)) +# .collect::(); +# +# Ok(ColumnarValue::from(Arc::new(new_array) as ArrayRef)) +# } +use datafusion::logical_expr::{Volatility, create_udf}; +use datafusion::arrow::datatypes::DataType; use datafusion::execution::context::SessionContext; -let mut ctx = SessionContext::new(); - -ctx.register_udf(udf); +#[tokio::main] +async fn main() { + let udf = create_udf( + "add_one", + vec![DataType::Int64], + DataType::Int64, + Volatility::Immutable, + Arc::new(add_one), + ); + + let mut ctx = SessionContext::new(); + ctx.register_udf(udf); + + // At this point, you can use the `add_one` function in your query: + let query = "SELECT add_one(1)"; + let df = ctx.sql(&query).await.unwrap(); +} ``` -At this point, you can use the `add_one` function in your query: - -```rust -let sql = "SELECT add_one(1)"; - -let df = ctx.sql( & sql).await.unwrap(); -``` +[`scalarudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.ScalarUDF.html +[`create_udf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udf.html +[`process_scalar_func_inputs`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.process_scalar_func_inputs.html +[`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs ## Adding a Window UDF @@ -294,17 +430,61 @@ with the `SessionContext`. DataFusion provides the [`create_udwf`] helper functi There is a lower level API with more functionality but is more complex, that is documented in [`advanced_udwf.rs`]. ```rust +# use datafusion::arrow::{array::{ArrayRef, Float64Array, AsArray}, datatypes::Float64Type}; +# use datafusion::logical_expr::{PartitionEvaluator}; +# use datafusion::common::ScalarValue; +# use datafusion::error::Result; +# +# #[derive(Clone, Debug)] +# struct MyPartitionEvaluator {} +# +# impl MyPartitionEvaluator { +# fn new() -> Self { +# Self {} +# } +# } +# +# impl PartitionEvaluator for MyPartitionEvaluator { +# fn uses_window_frame(&self) -> bool { +# true +# } +# +# fn evaluate( +# &mut self, +# values: &[ArrayRef], +# range: &std::ops::Range, +# ) -> Result { +# // Again, the input argument is an array of floating +# // point numbers to calculate a moving average +# let arr: &Float64Array = values[0].as_ref().as_primitive::(); +# +# let range_len = range.end - range.start; +# +# // our smoothing function will average all the values in the +# let output = if range_len > 0 { +# let sum: f64 = arr.values().iter().skip(range.start).take(range_len).sum(); +# Some(sum / range_len as f64) +# } else { +# None +# }; +# +# Ok(ScalarValue::Float64(output)) +# } +# } +# fn make_partition_evaluator() -> Result> { +# Ok(Box::new(MyPartitionEvaluator::new())) +# } use datafusion::logical_expr::{Volatility, create_udwf}; use datafusion::arrow::datatypes::DataType; use std::sync::Arc; // here is where we define the UDWF. We also declare its signature: let smooth_it = create_udwf( -"smooth_it", -DataType::Float64, -Arc::new(DataType::Float64), -Volatility::Immutable, -Arc::new(make_partition_evaluator), + "smooth_it", + DataType::Float64, + Arc::new(DataType::Float64), + Volatility::Immutable, + Arc::new(make_partition_evaluator), ); ``` @@ -327,6 +507,62 @@ The `create_udwf` has five arguments to check: That gives us a `WindowUDF` that we can register with the `SessionContext`: ```rust +# use datafusion::arrow::{array::{ArrayRef, Float64Array, AsArray}, datatypes::Float64Type}; +# use datafusion::logical_expr::{PartitionEvaluator}; +# use datafusion::common::ScalarValue; +# use datafusion::error::Result; +# +# #[derive(Clone, Debug)] +# struct MyPartitionEvaluator {} +# +# impl MyPartitionEvaluator { +# fn new() -> Self { +# Self {} +# } +# } +# +# impl PartitionEvaluator for MyPartitionEvaluator { +# fn uses_window_frame(&self) -> bool { +# true +# } +# +# fn evaluate( +# &mut self, +# values: &[ArrayRef], +# range: &std::ops::Range, +# ) -> Result { +# // Again, the input argument is an array of floating +# // point numbers to calculate a moving average +# let arr: &Float64Array = values[0].as_ref().as_primitive::(); +# +# let range_len = range.end - range.start; +# +# // our smoothing function will average all the values in the +# let output = if range_len > 0 { +# let sum: f64 = arr.values().iter().skip(range.start).take(range_len).sum(); +# Some(sum / range_len as f64) +# } else { +# None +# }; +# +# Ok(ScalarValue::Float64(output)) +# } +# } +# fn make_partition_evaluator() -> Result> { +# Ok(Box::new(MyPartitionEvaluator::new())) +# } +# use datafusion::logical_expr::{Volatility, create_udwf}; +# use datafusion::arrow::datatypes::DataType; +# use std::sync::Arc; +# +# // here is where we define the UDWF. We also declare its signature: +# let smooth_it = create_udwf( +# "smooth_it", +# DataType::Float64, +# Arc::new(DataType::Float64), +# Volatility::Immutable, +# Arc::new(make_partition_evaluator), +# ); use datafusion::execution::context::SessionContext; let ctx = SessionContext::new(); @@ -336,10 +572,9 @@ ctx.register_udwf(smooth_it); At this point, you can use the `smooth_it` function in your query: -For example, if we have a [ -`cars.csv`](https://github.com/apache/datafusion/blob/main/datafusion/core/tests/data/cars.csv) whose contents like +For example, if we have a [`cars.csv`](https://github.com/apache/datafusion/blob/main/datafusion/core/tests/data/cars.csv) whose contents like -``` +```csv car,speed,time red,20.0,1996-04-12T12:05:03.000000000 red,20.3,1996-04-12T12:05:04.000000000 @@ -351,30 +586,97 @@ green,10.3,1996-04-12T12:05:04.000000000 Then, we can query like below: ```rust +# use datafusion::arrow::{array::{ArrayRef, Float64Array, AsArray}, datatypes::Float64Type}; +# use datafusion::logical_expr::{PartitionEvaluator}; +# use datafusion::common::ScalarValue; +# use datafusion::error::Result; +# +# #[derive(Clone, Debug)] +# struct MyPartitionEvaluator {} +# +# impl MyPartitionEvaluator { +# fn new() -> Self { +# Self {} +# } +# } +# +# impl PartitionEvaluator for MyPartitionEvaluator { +# fn uses_window_frame(&self) -> bool { +# true +# } +# +# fn evaluate( +# &mut self, +# values: &[ArrayRef], +# range: &std::ops::Range, +# ) -> Result { +# // Again, the input argument is an array of floating +# // point numbers to calculate a moving average +# let arr: &Float64Array = values[0].as_ref().as_primitive::(); +# +# let range_len = range.end - range.start; +# +# // our smoothing function will average all the values in the +# let output = if range_len > 0 { +# let sum: f64 = arr.values().iter().skip(range.start).take(range_len).sum(); +# Some(sum / range_len as f64) +# } else { +# None +# }; +# +# Ok(ScalarValue::Float64(output)) +# } +# } +# fn make_partition_evaluator() -> Result> { +# Ok(Box::new(MyPartitionEvaluator::new())) +# } +# use datafusion::logical_expr::{Volatility, create_udwf}; +# use datafusion::arrow::datatypes::DataType; +# use std::sync::Arc; +# use datafusion::execution::context::SessionContext; + use datafusion::datasource::file_format::options::CsvReadOptions; -// register csv table first -let csv_path = "cars.csv".to_string(); -ctx.register_csv("cars", & csv_path, CsvReadOptions::default ().has_header(true)).await?; -// do query with smooth_it -let df = ctx -.sql( -"SELECT \ - car, \ - speed, \ - smooth_it(speed) OVER (PARTITION BY car ORDER BY time) as smooth_speed,\ - time \ - from cars \ - ORDER BY \ - car", -) -.await?; -// print the results -df.show().await?; + +#[tokio::main] +async fn main() -> Result<()> { + + let ctx = SessionContext::new(); + + let smooth_it = create_udwf( + "smooth_it", + DataType::Float64, + Arc::new(DataType::Float64), + Volatility::Immutable, + Arc::new(make_partition_evaluator), + ); + ctx.register_udwf(smooth_it); + + // register csv table first + let csv_path = "../../datafusion/core/tests/data/cars.csv".to_string(); + ctx.register_csv("cars", &csv_path, CsvReadOptions::default().has_header(true)).await?; + + // do query with smooth_it + let df = ctx + .sql(r#" + SELECT + car, + speed, + smooth_it(speed) OVER (PARTITION BY car ORDER BY time) as smooth_speed, + time + FROM cars + ORDER BY car + "#) + .await?; + + // print the results + df.show().await?; + Ok(()) +} ``` -the output will be like: +The output will be like: -``` +```text +-------+-------+--------------------+---------------------+ | car | speed | smooth_speed | time | +-------+-------+--------------------+---------------------+ @@ -403,6 +705,7 @@ Aggregate UDFs are functions that take a group of rows and return a single value For example, we will declare a single-type, single return type UDAF that computes the geometric mean. ```rust + use datafusion::arrow::array::ArrayRef; use datafusion::scalar::ScalarValue; use datafusion::{error::Result, physical_plan::Accumulator}; @@ -427,7 +730,7 @@ impl Accumulator for GeometricMean { // This function serializes our state to `ScalarValue`, which DataFusion uses // to pass this state between execution stages. // Note that this can be arbitrary data. - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![ ScalarValue::from(self.prod), ScalarValue::from(self.n), @@ -436,7 +739,7 @@ impl Accumulator for GeometricMean { // DataFusion expects this function to return the final value of this aggregator. // in this case, this is the formula of the geometric mean - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { let value = self.prod.powf(1.0 / self.n as f64); Ok(ScalarValue::from(value)) } @@ -491,37 +794,106 @@ impl Accumulator for GeometricMean { } ``` -### registering an Aggregate UDF +### Registering an Aggregate UDF To register a Aggregate UDF, you need to wrap the function implementation in a [`AggregateUDF`] struct and then register it with the `SessionContext`. DataFusion provides the [`create_udaf`] helper functions to make this easier. There is a lower level API with more functionality but is more complex, that is documented in [`advanced_udaf.rs`]. ```rust +# use datafusion::arrow::array::ArrayRef; +# use datafusion::scalar::ScalarValue; +# use datafusion::{error::Result, physical_plan::Accumulator}; +# +# #[derive(Debug)] +# struct GeometricMean { +# n: u32, +# prod: f64, +# } +# +# impl GeometricMean { +# pub fn new() -> Self { +# GeometricMean { n: 0, prod: 1.0 } +# } +# } +# +# impl Accumulator for GeometricMean { +# fn state(&mut self) -> Result> { +# Ok(vec![ +# ScalarValue::from(self.prod), +# ScalarValue::from(self.n), +# ]) +# } +# +# fn evaluate(&mut self) -> Result { +# let value = self.prod.powf(1.0 / self.n as f64); +# Ok(ScalarValue::from(value)) +# } +# +# fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { +# if values.is_empty() { +# return Ok(()); +# } +# let arr = &values[0]; +# (0..arr.len()).try_for_each(|index| { +# let v = ScalarValue::try_from_array(arr, index)?; +# +# if let ScalarValue::Float64(Some(value)) = v { +# self.prod *= value; +# self.n += 1; +# } else { +# unreachable!("") +# } +# Ok(()) +# }) +# } +# +# fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { +# if states.is_empty() { +# return Ok(()); +# } +# let arr = &states[0]; +# (0..arr.len()).try_for_each(|index| { +# let v = states +# .iter() +# .map(|array| ScalarValue::try_from_array(array, index)) +# .collect::>>()?; +# if let (ScalarValue::Float64(Some(prod)), ScalarValue::UInt32(Some(n))) = (&v[0], &v[1]) +# { +# self.prod *= prod; +# self.n += n; +# } else { +# unreachable!("") +# } +# Ok(()) +# }) +# } +# +# fn size(&self) -> usize { +# std::mem::size_of_val(self) +# } +# } + use datafusion::logical_expr::{Volatility, create_udaf}; use datafusion::arrow::datatypes::DataType; use std::sync::Arc; // here is where we define the UDAF. We also declare its signature: let geometric_mean = create_udaf( -// the name; used to represent it in plan descriptions and in the registry, to use in SQL. -"geo_mean", -// the input type; DataFusion guarantees that the first entry of `values` in `update` has this type. -vec![DataType::Float64], -// the return type; DataFusion expects this to match the type returned by `evaluate`. -Arc::new(DataType::Float64), -Volatility::Immutable, -// This is the accumulator factory; DataFusion uses it to create new accumulators. -Arc::new( | _ | Ok(Box::new(GeometricMean::new()))), -// This is the description of the state. `state()` must match the types here. -Arc::new(vec![DataType::Float64, DataType::UInt32]), + // the name; used to represent it in plan descriptions and in the registry, to use in SQL. + "geo_mean", + // the input type; DataFusion guarantees that the first entry of `values` in `update` has this type. + vec![DataType::Float64], + // the return type; DataFusion expects this to match the type returned by `evaluate`. + Arc::new(DataType::Float64), + Volatility::Immutable, + // This is the accumulator factory; DataFusion uses it to create new accumulators. + Arc::new( | _ | Ok(Box::new(GeometricMean::new()))), + // This is the description of the state. `state()` must match the types here. + Arc::new(vec![DataType::Float64, DataType::UInt32]), ); ``` -[`aggregateudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.AggregateUDF.html -[`create_udaf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udaf.html -[`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs - The `create_udaf` has six arguments to check: - The first argument is the name of the function. This is the name that will be used in SQL queries. @@ -535,22 +907,119 @@ The `create_udaf` has six arguments to check: - The fifth argument is the function implementation. This is the function that we defined above. - The sixth argument is the description of the state, which will by passed between execution stages. -That gives us a `AggregateUDF` that we can register with the `SessionContext`: - ```rust -use datafusion::execution::context::SessionContext; -let ctx = SessionContext::new(); +# use datafusion::arrow::array::ArrayRef; +# use datafusion::scalar::ScalarValue; +# use datafusion::{error::Result, physical_plan::Accumulator}; +# +# #[derive(Debug)] +# struct GeometricMean { +# n: u32, +# prod: f64, +# } +# +# impl GeometricMean { +# pub fn new() -> Self { +# GeometricMean { n: 0, prod: 1.0 } +# } +# } +# +# impl Accumulator for GeometricMean { +# fn state(&mut self) -> Result> { +# Ok(vec![ +# ScalarValue::from(self.prod), +# ScalarValue::from(self.n), +# ]) +# } +# +# fn evaluate(&mut self) -> Result { +# let value = self.prod.powf(1.0 / self.n as f64); +# Ok(ScalarValue::from(value)) +# } +# +# fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { +# if values.is_empty() { +# return Ok(()); +# } +# let arr = &values[0]; +# (0..arr.len()).try_for_each(|index| { +# let v = ScalarValue::try_from_array(arr, index)?; +# +# if let ScalarValue::Float64(Some(value)) = v { +# self.prod *= value; +# self.n += 1; +# } else { +# unreachable!("") +# } +# Ok(()) +# }) +# } +# +# fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { +# if states.is_empty() { +# return Ok(()); +# } +# let arr = &states[0]; +# (0..arr.len()).try_for_each(|index| { +# let v = states +# .iter() +# .map(|array| ScalarValue::try_from_array(array, index)) +# .collect::>>()?; +# if let (ScalarValue::Float64(Some(prod)), ScalarValue::UInt32(Some(n))) = (&v[0], &v[1]) +# { +# self.prod *= prod; +# self.n += n; +# } else { +# unreachable!("") +# } +# Ok(()) +# }) +# } +# +# fn size(&self) -> usize { +# std::mem::size_of_val(self) +# } +# } -ctx.register_udaf(geometric_mean); -``` +use datafusion::logical_expr::{Volatility, create_udaf}; +use datafusion::arrow::datatypes::DataType; +use std::sync::Arc; +use datafusion::execution::context::SessionContext; +use datafusion::datasource::file_format::options::CsvReadOptions; -Then, we can query like below: +#[tokio::main] +async fn main() -> Result<()> { + let geometric_mean = create_udaf( + "geo_mean", + vec![DataType::Float64], + Arc::new(DataType::Float64), + Volatility::Immutable, + Arc::new( | _ | Ok(Box::new(GeometricMean::new()))), + Arc::new(vec![DataType::Float64, DataType::UInt32]), + ); + + // That gives us a `AggregateUDF` that we can register with the `SessionContext`: + use datafusion::execution::context::SessionContext; + + let ctx = SessionContext::new(); + ctx.register_udaf(geometric_mean); + + // register csv table first + let csv_path = "../../datafusion/core/tests/data/cars.csv".to_string(); + ctx.register_csv("cars", &csv_path, CsvReadOptions::default().has_header(true)).await?; + + // Then, we can query like below: + let df = ctx.sql("SELECT geo_mean(speed) FROM cars").await?; + Ok(()) +} -```rust -let df = ctx.sql("SELECT geo_mean(a) FROM t").await?; ``` +[`aggregateudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.AggregateUDF.html +[`create_udaf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udaf.html +[`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs + ## Adding a User-Defined Table Function A User-Defined Table Function (UDTF) is a function that takes parameters and returns a `TableProvider`. @@ -592,12 +1061,17 @@ In the `call` method, you parse the input `Expr`s and return a `TableProvider`. validation of the input `Expr`s, e.g. checking that the number of arguments is correct. ```rust -use datafusion::common::plan_err; -use datafusion::datasource::function::TableFunctionImpl; -// Other imports here +use std::sync::Arc; +use datafusion::common::{plan_err, ScalarValue, Result}; +use datafusion::catalog::{TableFunctionImpl, TableProvider}; +use datafusion::arrow::array::{ArrayRef, Int64Array}; +use datafusion::datasource::memory::MemTable; +use arrow::record_batch::RecordBatch; +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion_expr::Expr; /// A table function that returns a table provider with the value as a single column -#[derive(Default)] +#[derive(Debug)] pub struct EchoFunction {} impl TableFunctionImpl for EchoFunction { @@ -628,22 +1102,57 @@ impl TableFunctionImpl for EchoFunction { With the UDTF implemented, you can register it with the `SessionContext`: ```rust +# use std::sync::Arc; +# use datafusion::common::{plan_err, ScalarValue, Result}; +# use datafusion::catalog::{TableFunctionImpl, TableProvider}; +# use datafusion::arrow::array::{ArrayRef, Int64Array}; +# use datafusion::datasource::memory::MemTable; +# use arrow::record_batch::RecordBatch; +# use arrow::datatypes::{DataType, Field, Schema}; +# use datafusion_expr::Expr; +# +# /// A table function that returns a table provider with the value as a single column +# #[derive(Debug, Default)] +# pub struct EchoFunction {} +# +# impl TableFunctionImpl for EchoFunction { +# fn call(&self, exprs: &[Expr]) -> Result> { +# let Some(Expr::Literal(ScalarValue::Int64(Some(value)))) = exprs.get(0) else { +# return plan_err!("First argument must be an integer"); +# }; +# +# // Create the schema for the table +# let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); +# +# // Create a single RecordBatch with the value as a single column +# let batch = RecordBatch::try_new( +# schema.clone(), +# vec![Arc::new(Int64Array::from(vec![*value]))], +# )?; +# +# // Create a MemTable plan that returns the RecordBatch +# let provider = MemTable::try_new(schema, vec![vec![batch]])?; +# +# Ok(Arc::new(provider)) +# } +# } + use datafusion::execution::context::SessionContext; +use datafusion::arrow::util::pretty; -let ctx = SessionContext::new(); +#[tokio::main] +async fn main() -> Result<()> { + let ctx = SessionContext::new(); -ctx.register_udtf("echo", Arc::new(EchoFunction::default ())); -``` + ctx.register_udtf("echo", Arc::new(EchoFunction::default())); -And if all goes well, you can use it in your query: + // And if all goes well, you can use it in your query: -```rust -use datafusion::arrow::util::pretty; - -let df = ctx.sql("SELECT * FROM echo(1)").await?; + let results = ctx.sql("SELECT * FROM echo(1)").await?.collect().await?; + pretty::print_batches(&results)?; + Ok(()) +} -let results = df.collect().await?; -pretty::print_batches( & results) ?; // +---+ // | a | // +---+ diff --git a/docs/source/library-user-guide/api-health.md b/docs/source/library-user-guide/api-health.md index b9c6de370e55..87d3754b21a7 100644 --- a/docs/source/library-user-guide/api-health.md +++ b/docs/source/library-user-guide/api-health.md @@ -62,8 +62,8 @@ To mark the API as deprecated, use the `#[deprecated(since = "...", note = "..." For example: ```rust - #[deprecated(since = "41.0.0", note = "Use SessionStateBuilder")] - pub fn new_with_config_rt(config: SessionConfig, runtime: Arc) -> Self +#[deprecated(since = "41.0.0", note = "Use new API instead")] +pub fn api_to_deprecated(a: usize, b: usize) {} ``` Deprecated methods will remain in the codebase for a period of 6 major versions or 6 months, whichever is longer, to provide users ample time to transition away from them. diff --git a/docs/source/library-user-guide/catalogs.md b/docs/source/library-user-guide/catalogs.md index 13158d656423..906039ba2300 100644 --- a/docs/source/library-user-guide/catalogs.md +++ b/docs/source/library-user-guide/catalogs.md @@ -40,6 +40,11 @@ In the following example, we'll implement an in memory catalog, starting with th The `MemorySchemaProvider` is a simple implementation of the `SchemaProvider` trait. It stores state (i.e. tables) in a `DashMap`, which then underlies the `SchemaProvider` trait. ```rust +use std::sync::Arc; +use dashmap::DashMap; +use datafusion::catalog::{TableProvider, SchemaProvider}; + +#[derive(Debug)] pub struct MemorySchemaProvider { tables: DashMap>, } @@ -50,6 +55,20 @@ pub struct MemorySchemaProvider { Then we implement the `SchemaProvider` trait for `MemorySchemaProvider`. ```rust +# use std::sync::Arc; +# use dashmap::DashMap; +# use datafusion::catalog::TableProvider; +# +# #[derive(Debug)] +# pub struct MemorySchemaProvider { +# tables: DashMap>, +# } + +use std::any::Any; +use datafusion::catalog::SchemaProvider; +use async_trait::async_trait; +use datafusion::common::{Result, exec_err}; + #[async_trait] impl SchemaProvider for MemorySchemaProvider { fn as_any(&self) -> &dyn Any { @@ -63,8 +82,8 @@ impl SchemaProvider for MemorySchemaProvider { .collect() } - async fn table(&self, name: &str) -> Option> { - self.tables.get(name).map(|table| table.value().clone()) + async fn table(&self, name: &str) -> Result>> { + Ok(self.tables.get(name).map(|table| table.value().clone())) } fn register_table( @@ -93,12 +112,85 @@ impl SchemaProvider for MemorySchemaProvider { Without getting into a `CatalogProvider` implementation, we can create a `MemorySchemaProvider` and register `TableProvider`s with it. ```rust +# use std::sync::Arc; +# use dashmap::DashMap; +# use datafusion::catalog::TableProvider; +# +# #[derive(Debug)] +# pub struct MemorySchemaProvider { +# tables: DashMap>, +# } +# +# use std::any::Any; +# use datafusion::catalog::SchemaProvider; +# use async_trait::async_trait; +# use datafusion::common::{Result, exec_err}; +# +# #[async_trait] +# impl SchemaProvider for MemorySchemaProvider { +# fn as_any(&self) -> &dyn Any { +# self +# } +# +# fn table_names(&self) -> Vec { +# self.tables +# .iter() +# .map(|table| table.key().clone()) +# .collect() +# } +# +# async fn table(&self, name: &str) -> Result>> { +# Ok(self.tables.get(name).map(|table| table.value().clone())) +# } +# +# fn register_table( +# &self, +# name: String, +# table: Arc, +# ) -> Result>> { +# if self.table_exist(name.as_str()) { +# return exec_err!( +# "The table {name} already exists" +# ); +# } +# Ok(self.tables.insert(name, table)) +# } +# +# fn deregister_table(&self, name: &str) -> Result>> { +# Ok(self.tables.remove(name).map(|(_, table)| table)) +# } +# +# fn table_exist(&self, name: &str) -> bool { +# self.tables.contains_key(name) +# } +# } + +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use datafusion::datasource::MemTable; +use arrow::array::{self, Array, ArrayRef, Int32Array}; + +impl MemorySchemaProvider { + /// Instantiates a new MemorySchemaProvider with an empty collection of tables. + pub fn new() -> Self { + Self { + tables: DashMap::new(), + } + } +} + let schema_provider = Arc::new(MemorySchemaProvider::new()); -let table_provider = _; // create a table provider -schema_provider.register_table("table_name".to_string(), table_provider); +let table_provider = { + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); + let arr = Arc::new(Int32Array::from((1..=1).collect::>())); + let partitions = vec![vec![RecordBatch::try_new(schema.clone(), vec![arr as ArrayRef]).unwrap()]]; + Arc::new(MemTable::try_new(schema, partitions).unwrap()) +}; + +schema_provider.register_table("users".to_string(), table_provider); -let table = schema_provider.table("table_name").unwrap(); +let table = schema_provider.table("users"); ``` ### Asynchronous `SchemaProvider` @@ -108,27 +200,44 @@ It's often useful to fetch metadata about which tables are in a schema, from a r The trait is roughly the same except for the `table` method, and the addition of the `#[async_trait]` attribute. ```rust +# use async_trait::async_trait; +# use std::sync::Arc; +# use datafusion::catalog::{TableProvider, SchemaProvider}; +# use datafusion::common::Result; +# +# type OriginSchema = arrow::datatypes::Schema; +# +# #[derive(Debug)] +# struct Schema(OriginSchema); + #[async_trait] impl SchemaProvider for Schema { - async fn table(&self, name: &str) -> Option> { - // fetch metadata from remote source + async fn table(&self, name: &str) -> Result>> { +# todo!(); } + +# fn as_any(&self) -> &(dyn std::any::Any + 'static) { todo!() } +# fn table_names(&self) -> Vec { todo!() } +# fn table_exist(&self, _: &str) -> bool { todo!() } } ``` ## Implementing `MemoryCatalogProvider` -As mentioned, the `CatalogProvider` can manage the schemas in a catalog, and the `MemoryCatalogProvider` is a simple implementation of the `CatalogProvider` trait. It stores schemas in a `DashMap`. +As mentioned, the `CatalogProvider` can manage the schemas in a catalog, and the `MemoryCatalogProvider` is a simple implementation of the `CatalogProvider` trait. It stores schemas in a `DashMap`. With that the `CatalogProvider` trait can be implemented. ```rust +use std::any::Any; +use std::sync::Arc; +use dashmap::DashMap; +use datafusion::catalog::{CatalogProvider, SchemaProvider}; +use datafusion::common::Result; + +#[derive(Debug)] pub struct MemoryCatalogProvider { schemas: DashMap>, } -``` - -With that the `CatalogProvider` trait can be implemented. -```rust impl CatalogProvider for MemoryCatalogProvider { fn as_any(&self) -> &dyn Any { self @@ -167,20 +276,24 @@ impl CatalogProvider for MemoryCatalogProvider { } ``` -Again, this is fairly straightforward, as there's an underlying data structure to store the state, via key-value pairs. +Again, this is fairly straightforward, as there's an underlying data structure to store the state, via key-value pairs. With that the `CatalogProviderList` trait can be implemented. ## Implementing `MemoryCatalogProviderList` ```rust + +use std::any::Any; +use std::sync::Arc; +use dashmap::DashMap; +use datafusion::catalog::{CatalogProviderList, CatalogProvider}; +use datafusion::common::Result; + +#[derive(Debug)] pub struct MemoryCatalogProviderList { /// Collection of catalogs containing schemas and ultimately TableProviders pub catalogs: DashMap>, } -``` -With that the `CatalogProviderList` trait can be implemented. - -```rust impl CatalogProviderList for MemoryCatalogProviderList { fn as_any(&self) -> &dyn Any { self diff --git a/docs/source/library-user-guide/custom-table-providers.md b/docs/source/library-user-guide/custom-table-providers.md index a7183fb3113e..886ac9629566 100644 --- a/docs/source/library-user-guide/custom-table-providers.md +++ b/docs/source/library-user-guide/custom-table-providers.md @@ -37,19 +37,84 @@ The `ExecutionPlan` trait at its core is a way to get a stream of batches. The a There are many different types of `SendableRecordBatchStream` implemented in DataFusion -- you can use a pre existing one, such as `MemoryStream` (if your `RecordBatch`es are all in memory) or implement your own custom logic, depending on your usecase. -Looking at the [example in this repo][ex], the execute method: +Looking at the full example below: ```rust +use std::any::Any; +use std::sync::{Arc, Mutex}; +use std::collections::{BTreeMap, HashMap}; +use datafusion::common::Result; +use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::physical_plan::expressions::PhysicalSortExpr; +use datafusion::physical_plan::{ + ExecutionPlan, SendableRecordBatchStream, DisplayAs, DisplayFormatType, + Statistics, PlanProperties +}; +use datafusion::execution::context::TaskContext; +use datafusion::arrow::array::{UInt64Builder, UInt8Builder}; +use datafusion::physical_plan::memory::MemoryStream; +use datafusion::arrow::record_batch::RecordBatch; + +/// A User, with an id and a bank account +#[derive(Clone, Debug)] +struct User { + id: u8, + bank_account: u64, +} + +/// A custom datasource, used to represent a datastore with a single index +#[derive(Clone, Debug)] +pub struct CustomDataSource { + inner: Arc>, +} + +#[derive(Debug)] +struct CustomDataSourceInner { + data: HashMap, + bank_account_index: BTreeMap, +} + +#[derive(Debug)] struct CustomExec { db: CustomDataSource, projected_schema: SchemaRef, } +impl DisplayAs for CustomExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "CustomExec") + } +} + impl ExecutionPlan for CustomExec { - fn name(&self) { + fn name(&self) -> &str { "CustomExec" } + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.projected_schema.clone() + } + + + fn properties(&self) -> &PlanProperties { + unreachable!() + } + + fn children(&self) -> Vec<&Arc> { + Vec::new() + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + Ok(self) + } + fn execute( &self, _partition: usize, @@ -83,7 +148,7 @@ impl ExecutionPlan for CustomExec { } ``` -This: +This `execute` method: 1. Gets the users from the database 2. Constructs the individual output arrays (columns) @@ -98,6 +163,134 @@ With the `ExecutionPlan` implemented, we can now implement the `scan` method of The `scan` method of the `TableProvider` returns a `Result>`. We can use the `Arc` to return a reference-counted pointer to the `ExecutionPlan` we implemented. In the example, this is done by: ```rust + +# use std::any::Any; +# use std::sync::{Arc, Mutex}; +# use std::collections::{BTreeMap, HashMap}; +# use datafusion::common::Result; +# use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +# use datafusion::physical_plan::expressions::PhysicalSortExpr; +# use datafusion::physical_plan::{ +# ExecutionPlan, SendableRecordBatchStream, DisplayAs, DisplayFormatType, +# Statistics, PlanProperties +# }; +# use datafusion::execution::context::TaskContext; +# use datafusion::arrow::array::{UInt64Builder, UInt8Builder}; +# use datafusion::physical_plan::memory::MemoryStream; +# use datafusion::arrow::record_batch::RecordBatch; +# +# /// A User, with an id and a bank account +# #[derive(Clone, Debug)] +# struct User { +# id: u8, +# bank_account: u64, +# } +# +# /// A custom datasource, used to represent a datastore with a single index +# #[derive(Clone, Debug)] +# pub struct CustomDataSource { +# inner: Arc>, +# } +# +# #[derive(Debug)] +# struct CustomDataSourceInner { +# data: HashMap, +# bank_account_index: BTreeMap, +# } +# +# #[derive(Debug)] +# struct CustomExec { +# db: CustomDataSource, +# projected_schema: SchemaRef, +# } +# +# impl DisplayAs for CustomExec { +# fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { +# write!(f, "CustomExec") +# } +# } +# +# impl ExecutionPlan for CustomExec { +# fn name(&self) -> &str { +# "CustomExec" +# } +# +# fn as_any(&self) -> &dyn Any { +# self +# } +# +# fn schema(&self) -> SchemaRef { +# self.projected_schema.clone() +# } +# +# +# fn properties(&self) -> &PlanProperties { +# unreachable!() +# } +# +# fn children(&self) -> Vec<&Arc> { +# Vec::new() +# } +# +# fn with_new_children( +# self: Arc, +# _: Vec>, +# ) -> Result> { +# Ok(self) +# } +# +# fn execute( +# &self, +# _partition: usize, +# _context: Arc, +# ) -> Result { +# let users: Vec = { +# let db = self.db.inner.lock().unwrap(); +# db.data.values().cloned().collect() +# }; +# +# let mut id_array = UInt8Builder::with_capacity(users.len()); +# let mut account_array = UInt64Builder::with_capacity(users.len()); +# +# for user in users { +# id_array.append_value(user.id); +# account_array.append_value(user.bank_account); +# } +# +# Ok(Box::pin(MemoryStream::try_new( +# vec![RecordBatch::try_new( +# self.projected_schema.clone(), +# vec![ +# Arc::new(id_array.finish()), +# Arc::new(account_array.finish()), +# ], +# )?], +# self.schema(), +# None, +# )?)) +# } +# } + +use async_trait::async_trait; +use datafusion::logical_expr::expr::Expr; +use datafusion::datasource::{TableProvider, TableType}; +use datafusion::physical_plan::project_schema; +use datafusion::catalog::Session; + +impl CustomExec { + fn new( + projections: Option<&Vec>, + schema: SchemaRef, + db: CustomDataSource, + ) -> Self { + let projected_schema = project_schema(&schema, projections).unwrap(); + Self { + db, + projected_schema, + } + } +} + impl CustomDataSource { pub(crate) async fn create_physical_plan( &self, @@ -110,6 +303,21 @@ impl CustomDataSource { #[async_trait] impl TableProvider for CustomDataSource { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + SchemaRef::new(Schema::new(vec![ + Field::new("id", DataType::UInt8, false), + Field::new("bank_account", DataType::UInt64, true), + ])) + } + + fn table_type(&self) -> TableType { + TableType::Base + } + async fn scan( &self, _state: &dyn Session, @@ -145,17 +353,194 @@ For filters that can be pushed down, they'll be passed to the `scan` method as t In order to use the custom table provider, we need to register it with DataFusion. This is done by creating a `TableProvider` and registering it with the `SessionContext`. -```rust -let ctx = SessionContext::new(); - -let custom_table_provider = CustomDataSource::new(); -ctx.register_table("custom_table", Arc::new(custom_table_provider)); -``` - This will allow you to use the custom table provider in DataFusion. For example, you could use it in a SQL query to get a `DataFrame`. ```rust -let df = ctx.sql("SELECT id, bank_account FROM custom_table")?; +# use std::any::Any; +# use std::sync::{Arc, Mutex}; +# use std::collections::{BTreeMap, HashMap}; +# use datafusion::common::Result; +# use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +# use datafusion::physical_plan::expressions::PhysicalSortExpr; +# use datafusion::physical_plan::{ +# ExecutionPlan, SendableRecordBatchStream, DisplayAs, DisplayFormatType, +# Statistics, PlanProperties +# }; +# use datafusion::execution::context::TaskContext; +# use datafusion::arrow::array::{UInt64Builder, UInt8Builder}; +# use datafusion::physical_plan::memory::MemoryStream; +# use datafusion::arrow::record_batch::RecordBatch; +# +# /// A User, with an id and a bank account +# #[derive(Clone, Debug)] +# struct User { +# id: u8, +# bank_account: u64, +# } +# +# /// A custom datasource, used to represent a datastore with a single index +# #[derive(Clone, Debug)] +# pub struct CustomDataSource { +# inner: Arc>, +# } +# +# #[derive(Debug)] +# struct CustomDataSourceInner { +# data: HashMap, +# bank_account_index: BTreeMap, +# } +# +# #[derive(Debug)] +# struct CustomExec { +# db: CustomDataSource, +# projected_schema: SchemaRef, +# } +# +# impl DisplayAs for CustomExec { +# fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { +# write!(f, "CustomExec") +# } +# } +# +# impl ExecutionPlan for CustomExec { +# fn name(&self) -> &str { +# "CustomExec" +# } +# +# fn as_any(&self) -> &dyn Any { +# self +# } +# +# fn schema(&self) -> SchemaRef { +# self.projected_schema.clone() +# } +# +# +# fn properties(&self) -> &PlanProperties { +# unreachable!() +# } +# +# fn children(&self) -> Vec<&Arc> { +# Vec::new() +# } +# +# fn with_new_children( +# self: Arc, +# _: Vec>, +# ) -> Result> { +# Ok(self) +# } +# +# fn execute( +# &self, +# _partition: usize, +# _context: Arc, +# ) -> Result { +# let users: Vec = { +# let db = self.db.inner.lock().unwrap(); +# db.data.values().cloned().collect() +# }; +# +# let mut id_array = UInt8Builder::with_capacity(users.len()); +# let mut account_array = UInt64Builder::with_capacity(users.len()); +# +# for user in users { +# id_array.append_value(user.id); +# account_array.append_value(user.bank_account); +# } +# +# Ok(Box::pin(MemoryStream::try_new( +# vec![RecordBatch::try_new( +# self.projected_schema.clone(), +# vec![ +# Arc::new(id_array.finish()), +# Arc::new(account_array.finish()), +# ], +# )?], +# self.schema(), +# None, +# )?)) +# } +# } + +# use async_trait::async_trait; +# use datafusion::logical_expr::expr::Expr; +# use datafusion::datasource::{TableProvider, TableType}; +# use datafusion::physical_plan::project_schema; +# use datafusion::catalog::Session; +# +# impl CustomExec { +# fn new( +# projections: Option<&Vec>, +# schema: SchemaRef, +# db: CustomDataSource, +# ) -> Self { +# let projected_schema = project_schema(&schema, projections).unwrap(); +# Self { +# db, +# projected_schema, +# } +# } +# } +# +# impl CustomDataSource { +# pub(crate) async fn create_physical_plan( +# &self, +# projections: Option<&Vec>, +# schema: SchemaRef, +# ) -> Result> { +# Ok(Arc::new(CustomExec::new(projections, schema, self.clone()))) +# } +# } +# +# #[async_trait] +# impl TableProvider for CustomDataSource { +# fn as_any(&self) -> &dyn Any { +# self +# } +# +# fn schema(&self) -> SchemaRef { +# SchemaRef::new(Schema::new(vec![ +# Field::new("id", DataType::UInt8, false), +# Field::new("bank_account", DataType::UInt64, true), +# ])) +# } +# +# fn table_type(&self) -> TableType { +# TableType::Base +# } +# +# async fn scan( +# &self, +# _state: &dyn Session, +# projection: Option<&Vec>, +# // filters and limit can be used here to inject some push-down operations if needed +# _filters: &[Expr], +# _limit: Option, +# ) -> Result> { +# return self.create_physical_plan(projection, self.schema()).await; +# } +# } + +use datafusion::execution::context::SessionContext; + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = SessionContext::new(); + + let custom_table_provider = CustomDataSource { + inner: Arc::new(Mutex::new(CustomDataSourceInner { + data: Default::default(), + bank_account_index: Default::default(), + })), + }; + + ctx.register_table("customers", Arc::new(custom_table_provider)); + let df = ctx.sql("SELECT id, bank_account FROM customers").await?; + + Ok(()) +} + ``` ## Recap diff --git a/docs/source/library-user-guide/query-optimizer.md b/docs/source/library-user-guide/query-optimizer.md index c2c60af85f4c..fad8adf83d81 100644 --- a/docs/source/library-user-guide/query-optimizer.md +++ b/docs/source/library-user-guide/query-optimizer.md @@ -35,18 +35,28 @@ and applying it to a logical plan to produce an optimized logical plan. ```rust +use std::sync::Arc; +use datafusion::logical_expr::{col, lit, LogicalPlan, LogicalPlanBuilder}; +use datafusion::optimizer::{OptimizerRule, OptimizerContext, Optimizer}; + // We need a logical plan as the starting point. There are many ways to build a logical plan: // // The `datafusion-expr` crate provides a LogicalPlanBuilder // The `datafusion-sql` crate provides a SQL query planner that can create a LogicalPlan from SQL // The `datafusion` crate provides a DataFrame API that can create a LogicalPlan -let logical_plan = ... -let mut config = OptimizerContext::default(); -let optimizer = Optimizer::new(&config); -let optimized_plan = optimizer.optimize(&logical_plan, &config, observe)?; +let initial_logical_plan = LogicalPlanBuilder::empty(false).build().unwrap(); + +// use builtin rules or customized rules +let rules: Vec> = vec![]; + +let optimizer = Optimizer::with_rules(rules); + +let config = OptimizerContext::new().with_max_passes(16); -fn observe(plan: &LogicalPlan, rule: &dyn OptimizerRule) { +let optimized_plan = optimizer.optimize(initial_logical_plan.clone(), &config, observer); + +fn observer(plan: &LogicalPlan, rule: &dyn OptimizerRule) { println!( "After applying rule '{}':\n{}", rule.name(), @@ -55,16 +65,6 @@ fn observe(plan: &LogicalPlan, rule: &dyn OptimizerRule) { } ``` -## Providing Custom Rules - -The optimizer can be created with a custom set of rules. - -```rust -let optimizer = Optimizer::with_rules(vec![ - Arc::new(MyRule {}) -]); -``` - ## Writing Optimization Rules Please refer to the @@ -72,26 +72,71 @@ Please refer to the example to learn more about the general approach to writing optimizer rules and then move onto studying the existing rules. +`OptimizerRule` transforms one ['LogicalPlan'] into another which +computes the same results, but in a potentially more efficient +way. If there are no suitable transformations for the input plan, +the optimizer can simply return it as is. + All rules must implement the `OptimizerRule` trait. ```rust -/// `OptimizerRule` transforms one ['LogicalPlan'] into another which -/// computes the same results, but in a potentially more efficient -/// way. If there are no suitable transformations for the input plan, -/// the optimizer can simply return it as is. -pub trait OptimizerRule { - /// Rewrite `plan` to an optimized form - fn optimize( - &self, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, - ) -> Result; +# use datafusion::common::tree_node::Transformed; +# use datafusion::common::Result; +# use datafusion::logical_expr::LogicalPlan; +# use datafusion::optimizer::{OptimizerConfig, OptimizerRule}; +# + +#[derive(Default, Debug)] +struct MyOptimizerRule {} + +impl OptimizerRule for MyOptimizerRule { + fn name(&self) -> &str { + "my_optimizer_rule" + } - /// A human readable name for this optimizer rule - fn name(&self) -> &str; + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + unimplemented!() + } } ``` +## Providing Custom Rules + +The optimizer can be created with a custom set of rules. + +```rust +# use std::sync::Arc; +# use datafusion::logical_expr::{col, lit, LogicalPlan, LogicalPlanBuilder}; +# use datafusion::optimizer::{OptimizerRule, OptimizerConfig, OptimizerContext, Optimizer}; +# use datafusion::common::tree_node::Transformed; +# use datafusion::common::Result; +# +# #[derive(Default, Debug)] +# struct MyOptimizerRule {} +# +# impl OptimizerRule for MyOptimizerRule { +# fn name(&self) -> &str { +# "my_optimizer_rule" +# } +# +# fn rewrite( +# &self, +# plan: LogicalPlan, +# _config: &dyn OptimizerConfig, +# ) -> Result> { +# unimplemented!() +# } +# } + +let optimizer = Optimizer::with_rules(vec![ + Arc::new(MyOptimizerRule {}) +]); +``` + ### General Guidelines Rules typical walk the logical plan and walk the expression trees inside operators and selectively mutate @@ -168,16 +213,19 @@ and [#3555](https://github.com/apache/datafusion/issues/3555) occur where the ex There are currently two ways to create a name for an expression in the logical plan. ```rust +# use datafusion::common::Result; +# struct Expr; + impl Expr { /// Returns the name of this expression as it should appear in a schema. This name /// will not include any CAST expressions. pub fn display_name(&self) -> Result { - create_name(self) + Ok("display_name".to_string()) } /// Returns a full and complete string representation of this expression. pub fn canonical_name(&self) -> String { - format!("{}", self) + "canonical_name".to_string() } } ``` @@ -187,93 +235,99 @@ name to be used in a schema, `display_name` should be used. ### Utilities -There are a number of utility methods provided that take care of some common tasks. +There are a number of [utility methods][util] provided that take care of some common tasks. -### ExprVisitor +[util]: https://github.com/apache/datafusion/blob/main/datafusion/expr/src/utils.rs -The `ExprVisitor` and `ExprVisitable` traits provide a mechanism for applying a visitor pattern to an expression tree. +### Recursively walk an expression tree -Here is an example that demonstrates this. +The [TreeNode API] provides a convenient way to recursively walk an expression or plan tree. -```rust -fn extract_subquery_filters(expression: &Expr, extracted: &mut Vec) -> Result<()> { - struct InSubqueryVisitor<'a> { - accum: &'a mut Vec, - } +For example, to find all subquery references in a logical plan, the following code can be used: - impl ExpressionVisitor for InSubqueryVisitor<'_> { - fn pre_visit(self, expr: &Expr) -> Result> { +```rust +# use datafusion::prelude::*; +# use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion}; +# use datafusion::common::Result; +// Return all subquery references in an expression +fn extract_subquery_filters(expression: &Expr) -> Result> { + let mut extracted = vec![]; + expression.apply(|expr| { if let Expr::InSubquery(_) = expr { - self.accum.push(expr.to_owned()); + extracted.push(expr); } - Ok(Recursion::Continue(self)) - } - } + Ok(TreeNodeRecursion::Continue) + })?; + Ok(extracted) +} +``` - expression.accept(InSubqueryVisitor { accum: extracted })?; - Ok(()) +Likewise you can use the [TreeNode API] to rewrite a `LogicalPlan` or `ExecutionPlan` + +```rust +# use datafusion::prelude::*; +# use datafusion::logical_expr::{LogicalPlan, Join}; +# use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion}; +# use datafusion::common::Result; +// Return all joins in a logical plan +fn find_joins(overall_plan: &LogicalPlan) -> Result> { + let mut extracted = vec![]; + overall_plan.apply(|plan| { + if let LogicalPlan::Join(join) = plan { + extracted.push(join); + } + Ok(TreeNodeRecursion::Continue) + })?; + Ok(extracted) } ``` -### Rewriting Expressions +### Rewriting expressions -The `MyExprRewriter` trait can be implemented to provide a way to rewrite expressions. This rule can then be applied -to an expression by calling `Expr::rewrite` (from the `ExprRewritable` trait). +The [TreeNode API] also provides a convenient way to rewrite expressions and +plans as well. For example to rewrite all expressions like -The `rewrite` method will perform a depth first walk of the expression and its children to rewrite an expression, -consuming `self` producing a new expression. +```sql +col BETWEEN x AND y +``` -```rust -let mut expr_rewriter = MyExprRewriter {}; -let expr = expr.rewrite(&mut expr_rewriter)?; +into + +```sql +col >= x AND col <= y ``` -Here is an example implementation which will rewrite `expr BETWEEN a AND b` as `expr >= a AND expr <= b`. Note that the -implementation does not need to perform any recursion since this is handled by the `rewrite` method. +you can use the following code: ```rust -struct MyExprRewriter {} - -impl ExprRewriter for MyExprRewriter { - fn mutate(&mut self, expr: Expr) -> Result { - match expr { - Expr::Between { +# use datafusion::prelude::*; +# use datafusion::logical_expr::{Between}; +# use datafusion::logical_expr::expr_fn::*; +# use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +# use datafusion::common::Result; +// Recursively rewrite all BETWEEN expressions +// returns Transformed::yes if any changes were made +fn rewrite_between(expr: Expr) -> Result> { + // transform_up does a bottom up rewrite + expr.transform_up(|expr| { + // only handle BETWEEN expressions + let Expr::Between(Between { negated, expr, low, high, - } => { - let expr: Expr = expr.as_ref().clone(); - let low: Expr = low.as_ref().clone(); - let high: Expr = high.as_ref().clone(); - if negated { - Ok(expr.clone().lt(low).or(expr.clone().gt(high))) - } else { - Ok(expr.clone().gt_eq(low).and(expr.clone().lt_eq(high))) - } - } - _ => Ok(expr.clone()), - } - } -} -``` - -### optimize_children - -Typically a rule is applied recursively to all operators within a query plan. Rather than duplicate -that logic in each rule, an `optimize_children` method is provided. This recursively invokes the `optimize` method on -the plan's children and then returns a node of the same type. - -```rust -fn optimize( - &self, - plan: &LogicalPlan, - _config: &mut OptimizerConfig, -) -> Result { - // recurse down and optimize children first - let plan = utils::optimize_children(self, plan, _config)?; - - ... + }) = expr else { + return Ok(Transformed::no(expr)) + }; + let rewritten_expr = if negated { + // don't rewrite NOT BETWEEN + Expr::Between(Between::new(expr, negated, low, high)) + } else { + // rewrite to (expr >= low) AND (expr <= high) + expr.clone().gt_eq(*low).and(expr.lt_eq(*high)) + }; + Ok(Transformed::yes(rewritten_expr)) + }) } ``` diff --git a/docs/source/library-user-guide/working-with-exprs.md b/docs/source/library-user-guide/working-with-exprs.md index e0b6f434a032..1a6e9123086d 100644 --- a/docs/source/library-user-guide/working-with-exprs.md +++ b/docs/source/library-user-guide/working-with-exprs.md @@ -61,12 +61,34 @@ We'll use a `ScalarUDF` expression as our example. This necessitates implementin So assuming you've written that function, you can use it to create an `Expr`: ```rust +# use std::sync::Arc; +# use datafusion::arrow::array::{ArrayRef, Int64Array}; +# use datafusion::common::cast::as_int64_array; +# use datafusion::common::Result; +# use datafusion::logical_expr::ColumnarValue; +# +# pub fn add_one(args: &[ColumnarValue]) -> Result { +# // Error handling omitted for brevity +# let args = ColumnarValue::values_to_arrays(args)?; +# let i64s = as_int64_array(&args[0])?; +# +# let new_array = i64s +# .iter() +# .map(|array_elem| array_elem.map(|value| value + 1)) +# .collect::(); +# +# Ok(ColumnarValue::from(Arc::new(new_array) as ArrayRef)) +# } +use datafusion::logical_expr::{Volatility, create_udf}; +use datafusion::arrow::datatypes::DataType; +use datafusion::logical_expr::{col, lit}; + let add_one_udf = create_udf( "add_one", vec![DataType::Int64], - Arc::new(DataType::Int64), + DataType::Int64, Volatility::Immutable, - make_scalar_function(add_one), // <-- the function we wrote + Arc::new(add_one), ); // make the expr `add_one(5)` @@ -99,11 +121,16 @@ In our example, we'll use rewriting to update our `add_one` UDF, to be rewritten To implement the inlining, we'll need to write a function that takes an `Expr` and returns a `Result`. If the expression is _not_ to be rewritten `Transformed::no` is used to wrap the original `Expr`. If the expression _is_ to be rewritten, `Transformed::yes` is used to wrap the new `Expr`. ```rust -fn rewrite_add_one(expr: Expr) -> Result { +use datafusion::common::Result; +use datafusion::common::tree_node::{Transformed, TreeNode}; +use datafusion::logical_expr::{col, lit, Expr}; +use datafusion::logical_expr::{ScalarUDF}; + +fn rewrite_add_one(expr: Expr) -> Result> { expr.transform(&|expr| { Ok(match expr { - Expr::ScalarUDF(scalar_fun) if scalar_fun.fun.name == "add_one" => { - let input_arg = scalar_fun.args[0].clone(); + Expr::ScalarFunction(scalar_func) if scalar_func.func.inner().name() == "add_one" => { + let input_arg = scalar_func.args[0].clone(); let new_expression = input_arg + lit(1i64); Transformed::yes(new_expression) @@ -124,6 +151,27 @@ We'll call our rule `AddOneInliner` and implement the `OptimizerRule` trait. The - `try_optimize` - takes a `LogicalPlan` and returns an `Option`. If the rule is able to optimize the plan, it returns `Some(LogicalPlan)` with the optimized plan. If the rule is not able to optimize the plan, it returns `None`. ```rust +use std::sync::Arc; +use datafusion::common::Result; +use datafusion::common::tree_node::{Transformed, TreeNode}; +use datafusion::logical_expr::{col, lit, Expr, LogicalPlan, LogicalPlanBuilder}; +use datafusion::optimizer::{OptimizerRule, OptimizerConfig, OptimizerContext, Optimizer}; + +# fn rewrite_add_one(expr: Expr) -> Result> { +# expr.transform(&|expr| { +# Ok(match expr { +# Expr::ScalarFunction(scalar_func) if scalar_func.func.inner().name() == "add_one" => { +# let input_arg = scalar_func.args[0].clone(); +# let new_expression = input_arg + lit(1i64); +# +# Transformed::yes(new_expression) +# } +# _ => Transformed::no(expr), +# }) +# }) +# } + +#[derive(Default, Debug)] struct AddOneInliner {} impl OptimizerRule for AddOneInliner { @@ -131,23 +179,26 @@ impl OptimizerRule for AddOneInliner { "add_one_inliner" } - fn try_optimize( + fn rewrite( &self, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, - ) -> Result> { + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { // Map over the expressions and rewrite them - let new_expressions = plan + let new_expressions: Vec = plan .expressions() .into_iter() .map(|expr| rewrite_add_one(expr)) - .collect::>>()?; + .collect::>>()? // returns Vec> + .into_iter() + .map(|transformed| transformed.data) + .collect(); let inputs = plan.inputs().into_iter().cloned().collect::>(); - let plan = plan.with_new_exprs(&new_expressions, &inputs); + let plan: Result = plan.with_new_exprs(new_expressions, inputs); - plan.map(Some) + plan.map(|p| Transformed::yes(p)) } } ``` @@ -161,25 +212,111 @@ We're almost there. Let's just test our rule works properly. Testing the rule is fairly simple, we can create a SessionState with our rule and then create a DataFrame and run a query. The logical plan will be optimized by our rule. ```rust -use datafusion::prelude::*; - -let rules = Arc::new(AddOneInliner {}); -let state = ctx.state().with_optimizer_rules(vec![rules]); - -let ctx = SessionContext::with_state(state); -ctx.register_udf(add_one); - -let sql = "SELECT add_one(1) AS added_one"; -let plan = ctx.sql(sql).await?.logical_plan(); - -println!("{:?}", plan); +# use std::sync::Arc; +# use datafusion::common::Result; +# use datafusion::common::tree_node::{Transformed, TreeNode}; +# use datafusion::logical_expr::{col, lit, Expr, LogicalPlan, LogicalPlanBuilder}; +# use datafusion::optimizer::{OptimizerRule, OptimizerConfig, OptimizerContext, Optimizer}; +# use datafusion::arrow::array::{ArrayRef, Int64Array}; +# use datafusion::common::cast::as_int64_array; +# use datafusion::logical_expr::ColumnarValue; +# use datafusion::logical_expr::{Volatility, create_udf}; +# use datafusion::arrow::datatypes::DataType; +# +# fn rewrite_add_one(expr: Expr) -> Result> { +# expr.transform(&|expr| { +# Ok(match expr { +# Expr::ScalarFunction(scalar_func) if scalar_func.func.inner().name() == "add_one" => { +# let input_arg = scalar_func.args[0].clone(); +# let new_expression = input_arg + lit(1i64); +# +# Transformed::yes(new_expression) +# } +# _ => Transformed::no(expr), +# }) +# }) +# } +# +# #[derive(Default, Debug)] +# struct AddOneInliner {} +# +# impl OptimizerRule for AddOneInliner { +# fn name(&self) -> &str { +# "add_one_inliner" +# } +# +# fn rewrite( +# &self, +# plan: LogicalPlan, +# _config: &dyn OptimizerConfig, +# ) -> Result> { +# // Map over the expressions and rewrite them +# let new_expressions: Vec = plan +# .expressions() +# .into_iter() +# .map(|expr| rewrite_add_one(expr)) +# .collect::>>()? // returns Vec> +# .into_iter() +# .map(|transformed| transformed.data) +# .collect(); +# +# let inputs = plan.inputs().into_iter().cloned().collect::>(); +# +# let plan: Result = plan.with_new_exprs(new_expressions, inputs); +# +# plan.map(|p| Transformed::yes(p)) +# } +# } +# +# pub fn add_one(args: &[ColumnarValue]) -> Result { +# // Error handling omitted for brevity +# let args = ColumnarValue::values_to_arrays(args)?; +# let i64s = as_int64_array(&args[0])?; +# +# let new_array = i64s +# .iter() +# .map(|array_elem| array_elem.map(|value| value + 1)) +# .collect::(); +# +# Ok(ColumnarValue::from(Arc::new(new_array) as ArrayRef)) +# } + +use datafusion::execution::context::SessionContext; + +#[tokio::main] +async fn main() -> Result<()> { + + let ctx = SessionContext::new(); + // ctx.add_optimizer_rule(Arc::new(AddOneInliner {})); + + let add_one_udf = create_udf( + "add_one", + vec![DataType::Int64], + DataType::Int64, + Volatility::Immutable, + Arc::new(add_one), + ); + ctx.register_udf(add_one_udf); + + let sql = "SELECT add_one(5) AS added_one"; + // let plan = ctx.sql(sql).await?.into_unoptimized_plan().clone(); + let plan = ctx.sql(sql).await?.into_optimized_plan()?.clone(); + + let expected = r#"Projection: Int64(6) AS added_one + EmptyRelation"#; + + assert_eq!(plan.to_string(), expected); + + Ok(()) +} ``` -This results in the following output: +This plan is optimized as: ```text -Projection: Int64(1) + Int64(1) AS added_one - EmptyRelation +Projection: add_one(Int64(5)) AS added_one + -> Projection: Int64(5) + Int64(1) AS added_one + -> Projection: Int64(6) AS added_one ``` I.e. the `add_one` UDF has been inlined into the projection. @@ -189,27 +326,23 @@ I.e. the `add_one` UDF has been inlined into the projection. The `arrow::datatypes::DataType` of the expression can be obtained by calling the `get_type` given something that implements `Expr::Schemable`, for example a `DFschema` object: ```rust -use arrow_schema::DataType; -use datafusion::common::{DFField, DFSchema}; +use arrow::datatypes::{DataType, Field}; +use datafusion::common::DFSchema; use datafusion::logical_expr::{col, ExprSchemable}; use std::collections::HashMap; +// Get the type of an expression that adds 2 columns. Adding an Int32 +// and Float32 results in Float32 type let expr = col("c1") + col("c2"); -let schema = DFSchema::new_with_metadata( +let schema = DFSchema::from_unqualified_fields( vec![ - DFField::new_unqualified("c1", DataType::Int32, true), - DFField::new_unqualified("c2", DataType::Float32, true), - ], + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Float32, true), + ] + .into(), HashMap::new(), -) -.unwrap(); -print!("type = {}", expr.get_type(&schema).unwrap()); -``` - -This results in the following output: - -```text -type = Float32 +).unwrap(); +assert_eq!("Float32", format!("{}", expr.get_type(&schema).unwrap())); ``` ## Conclusion diff --git a/docs/source/user-guide/cli/usage.md b/docs/source/user-guide/cli/usage.md index 6a620fc69252..fb238dad10bb 100644 --- a/docs/source/user-guide/cli/usage.md +++ b/docs/source/user-guide/cli/usage.md @@ -127,7 +127,7 @@ supports additional statements and commands: Show configuration options -```SQL +```sql > show all; +-------------------------------------------------+---------+ @@ -163,7 +163,7 @@ Show specific configuration option - Set configuration options -```SQL +```sql > SET datafusion.execution.batch_size to 1024; ``` diff --git a/docs/source/user-guide/concepts-readings-events.md b/docs/source/user-guide/concepts-readings-events.md index 609dcadf2a8d..102090601b23 100644 --- a/docs/source/user-guide/concepts-readings-events.md +++ b/docs/source/user-guide/concepts-readings-events.md @@ -150,9 +150,9 @@ This is a list of DataFusion related blog posts, articles, and other resources. # 🌎 Community Events -- **2025-01-25** (Upcoming) [Amsterdam Apache DataFusion Meetup](https://github.com/apache/datafusion/discussions/12988) -- **2025-01-15** (Upcoming) [Boston Apache DataFusion Meetup](https://github.com/apache/datafusion/discussions/13165) -- **2024-12-18** (Upcoming) [Chicago Apache DataFusion Meetup](https://lu.ma/eq5myc5i) +- **2025-01-23** [Amsterdam Apache DataFusion Meetup](https://github.com/apache/datafusion/discussions/12988) +- **2025-01-15** [Boston Apache DataFusion Meetup](https://github.com/apache/datafusion/discussions/13165) +- **2024-12-18** [Chicago Apache DataFusion Meetup](https://lu.ma/eq5myc5i) - **2024-10-14** [Seattle Apache DataFusion Meetup](https://lu.ma/tnwl866b) - **2024-09-27** [Belgrade Apache DataFusion Meetup](https://lu.ma/tmwuz4lg), [recap](https://github.com/apache/datafusion/discussions/11431#discussioncomment-10832070), [slides](https://github.com/apache/datafusion/discussions/11431#discussioncomment-10826169), [recordings](https://www.youtube.com/watch?v=4huEsFFv6bQ&list=PLrhIfEjaw9ilQEczOQlHyMznabtVRptyX) - **2024-06-26** [New York City Apache DataFusion Meetup](https://lu.ma/2iwba0xm). [slides](https://docs.google.com/presentation/d/1dOLPAFPEMLhLv4NN6O9QSDIyyeiIySqAjky5cVgdWAE/edit#slide=id.g26bebde4fcc_3_7) diff --git a/docs/source/user-guide/crate-configuration.md b/docs/source/user-guide/crate-configuration.md index 9d22e3403097..f4a1910f5f78 100644 --- a/docs/source/user-guide/crate-configuration.md +++ b/docs/source/user-guide/crate-configuration.md @@ -68,7 +68,9 @@ codegen-units = 1 Then, in `main.rs.` update the memory allocator with the below after your imports: -```rust ,ignore + + +```no-run use datafusion::prelude::*; #[global_allocator] diff --git a/docs/source/user-guide/explain-usage.md b/docs/source/user-guide/explain-usage.md index 32a87ae9198d..d89ed5f0e7ea 100644 --- a/docs/source/user-guide/explain-usage.md +++ b/docs/source/user-guide/explain-usage.md @@ -49,7 +49,7 @@ LIMIT 5; The output will look like -``` +```text +---------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | plan_type | plan | +---------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -70,7 +70,7 @@ Elapsed 0.060 seconds. There are two sections: logical plan and physical plan -- **Logical Plan:** is a plan generated for a specific SQL query, DataFrame, or other language without the +- **Logical Plan:** is a plan generated for a specific SQL query, DataFrame, or other language without the knowledge of the underlying data organization. - **Physical Plan:** is a plan generated from a logical plan along with consideration of the hardware configuration (e.g number of CPUs) and the underlying data organization (e.g number of files). @@ -87,7 +87,7 @@ query run faster depends on the reason it is slow and beyond the scope of this d A query plan is an upside down tree, and we always read from bottom up. The physical plan in Figure 1 in tree format will look like -``` +```text ▲ │ │ @@ -174,7 +174,7 @@ above but with `EXPLAIN ANALYZE` (note the output is edited for clarity) [`executionplan::metrics`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html#method.metrics -``` +```sql > EXPLAIN ANALYZE SELECT "WatchID" AS wid, "hits.parquet"."ClientIP" AS ip FROM 'hits.parquet' WHERE starts_with("URL", 'http://domcheloveplanet.ru/') @@ -267,7 +267,7 @@ LIMIT 10; We can again see the query plan by using `EXPLAIN`: -``` +```sql > EXPLAIN SELECT "UserID", COUNT(*) FROM 'hits.parquet' GROUP BY "UserID" ORDER BY COUNT(*) DESC LIMIT 10; +---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | plan_type | plan | diff --git a/docs/source/user-guide/sql/sql_status.md b/docs/source/user-guide/features.md similarity index 65% rename from docs/source/user-guide/sql/sql_status.md rename to docs/source/user-guide/features.md index cb9bc0bb67b3..1f73ce7eac11 100644 --- a/docs/source/user-guide/sql/sql_status.md +++ b/docs/source/user-guide/features.md @@ -17,23 +17,28 @@ under the License. --> -# Status +# Features ## General - [x] SQL Parser - [x] SQL Query Planner +- [x] DataFrame API +- [x] Parallel query execution +- [x] Streaming Execution + +## Optimizations + - [x] Query Optimizer - [x] Constant folding - [x] Join Reordering - [x] Limit Pushdown - [x] Projection push down - [x] Predicate push down -- [x] Type coercion -- [x] Parallel query execution ## SQL Support +- [x] Type coercion - [x] Projection (`SELECT`) - [x] Filter (`WHERE`) - [x] Filter post-aggregate (`HAVING`) @@ -42,23 +47,23 @@ - [x] Aggregate (`GROUP BY`) - [x] cast /try_cast - [x] [`VALUES` lists](https://www.postgresql.org/docs/current/queries-values.html) -- [x] [String Functions](./scalar_functions.md#string-functions) -- [x] [Conditional Functions](./scalar_functions.md#conditional-functions) -- [x] [Time and Date Functions](./scalar_functions.md#time-and-date-functions) -- [x] [Math Functions](./scalar_functions.md#math-functions) -- [x] [Aggregate Functions](./aggregate_functions.md) (`SUM`, `MEDIAN`, and many more) +- [x] [String Functions](./sql/scalar_functions.md#string-functions) +- [x] [Conditional Functions](./sql/scalar_functions.md#conditional-functions) +- [x] [Time and Date Functions](./sql/scalar_functions.md#time-and-date-functions) +- [x] [Math Functions](./sql/scalar_functions.md#math-functions) +- [x] [Aggregate Functions](./sql/aggregate_functions.md) (`SUM`, `MEDIAN`, and many more) - [x] Schema Queries - [x] `SHOW TABLES` - [x] `SHOW COLUMNS FROM ` - [x] `SHOW CREATE TABLE ` - - [x] Basic SQL [Information Schema](./information_schema.md) (`TABLES`, `VIEWS`, `COLUMNS`) - - [ ] Full SQL [Information Schema](./information_schema.md) support -- [ ] Support for nested types (`ARRAY`/`LIST` and `STRUCT`. See [#2326](https://github.com/apache/datafusion/issues/2326) for details) + - [x] Basic SQL [Information Schema](./sql/information_schema.md) (`TABLES`, `VIEWS`, `COLUMNS`) + - [ ] Full SQL [Information Schema](./sql/information_schema.md) support +- [x] Support for nested types (`ARRAY`/`LIST` and `STRUCT`. - [x] Read support - [x] Write support - [x] Field access (`col['field']` and [`col[1]`]) - - [x] [Array Functions](./scalar_functions.md#array-functions) - - [ ] [Struct Functions](./scalar_functions.md#struct-functions) + - [x] [Array Functions](./sql/scalar_functions.md#array-functions) + - [x] [Struct Functions](./sql/scalar_functions.md#struct-functions) - [x] `struct` - [ ] [Postgres JSON operators](https://github.com/apache/datafusion/issues/6631) (`->`, `->>`, etc.) - [x] Subqueries @@ -73,12 +78,12 @@ - [x] Catalogs - [x] Schemas (`CREATE / DROP SCHEMA`) - [x] Tables (`CREATE / DROP TABLE`, `CREATE TABLE AS SELECT`) -- [ ] Data Insert +- [x] Data Insert - [x] `INSERT INTO` - - [ ] `COPY .. INTO ..` + - [x] `COPY .. INTO ..` - [x] CSV - - [ ] JSON - - [ ] Parquet + - [x] JSON + - [x] Parquet - [ ] Avro ## Runtime @@ -87,16 +92,22 @@ - [x] Streaming Window Evaluation - [x] Memory limits enforced - [x] Spilling (to disk) Sort -- [ ] Spilling (to disk) Grouping +- [x] Spilling (to disk) Grouping - [ ] Spilling (to disk) Joins ## Data Sources -In addition to allowing arbitrary datasources via the `TableProvider` +In addition to allowing arbitrary datasources via the [`TableProvider`] trait, DataFusion includes built in support for the following formats: - [x] CSV -- [x] Parquet (for all primitive and nested types) +- [x] Parquet + - [x] Primitive and Nested Types + - [x] Row Group and Data Page pruning on min/max statistics + - [x] Row Group pruning on Bloom Filters + - [x] Predicate push down (late materialization) [not by default](https://github.com/apache/datafusion/issues/3463) - [x] JSON - [x] Avro - [x] Arrow + +[`tableprovider`]: https://docs.rs/datafusion/latest/datafusion/catalog/trait.TableProvider.html diff --git a/docs/source/user-guide/introduction.md b/docs/source/user-guide/introduction.md index c97042fdc525..bed9233b9c23 100644 --- a/docs/source/user-guide/introduction.md +++ b/docs/source/user-guide/introduction.md @@ -119,6 +119,7 @@ Here are some active projects using DataFusion: - [ROAPI](https://github.com/roapi/roapi) - [Sail](https://github.com/lakehq/sail) Unifying stream, batch, and AI workloads with Apache Spark compatibility - [Seafowl](https://github.com/splitgraph/seafowl) CDN-friendly analytical database +- [Sleeper](https://github.com/gchq/sleeper) Serverless, cloud-native, log-structured merge tree based, scalable key-value store - [Spice.ai](https://github.com/spiceai/spiceai) Unified SQL query interface & materialization engine - [Synnada](https://synnada.ai/) Streaming-first framework for data products - [VegaFusion](https://vegafusion.io/) Server-side acceleration for the [Vega](https://vega.github.io/) visualization grammar diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index 221bb0572eb8..7d88d3168d23 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -58,7 +58,7 @@ Aggregate functions operate on a set of values to compute a single result. Returns an array created from the expression elements. If ordering is required, elements are inserted in the specified order. -``` +```sql array_agg(expression [ORDER BY expression]) ``` @@ -81,7 +81,7 @@ array_agg(expression [ORDER BY expression]) Returns the average of numeric values in the specified column. -``` +```sql avg(expression) ``` @@ -108,7 +108,7 @@ avg(expression) Computes the bitwise AND of all non-null input values. -``` +```sql bit_and(expression) ``` @@ -120,7 +120,7 @@ bit_and(expression) Computes the bitwise OR of all non-null input values. -``` +```sql bit_or(expression) ``` @@ -132,7 +132,7 @@ bit_or(expression) Computes the bitwise exclusive OR of all non-null input values. -``` +```sql bit_xor(expression) ``` @@ -144,7 +144,7 @@ bit_xor(expression) Returns true if all non-null input values are true, otherwise false. -``` +```sql bool_and(expression) ``` @@ -167,7 +167,7 @@ bool_and(expression) Returns true if all non-null input values are true, otherwise false. -``` +```sql bool_and(expression) ``` @@ -190,7 +190,7 @@ bool_and(expression) Returns the number of non-null values in the specified column. To include null values in the total count, use `count(*)`. -``` +```sql count(expression) ``` @@ -220,7 +220,7 @@ count(expression) Returns the first element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group. -``` +```sql first_value(expression [ORDER BY expression]) ``` @@ -243,7 +243,7 @@ first_value(expression [ORDER BY expression]) Returns 1 if the data is aggregated across the specified column, or 0 if it is not aggregated in the result set. -``` +```sql grouping(expression) ``` @@ -270,7 +270,7 @@ grouping(expression) Returns the last element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group. -``` +```sql last_value(expression [ORDER BY expression]) ``` @@ -293,7 +293,7 @@ last_value(expression [ORDER BY expression]) Returns the maximum value in the specified column. -``` +```sql max(expression) ``` @@ -320,7 +320,7 @@ _Alias of [avg](#avg)._ Returns the median value in the specified column. -``` +```sql median(expression) ``` @@ -343,7 +343,7 @@ median(expression) Returns the minimum value in the specified column. -``` +```sql min(expression) ``` @@ -366,7 +366,7 @@ min(expression) Concatenates the values of string expressions and places separator values between them. -``` +```sql string_agg(expression, delimiter) ``` @@ -391,7 +391,7 @@ string_agg(expression, delimiter) Returns the sum of all values in the specified column. -``` +```sql sum(expression) ``` @@ -414,7 +414,7 @@ sum(expression) Returns the statistical sample variance of a set of numbers. -``` +```sql var(expression) ``` @@ -431,7 +431,7 @@ var(expression) Returns the statistical population variance of a set of numbers. -``` +```sql var_pop(expression) ``` @@ -479,7 +479,7 @@ _Alias of [var](#var)._ Returns the coefficient of correlation between two numeric values. -``` +```sql corr(expression1, expression2) ``` @@ -507,7 +507,7 @@ _Alias of [covar_samp](#covar_samp)._ Returns the sample covariance of a set of number pairs. -``` +```sql covar_samp(expression1, expression2) ``` @@ -531,7 +531,7 @@ covar_samp(expression1, expression2) Returns the sample covariance of a set of number pairs. -``` +```sql covar_samp(expression1, expression2) ``` @@ -559,7 +559,7 @@ covar_samp(expression1, expression2) Returns the nth value in a group of values. -``` +```sql nth_value(expression, n ORDER BY expression) ``` @@ -588,7 +588,7 @@ nth_value(expression, n ORDER BY expression) Computes the average of the independent variable (input) expression_x for the non-null paired data points. -``` +```sql regr_avgx(expression_y, expression_x) ``` @@ -601,7 +601,7 @@ regr_avgx(expression_y, expression_x) Computes the average of the dependent variable (output) expression_y for the non-null paired data points. -``` +```sql regr_avgy(expression_y, expression_x) ``` @@ -614,7 +614,7 @@ regr_avgy(expression_y, expression_x) Counts the number of non-null paired data points. -``` +```sql regr_count(expression_y, expression_x) ``` @@ -627,7 +627,7 @@ regr_count(expression_y, expression_x) Computes the y-intercept of the linear regression line. For the equation (y = kx + b), this function returns b. -``` +```sql regr_intercept(expression_y, expression_x) ``` @@ -640,7 +640,7 @@ regr_intercept(expression_y, expression_x) Computes the square of the correlation coefficient between the independent and dependent variables. -``` +```sql regr_r2(expression_y, expression_x) ``` @@ -653,7 +653,7 @@ regr_r2(expression_y, expression_x) Returns the slope of the linear regression line for non-null pairs in aggregate columns. Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k\*X + b) using minimal RSS fitting. -``` +```sql regr_slope(expression_y, expression_x) ``` @@ -666,7 +666,7 @@ regr_slope(expression_y, expression_x) Computes the sum of squares of the independent variable. -``` +```sql regr_sxx(expression_y, expression_x) ``` @@ -679,7 +679,7 @@ regr_sxx(expression_y, expression_x) Computes the sum of products of paired data points. -``` +```sql regr_sxy(expression_y, expression_x) ``` @@ -692,7 +692,7 @@ regr_sxy(expression_y, expression_x) Computes the sum of squares of the dependent variable. -``` +```sql regr_syy(expression_y, expression_x) ``` @@ -705,7 +705,7 @@ regr_syy(expression_y, expression_x) Returns the standard deviation of a set of numbers. -``` +```sql stddev(expression) ``` @@ -732,7 +732,7 @@ stddev(expression) Returns the population standard deviation of a set of numbers. -``` +```sql stddev_pop(expression) ``` @@ -766,7 +766,7 @@ _Alias of [stddev](#stddev)._ Returns the approximate number of distinct input values calculated using the HyperLogLog algorithm. -``` +```sql approx_distinct(expression) ``` @@ -789,7 +789,7 @@ approx_distinct(expression) Returns the approximate median (50th percentile) of input values. It is an alias of `approx_percentile_cont(x, 0.5)`. -``` +```sql approx_median(expression) ``` @@ -812,7 +812,7 @@ approx_median(expression) Returns the approximate percentile of input values using the t-digest algorithm. -``` +```sql approx_percentile_cont(expression, percentile, centroids) ``` @@ -837,7 +837,7 @@ approx_percentile_cont(expression, percentile, centroids) Returns the weighted approximate percentile of input values using the t-digest algorithm. -``` +```sql approx_percentile_cont_with_weight(expression, weight, percentile) ``` diff --git a/docs/source/user-guide/sql/ddl.md b/docs/source/user-guide/sql/ddl.md index e16b9681eb80..71475cff9a39 100644 --- a/docs/source/user-guide/sql/ddl.md +++ b/docs/source/user-guide/sql/ddl.md @@ -55,7 +55,7 @@ file system or remote object store as a named table which can be queried. The supported syntax is: -``` +```sql CREATE [UNBOUNDED] EXTERNAL TABLE [ IF NOT EXISTS ] [ () ] @@ -185,7 +185,7 @@ OPTIONS ('has_header' 'true'); Where `WITH ORDER` clause specifies the sort order: -``` +```sql WITH ORDER (sort_expression1 [ASC | DESC] [NULLS { FIRST | LAST }] [, sort_expression2 [ASC | DESC] [NULLS { FIRST | LAST }] ...]) ``` @@ -198,7 +198,7 @@ WITH ORDER (sort_expression1 [ASC | DESC] [NULLS { FIRST | LAST }] If data sources are already partitioned in Hive style, `PARTITIONED BY` can be used for partition pruning. -``` +```text /mnt/nyctaxi/year=2022/month=01/tripdata.parquet /mnt/nyctaxi/year=2021/month=12/tripdata.parquet /mnt/nyctaxi/year=2021/month=11/tripdata.parquet diff --git a/docs/source/user-guide/sql/dml.md b/docs/source/user-guide/sql/dml.md index dd016cabbfb7..4eda59d6dea1 100644 --- a/docs/source/user-guide/sql/dml.md +++ b/docs/source/user-guide/sql/dml.md @@ -28,7 +28,7 @@ Copies the contents of a table or query to file(s). Supported file formats are `parquet`, `csv`, `json`, and `arrow`.
-COPY { table_name | query } 
+COPY { table_name | query }
 TO 'file_name'
 [ STORED AS format ]
 [ PARTITIONED BY column_name [, ...] ]
@@ -91,7 +91,7 @@ of hive-style partitioned parquet files:
 If the the data contains values of `x` and `y` in column1 and only `a` in
 column2, output files will appear in the following directory structure:
 
-```
+```text
 dir_name/
   column1=x/
     column2=a/
diff --git a/docs/source/user-guide/sql/explain.md b/docs/source/user-guide/sql/explain.md
index 709e6311c28e..3f2c7de43eac 100644
--- a/docs/source/user-guide/sql/explain.md
+++ b/docs/source/user-guide/sql/explain.md
@@ -32,8 +32,9 @@ EXPLAIN [ANALYZE] [VERBOSE] statement
 Shows the execution plan of a statement.
 If you need more detailed output, use `EXPLAIN VERBOSE`.
 
-```
+```sql
 EXPLAIN SELECT SUM(x) FROM table GROUP BY b;
+
 +---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------+
 | plan_type     | plan                                                                                                                                                           |
 +---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------+
@@ -56,8 +57,9 @@ EXPLAIN SELECT SUM(x) FROM table GROUP BY b;
 Shows the execution plan and metrics of a statement.
 If you need more information output, use `EXPLAIN ANALYZE VERBOSE`.
 
-```
+```sql
 EXPLAIN ANALYZE SELECT SUM(x) FROM table GROUP BY b;
+
 +-------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------+
 | plan_type         | plan                                                                                                                                                      |
 +-------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------+
diff --git a/docs/source/user-guide/sql/index.rst b/docs/source/user-guide/sql/index.rst
index 0508fa12f0f3..8e3f51bf8b0b 100644
--- a/docs/source/user-guide/sql/index.rst
+++ b/docs/source/user-guide/sql/index.rst
@@ -33,5 +33,5 @@ SQL Reference
    window_functions
    scalar_functions
    special_functions
-   sql_status
    write_options
+   prepared_statements
diff --git a/docs/source/user-guide/sql/prepared_statements.md b/docs/source/user-guide/sql/prepared_statements.md
new file mode 100644
index 000000000000..6677b212fdf2
--- /dev/null
+++ b/docs/source/user-guide/sql/prepared_statements.md
@@ -0,0 +1,139 @@
+
+
+# Prepared Statements
+
+The `PREPARE` statement allows for the creation and storage of a SQL statement with placeholder arguments.
+
+The prepared statements can then be executed repeatedly in an efficient manner.
+
+**SQL Example**
+
+Create a prepared statement `greater_than` that selects all records where column "a" is greater than the parameter:
+
+```sql
+PREPARE greater_than(INT) AS SELECT * FROM example WHERE a > $1;
+```
+
+The prepared statement can then be executed with parameters as needed:
+
+```sql
+EXECUTE greater_than(20);
+```
+
+**Rust Example**
+
+```rust
+use datafusion::prelude::*;
+
+#[tokio::main]
+async fn main() -> datafusion::error::Result<()> {
+  // Register the table
+  let ctx = SessionContext::new();
+  ctx.register_csv("example", "tests/data/example.csv", CsvReadOptions::new()).await?;
+
+  // Create the prepared statement `greater_than`
+  let prepare_sql = "PREPARE greater_than(INT) AS SELECT * FROM example WHERE a > $1";
+  ctx.sql(prepare_sql).await?;
+
+  // Execute the prepared statement `greater_than`
+  let execute_sql = "EXECUTE greater_than(20)";
+  let df = ctx.sql(execute_sql).await?;
+
+  // Execute and print results
+  df.show().await?;
+  Ok(())
+}
+```
+
+## Inferred Types
+
+If the parameter type is not specified, it can be inferred at execution time:
+
+**SQL Example**
+
+Create the prepared statement `greater_than`
+
+```sql
+PREPARE greater_than AS SELECT * FROM example WHERE a > $1;
+```
+
+Execute the prepared statement `greater_than`
+
+```sql
+EXECUTE greater_than(20);
+```
+
+**Rust Example**
+
+```rust
+# use datafusion::prelude::*;
+# #[tokio::main]
+# async fn main() -> datafusion::error::Result<()> {
+#    let ctx = SessionContext::new();
+#    ctx.register_csv("example", "tests/data/example.csv", CsvReadOptions::new()).await?;
+#
+    // Create the prepared statement `greater_than`
+    let prepare_sql = "PREPARE greater_than AS SELECT * FROM example WHERE a > $1";
+    ctx.sql(prepare_sql).await?;
+
+    // Execute the prepared statement `greater_than`
+    let execute_sql = "EXECUTE greater_than(20)";
+    let df = ctx.sql(execute_sql).await?;
+#
+#    Ok(())
+# }
+```
+
+## Positional Arguments
+
+In the case of multiple parameters, prepared statements can use positional arguments:
+
+**SQL Example**
+
+Create the prepared statement `greater_than`
+
+```sql
+PREPARE greater_than(INT, DOUBLE) AS SELECT * FROM example WHERE a > $1 AND b > $2;
+```
+
+Execute the prepared statement `greater_than`
+
+```sql
+EXECUTE greater_than(20, 23.3);
+```
+
+**Rust Example**
+
+```rust
+# use datafusion::prelude::*;
+# #[tokio::main]
+# async fn main() -> datafusion::error::Result<()> {
+#    let ctx = SessionContext::new();
+#    ctx.register_csv("example", "tests/data/example.csv", CsvReadOptions::new()).await?;
+  // Create the prepared statement `greater_than`
+  let prepare_sql = "PREPARE greater_than(INT, DOUBLE) AS SELECT * FROM example WHERE a > $1 AND b > $2";
+  ctx.sql(prepare_sql).await?;
+
+  // Execute the prepared statement `greater_than`
+  let execute_sql = "EXECUTE greater_than(20, 23.3)";
+  let df = ctx.sql(execute_sql).await?;
+#    Ok(())
+# }
+```
diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md
index b769b8b7bdb0..fb4043c33efc 100644
--- a/docs/source/user-guide/sql/scalar_functions.md
+++ b/docs/source/user-guide/sql/scalar_functions.md
@@ -73,7 +73,7 @@ dev/update_function_docs.sh file for updating surrounding text.
 
 Returns the absolute value of a number.
 
-```
+```sql
 abs(numeric_expression)
 ```
 
@@ -85,7 +85,7 @@ abs(numeric_expression)
 
 Returns the arc cosine or inverse cosine of a number.
 
-```
+```sql
 acos(numeric_expression)
 ```
 
@@ -97,7 +97,7 @@ acos(numeric_expression)
 
 Returns the area hyperbolic cosine or inverse hyperbolic cosine of a number.
 
-```
+```sql
 acosh(numeric_expression)
 ```
 
@@ -109,7 +109,7 @@ acosh(numeric_expression)
 
 Returns the arc sine or inverse sine of a number.
 
-```
+```sql
 asin(numeric_expression)
 ```
 
@@ -121,7 +121,7 @@ asin(numeric_expression)
 
 Returns the area hyperbolic sine or inverse hyperbolic sine of a number.
 
-```
+```sql
 asinh(numeric_expression)
 ```
 
@@ -133,7 +133,7 @@ asinh(numeric_expression)
 
 Returns the arc tangent or inverse tangent of a number.
 
-```
+```sql
 atan(numeric_expression)
 ```
 
@@ -145,7 +145,7 @@ atan(numeric_expression)
 
 Returns the arc tangent or inverse tangent of `expression_y / expression_x`.
 
-```
+```sql
 atan2(expression_y, expression_x)
 ```
 
@@ -160,7 +160,7 @@ atan2(expression_y, expression_x)
 
 Returns the area hyperbolic tangent or inverse hyperbolic tangent of a number.
 
-```
+```sql
 atanh(numeric_expression)
 ```
 
@@ -172,7 +172,7 @@ atanh(numeric_expression)
 
 Returns the cube root of a number.
 
-```
+```sql
 cbrt(numeric_expression)
 ```
 
@@ -184,7 +184,7 @@ cbrt(numeric_expression)
 
 Returns the nearest integer greater than or equal to a number.
 
-```
+```sql
 ceil(numeric_expression)
 ```
 
@@ -196,7 +196,7 @@ ceil(numeric_expression)
 
 Returns the cosine of a number.
 
-```
+```sql
 cos(numeric_expression)
 ```
 
@@ -208,7 +208,7 @@ cos(numeric_expression)
 
 Returns the hyperbolic cosine of a number.
 
-```
+```sql
 cosh(numeric_expression)
 ```
 
@@ -220,7 +220,7 @@ cosh(numeric_expression)
 
 Returns the cotangent of a number.
 
-```
+```sql
 cot(numeric_expression)
 ```
 
@@ -232,7 +232,7 @@ cot(numeric_expression)
 
 Converts radians to degrees.
 
-```
+```sql
 degrees(numeric_expression)
 ```
 
@@ -244,7 +244,7 @@ degrees(numeric_expression)
 
 Returns the base-e exponential of a number.
 
-```
+```sql
 exp(numeric_expression)
 ```
 
@@ -256,7 +256,7 @@ exp(numeric_expression)
 
 Factorial. Returns 1 if value is less than 2.
 
-```
+```sql
 factorial(numeric_expression)
 ```
 
@@ -268,7 +268,7 @@ factorial(numeric_expression)
 
 Returns the nearest integer less than or equal to a number.
 
-```
+```sql
 floor(numeric_expression)
 ```
 
@@ -280,7 +280,7 @@ floor(numeric_expression)
 
 Returns the greatest common divisor of `expression_x` and `expression_y`. Returns 0 if both inputs are zero.
 
-```
+```sql
 gcd(expression_x, expression_y)
 ```
 
@@ -293,7 +293,7 @@ gcd(expression_x, expression_y)
 
 Returns true if a given number is +NaN or -NaN otherwise returns false.
 
-```
+```sql
 isnan(numeric_expression)
 ```
 
@@ -305,7 +305,7 @@ isnan(numeric_expression)
 
 Returns true if a given number is +0.0 or -0.0 otherwise returns false.
 
-```
+```sql
 iszero(numeric_expression)
 ```
 
@@ -317,7 +317,7 @@ iszero(numeric_expression)
 
 Returns the least common multiple of `expression_x` and `expression_y`. Returns 0 if either input is zero.
 
-```
+```sql
 lcm(expression_x, expression_y)
 ```
 
@@ -330,7 +330,7 @@ lcm(expression_x, expression_y)
 
 Returns the natural logarithm of a number.
 
-```
+```sql
 ln(numeric_expression)
 ```
 
@@ -342,7 +342,7 @@ ln(numeric_expression)
 
 Returns the base-x logarithm of a number. Can either provide a specified base, or if omitted then takes the base-10 of a number.
 
-```
+```sql
 log(base, numeric_expression)
 log(numeric_expression)
 ```
@@ -356,7 +356,7 @@ log(numeric_expression)
 
 Returns the base-10 logarithm of a number.
 
-```
+```sql
 log10(numeric_expression)
 ```
 
@@ -368,7 +368,7 @@ log10(numeric_expression)
 
 Returns the base-2 logarithm of a number.
 
-```
+```sql
 log2(numeric_expression)
 ```
 
@@ -381,7 +381,7 @@ log2(numeric_expression)
 Returns the first argument if it's not _NaN_.
 Returns the second argument otherwise.
 
-```
+```sql
 nanvl(expression_x, expression_y)
 ```
 
@@ -394,7 +394,7 @@ nanvl(expression_x, expression_y)
 
 Returns an approximate value of π.
 
-```
+```sql
 pi()
 ```
 
@@ -406,7 +406,7 @@ _Alias of [power](#power)._
 
 Returns a base expression raised to the power of an exponent.
 
-```
+```sql
 power(base, exponent)
 ```
 
@@ -423,7 +423,7 @@ power(base, exponent)
 
 Converts degrees to radians.
 
-```
+```sql
 radians(numeric_expression)
 ```
 
@@ -436,7 +436,7 @@ radians(numeric_expression)
 Returns a random float value in the range [0, 1).
 The random seed is unique to each row.
 
-```
+```sql
 random()
 ```
 
@@ -444,7 +444,7 @@ random()
 
 Rounds a number to the nearest integer.
 
-```
+```sql
 round(numeric_expression[, decimal_places])
 ```
 
@@ -459,7 +459,7 @@ Returns the sign of a number.
 Negative numbers return `-1`.
 Zero and positive numbers return `1`.
 
-```
+```sql
 signum(numeric_expression)
 ```
 
@@ -471,7 +471,7 @@ signum(numeric_expression)
 
 Returns the sine of a number.
 
-```
+```sql
 sin(numeric_expression)
 ```
 
@@ -483,7 +483,7 @@ sin(numeric_expression)
 
 Returns the hyperbolic sine of a number.
 
-```
+```sql
 sinh(numeric_expression)
 ```
 
@@ -495,7 +495,7 @@ sinh(numeric_expression)
 
 Returns the square root of a number.
 
-```
+```sql
 sqrt(numeric_expression)
 ```
 
@@ -507,7 +507,7 @@ sqrt(numeric_expression)
 
 Returns the tangent of a number.
 
-```
+```sql
 tan(numeric_expression)
 ```
 
@@ -519,7 +519,7 @@ tan(numeric_expression)
 
 Returns the hyperbolic tangent of a number.
 
-```
+```sql
 tanh(numeric_expression)
 ```
 
@@ -531,7 +531,7 @@ tanh(numeric_expression)
 
 Truncates a number to a whole number or truncated to the specified decimal places.
 
-```
+```sql
 trunc(numeric_expression[, decimal_places])
 ```
 
@@ -558,7 +558,7 @@ trunc(numeric_expression[, decimal_places])
 
 Returns the first of its arguments that is not _null_. Returns _null_ if all arguments are _null_. This function is often used to substitute a default value for _null_ values.
 
-```
+```sql
 coalesce(expression1[, ..., expression_n])
 ```
 
@@ -581,7 +581,7 @@ coalesce(expression1[, ..., expression_n])
 
 Returns the greatest value in a list of expressions. Returns _null_ if all expressions are _null_.
 
-```
+```sql
 greatest(expression1[, ..., expression_n])
 ```
 
@@ -608,7 +608,7 @@ _Alias of [nvl](#nvl)._
 
 Returns the smallest value in a list of expressions. Returns _null_ if all expressions are _null_.
 
-```
+```sql
 least(expression1[, ..., expression_n])
 ```
 
@@ -632,7 +632,7 @@ least(expression1[, ..., expression_n])
 Returns _null_ if _expression1_ equals _expression2_; otherwise it returns _expression1_.
 This can be used to perform the inverse operation of [`coalesce`](#coalesce).
 
-```
+```sql
 nullif(expression1, expression2)
 ```
 
@@ -662,7 +662,7 @@ nullif(expression1, expression2)
 
 Returns _expression2_ if _expression1_ is NULL otherwise it returns _expression1_.
 
-```
+```sql
 nvl(expression1, expression2)
 ```
 
@@ -696,7 +696,7 @@ nvl(expression1, expression2)
 
 Returns _expression2_ if _expression1_ is not NULL; otherwise it returns _expression3_.
 
-```
+```sql
 nvl2(expression1, expression2, expression3)
 ```
 
@@ -769,7 +769,7 @@ nvl2(expression1, expression2, expression3)
 
 Returns the Unicode character code of the first character in a string.
 
-```
+```sql
 ascii(str)
 ```
 
@@ -802,7 +802,7 @@ ascii(str)
 
 Returns the bit length of a string.
 
-```
+```sql
 bit_length(str)
 ```
 
@@ -830,7 +830,7 @@ bit_length(str)
 
 Trims the specified trim string from the start and end of a string. If no trim string is provided, all whitespace is removed from the start and end of the input string.
 
-```
+```sql
 btrim(str[, trim_str])
 ```
 
@@ -877,7 +877,7 @@ _Alias of [character_length](#character_length)._
 
 Returns the number of characters in a string.
 
-```
+```sql
 character_length(str)
 ```
 
@@ -910,7 +910,7 @@ character_length(str)
 
 Returns the character with the specified ASCII or Unicode code value.
 
-```
+```sql
 chr(expression)
 ```
 
@@ -937,7 +937,7 @@ chr(expression)
 
 Concatenates multiple strings together.
 
-```
+```sql
 concat(str[, ..., str_n])
 ```
 
@@ -965,7 +965,7 @@ concat(str[, ..., str_n])
 
 Concatenates multiple strings together with a specified separator.
 
-```
+```sql
 concat_ws(separator, str[, ..., str_n])
 ```
 
@@ -994,7 +994,7 @@ concat_ws(separator, str[, ..., str_n])
 
 Return true if search_str is found within string (case-sensitive).
 
-```
+```sql
 contains(str, search_str)
 ```
 
@@ -1018,7 +1018,7 @@ contains(str, search_str)
 
 Tests if a string ends with a substring.
 
-```
+```sql
 ends_with(str, substr)
 ```
 
@@ -1048,7 +1048,7 @@ ends_with(str, substr)
 
 Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings.
 
-```
+```sql
 find_in_set(str, strlist)
 ```
 
@@ -1072,7 +1072,7 @@ find_in_set(str, strlist)
 
 Capitalizes the first character in each word in the input string. Words are delimited by non-alphanumeric characters.
 
-```
+```sql
 initcap(str)
 ```
 
@@ -1104,7 +1104,7 @@ _Alias of [strpos](#strpos)._
 
 Returns a specified number of characters from the left side of a string.
 
-```
+```sql
 left(str, n)
 ```
 
@@ -1136,7 +1136,7 @@ _Alias of [character_length](#character_length)._
 
 Returns the [`Levenshtein distance`](https://en.wikipedia.org/wiki/Levenshtein_distance) between the two given strings.
 
-```
+```sql
 levenshtein(str1, str2)
 ```
 
@@ -1160,7 +1160,7 @@ levenshtein(str1, str2)
 
 Converts a string to lower-case.
 
-```
+```sql
 lower(str)
 ```
 
@@ -1188,7 +1188,7 @@ lower(str)
 
 Pads the left side of a string with another string to a specified string length.
 
-```
+```sql
 lpad(str, n[, padding_str])
 ```
 
@@ -1217,7 +1217,7 @@ lpad(str, n[, padding_str])
 
 Trims the specified trim string from the beginning of a string. If no trim string is provided, all whitespace is removed from the start of the input string.
 
-```
+```sql
 ltrim(str[, trim_str])
 ```
 
@@ -1258,7 +1258,7 @@ trim(LEADING trim_str FROM str)
 
 Returns the length of a string in bytes.
 
-```
+```sql
 octet_length(str)
 ```
 
@@ -1290,7 +1290,7 @@ _Alias of [strpos](#strpos)._
 
 Returns a string with an input string repeated a specified number.
 
-```
+```sql
 repeat(str, n)
 ```
 
@@ -1314,7 +1314,7 @@ repeat(str, n)
 
 Replaces all occurrences of a specified substring in a string with a new substring.
 
-```
+```sql
 replace(str, substr, replacement)
 ```
 
@@ -1339,7 +1339,7 @@ replace(str, substr, replacement)
 
 Reverses the character order of a string.
 
-```
+```sql
 reverse(str)
 ```
 
@@ -1362,7 +1362,7 @@ reverse(str)
 
 Returns a specified number of characters from the right side of a string.
 
-```
+```sql
 right(str, n)
 ```
 
@@ -1390,7 +1390,7 @@ right(str, n)
 
 Pads the right side of a string with another string to a specified string length.
 
-```
+```sql
 rpad(str, n[, padding_str])
 ```
 
@@ -1419,7 +1419,7 @@ rpad(str, n[, padding_str])
 
 Trims the specified trim string from the end of a string. If no trim string is provided, all whitespace is removed from the end of the input string.
 
-```
+```sql
 rtrim(str[, trim_str])
 ```
 
@@ -1460,7 +1460,7 @@ trim(TRAILING trim_str FROM str)
 
 Splits a string based on a specified delimiter and returns the substring in the specified position.
 
-```
+```sql
 split_part(str, delimiter, pos)
 ```
 
@@ -1485,7 +1485,7 @@ split_part(str, delimiter, pos)
 
 Tests if a string starts with a substring.
 
-```
+```sql
 starts_with(str, substr)
 ```
 
@@ -1509,7 +1509,7 @@ starts_with(str, substr)
 
 Returns the starting position of a specified substring in a string. Positions begin at 1. If the substring does not exist in the string, the function returns 0.
 
-```
+```sql
 strpos(str, substr)
 ```
 
@@ -1544,7 +1544,7 @@ position(substr in origstr)
 
 Extracts a substring of a specified number of characters from a specific starting position in a string.
 
-```
+```sql
 substr(str, start_pos[, length])
 ```
 
@@ -1581,7 +1581,7 @@ Returns the substring from str before count occurrences of the delimiter delim.
 If count is positive, everything to the left of the final delimiter (counting from the left) is returned.
 If count is negative, everything to the right of the final delimiter (counting from the right) is returned.
 
-```
+```sql
 substr_index(str, delim, count)
 ```
 
@@ -1624,7 +1624,7 @@ _Alias of [substr_index](#substr_index)._
 
 Converts an integer to a hexadecimal string.
 
-```
+```sql
 to_hex(int)
 ```
 
@@ -1647,7 +1647,7 @@ to_hex(int)
 
 Translates characters in a string to specified translation characters.
 
-```
+```sql
 translate(str, chars, translation)
 ```
 
@@ -1676,7 +1676,7 @@ _Alias of [btrim](#btrim)._
 
 Converts a string to upper-case.
 
-```
+```sql
 upper(str)
 ```
 
@@ -1704,7 +1704,7 @@ upper(str)
 
 Returns [`UUID v4`]() string value which is unique per row.
 
-```
+```sql
 uuid()
 ```
 
@@ -1728,7 +1728,7 @@ uuid()
 
 Decode binary data from textual representation in string.
 
-```
+```sql
 decode(expression, format)
 ```
 
@@ -1745,7 +1745,7 @@ decode(expression, format)
 
 Encode binary data into a textual representation.
 
-```
+```sql
 encode(expression, format)
 ```
 
@@ -1774,7 +1774,7 @@ The following regular expression functions are supported:
 
 Returns the number of matches that a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has in a string.
 
-```
+```sql
 regexp_count(str, regexp[, start, flags])
 ```
 
@@ -1805,7 +1805,7 @@ regexp_count(str, regexp[, start, flags])
 
 Returns true if a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has at least one match in a string, false otherwise.
 
-```
+```sql
 regexp_like(str, regexp[, flags])
 ```
 
@@ -1843,7 +1843,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo
 
 Returns the first [regular expression](https://docs.rs/regex/latest/regex/#syntax) matches in a string.
 
-```
+```sql
 regexp_match(str, regexp[, flags])
 ```
 
@@ -1882,7 +1882,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo
 
 Replaces substrings in a string that match a [regular expression](https://docs.rs/regex/latest/regex/#syntax).
 
-```
+```sql
 regexp_replace(str, regexp, replacement[, flags])
 ```
 
@@ -1950,7 +1950,7 @@ Returns the current UTC date.
 
 The `current_date()` return value is determined at query time and will return the same date, no matter when in the query plan the function executes.
 
-```
+```sql
 current_date()
 ```
 
@@ -1964,7 +1964,7 @@ Returns the current UTC time.
 
 The `current_time()` return value is determined at query time and will return the same time, no matter when in the query plan the function executes.
 
-```
+```sql
 current_time()
 ```
 
@@ -1978,7 +1978,7 @@ Calculates time intervals and returns the start of the interval nearest to the s
 
 For example, if you "bin" or "window" data into 15 minute intervals, an input timestamp of `2023-01-01T18:18:18Z` will be updated to the start time of the 15 minute bin it is in: `2023-01-01T18:15:00Z`.
 
-```
+```sql
 date_bin(interval, expression, origin-timestamp)
 ```
 
@@ -2034,7 +2034,7 @@ _Alias of [to_char](#to_char)._
 
 Returns the specified part of the date as an integer.
 
-```
+```sql
 date_part(part, expression)
 ```
 
@@ -2073,7 +2073,7 @@ extract(field FROM source)
 
 Truncates a timestamp value to a specified precision.
 
-```
+```sql
 date_trunc(precision, expression)
 ```
 
@@ -2108,7 +2108,7 @@ _Alias of [date_trunc](#date_trunc)._
 
 Converts an integer to RFC3339 timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`) return the corresponding timestamp.
 
-```
+```sql
 from_unixtime(expression[, timezone])
 ```
 
@@ -2132,7 +2132,7 @@ from_unixtime(expression[, timezone])
 
 Make a date from year/month/day component parts.
 
-```
+```sql
 make_date(year, month, day)
 ```
 
@@ -2167,7 +2167,7 @@ Returns the current UTC timestamp.
 
 The `now()` return value is determined at query time and will return the same timestamp, no matter when in the query plan the function executes.
 
-```
+```sql
 now()
 ```
 
@@ -2179,7 +2179,7 @@ now()
 
 Returns a string representation of a date, time, timestamp or duration based on a [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html). Unlike the PostgreSQL equivalent of this function numerical formatting is not supported.
 
-```
+```sql
 to_char(expression, format)
 ```
 
@@ -2216,7 +2216,7 @@ Returns the corresponding date.
 
 Note: `to_date` returns Date32, which represents its values as the number of days since unix epoch(`1970-01-01`) stored as signed 32 bit value. The largest supported date value is `9999-12-31`.
 
-```
+```sql
 to_date('2017-05-31', '%Y-%m-%d')
 ```
 
@@ -2250,7 +2250,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo
 
 Converts a timestamp with a timezone to a timestamp without a timezone (with no offset or timezone information). This function handles daylight saving time changes.
 
-```
+```sql
 to_local_time(expression)
 ```
 
@@ -2313,7 +2313,7 @@ Converts a value to a timestamp (`YYYY-MM-DDT00:00:00Z`). Supports strings, inte
 
 Note: `to_timestamp` returns `Timestamp(Nanosecond)`. The supported range for integer input is between `-9223372037` and `9223372036`. Supported range for string input is between `1677-09-21T00:12:44.0` and `2262-04-11T23:47:16.0`. Please use `to_timestamp_seconds` for the input outside of supported bounds.
 
-```
+```sql
 to_timestamp(expression[, ..., format_n])
 ```
 
@@ -2345,7 +2345,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo
 
 Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as microseconds since the unix epoch (`1970-01-01T00:00:00Z`) Returns the corresponding timestamp.
 
-```
+```sql
 to_timestamp_micros(expression[, ..., format_n])
 ```
 
@@ -2377,7 +2377,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo
 
 Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. Integers and unsigned integers are interpreted as milliseconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp.
 
-```
+```sql
 to_timestamp_millis(expression[, ..., format_n])
 ```
 
@@ -2409,7 +2409,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo
 
 Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp.
 
-```
+```sql
 to_timestamp_nanos(expression[, ..., format_n])
 ```
 
@@ -2441,7 +2441,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo
 
 Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp.
 
-```
+```sql
 to_timestamp_seconds(expression[, ..., format_n])
 ```
 
@@ -2473,7 +2473,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo
 
 Converts a value to seconds since the unix epoch (`1970-01-01T00:00:00Z`). Supports strings, dates, timestamps and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided.
 
-```
+```sql
 to_unixtime(expression[, ..., format_n])
 ```
 
@@ -2600,7 +2600,7 @@ _Alias of [current_date](#current_date)._
 
 Returns the first non-null element in the array.
 
-```
+```sql
 array_any_value(array)
 ```
 
@@ -2627,7 +2627,7 @@ array_any_value(array)
 
 Appends an element to the end of an array.
 
-```
+```sql
 array_append(array, element)
 ```
 
@@ -2661,7 +2661,7 @@ _Alias of [array_concat](#array_concat)._
 
 Concatenates arrays.
 
-```
+```sql
 array_concat(array[, ..., array_n])
 ```
 
@@ -2695,7 +2695,7 @@ _Alias of [array_has](#array_has)._
 
 Returns an array of the array's dimensions.
 
-```
+```sql
 array_dims(array)
 ```
 
@@ -2722,7 +2722,7 @@ array_dims(array)
 
 Returns the Euclidean distance between two input arrays of equal length.
 
-```
+```sql
 array_distance(array1, array2)
 ```
 
@@ -2750,7 +2750,7 @@ array_distance(array1, array2)
 
 Returns distinct values from the array after removing duplicates.
 
-```
+```sql
 array_distinct(array)
 ```
 
@@ -2777,7 +2777,7 @@ array_distinct(array)
 
 Extracts the element with the index n from the array.
 
-```
+```sql
 array_element(array, index)
 ```
 
@@ -2811,7 +2811,7 @@ _Alias of [empty](#empty)._
 
 Returns an array of the elements that appear in the first array but not in the second.
 
-```
+```sql
 array_except(array1, array2)
 ```
 
@@ -2849,7 +2849,7 @@ _Alias of [array_element](#array_element)._
 
 Returns true if the array contains the element.
 
-```
+```sql
 array_has(array, element)
 ```
 
@@ -2879,7 +2879,7 @@ array_has(array, element)
 
 Returns true if all elements of sub-array exist in array.
 
-```
+```sql
 array_has_all(array, sub-array)
 ```
 
@@ -2907,7 +2907,7 @@ array_has_all(array, sub-array)
 
 Returns true if any elements exist in both arrays.
 
-```
+```sql
 array_has_any(array, sub-array)
 ```
 
@@ -2940,7 +2940,7 @@ _Alias of [array_position](#array_position)._
 
 Returns an array of elements in the intersection of array1 and array2.
 
-```
+```sql
 array_intersect(array1, array2)
 ```
 
@@ -2978,7 +2978,7 @@ _Alias of [array_to_string](#array_to_string)._
 
 Returns the length of the array dimension.
 
-```
+```sql
 array_length(array, dimension)
 ```
 
@@ -3006,7 +3006,7 @@ array_length(array, dimension)
 
 Returns the number of dimensions of the array.
 
-```
+```sql
 array_ndims(array, element)
 ```
 
@@ -3034,7 +3034,7 @@ array_ndims(array, element)
 
 Returns the array without the last element.
 
-```
+```sql
 array_pop_back(array)
 ```
 
@@ -3061,7 +3061,7 @@ array_pop_back(array)
 
 Returns the array without the first element.
 
-```
+```sql
 array_pop_front(array)
 ```
 
@@ -3088,7 +3088,7 @@ array_pop_front(array)
 
 Returns the position of the first occurrence of the specified element in the array.
 
-```
+```sql
 array_position(array, element)
 array_position(array, element, index)
 ```
@@ -3126,7 +3126,7 @@ array_position(array, element, index)
 
 Searches for an element in the array, returns all occurrences.
 
-```
+```sql
 array_positions(array, element)
 ```
 
@@ -3154,7 +3154,7 @@ array_positions(array, element)
 
 Prepends an element to the beginning of an array.
 
-```
+```sql
 array_prepend(element, array)
 ```
 
@@ -3192,7 +3192,7 @@ _Alias of [array_prepend](#array_prepend)._
 
 Removes the first element from the array equal to the given value.
 
-```
+```sql
 array_remove(array, element)
 ```
 
@@ -3220,7 +3220,7 @@ array_remove(array, element)
 
 Removes all elements from the array equal to the given value.
 
-```
+```sql
 array_remove_all(array, element)
 ```
 
@@ -3248,7 +3248,7 @@ array_remove_all(array, element)
 
 Removes the first `max` elements from the array equal to the given value.
 
-```
+```sql
 array_remove_n(array, element, max))
 ```
 
@@ -3277,7 +3277,7 @@ array_remove_n(array, element, max))
 
 Returns an array containing element `count` times.
 
-```
+```sql
 array_repeat(element, count)
 ```
 
@@ -3311,7 +3311,7 @@ array_repeat(element, count)
 
 Replaces the first occurrence of the specified element with another specified element.
 
-```
+```sql
 array_replace(array, from, to)
 ```
 
@@ -3340,7 +3340,7 @@ array_replace(array, from, to)
 
 Replaces all occurrences of the specified element with another specified element.
 
-```
+```sql
 array_replace_all(array, from, to)
 ```
 
@@ -3369,7 +3369,7 @@ array_replace_all(array, from, to)
 
 Replaces the first `max` occurrences of the specified element with another specified element.
 
-```
+```sql
 array_replace_n(array, from, to, max)
 ```
 
@@ -3399,7 +3399,7 @@ array_replace_n(array, from, to, max)
 
 Resizes the list to contain size elements. Initializes new elements with value or empty if value is not set.
 
-```
+```sql
 array_resize(array, size, value)
 ```
 
@@ -3428,7 +3428,7 @@ array_resize(array, size, value)
 
 Returns the array with the order of the elements reversed.
 
-```
+```sql
 array_reverse(array)
 ```
 
@@ -3455,7 +3455,7 @@ array_reverse(array)
 
 Returns a slice of the array based on 1-indexed start and end positions.
 
-```
+```sql
 array_slice(array, begin, end)
 ```
 
@@ -3485,7 +3485,7 @@ array_slice(array, begin, end)
 
 Sort array.
 
-```
+```sql
 array_sort(array, desc, nulls_first)
 ```
 
@@ -3514,7 +3514,7 @@ array_sort(array, desc, nulls_first)
 
 Converts each element to its text representation.
 
-```
+```sql
 array_to_string(array, delimiter[, null_string])
 ```
 
@@ -3545,7 +3545,7 @@ array_to_string(array, delimiter[, null_string])
 
 Returns an array of elements that are present in both arrays (all elements from both arrays) with out duplicates.
 
-```
+```sql
 array_union(array1, array2)
 ```
 
@@ -3583,7 +3583,7 @@ _Alias of [array_has_any](#array_has_any)._
 
 Returns the total number of elements in the array.
 
-```
+```sql
 cardinality(array)
 ```
 
@@ -3606,7 +3606,7 @@ cardinality(array)
 
 Returns 1 for an empty array or 0 for a non-empty array.
 
-```
+```sql
 empty(array)
 ```
 
@@ -3639,7 +3639,7 @@ Converts an array of arrays to a flat array.
 
 The flattened array contains all the elements from all source arrays.
 
-```
+```sql
 flatten(array)
 ```
 
@@ -3662,7 +3662,7 @@ flatten(array)
 
 Similar to the range function, but it includes the upper bound.
 
-```
+```sql
 generate_series(start, stop, step)
 ```
 
@@ -3847,7 +3847,7 @@ _Alias of [array_union](#array_union)._
 
 Returns an array using the specified input expressions.
 
-```
+```sql
 make_array(expression1[, ..., expression_n])
 ```
 
@@ -3878,7 +3878,7 @@ _Alias of [make_array](#make_array)._
 
 Returns an Arrow array between start and stop with step. The range start..end contains all values with start <= x < end. It is empty if start >= end. Step cannot be 0.
 
-```
+```sql
 range(start, stop, step)
 ```
 
@@ -3910,7 +3910,7 @@ range(start, stop, step)
 
 Splits a string into an array of substrings based on a delimiter. Any substrings matching the optional `null_str` argument are replaced with NULL.
 
-```
+```sql
 string_to_array(str, delimiter[, null_str])
 ```
 
@@ -3955,7 +3955,7 @@ _Alias of [string_to_array](#string_to_array)._
 
 Returns an Arrow struct using the specified name and input expressions pairs.
 
-```
+```sql
 named_struct(expression1_name, expression1_input[, ..., expression_n_name, expression_n_input])
 ```
 
@@ -3996,7 +3996,7 @@ Returns an Arrow struct using the specified input expressions optionally named.
 Fields in the returned struct use the optional name or the `cN` naming convention.
 For example: `c0`, `c1`, `c2`, etc.
 
-```
+```sql
 struct(expression1[, ..., expression_n])
 ```
 
@@ -4059,7 +4059,7 @@ Returns an Arrow map with the specified key-value pairs.
 
 The `make_map` function creates a map from two lists: one for keys and one for values. Each key must be unique and non-null.
 
-```
+```sql
 map(key, value)
 map(key: value)
 make_map(['key1', 'key2'], ['value1', 'value2'])
@@ -4106,7 +4106,7 @@ SELECT MAKE_MAP(['key1', 'key2'], ['value1', null]);
 
 Returns a list containing the value for the given key or an empty list if the key is not present in the map.
 
-```
+```sql
 map_extract(map, key)
 ```
 
@@ -4139,7 +4139,7 @@ SELECT map_extract(MAP {'x': 10, 'y': NULL, 'z': 30}, 'y');
 
 Returns a list of all keys in the map.
 
-```
+```sql
 map_keys(map)
 ```
 
@@ -4163,7 +4163,7 @@ SELECT map_keys(map([100, 5], [42, 43]));
 
 Returns a list of all values in the map.
 
-```
+```sql
 map_values(map)
 ```
 
@@ -4196,7 +4196,7 @@ SELECT map_values(map([100, 5], [42, 43]));
 
 Computes the binary hash of an expression using the specified algorithm.
 
-```
+```sql
 digest(expression, algorithm)
 ```
 
@@ -4228,7 +4228,7 @@ digest(expression, algorithm)
 
 Computes an MD5 128-bit checksum for a string expression.
 
-```
+```sql
 md5(expression)
 ```
 
@@ -4251,7 +4251,7 @@ md5(expression)
 
 Computes the SHA-224 hash of a binary string.
 
-```
+```sql
 sha224(expression)
 ```
 
@@ -4274,7 +4274,7 @@ sha224(expression)
 
 Computes the SHA-256 hash of a binary string.
 
-```
+```sql
 sha256(expression)
 ```
 
@@ -4297,7 +4297,7 @@ sha256(expression)
 
 Computes the SHA-384 hash of a binary string.
 
-```
+```sql
 sha384(expression)
 ```
 
@@ -4320,7 +4320,7 @@ sha384(expression)
 
 Computes the SHA-512 hash of a binary string.
 
-```
+```sql
 sha512(expression)
 ```
 
@@ -4339,6 +4339,40 @@ sha512(expression)
 +-------------------------------------------+
 ```
 
+## Union Functions
+
+Functions to work with the union data type, also know as tagged unions, variant types, enums or sum types. Note: Not related to the SQL UNION operator
+
+- [union_extract](#union_extract)
+
+### `union_extract`
+
+Returns the value of the given field in the union when selected, or NULL otherwise.
+
+```sql
+union_extract(union, field_name)
+```
+
+#### Arguments
+
+- **union**: Union expression to operate on. Can be a constant, column, or function, and any combination of operators.
+- **field_name**: String expression to operate on. Must be a constant.
+
+#### Example
+
+```sql
+❯ select union_column, union_extract(union_column, 'a'), union_extract(union_column, 'b') from table_with_union;
++--------------+----------------------------------+----------------------------------+
+| union_column | union_extract(union_column, 'a') | union_extract(union_column, 'b') |
++--------------+----------------------------------+----------------------------------+
+| {a=1}        | 1                                |                                  |
+| {b=3.0}      |                                  | 3.0                              |
+| {a=4}        | 4                                |                                  |
+| {b=}         |                                  |                                  |
+| {a=}         |                                  |                                  |
++--------------+----------------------------------+----------------------------------+
+```
+
 ## Other Functions
 
 - [arrow_cast](#arrow_cast)
@@ -4350,7 +4384,7 @@ sha512(expression)
 
 Casts a value to a specific Arrow data type.
 
-```
+```sql
 arrow_cast(expression, datatype)
 ```
 
@@ -4378,7 +4412,7 @@ arrow_cast(expression, datatype)
 
 Returns the name of the underlying [Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) of the expression.
 
-```
+```sql
 arrow_typeof(expression)
 ```
 
@@ -4404,7 +4438,7 @@ Note: most users invoke `get_field` indirectly via field access
 syntax such as `my_struct_col['field_name']` which results in a call to
 `get_field(my_struct_col, 'field_name')`.
 
-```
+```sql
 get_field(expression1, expression2)
 ```
 
@@ -4444,7 +4478,7 @@ get_field(expression1, expression2)
 
 Returns the version of DataFusion.
 
-```
+```sql
 version()
 ```
 
diff --git a/docs/source/user-guide/sql/window_functions.md b/docs/source/user-guide/sql/window_functions.md
index a68fdbda6709..1c02804f0dee 100644
--- a/docs/source/user-guide/sql/window_functions.md
+++ b/docs/source/user-guide/sql/window_functions.md
@@ -115,7 +115,7 @@ WINDOW w AS (PARTITION BY depname ORDER BY salary DESC);
 
 The syntax for the OVER-clause is
 
-```
+```sql
 function([expr])
   OVER(
     [PARTITION BY expr[, …]]
@@ -126,7 +126,7 @@ function([expr])
 
 where **frame_clause** is one of:
 
-```
+```sql
   { RANGE | ROWS | GROUPS } frame_start
   { RANGE | ROWS | GROUPS } BETWEEN frame_start AND frame_end
 ```
@@ -162,7 +162,7 @@ All [aggregate functions](aggregate_functions.md) can be used as window function
 
 Relative rank of the current row: (number of rows preceding or peer with current row) / (total rows).
 
-```
+```sql
 cume_dist()
 ```
 
@@ -170,7 +170,7 @@ cume_dist()
 
 Returns the rank of the current row without gaps. This function ranks rows in a dense manner, meaning consecutive ranks are assigned even for identical values.
 
-```
+```sql
 dense_rank()
 ```
 
@@ -178,7 +178,7 @@ dense_rank()
 
 Integer ranging from 1 to the argument value, dividing the partition as equally as possible
 
-```
+```sql
 ntile(expression)
 ```
 
@@ -190,7 +190,7 @@ ntile(expression)
 
 Returns the percentage rank of the current row within its partition. The value ranges from 0 to 1 and is computed as `(rank - 1) / (total_rows - 1)`.
 
-```
+```sql
 percent_rank()
 ```
 
@@ -198,7 +198,7 @@ percent_rank()
 
 Returns the rank of the current row within its partition, allowing gaps between ranks. This function provides a ranking similar to `row_number`, but skips ranks for identical values.
 
-```
+```sql
 rank()
 ```
 
@@ -206,7 +206,7 @@ rank()
 
 Number of the current row within its partition, counting from 1.
 
-```
+```sql
 row_number()
 ```
 
@@ -222,7 +222,7 @@ row_number()
 
 Returns value evaluated at the row that is the first row of the window frame.
 
-```
+```sql
 first_value(expression)
 ```
 
@@ -234,7 +234,7 @@ first_value(expression)
 
 Returns value evaluated at the row that is offset rows before the current row within the partition; if there is no such row, instead return default (which must be of the same type as value).
 
-```
+```sql
 lag(expression, offset, default)
 ```
 
@@ -248,7 +248,7 @@ lag(expression, offset, default)
 
 Returns value evaluated at the row that is the last row of the window frame.
 
-```
+```sql
 last_value(expression)
 ```
 
@@ -260,7 +260,7 @@ last_value(expression)
 
 Returns value evaluated at the row that is offset rows after the current row within the partition; if there is no such row, instead return default (which must be of the same type as value).
 
-```
+```sql
 lead(expression, offset, default)
 ```
 
@@ -274,7 +274,7 @@ lead(expression, offset, default)
 
 Returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row.
 
-```
+```sql
 nth_value(expression, n)
 ```
 
diff --git a/rust-toolchain.toml b/rust-toolchain.toml
new file mode 100644
index 000000000000..bd764d201018
--- /dev/null
+++ b/rust-toolchain.toml
@@ -0,0 +1,23 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# This file specifies the default version of Rust used
+# to compile this workspace and run CI jobs.
+
+[toolchain]
+channel = "1.84.1"
+components = ["rustfmt", "clippy"]
diff --git a/taplo.toml b/taplo.toml
index b7089c501680..47b33161c37e 100644
--- a/taplo.toml
+++ b/taplo.toml
@@ -18,6 +18,7 @@
 ## https://taplo.tamasfe.dev/configuration/file.html
 
 include = ["**/Cargo.toml"]
+exclude = ["target/*"]
 
 [formatting]
 # Align consecutive entries vertically.