diff --git a/Cargo.toml b/Cargo.toml index 0c76ff196a10..cc94b4292a50 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -93,7 +93,7 @@ arrow-ipc = { version = "53.3.0", default-features = false, features = [ arrow-ord = { version = "53.3.0", default-features = false } arrow-schema = { version = "53.3.0", default-features = false } async-trait = "0.1.73" -bigdecimal = "0.4.6" +bigdecimal = "0.4.7" bytes = "1.4" chrono = { version = "0.4.38", default-features = false } ctor = "0.2.0" diff --git a/README.md b/README.md index 553097552418..2e4f2c347fe5 100644 --- a/README.md +++ b/README.md @@ -126,14 +126,17 @@ Optional features: ## Rust Version Compatibility Policy -DataFusion's Minimum Required Stable Rust Version (MSRV) policy is to support stable [4 latest -Rust versions](https://releases.rs) OR the stable minor Rust version as of 4 months, whichever is lower. +The Rust toolchain releases are tracked at [Rust Versions](https://releases.rs) and follow +[semantic versioning](https://semver.org/). A Rust toolchain release can be identified +by a version string like `1.80.0`, or more generally `major.minor.patch`. + +DataFusion's supports the last 4 stable Rust minor versions released and any such versions released within the last 4 months. For example, given the releases `1.78.0`, `1.79.0`, `1.80.0`, `1.80.1` and `1.81.0` DataFusion will support 1.78.0, which is 3 minor versions prior to the most minor recent `1.81`. -If a hotfix is released for the minimum supported Rust version (MSRV), the MSRV will be the minor version with all hotfixes, even if it surpasses the four-month window. +Note: If a Rust hotfix is released for the current MSRV, the MSRV will be updated to the specific minor version that includes all applicable hotfixes preceding other policies. -We enforce this policy using a [MSRV CI Check](https://github.com/search?q=repo%3Aapache%2Fdatafusion+rust-version+language%3ATOML+path%3A%2F%5ECargo.toml%2F&type=code) +DataFusion enforces MSRV policy using a [MSRV CI Check](https://github.com/search?q=repo%3Aapache%2Fdatafusion+rust-version+language%3ATOML+path%3A%2F%5ECargo.toml%2F&type=code) ## DataFusion API evolution policy diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index c871b2fdda08..a0a89fb3d14f 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -133,14 +133,14 @@ dependencies = [ [[package]] name = "apache-avro" -version = "0.16.0" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ceb7c683b2f8f40970b70e39ff8be514c95b96fcb9c4af87e1ed2cb2e10801a0" +checksum = "1aef82843a0ec9f8b19567445ad2421ceeb1d711514384bdd3d49fe37102ee13" dependencies = [ + "bigdecimal", "bzip2", "crc32fast", "digest", - "lazy_static", "libflate", "log", "num-bigint", @@ -148,15 +148,16 @@ dependencies = [ "rand", "regex-lite", "serde", + "serde_bytes", "serde_json", "snap", - "strum 0.25.0", - "strum_macros 0.25.3", + "strum", + "strum_macros", "thiserror 1.0.69", "typed-builder", "uuid", "xz2", - "zstd 0.12.4", + "zstd", ] [[package]] @@ -418,8 +419,8 @@ dependencies = [ "pin-project-lite", "tokio", "xz2", - "zstd 0.13.2", - "zstd-safe 7.2.1", + "zstd", + "zstd-safe", ] [[package]] @@ -800,6 +801,20 @@ dependencies = [ "vsimd", ] +[[package]] +name = "bigdecimal" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f31f3af01c5c65a07985c804d3366560e6fa7883d640a122819b14ec327482c" +dependencies = [ + "autocfg", + "libm", + "num-bigint", + "num-integer", + "num-traits", + "serde", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -866,9 +881,9 @@ dependencies = [ [[package]] name = "bstr" -version = "1.11.0" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a68f1f47cdf0ec8ee4b941b2eee2a80cb796db73118c0dd09ac63fbe405be22" +checksum = "786a307d683a5bf92e6fd5fd69a7eb613751668d1d8d67d802846dfe367c62c8" dependencies = [ "memchr", "regex-automata", @@ -926,9 +941,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.2" +version = "1.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f34d93e62b03caf570cccc334cbc6c2fceca82f39211051345108adcba3eebdc" +checksum = "27f657647bcff5394bf56c7317665bbf790a137a50eaaa5c6bfbb9e27a518f2d" dependencies = [ "jobserver", "libc", @@ -955,9 +970,9 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chrono" -version = "0.4.38" +version = "0.4.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" +checksum = "7e36cc9d416881d2e24f9a963be5fb1cd90966419ac844274161d10488b3e825" dependencies = [ "android-tzdata", "iana-time-zone", @@ -989,9 +1004,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.22" +version = "4.5.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69371e34337c4c984bbe322360c2547210bf632eb2814bbe78a6e87a2935bd2b" +checksum = "3135e7ec2ef7b10c6ed8950f0f792ed96ee093fa088608f1c76e569722700c84" dependencies = [ "clap_builder", "clap_derive", @@ -999,9 +1014,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.22" +version = "4.5.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e24c1b4099818523236a8ca881d2b45db98dadfb4625cf6608c12069fcbbde1" +checksum = "30582fc632330df2bd26877bde0c1f4470d57c582bbc070376afcd04d8cb4838" dependencies = [ "anstream", "anstyle", @@ -1015,7 +1030,7 @@ version = "4.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "syn", @@ -1023,9 +1038,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7" +checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" [[package]] name = "clipboard-win" @@ -1048,8 +1063,8 @@ version = "7.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24f165e7b643266ea80cb858aed492ad9280e3e05ce24d4a99d7d7b889b6a4d9" dependencies = [ - "strum 0.26.3", - "strum_macros 0.26.4", + "strum", + "strum_macros", "unicode-width 0.2.0", ] @@ -1245,7 +1260,6 @@ dependencies = [ "object_store", "parking_lot", "parquet", - "paste", "rand", "sqlparser", "tempfile", @@ -1254,7 +1268,7 @@ dependencies = [ "url", "uuid", "xz2", - "zstd 0.13.2", + "zstd", ] [[package]] @@ -1311,7 +1325,6 @@ dependencies = [ "arrow-array", "arrow-buffer", "arrow-schema", - "chrono", "half", "hashbrown 0.14.5", "indexmap", @@ -1342,7 +1355,6 @@ name = "datafusion-execution" version = "43.0.0" dependencies = [ "arrow", - "chrono", "dashmap", "datafusion-common", "datafusion-expr", @@ -1359,10 +1371,7 @@ dependencies = [ name = "datafusion-expr" version = "43.0.0" dependencies = [ - "ahash", "arrow", - "arrow-array", - "arrow-buffer", "chrono", "datafusion-common", "datafusion-doc", @@ -1375,8 +1384,6 @@ dependencies = [ "recursive", "serde_json", "sqlparser", - "strum 0.26.3", - "strum_macros 0.26.4", ] [[package]] @@ -1386,7 +1393,6 @@ dependencies = [ "arrow", "datafusion-common", "itertools", - "paste", ] [[package]] @@ -1445,7 +1451,6 @@ dependencies = [ "datafusion-common", "datafusion-expr-common", "datafusion-physical-expr-common", - "rand", ] [[package]] @@ -1466,28 +1471,18 @@ dependencies = [ "itertools", "log", "paste", - "rand", ] [[package]] name = "datafusion-functions-table" version = "43.0.0" dependencies = [ - "ahash", "arrow", - "arrow-schema", "async-trait", "datafusion-catalog", "datafusion-common", - "datafusion-execution", "datafusion-expr", - "datafusion-functions-aggregate-common", - "datafusion-physical-expr", - "datafusion-physical-expr-common", "datafusion-physical-plan", - "half", - "indexmap", - "log", "parking_lot", "paste", ] @@ -1497,8 +1492,10 @@ name = "datafusion-functions-window" version = "43.0.0" dependencies = [ "datafusion-common", + "datafusion-doc", "datafusion-expr", "datafusion-functions-window-common", + "datafusion-macros", "datafusion-physical-expr", "datafusion-physical-expr-common", "log", @@ -1517,8 +1514,6 @@ dependencies = [ name = "datafusion-macros" version = "43.0.0" dependencies = [ - "datafusion-doc", - "proc-macro2", "quote", "syn", ] @@ -1528,7 +1523,6 @@ name = "datafusion-optimizer" version = "43.0.0" dependencies = [ "arrow", - "async-trait", "chrono", "datafusion-common", "datafusion-expr", @@ -1616,10 +1610,8 @@ dependencies = [ "indexmap", "itertools", "log", - "once_cell", "parking_lot", "pin-project-lite", - "rand", "tokio", ] @@ -1762,9 +1754,9 @@ checksum = "a5d9305ccc6942a704f4335694ecd3de2ea531b114ac2d51f5f843750787a92f" [[package]] name = "fastrand" -version = "2.2.0" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" [[package]] name = "fd-lock" @@ -2022,12 +2014,6 @@ version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" -[[package]] -name = "heck" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" - [[package]] name = "heck" version = "0.5.0" @@ -2202,11 +2188,11 @@ dependencies = [ "http 1.2.0", "hyper 1.5.1", "hyper-util", - "rustls 0.23.19", + "rustls 0.23.20", "rustls-native-certs 0.8.1", "rustls-pki-types", "tokio", - "tokio-rustls 0.26.0", + "tokio-rustls 0.26.1", "tower-service", ] @@ -2445,9 +2431,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.74" +version = "0.3.76" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a865e038f7f6ed956f788f0d7d60c541fff74c7bd74272c5d4cf15c63743e705" +checksum = "6717b6b5b077764fb5966237269cb3c64edddde4b14ce42647430a78ced9e7b7" dependencies = [ "once_cell", "wasm-bindgen", @@ -2461,9 +2447,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "lexical-core" -version = "1.0.2" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0431c65b318a590c1de6b8fd6e72798c92291d27762d94c9e6c37ed7a73d8458" +checksum = "b765c31809609075565a70b4b71402281283aeda7ecaf4818ac14a7b2ade8958" dependencies = [ "lexical-parse-float", "lexical-parse-integer", @@ -2474,9 +2460,9 @@ dependencies = [ [[package]] name = "lexical-parse-float" -version = "1.0.2" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb17a4bdb9b418051aa59d41d65b1c9be5affab314a872e5ad7f06231fb3b4e0" +checksum = "de6f9cb01fb0b08060209a057c048fcbab8717b4c1ecd2eac66ebfe39a65b0f2" dependencies = [ "lexical-parse-integer", "lexical-util", @@ -2485,9 +2471,9 @@ dependencies = [ [[package]] name = "lexical-parse-integer" -version = "1.0.2" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5df98f4a4ab53bf8b175b363a34c7af608fe31f93cc1fb1bf07130622ca4ef61" +checksum = "72207aae22fc0a121ba7b6d479e42cbfea549af1479c3f3a4f12c70dd66df12e" dependencies = [ "lexical-util", "static_assertions", @@ -2495,18 +2481,18 @@ dependencies = [ [[package]] name = "lexical-util" -version = "1.0.3" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85314db53332e5c192b6bca611fb10c114a80d1b831ddac0af1e9be1b9232ca0" +checksum = "5a82e24bf537fd24c177ffbbdc6ebcc8d54732c35b50a3f28cc3f4e4c949a0b3" dependencies = [ "static_assertions", ] [[package]] name = "lexical-write-float" -version = "1.0.2" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e7c3ad4e37db81c1cbe7cf34610340adc09c322871972f74877a712abc6c809" +checksum = "c5afc668a27f460fb45a81a757b6bf2f43c2d7e30cb5a2dcd3abf294c78d62bd" dependencies = [ "lexical-util", "lexical-write-integer", @@ -2515,9 +2501,9 @@ dependencies = [ [[package]] name = "lexical-write-integer" -version = "1.0.2" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb89e9f6958b83258afa3deed90b5de9ef68eef090ad5086c791cd2345610162" +checksum = "629ddff1a914a836fb245616a7888b62903aae58fa771e1d83943035efa0f978" dependencies = [ "lexical-util", "static_assertions", @@ -2525,9 +2511,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.167" +version = "0.2.168" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09d6582e104315a817dff97f75133544b2e094ee22447d2acf4a74e189ba06fc" +checksum = "5aaeb2981e0606ca11d79718f8bb01164f1d6ed75080182d3abf017e6d244b6d" [[package]] name = "libflate" @@ -2727,6 +2713,7 @@ checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" dependencies = [ "num-integer", "num-traits", + "serde", ] [[package]] @@ -2913,7 +2900,7 @@ dependencies = [ "thrift", "tokio", "twox-hash", - "zstd 0.13.2", + "zstd", "zstd-sys", ] @@ -3103,9 +3090,9 @@ dependencies = [ "quinn-proto", "quinn-udp", "rustc-hash", - "rustls 0.23.19", + "rustls 0.23.20", "socket2", - "thiserror 2.0.4", + "thiserror 2.0.6", "tokio", "tracing", ] @@ -3121,10 +3108,10 @@ dependencies = [ "rand", "ring", "rustc-hash", - "rustls 0.23.19", + "rustls 0.23.20", "rustls-pki-types", "slab", - "thiserror 2.0.4", + "thiserror 2.0.6", "tinyvec", "tracing", "web-time", @@ -3132,9 +3119,9 @@ dependencies = [ [[package]] name = "quinn-udp" -version = "0.5.7" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d5a626c6807713b15cac82a6acaccd6043c9a5408c24baae07611fec3f243da" +checksum = "52cd4b1eff68bf27940dd39811292c49e007f4d0b4c357358dc9b0197be6b527" dependencies = [ "cfg_aliases 0.2.1", "libc", @@ -3215,9 +3202,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.7" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" +checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834" dependencies = [ "bitflags 2.6.0", ] @@ -3299,7 +3286,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls 0.23.19", + "rustls 0.23.20", "rustls-native-certs 0.8.1", "rustls-pemfile 2.2.0", "rustls-pki-types", @@ -3308,7 +3295,7 @@ dependencies = [ "serde_urlencoded", "sync_wrapper", "tokio", - "tokio-rustls 0.26.0", + "tokio-rustls 0.26.1", "tokio-util", "tower-service", "url", @@ -3393,15 +3380,15 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.41" +version = "0.38.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7f649912bc1495e167a6edee79151c84b1bad49748cb4f1f1167f459f6224f6" +checksum = "f93dc38ecbab2eb790ff964bb77fa94faf256fd3e73285fd7ba0903b76bedb85" dependencies = [ "bitflags 2.6.0", "errno", "libc", "linux-raw-sys", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -3418,9 +3405,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.19" +version = "0.23.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "934b404430bb06b3fae2cba809eb45a1ab1aecd64491213d7c3301b88393f8d1" +checksum = "5065c3f250cbd332cd894be57c40fa52387247659b14a2d6041d121547903b1b" dependencies = [ "once_cell", "ring", @@ -3608,9 +3595,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.23" +version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" +checksum = "3cb6eb87a131f756572d7fb904f6e7b68633f09cca868c5df1c4b8d1a694bbba" [[package]] name = "seq-macro" @@ -3620,18 +3607,27 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.215" +version = "1.0.216" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f" +checksum = "0b9781016e935a97e8beecf0c933758c97a5520d32930e460142b4cd80c6338e" dependencies = [ "serde_derive", ] +[[package]] +name = "serde_bytes" +version = "0.11.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "387cc504cb06bb40a96c8e04e951fe01854cf6bc921053c954e4a606d9675c6a" +dependencies = [ + "serde", +] + [[package]] name = "serde_derive" -version = "1.0.215" +version = "1.0.216" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0" +checksum = "46f859dbbf73865c6627ed570e78961cd3ac92407a2d117204c49232485da55e" dependencies = [ "proc-macro2", "quote", @@ -3724,7 +3720,7 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "03c3c6b7927ffe7ecaa769ee0e3994da3b8cafc8f444578982c83ecb161af917" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "syn", @@ -3804,33 +3800,11 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" -[[package]] -name = "strum" -version = "0.25.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" - [[package]] name = "strum" version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" -dependencies = [ - "strum_macros 0.26.4", -] - -[[package]] -name = "strum_macros" -version = "0.25.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0" -dependencies = [ - "heck 0.4.1", - "proc-macro2", - "quote", - "rustversion", - "syn", -] [[package]] name = "strum_macros" @@ -3838,7 +3812,7 @@ version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "rustversion", @@ -3912,11 +3886,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.4" +version = "2.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f49a1853cf82743e3b7950f77e0f4d622ca36cf4317cba00c767838bac8d490" +checksum = "8fec2a1820ebd077e2b90c4df007bebf344cd394098a13c563957d0afc83ea47" dependencies = [ - "thiserror-impl 2.0.4", + "thiserror-impl 2.0.6", ] [[package]] @@ -3932,9 +3906,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.4" +version = "2.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8381894bb3efe0c4acac3ded651301ceee58a15d47c2e34885ed1908ad667061" +checksum = "d65750cab40f4ff1929fb1ba509e9914eb756131cef4210da8d5d700d26f6312" dependencies = [ "proc-macro2", "quote", @@ -4057,12 +4031,11 @@ dependencies = [ [[package]] name = "tokio-rustls" -version = "0.26.0" +version = "0.26.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" +checksum = "5f6d0975eaace0cf0fcadee4e4aaa5da15b5c079146f2cffb67c113be122bf37" dependencies = [ - "rustls 0.23.19", - "rustls-pki-types", + "rustls 0.23.20", "tokio", ] @@ -4151,18 +4124,18 @@ dependencies = [ [[package]] name = "typed-builder" -version = "0.16.2" +version = "0.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34085c17941e36627a879208083e25d357243812c30e7d7387c3b954f30ade16" +checksum = "a06fbd5b8de54c5f7c91f6fe4cebb949be2125d7758e630bb58b1d831dbce600" dependencies = [ "typed-builder-macro", ] [[package]] name = "typed-builder-macro" -version = "0.16.2" +version = "0.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" +checksum = "f9534daa9fd3ed0bd911d462a37f172228077e7abf18c18a5f67199d959205f8" dependencies = [ "proc-macro2", "quote", @@ -4298,9 +4271,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.97" +version = "0.2.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d15e63b4482863c109d70a7b8706c1e364eb6ea449b201a76c5b89cedcec2d5c" +checksum = "a474f6281d1d70c17ae7aa6a613c87fce69a127e2624002df63dcb39d6cf6396" dependencies = [ "cfg-if", "once_cell", @@ -4309,13 +4282,12 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.97" +version = "0.2.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d36ef12e3aaca16ddd3f67922bc63e48e953f126de60bd33ccc0101ef9998cd" +checksum = "5f89bb38646b4f81674e8f5c3fb81b562be1fd936d84320f3264486418519c79" dependencies = [ "bumpalo", "log", - "once_cell", "proc-macro2", "quote", "syn", @@ -4324,9 +4296,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.47" +version = "0.4.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9dfaf8f50e5f293737ee323940c7d8b08a66a95a419223d9f41610ca08b0833d" +checksum = "38176d9b44ea84e9184eff0bc34cc167ed044f816accfe5922e54d84cf48eca2" dependencies = [ "cfg-if", "js-sys", @@ -4337,9 +4309,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.97" +version = "0.2.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "705440e08b42d3e4b36de7d66c944be628d579796b8090bfa3471478a2260051" +checksum = "2cc6181fd9a7492eef6fef1f33961e3695e4579b9872a6f7c83aee556666d4fe" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -4347,9 +4319,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.97" +version = "0.2.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "98c9ae5a76e46f4deecd0f0255cc223cfa18dc9b261213b8aa0c7b36f61b3f1d" +checksum = "30d7a95b763d3c45903ed6c81f156801839e5ee968bb07e534c44df0fcd330c2" dependencies = [ "proc-macro2", "quote", @@ -4360,9 +4332,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.97" +version = "0.2.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ee99da9c5ba11bd675621338ef6fa52296b76b83305e9b6e5c77d4c286d6d49" +checksum = "943aab3fdaaa029a6e0271b35ea10b72b943135afe9bffca82384098ad0e06a6" [[package]] name = "wasm-streams" @@ -4379,9 +4351,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.74" +version = "0.3.76" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a98bc3c33f0fe7e59ad7cd041b89034fa82a7c2d4365ca538dda6cdaf513863c" +checksum = "04dd7223427d52553d3702c004d3b2fe07c148165faa56313cb00211e31c12bc" dependencies = [ "js-sys", "wasm-bindgen", @@ -4723,32 +4695,13 @@ dependencies = [ "syn", ] -[[package]] -name = "zstd" -version = "0.12.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a27595e173641171fc74a1232b7b1c7a7cb6e18222c11e9dfb9888fa424c53c" -dependencies = [ - "zstd-safe 6.0.6", -] - [[package]] name = "zstd" version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" dependencies = [ - "zstd-safe 7.2.1", -] - -[[package]] -name = "zstd-safe" -version = "6.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee98ffd0b48ee95e6c5168188e44a54550b1564d9d530ee21d5f0eaed1069581" -dependencies = [ - "libc", - "zstd-sys", + "zstd-safe", ] [[package]] diff --git a/datafusion-examples/examples/advanced_udf.rs b/datafusion-examples/examples/advanced_udf.rs index aee3be6c9285..ae35cff6facf 100644 --- a/datafusion-examples/examples/advanced_udf.rs +++ b/datafusion-examples/examples/advanced_udf.rs @@ -27,9 +27,11 @@ use arrow::record_batch::RecordBatch; use datafusion::error::Result; use datafusion::logical_expr::Volatility; use datafusion::prelude::*; -use datafusion_common::{internal_err, ScalarValue}; +use datafusion_common::{exec_err, internal_err, ScalarValue}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; -use datafusion_expr::{ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, +}; /// This example shows how to use the full ScalarUDFImpl API to implement a user /// defined function. As in the `simple_udf.rs` example, this struct implements @@ -83,23 +85,27 @@ impl ScalarUDFImpl for PowUdf { Ok(DataType::Float64) } - /// This is the function that actually calculates the results. + /// This function actually calculates the results of the scalar function. + /// + /// This is the same way that functions provided with DataFusion are invoked, + /// which permits important special cases: /// - /// This is the same way that functions built into DataFusion are invoked, - /// which permits important special cases when one or both of the arguments - /// are single values (constants). For example `pow(a, 2)` + ///1. When one or both of the arguments are single values (constants). + /// For example `pow(a, 2)` + /// 2. When the input arrays can be reused (avoid allocating a new output array) /// /// However, it also means the implementation is more complex than when /// using `create_udf`. - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + // The other fields of the `args` struct are used for more specialized + // uses, and are not needed in this example + let ScalarFunctionArgs { mut args, .. } = args; // DataFusion has arranged for the correct inputs to be passed to this // function, but we check again to make sure assert_eq!(args.len(), 2); - let (base, exp) = (&args[0], &args[1]); + // take ownership of arguments by popping in reverse order + let exp = args.pop().unwrap(); + let base = args.pop().unwrap(); assert_eq!(base.data_type(), DataType::Float64); assert_eq!(exp.data_type(), DataType::Float64); @@ -118,7 +124,7 @@ impl ScalarUDFImpl for PowUdf { ) => { // compute the output. Note DataFusion treats `None` as NULL. let res = match (base, exp) { - (Some(base), Some(exp)) => Some(base.powf(*exp)), + (Some(base), Some(exp)) => Some(base.powf(exp)), // one or both arguments were NULL _ => None, }; @@ -140,31 +146,33 @@ impl ScalarUDFImpl for PowUdf { // kernel creates very fast "vectorized" code and // handles things like null values for us. let res: Float64Array = - compute::unary(base_array, |base| base.powf(*exp)); + compute::unary(base_array, |base| base.powf(exp)); Arc::new(res) } }; Ok(ColumnarValue::Array(result_array)) } - // special case if the base is a constant (note this code is quite - // similar to the previous case, so we omit comments) + // special case if the base is a constant. + // + // Note this case is very similar to the previous case, so we could + // use the same pattern. However, for this case we demonstrate an + // even more advanced pattern to potentially avoid allocating a new array ( ColumnarValue::Scalar(ScalarValue::Float64(base)), ColumnarValue::Array(exp_array), ) => { let res = match base { None => new_null_array(exp_array.data_type(), exp_array.len()), - Some(base) => { - let exp_array = exp_array.as_primitive::(); - let res: Float64Array = - compute::unary(exp_array, |exp| base.powf(exp)); - Arc::new(res) - } + Some(base) => maybe_pow_in_place(base, exp_array)?, }; Ok(ColumnarValue::Array(res)) } - // Both arguments are arrays so we have to perform the calculation for every row + // Both arguments are arrays so we have to perform the calculation + // for every row + // + // Note this could also be done in place using `binary_mut` as + // is done in `maybe_pow_in_place` but here we use binary for simplicity (ColumnarValue::Array(base_array), ColumnarValue::Array(exp_array)) => { let res: Float64Array = compute::binary( base_array.as_primitive::(), @@ -191,6 +199,52 @@ impl ScalarUDFImpl for PowUdf { } } +/// Evaluate `base ^ exp` *without* allocating a new array, if possible +fn maybe_pow_in_place(base: f64, exp_array: ArrayRef) -> Result { + // Calling `unary` creates a new array for the results. Avoiding + // allocations is a common optimization in performance critical code. + // arrow-rs allows this optimization via the `unary_mut` + // and `binary_mut` kernels in certain cases + // + // These kernels can only be used if there are no other references to + // the arrays (exp_array has to be the last remaining reference). + let owned_array = exp_array + // as in the previous example, we first downcast to &Float64Array + .as_primitive::() + // non-obviously, we call clone here to get an owned `Float64Array`. + // Calling clone() is relatively inexpensive as it increments + // some ref counts but doesn't clone the data) + // + // Once we have the owned Float64Array we can drop the original + // exp_array (untyped) reference + .clone(); + + // We *MUST* drop the reference to `exp_array` explicitly so that + // owned_array is the only reference remaining in this function. + // + // Note that depending on the query there may still be other references + // to the underlying buffers, which would prevent reuse. The only way to + // know for sure is the result of `compute::unary_mut` + drop(exp_array); + + // If we have the only reference, compute the result directly into the same + // allocation as was used for the input array + match compute::unary_mut(owned_array, |exp| base.powf(exp)) { + Err(_orig_array) => { + // unary_mut will return the original array if there are other + // references into the underling buffer (and thus reuse is + // impossible) + // + // In a real implementation, this case should fall back to + // calling `unary` and allocate a new array; In this example + // we will return an error for demonstration purposes + exec_err!("Could not reuse array for maybe_pow_in_place") + } + // a result of OK means the operation was run successfully + Ok(res) => Ok(Arc::new(res)), + } +} + /// In this example we register `PowUdf` as a user defined function /// and invoke it via the DataFrame API and SQL #[tokio::main] @@ -215,9 +269,29 @@ async fn main() -> Result<()> { // print the results df.show().await?; - // You can also invoke both pow(2, 10) and its alias my_pow(a, b) using SQL - let sql_df = ctx.sql("SELECT pow(2, 10), my_pow(a, b) FROM t").await?; - sql_df.show().await?; + // You can also invoke both pow(2, 10) and its alias my_pow(a, b) using SQL + ctx.sql("SELECT pow(2, 10), my_pow(a, b) FROM t") + .await? + .show() + .await?; + + // You can also invoke pow_in_place by passing a constant base and a + // column `a` as the exponent . If there is only a single + // reference to `a` the code works well + ctx.sql("SELECT pow(2, a) FROM t").await?.show().await?; + + // However, if there are multiple references to `a` in the evaluation + // the array storage can not be reused + let err = ctx + .sql("SELECT pow(2, a), pow(3, a) FROM t") + .await? + .show() + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "Execution error: Could not reuse array for maybe_pow_in_place" + ); Ok(()) } diff --git a/datafusion-examples/examples/parse_sql_expr.rs b/datafusion-examples/examples/parse_sql_expr.rs index e23e5accae39..d8f0778e19e3 100644 --- a/datafusion-examples/examples/parse_sql_expr.rs +++ b/datafusion-examples/examples/parse_sql_expr.rs @@ -121,11 +121,11 @@ async fn query_parquet_demo() -> Result<()> { assert_batches_eq!( &[ - "+------------+----------------------+", - "| double_col | sum(?table?.int_col) |", - "+------------+----------------------+", - "| 10.1 | 4 |", - "+------------+----------------------+", + "+------------+-------------+", + "| double_col | sum_int_col |", + "+------------+-------------+", + "| 10.1 | 4 |", + "+------------+-------------+", ], &result ); diff --git a/datafusion-examples/examples/regexp.rs b/datafusion-examples/examples/regexp.rs index 02e74bae22af..5419efd2faea 100644 --- a/datafusion-examples/examples/regexp.rs +++ b/datafusion-examples/examples/regexp.rs @@ -148,7 +148,7 @@ async fn main() -> Result<()> { // invalid flags will result in an error let result = ctx - .sql(r"select regexp_like('\b4(?!000)\d\d\d\b', 4010, 'g')") + .sql(r"select regexp_like('\b4(?!000)\d\d\d\b', '4010', 'g')") .await? .collect() .await; diff --git a/datafusion/catalog/src/table.rs b/datafusion/catalog/src/table.rs index b6752191d9a7..3c8960495588 100644 --- a/datafusion/catalog/src/table.rs +++ b/datafusion/catalog/src/table.rs @@ -33,7 +33,19 @@ use datafusion_expr::{ }; use datafusion_physical_plan::ExecutionPlan; -/// Source table +/// A named table which can be queried. +/// +/// Please see [`CatalogProvider`] for details of implementing a custom catalog. +/// +/// [`TableProvider`] represents a source of data which can provide data as +/// Apache Arrow `RecordBatch`es. Implementations of this trait provide +/// important information for planning such as: +/// +/// 1. [`Self::schema`]: The schema (columns and their types) of the table +/// 2. [`Self::supports_filters_pushdown`]: Should filters be pushed into this scan +/// 2. [`Self::scan`]: An [`ExecutionPlan`] that can read data +/// +/// [`CatalogProvider`]: super::CatalogProvider #[async_trait] pub trait TableProvider: Debug + Sync + Send { /// Returns the table provider as [`Any`](std::any::Any) so that it can be diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index d76848dfe95e..82909404e455 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -43,7 +43,7 @@ force_hash_collisions = [] [dependencies] ahash = { workspace = true } -apache-avro = { version = "0.16", default-features = false, features = [ +apache-avro = { version = "0.17", default-features = false, features = [ "bzip", "snappy", "xz", @@ -53,7 +53,6 @@ arrow = { workspace = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } arrow-schema = { workspace = true } -chrono = { workspace = true } half = { workspace = true } hashbrown = { workspace = true } indexmap = { workspace = true } @@ -70,4 +69,5 @@ tokio = { workspace = true } web-time = "1.1.0" [dev-dependencies] +chrono = { workspace = true } rand = { workspace = true } diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 4706afc897c2..1995ab4ca075 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -77,7 +77,7 @@ unicode_expressions = [ ] [dependencies] -apache-avro = { version = "0.16", optional = true } +apache-avro = { version = "0.17", optional = true } arrow = { workspace = true } arrow-array = { workspace = true } arrow-ipc = { workspace = true } @@ -120,7 +120,6 @@ num-traits = { version = "0.2", optional = true } object_store = { workspace = true } parking_lot = { workspace = true } parquet = { workspace = true, optional = true, default-features = true } -paste = "1.0.15" rand = { workspace = true } sqlparser = { workspace = true } tempfile = { workspace = true } @@ -140,17 +139,13 @@ datafusion-functions-window-common = { workspace = true } doc-comment = { workspace = true } env_logger = { workspace = true } paste = "^1.0" -postgres-protocol = "0.6.4" -postgres-types = { version = "0.2.4", features = ["derive", "with-chrono-0_4"] } rand = { workspace = true, features = ["small_rng"] } rand_distr = "0.4.3" regex = { workspace = true } rstest = { workspace = true } -rust_decimal = { version = "1.27.0", features = ["tokio-pg"] } serde_json = { workspace = true } test-utils = { path = "../../test-utils" } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot", "fs"] } -tokio-postgres = "0.7.7" [target.'cfg(not(target_os = "windows"))'.dev-dependencies] nix = { version = "0.29.0", features = ["fs"] } diff --git a/datafusion/core/src/bin/print_functions_docs.rs b/datafusion/core/src/bin/print_functions_docs.rs index b58f6e47d333..8b453d5e9698 100644 --- a/datafusion/core/src/bin/print_functions_docs.rs +++ b/datafusion/core/src/bin/print_functions_docs.rs @@ -88,6 +88,7 @@ fn print_window_docs() -> Result { // the migration of UDF documentation generation from code based // to attribute based // To be removed +#[allow(dead_code)] fn save_doc_code_text(documentation: &Documentation, name: &str) { let attr_text = documentation.to_doc_attribute(); @@ -182,7 +183,7 @@ fn print_docs( }; // Temporary for doc gen migration, see `save_doc_code_text` comments - save_doc_code_text(documentation, &name); + // save_doc_code_text(documentation, &name); // first, the name, description and syntax example let _ = write!( diff --git a/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs b/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs index f3358bce7623..8f0e3792ffec 100644 --- a/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs +++ b/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs @@ -138,7 +138,11 @@ impl AvroArrowArrayReader<'_, R> { } AvroSchema::Array(schema) => { let sub_parent_field_name = format!("{}.element", parent_field_name); - Self::child_schema_lookup(&sub_parent_field_name, schema, schema_lookup)?; + Self::child_schema_lookup( + &sub_parent_field_name, + &schema.items, + schema_lookup, + )?; } _ => (), } diff --git a/datafusion/core/src/datasource/avro_to_arrow/schema.rs b/datafusion/core/src/datasource/avro_to_arrow/schema.rs index 039a6aacc07e..991f648e58bd 100644 --- a/datafusion/core/src/datasource/avro_to_arrow/schema.rs +++ b/datafusion/core/src/datasource/avro_to_arrow/schema.rs @@ -73,11 +73,15 @@ fn schema_to_field_with_props( AvroSchema::Bytes => DataType::Binary, AvroSchema::String => DataType::Utf8, AvroSchema::Array(item_schema) => DataType::List(Arc::new( - schema_to_field_with_props(item_schema, Some("element"), false, None)?, + schema_to_field_with_props(&item_schema.items, Some("element"), false, None)?, )), AvroSchema::Map(value_schema) => { - let value_field = - schema_to_field_with_props(value_schema, Some("value"), false, None)?; + let value_field = schema_to_field_with_props( + &value_schema.types, + Some("value"), + false, + None, + )?; DataType::Dictionary( Box::new(DataType::Utf8), Box::new(value_field.data_type().clone()), @@ -144,14 +148,17 @@ fn schema_to_field_with_props( AvroSchema::Decimal(DecimalSchema { precision, scale, .. }) => DataType::Decimal128(*precision as u8, *scale as i8), + AvroSchema::BigDecimal => DataType::LargeBinary, AvroSchema::Uuid => DataType::FixedSizeBinary(16), AvroSchema::Date => DataType::Date32, AvroSchema::TimeMillis => DataType::Time32(TimeUnit::Millisecond), AvroSchema::TimeMicros => DataType::Time64(TimeUnit::Microsecond), AvroSchema::TimestampMillis => DataType::Timestamp(TimeUnit::Millisecond, None), AvroSchema::TimestampMicros => DataType::Timestamp(TimeUnit::Microsecond, None), + AvroSchema::TimestampNanos => DataType::Timestamp(TimeUnit::Nanosecond, None), AvroSchema::LocalTimestampMillis => todo!(), AvroSchema::LocalTimestampMicros => todo!(), + AvroSchema::LocalTimestampNanos => todo!(), AvroSchema::Duration => DataType::Duration(TimeUnit::Millisecond), }; @@ -371,6 +378,7 @@ mod test { aliases: Some(vec![alias("foofixed"), alias("barfixed")]), size: 1, doc: None, + default: None, attributes: Default::default(), }); let props = external_props(&fixed_schema); diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 48a2474c40ae..a00e74cf4fcd 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -614,11 +614,13 @@ impl FileOpener for CsvOpener { } let store = Arc::clone(&self.config.object_store); + let terminator = self.config.terminator; Ok(Box::pin(async move { // Current partition contains bytes [start_byte, end_byte) (might contain incomplete lines at boundaries) - let calculated_range = calculate_range(&file_meta, &store).await?; + let calculated_range = + calculate_range(&file_meta, &store, terminator).await?; let range = match calculated_range { RangeCalculation::Range(None) => None, diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index 5d277a55678a..5e68193a5941 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -275,7 +275,7 @@ impl FileOpener for JsonOpener { let file_compression_type = self.file_compression_type.to_owned(); Ok(Box::pin(async move { - let calculated_range = calculate_range(&file_meta, &store).await?; + let calculated_range = calculate_range(&file_meta, &store, None).await?; let range = match calculated_range { RangeCalculation::Range(None) => None, diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index 449b7bb43519..3146d124d9f1 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -426,9 +426,11 @@ enum RangeCalculation { async fn calculate_range( file_meta: &FileMeta, store: &Arc, + terminator: Option, ) -> Result { let location = file_meta.location(); let file_size = file_meta.object_meta.size; + let newline = terminator.unwrap_or(b'\n'); match file_meta.range { None => Ok(RangeCalculation::Range(None)), @@ -436,13 +438,13 @@ async fn calculate_range( let (start, end) = (start as usize, end as usize); let start_delta = if start != 0 { - find_first_newline(store, location, start - 1, file_size).await? + find_first_newline(store, location, start - 1, file_size, newline).await? } else { 0 }; let end_delta = if end != file_size { - find_first_newline(store, location, end - 1, file_size).await? + find_first_newline(store, location, end - 1, file_size, newline).await? } else { 0 }; @@ -462,7 +464,7 @@ async fn calculate_range( /// within an object, such as a file, in an object store. /// /// This function scans the contents of the object starting from the specified `start` position -/// up to the `end` position, looking for the first occurrence of a newline (`'\n'`) character. +/// up to the `end` position, looking for the first occurrence of a newline character. /// It returns the position of the first newline relative to the start of the range. /// /// Returns a `Result` wrapping a `usize` that represents the position of the first newline character found within the specified range. If no newline is found, it returns the length of the scanned data, effectively indicating the end of the range. @@ -474,6 +476,7 @@ async fn find_first_newline( location: &Path, start: usize, end: usize, + newline: u8, ) -> Result { let options = GetOptions { range: Some(GetRange::Bounded(start..end)), @@ -486,7 +489,7 @@ async fn find_first_newline( let mut index = 0; while let Some(chunk) = result_stream.next().await.transpose()? { - if let Some(position) = chunk.iter().position(|&byte| byte == b'\n') { + if let Some(position) = chunk.iter().position(|&byte| byte == newline) { return Ok(index + position); } diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 4ccad5ffd323..cef5d4c1ee2a 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -68,7 +68,7 @@ use datafusion_sql::planner::{ContextProvider, ParserOptions, PlannerContext, Sq use itertools::Itertools; use log::{debug, info}; use object_store::ObjectStore; -use sqlparser::ast::Expr as SQLExpr; +use sqlparser::ast::{Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias}; use sqlparser::dialect::dialect_from_str; use std::any::Any; use std::collections::hash_map::Entry; @@ -500,11 +500,22 @@ impl SessionState { sql: &str, dialect: &str, ) -> datafusion_common::Result { + self.sql_to_expr_with_alias(sql, dialect).map(|x| x.expr) + } + + /// parse a sql string into a sqlparser-rs AST [`SQLExprWithAlias`]. + /// + /// See [`Self::create_logical_expr`] for parsing sql to [`Expr`]. + pub fn sql_to_expr_with_alias( + &self, + sql: &str, + dialect: &str, + ) -> datafusion_common::Result { let dialect = dialect_from_str(dialect).ok_or_else(|| { plan_datafusion_err!( "Unsupported SQL dialect: {dialect}. Available dialects: \ - Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \ - MsSQL, ClickHouse, BigQuery, Ansi." + Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \ + MsSQL, ClickHouse, BigQuery, Ansi." ) })?; @@ -603,7 +614,7 @@ impl SessionState { ) -> datafusion_common::Result { let dialect = self.config.options().sql_parser.dialect.as_str(); - let sql_expr = self.sql_to_expr(sql, dialect)?; + let sql_expr = self.sql_to_expr_with_alias(sql, dialect)?; let provider = SessionContextProvider { state: self, @@ -611,7 +622,7 @@ impl SessionState { }; let query = SqlToRel::new_with_options(&provider, self.get_parser_options()); - query.sql_to_expr(sql_expr, df_schema, &mut PlannerContext::new()) + query.sql_to_expr_with_alias(sql_expr, df_schema, &mut PlannerContext::new()) } /// Returns the [`Analyzer`] for this session diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index d48c7118cb8e..6c761f674b3b 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -28,6 +28,7 @@ use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::tree_node::PlanContext; use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; +use arrow_schema::SchemaRef; use datafusion_common::tree_node::{ ConcreteTreeNode, Transformed, TreeNode, TreeNodeRecursion, @@ -38,6 +39,8 @@ use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::PhysicalSortRequirement; use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_plan::joins::utils::ColumnIndex; +use datafusion_physical_plan::joins::HashJoinExec; /// This is a "data class" we use within the [`EnforceSorting`] rule to push /// down [`SortExec`] in the plan. In some cases, we can reduce the total @@ -294,6 +297,8 @@ fn pushdown_requirement_to_children( .then(|| LexRequirement::new(parent_required.to_vec())); Ok(Some(vec![req])) } + } else if let Some(hash_join) = plan.as_any().downcast_ref::() { + handle_hash_join(hash_join, parent_required) } else { handle_custom_pushdown(plan, parent_required, maintains_input_order) } @@ -606,6 +611,102 @@ fn handle_custom_pushdown( } } +// For hash join we only maintain the input order for the right child +// for join type: Inner, Right, RightSemi, RightAnti +fn handle_hash_join( + plan: &HashJoinExec, + parent_required: &LexRequirement, +) -> Result>>> { + // If there's no requirement from the parent or the plan has no children + // or the join type is not Inner, Right, RightSemi, RightAnti, return early + if parent_required.is_empty() || !plan.maintains_input_order()[1] { + return Ok(None); + } + + // Collect all unique column indices used in the parent-required sorting expression + let all_indices: HashSet = parent_required + .iter() + .flat_map(|order| { + collect_columns(&order.expr) + .into_iter() + .map(|col| col.index()) + .collect::>() + }) + .collect(); + + let column_indices = build_join_column_index(plan); + let projected_indices: Vec<_> = if let Some(projection) = &plan.projection { + projection.iter().map(|&i| &column_indices[i]).collect() + } else { + column_indices.iter().collect() + }; + let len_of_left_fields = projected_indices + .iter() + .filter(|ci| ci.side == JoinSide::Left) + .count(); + + let all_from_right_child = all_indices.iter().all(|i| *i >= len_of_left_fields); + + // If all columns are from the right child, update the parent requirements + if all_from_right_child { + // Transform the parent-required expression for the child schema by adjusting columns + let updated_parent_req = parent_required + .iter() + .map(|req| { + let child_schema = plan.children()[1].schema(); + let updated_columns = Arc::clone(&req.expr) + .transform_up(|expr| { + if let Some(col) = expr.as_any().downcast_ref::() { + let index = projected_indices[col.index()].index; + Ok(Transformed::yes(Arc::new(Column::new( + child_schema.field(index).name(), + index, + )))) + } else { + Ok(Transformed::no(expr)) + } + })? + .data; + Ok(PhysicalSortRequirement::new(updated_columns, req.options)) + }) + .collect::>>()?; + + // Populating with the updated requirements for children that maintain order + Ok(Some(vec![ + None, + Some(LexRequirement::new(updated_parent_req)), + ])) + } else { + Ok(None) + } +} + +// this function is used to build the column index for the hash join +// push down sort requirements to the right child +fn build_join_column_index(plan: &HashJoinExec) -> Vec { + let map_fields = |schema: SchemaRef, side: JoinSide| { + schema + .fields() + .iter() + .enumerate() + .map(|(index, _)| ColumnIndex { index, side }) + .collect::>() + }; + + match plan.join_type() { + JoinType::Inner | JoinType::Right => { + map_fields(plan.left().schema(), JoinSide::Left) + .into_iter() + .chain(map_fields(plan.right().schema(), JoinSide::Right)) + .collect::>() + } + JoinType::RightSemi | JoinType::RightAnti => { + map_fields(plan.right().schema(), JoinSide::Right) + } + _ => unreachable!("unexpected join type: {}", plan.join_type()), + } +} + /// Define the Requirements Compatibility #[derive(Debug)] enum RequirementsCompatibility { diff --git a/datafusion/execution/Cargo.toml b/datafusion/execution/Cargo.toml index ac1eb729b6ff..bb86868a8214 100644 --- a/datafusion/execution/Cargo.toml +++ b/datafusion/execution/Cargo.toml @@ -37,7 +37,6 @@ path = "src/lib.rs" [dependencies] arrow = { workspace = true } -chrono = { workspace = true } dashmap = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } @@ -48,3 +47,6 @@ parking_lot = { workspace = true } rand = { workspace = true } tempfile = { workspace = true } url = { workspace = true } + +[dev-dependencies] +chrono = { workspace = true } diff --git a/datafusion/expr-common/Cargo.toml b/datafusion/expr-common/Cargo.toml index 109d8e0b89a6..1ccc6fc17293 100644 --- a/datafusion/expr-common/Cargo.toml +++ b/datafusion/expr-common/Cargo.toml @@ -40,4 +40,6 @@ path = "src/lib.rs" arrow = { workspace = true } datafusion-common = { workspace = true } itertools = { workspace = true } + +[dev-dependencies] paste = "^1.0" diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 438662e0642b..2f41292f680f 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -38,10 +38,7 @@ path = "src/lib.rs" [features] [dependencies] -ahash = { workspace = true } arrow = { workspace = true } -arrow-array = { workspace = true } -arrow-buffer = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true } datafusion-doc = { workspace = true } @@ -54,8 +51,6 @@ paste = "^1.0" recursive = { workspace = true } serde_json = { workspace = true } sqlparser = { workspace = true } -strum = { version = "0.26.1", features = ["derive"] } -strum_macros = "0.26.0" [dev-dependencies] ctor = { workspace = true } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 681eb3c0afd5..a44dd24039dc 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -416,9 +416,10 @@ pub struct SimpleScalarUDF { impl Debug for SimpleScalarUDF { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("ScalarUDF") + f.debug_struct("SimpleScalarUDF") .field("name", &self.name) .field("signature", &self.signature) + .field("return_type", &self.return_type) .field("fun", &"") .finish() } @@ -524,9 +525,10 @@ pub struct SimpleAggregateUDF { impl Debug for SimpleAggregateUDF { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("AggregateUDF") + f.debug_struct("SimpleAggregateUDF") .field("name", &self.name) .field("signature", &self.signature) + .field("return_type", &self.return_type) .field("fun", &"") .finish() } diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index bf9c9f407eff..809c78f30eff 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -326,13 +326,15 @@ where } } +/// Arguments passed to [`ScalarUDFImpl::invoke_with_args`] when invoking a +/// scalar function. pub struct ScalarFunctionArgs<'a> { - // The evaluated arguments to the function - pub args: &'a [ColumnarValue], - // The number of rows in record batch being evaluated + /// The evaluated arguments to the function + pub args: Vec, + /// The number of rows in record batch being evaluated pub number_rows: usize, - // The return type of the scalar function returned (from `return_type` or `return_type_from_exprs`) - // when creating the physical expression from the logical expression + /// The return type of the scalar function returned (from `return_type` or `return_type_from_exprs`) + /// when creating the physical expression from the logical expression pub return_type: &'a DataType, } @@ -539,7 +541,7 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments /// to arrays, which will likely be simpler code, but be slower. fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - self.invoke_batch(args.args, args.number_rows) + self.invoke_batch(&args.args, args.number_rows) } /// Invoke the function without `args`, instead the number of rows are provided, diff --git a/datafusion/ffi/Cargo.toml b/datafusion/ffi/Cargo.toml index b74bb230a0f3..fb4701cd8988 100644 --- a/datafusion/ffi/Cargo.toml +++ b/datafusion/ffi/Cargo.toml @@ -42,10 +42,10 @@ async-ffi = { version = "0.5.0", features = ["abi_stable"] } async-trait = { workspace = true } datafusion = { workspace = true, default-features = false } datafusion-proto = { workspace = true } -doc-comment = { workspace = true } futures = { workspace = true } log = { workspace = true } prost = { workspace = true } [dev-dependencies] +doc-comment = { workspace = true } tokio = { workspace = true } diff --git a/datafusion/functions-aggregate-common/Cargo.toml b/datafusion/functions-aggregate-common/Cargo.toml index 664746808fb4..cf6eb99e60c6 100644 --- a/datafusion/functions-aggregate-common/Cargo.toml +++ b/datafusion/functions-aggregate-common/Cargo.toml @@ -42,10 +42,10 @@ arrow = { workspace = true } datafusion-common = { workspace = true } datafusion-expr-common = { workspace = true } datafusion-physical-expr-common = { workspace = true } -rand = { workspace = true } [dev-dependencies] criterion = "0.5" +rand = { workspace = true } [[bench]] harness = false diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs index ac4d0e75535e..e629e99e1657 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs @@ -371,6 +371,75 @@ pub fn accumulate( } } +/// Accumulates with multiple accumulate(value) columns. (e.g. `corr(c1, c2)`) +/// +/// This method assumes that for any input record index, if any of the value column +/// is null, or it's filtered out by `opt_filter`, then the record would be ignored. +/// (won't be accumulated by `value_fn`) +/// +/// # Arguments +/// +/// * `group_indices` - To which groups do the rows in `value_columns` belong +/// * `value_columns` - The input arrays to accumulate +/// * `opt_filter` - Optional filter array. If present, only rows where filter is `Some(true)` are included +/// * `value_fn` - Callback function for each valid row, with parameters: +/// * `group_idx`: The group index for the current row +/// * `batch_idx`: The index of the current row in the input arrays +/// * `columns`: Reference to all input arrays for accessing values +pub fn accumulate_multiple( + group_indices: &[usize], + value_columns: &[&PrimitiveArray], + opt_filter: Option<&BooleanArray>, + mut value_fn: F, +) where + T: ArrowPrimitiveType + Send, + F: FnMut(usize, usize, &[&PrimitiveArray]) + Send, +{ + // Calculate `valid_indices` to accumulate, non-valid indices are ignored. + // `valid_indices` is a bit mask corresponding to the `group_indices`. An index + // is considered valid if: + // 1. All columns are non-null at this index. + // 2. Not filtered out by `opt_filter` + + // Take AND from all null buffers of `value_columns`. + let combined_nulls = value_columns + .iter() + .map(|arr| arr.logical_nulls()) + .fold(None, |acc, nulls| { + NullBuffer::union(acc.as_ref(), nulls.as_ref()) + }); + + // Take AND from previous combined nulls and `opt_filter`. + let valid_indices = match (combined_nulls, opt_filter) { + (None, None) => None, + (None, Some(filter)) => Some(filter.clone()), + (Some(nulls), None) => Some(BooleanArray::new(nulls.inner().clone(), None)), + (Some(nulls), Some(filter)) => { + let combined = nulls.inner() & filter.values(); + Some(BooleanArray::new(combined, None)) + } + }; + + for col in value_columns.iter() { + debug_assert_eq!(col.len(), group_indices.len()); + } + + match valid_indices { + None => { + for (batch_idx, &group_idx) in group_indices.iter().enumerate() { + value_fn(group_idx, batch_idx, value_columns); + } + } + Some(valid_indices) => { + for (batch_idx, &group_idx) in group_indices.iter().enumerate() { + if valid_indices.value(batch_idx) { + value_fn(group_idx, batch_idx, value_columns); + } + } + } + } +} + /// This function is called to update the accumulator state per row /// when the value is not needed (e.g. COUNT) /// @@ -528,7 +597,7 @@ fn initialize_builder( mod test { use super::*; - use arrow::array::UInt32Array; + use arrow::array::{Int32Array, UInt32Array}; use rand::{rngs::ThreadRng, Rng}; use std::collections::HashSet; @@ -940,4 +1009,107 @@ mod test { .collect() } } + + #[test] + fn test_accumulate_multiple_no_nulls_no_filter() { + let group_indices = vec![0, 1, 0, 1]; + let values1 = Int32Array::from(vec![1, 2, 3, 4]); + let values2 = Int32Array::from(vec![10, 20, 30, 40]); + let value_columns = [values1, values2]; + + let mut accumulated = vec![]; + accumulate_multiple( + &group_indices, + &value_columns.iter().collect::>(), + None, + |group_idx, batch_idx, columns| { + let values = columns.iter().map(|col| col.value(batch_idx)).collect(); + accumulated.push((group_idx, values)); + }, + ); + + let expected = vec![ + (0, vec![1, 10]), + (1, vec![2, 20]), + (0, vec![3, 30]), + (1, vec![4, 40]), + ]; + assert_eq!(accumulated, expected); + } + + #[test] + fn test_accumulate_multiple_with_nulls() { + let group_indices = vec![0, 1, 0, 1]; + let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]); + let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]); + let value_columns = [values1, values2]; + + let mut accumulated = vec![]; + accumulate_multiple( + &group_indices, + &value_columns.iter().collect::>(), + None, + |group_idx, batch_idx, columns| { + let values = columns.iter().map(|col| col.value(batch_idx)).collect(); + accumulated.push((group_idx, values)); + }, + ); + + // Only rows where both columns are non-null should be accumulated + let expected = vec![(0, vec![1, 10]), (1, vec![4, 40])]; + assert_eq!(accumulated, expected); + } + + #[test] + fn test_accumulate_multiple_with_filter() { + let group_indices = vec![0, 1, 0, 1]; + let values1 = Int32Array::from(vec![1, 2, 3, 4]); + let values2 = Int32Array::from(vec![10, 20, 30, 40]); + let value_columns = [values1, values2]; + + let filter = BooleanArray::from(vec![true, false, true, false]); + + let mut accumulated = vec![]; + accumulate_multiple( + &group_indices, + &value_columns.iter().collect::>(), + Some(&filter), + |group_idx, batch_idx, columns| { + let values = columns.iter().map(|col| col.value(batch_idx)).collect(); + accumulated.push((group_idx, values)); + }, + ); + + // Only rows where filter is true should be accumulated + let expected = vec![(0, vec![1, 10]), (0, vec![3, 30])]; + assert_eq!(accumulated, expected); + } + + #[test] + fn test_accumulate_multiple_with_nulls_and_filter() { + let group_indices = vec![0, 1, 0, 1]; + let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]); + let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]); + let value_columns = [values1, values2]; + + let filter = BooleanArray::from(vec![true, true, true, false]); + + let mut accumulated = vec![]; + accumulate_multiple( + &group_indices, + &value_columns.iter().collect::>(), + Some(&filter), + |group_idx, batch_idx, columns| { + let values = columns.iter().map(|col| col.value(batch_idx)).collect(); + accumulated.push((group_idx, values)); + }, + ); + + // Only rows where both: + // 1. Filter is true + // 2. Both columns are non-null + // should be accumulated + let expected = [(0, vec![1, 10])]; + assert_eq!(accumulated, expected); + } } diff --git a/datafusion/functions-aggregate/src/approx_distinct.rs b/datafusion/functions-aggregate/src/approx_distinct.rs index 74691ba740fd..1d378fff176f 100644 --- a/datafusion/functions-aggregate/src/approx_distinct.rs +++ b/datafusion/functions-aggregate/src/approx_distinct.rs @@ -31,7 +31,6 @@ use datafusion_common::ScalarValue; use datafusion_common::{ downcast_value, internal_err, not_impl_err, DataFusionError, Result, }; -use datafusion_doc::DocSection; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ @@ -42,7 +41,6 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; use std::hash::Hash; use std::marker::PhantomData; -use std::sync::OnceLock; make_udaf_expr_and_func!( ApproxDistinct, diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index d4441da61292..5d174a752296 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -19,13 +19,11 @@ use std::any::Any; use std::fmt::Debug; -use std::sync::OnceLock; use arrow::{datatypes::DataType, datatypes::Field}; use arrow_schema::DataType::{Float64, UInt64}; use datafusion_common::{not_impl_err, plan_err, Result}; -use datafusion_doc::DocSection; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 13407fecf220..61424e8f2445 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -18,7 +18,7 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; use std::mem::size_of_val; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::{Array, RecordBatch}; use arrow::compute::{filter, is_not_null}; @@ -35,7 +35,6 @@ use datafusion_common::{ downcast_value, internal_err, not_impl_datafusion_err, not_impl_err, plan_err, DataFusionError, Result, ScalarValue, }; -use datafusion_doc::DocSection; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; use datafusion_expr::utils::format_state_name; diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index 485874aeb284..10b9b06f1f94 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -18,7 +18,7 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; use std::mem::size_of_val; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::{ array::ArrayRef, @@ -27,7 +27,6 @@ use arrow::{ use datafusion_common::ScalarValue; use datafusion_common::{not_impl_err, plan_err, Result}; -use datafusion_doc::DocSection; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::Volatility::Immutable; diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index 98530a9fc236..b75de83f6ace 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -25,7 +25,6 @@ use datafusion_common::cast::as_list_array; use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder}; use datafusion_common::{exec_err, ScalarValue}; use datafusion_common::{internal_err, Result}; -use datafusion_doc::DocSection; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{Accumulator, Signature, Volatility}; @@ -36,7 +35,7 @@ use datafusion_macros::user_doc; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use std::collections::{HashSet, VecDeque}; use std::mem::{size_of, size_of_val}; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; make_udaf_expr_and_func!( ArrayAgg, diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 65ca441517a0..18874f831e9d 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -42,14 +42,13 @@ use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls: filtered_null_mask, set_nulls, }; -use datafusion_doc::DocSection; use datafusion_functions_aggregate_common::utils::DecimalAverager; use datafusion_macros::user_doc; use log::debug; use std::any::Any; use std::fmt::Debug; use std::mem::{size_of, size_of_val}; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; make_udaf_expr_and_func!( Avg, diff --git a/datafusion/functions-aggregate/src/bool_and_or.rs b/datafusion/functions-aggregate/src/bool_and_or.rs index 1b5b20f43b3e..29dfc68e0576 100644 --- a/datafusion/functions-aggregate/src/bool_and_or.rs +++ b/datafusion/functions-aggregate/src/bool_and_or.rs @@ -19,7 +19,6 @@ use std::any::Any; use std::mem::size_of_val; -use std::sync::OnceLock; use arrow::array::ArrayRef; use arrow::array::BooleanArray; @@ -38,7 +37,6 @@ use datafusion_expr::{ Signature, Volatility, }; -use datafusion_doc::DocSection; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::bool_op::BooleanGroupsAccumulator; use datafusion_macros::user_doc; diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs index b40555bf6c7f..72c1f6dbaed2 100644 --- a/datafusion/functions-aggregate/src/correlation.rs +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -20,18 +20,25 @@ use std::any::Any; use std::fmt::Debug; use std::mem::size_of_val; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; -use arrow::compute::{and, filter, is_not_null}; +use arrow::array::{ + downcast_array, Array, AsArray, BooleanArray, BooleanBufferBuilder, Float64Array, + UInt64Array, +}; +use arrow::compute::{and, filter, is_not_null, kernels::cast}; +use arrow::datatypes::{Float64Type, UInt64Type}; use arrow::{ array::ArrayRef, datatypes::{DataType, Field}, }; +use datafusion_expr::{EmitTo, GroupsAccumulator}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_multiple; +use log::debug; use crate::covariance::CovarianceAccumulator; use crate::stddev::StddevAccumulator; use datafusion_common::{plan_err, Result, ScalarValue}; -use datafusion_doc::DocSection; use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, type_coercion::aggregates::NUMERICS, @@ -129,6 +136,18 @@ impl AggregateUDFImpl for Correlation { fn documentation(&self) -> Option<&Documentation> { self.doc() } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + debug!("GroupsAccumulator is created for aggregate function `corr(c1, c2)`"); + Ok(Box::new(CorrelationGroupsAccumulator::new())) + } } /// An accumulator to compute correlation @@ -253,3 +272,308 @@ impl Accumulator for CorrelationAccumulator { Ok(()) } } + +#[derive(Default)] +pub struct CorrelationGroupsAccumulator { + // Number of elements for each group + // This is also used to track nulls: if a group has 0 valid values accumulated, + // final aggregation result will be null. + count: Vec, + // Sum of x values for each group + sum_x: Vec, + // Sum of y + sum_y: Vec, + // Sum of x*y + sum_xy: Vec, + // Sum of x^2 + sum_xx: Vec, + // Sum of y^2 + sum_yy: Vec, +} + +impl CorrelationGroupsAccumulator { + pub fn new() -> Self { + Default::default() + } +} + +/// Specialized version of `accumulate_multiple` for correlation's merge_batch +/// +/// Note: Arrays in `state_arrays` should not have null values, because they are all +/// intermediate states created within the accumulator, instead of inputs from +/// outside. +fn accumulate_correlation_states( + group_indices: &[usize], + state_arrays: ( + &UInt64Array, // count + &Float64Array, // sum_x + &Float64Array, // sum_y + &Float64Array, // sum_xy + &Float64Array, // sum_xx + &Float64Array, // sum_yy + ), + mut value_fn: impl FnMut(usize, u64, &[f64]), +) { + let (counts, sum_x, sum_y, sum_xy, sum_xx, sum_yy) = state_arrays; + + assert_eq!(counts.null_count(), 0); + assert_eq!(sum_x.null_count(), 0); + assert_eq!(sum_y.null_count(), 0); + assert_eq!(sum_xy.null_count(), 0); + assert_eq!(sum_xx.null_count(), 0); + assert_eq!(sum_yy.null_count(), 0); + + let counts_values = counts.values().as_ref(); + let sum_x_values = sum_x.values().as_ref(); + let sum_y_values = sum_y.values().as_ref(); + let sum_xy_values = sum_xy.values().as_ref(); + let sum_xx_values = sum_xx.values().as_ref(); + let sum_yy_values = sum_yy.values().as_ref(); + + for (idx, &group_idx) in group_indices.iter().enumerate() { + let row = [ + sum_x_values[idx], + sum_y_values[idx], + sum_xy_values[idx], + sum_xx_values[idx], + sum_yy_values[idx], + ]; + value_fn(group_idx, counts_values[idx], &row); + } +} + +/// GroupsAccumulator implementation for `corr(x, y)` that computes the Pearson correlation coefficient +/// between two numeric columns. +/// +/// Online algorithm for correlation: +/// +/// r = (n * sum_xy - sum_x * sum_y) / sqrt((n * sum_xx - sum_x^2) * (n * sum_yy - sum_y^2)) +/// where: +/// n = number of observations +/// sum_x = sum of x values +/// sum_y = sum of y values +/// sum_xy = sum of (x * y) +/// sum_xx = sum of x^2 values +/// sum_yy = sum of y^2 values +/// +/// Reference: +impl GroupsAccumulator for CorrelationGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + self.count.resize(total_num_groups, 0); + self.sum_x.resize(total_num_groups, 0.0); + self.sum_y.resize(total_num_groups, 0.0); + self.sum_xy.resize(total_num_groups, 0.0); + self.sum_xx.resize(total_num_groups, 0.0); + self.sum_yy.resize(total_num_groups, 0.0); + + let array_x = &cast(&values[0], &DataType::Float64)?; + let array_x = downcast_array::(array_x); + let array_y = &cast(&values[1], &DataType::Float64)?; + let array_y = downcast_array::(array_y); + + accumulate_multiple( + group_indices, + &[&array_x, &array_y], + opt_filter, + |group_index, batch_index, columns| { + let x = columns[0].value(batch_index); + let y = columns[1].value(batch_index); + self.count[group_index] += 1; + self.sum_x[group_index] += x; + self.sum_y[group_index] += y; + self.sum_xy[group_index] += x * y; + self.sum_xx[group_index] += x * x; + self.sum_yy[group_index] += y * y; + }, + ); + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // Resize vectors to accommodate total number of groups + self.count.resize(total_num_groups, 0); + self.sum_x.resize(total_num_groups, 0.0); + self.sum_y.resize(total_num_groups, 0.0); + self.sum_xy.resize(total_num_groups, 0.0); + self.sum_xx.resize(total_num_groups, 0.0); + self.sum_yy.resize(total_num_groups, 0.0); + + // Extract arrays from input values + let partial_counts = values[0].as_primitive::(); + let partial_sum_x = values[1].as_primitive::(); + let partial_sum_y = values[2].as_primitive::(); + let partial_sum_xy = values[3].as_primitive::(); + let partial_sum_xx = values[4].as_primitive::(); + let partial_sum_yy = values[5].as_primitive::(); + + assert!(opt_filter.is_none(), "aggregate filter should be applied in partial stage, there should be no filter in final stage"); + + accumulate_correlation_states( + group_indices, + ( + partial_counts, + partial_sum_x, + partial_sum_y, + partial_sum_xy, + partial_sum_xx, + partial_sum_yy, + ), + |group_index, count, values| { + self.count[group_index] += count; + self.sum_x[group_index] += values[0]; + self.sum_y[group_index] += values[1]; + self.sum_xy[group_index] += values[2]; + self.sum_xx[group_index] += values[3]; + self.sum_yy[group_index] += values[4]; + }, + ); + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let n = match emit_to { + EmitTo::All => self.count.len(), + EmitTo::First(n) => n, + }; + + let mut values = Vec::with_capacity(n); + let mut nulls = BooleanBufferBuilder::new(n); + + // Notes for `Null` handling: + // - If the `count` state of a group is 0, no valid records are accumulated + // for this group, so the aggregation result is `Null`. + // - Correlation can't be calculated when a group only has 1 record, or when + // the `denominator` state is 0. In these cases, the final aggregation + // result should be `Null` (according to PostgreSQL's behavior). + // + // TODO: Old datafusion implementation returns 0.0 for these invalid cases. + // Update this to match PostgreSQL's behavior. + for i in 0..n { + if self.count[i] < 2 { + // TODO: Evaluate as `Null` (see notes above) + values.push(0.0); + nulls.append(false); + continue; + } + + let count = self.count[i]; + let sum_x = self.sum_x[i]; + let sum_y = self.sum_y[i]; + let sum_xy = self.sum_xy[i]; + let sum_xx = self.sum_xx[i]; + let sum_yy = self.sum_yy[i]; + + let mean_x = sum_x / count as f64; + let mean_y = sum_y / count as f64; + + let numerator = sum_xy - sum_x * mean_y; + let denominator = + ((sum_xx - sum_x * mean_x) * (sum_yy - sum_y * mean_y)).sqrt(); + + if denominator == 0.0 { + // TODO: Evaluate as `Null` (see notes above) + values.push(0.0); + nulls.append(false); + } else { + values.push(numerator / denominator); + nulls.append(true); + } + } + + Ok(Arc::new(Float64Array::new( + values.into(), + Some(nulls.finish().into()), + ))) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + let n = match emit_to { + EmitTo::All => self.count.len(), + EmitTo::First(n) => n, + }; + + Ok(vec![ + Arc::new(UInt64Array::from(self.count[0..n].to_vec())), + Arc::new(Float64Array::from(self.sum_x[0..n].to_vec())), + Arc::new(Float64Array::from(self.sum_y[0..n].to_vec())), + Arc::new(Float64Array::from(self.sum_xy[0..n].to_vec())), + Arc::new(Float64Array::from(self.sum_xx[0..n].to_vec())), + Arc::new(Float64Array::from(self.sum_yy[0..n].to_vec())), + ]) + } + + fn size(&self) -> usize { + size_of_val(&self.count) + + size_of_val(&self.sum_x) + + size_of_val(&self.sum_y) + + size_of_val(&self.sum_xy) + + size_of_val(&self.sum_xx) + + size_of_val(&self.sum_yy) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Float64Array, UInt64Array}; + + #[test] + fn test_accumulate_correlation_states() { + // Test data + let group_indices = vec![0, 1, 0, 1]; + let counts = UInt64Array::from(vec![1, 2, 3, 4]); + let sum_x = Float64Array::from(vec![10.0, 20.0, 30.0, 40.0]); + let sum_y = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]); + let sum_xy = Float64Array::from(vec![10.0, 40.0, 90.0, 160.0]); + let sum_xx = Float64Array::from(vec![100.0, 400.0, 900.0, 1600.0]); + let sum_yy = Float64Array::from(vec![1.0, 4.0, 9.0, 16.0]); + + let mut accumulated = vec![]; + accumulate_correlation_states( + &group_indices, + (&counts, &sum_x, &sum_y, &sum_xy, &sum_xx, &sum_yy), + |group_idx, count, values| { + accumulated.push((group_idx, count, values.to_vec())); + }, + ); + + let expected = vec![ + (0, 1, vec![10.0, 1.0, 10.0, 100.0, 1.0]), + (1, 2, vec![20.0, 2.0, 40.0, 400.0, 4.0]), + (0, 3, vec![30.0, 3.0, 90.0, 900.0, 9.0]), + (1, 4, vec![40.0, 4.0, 160.0, 1600.0, 16.0]), + ]; + assert_eq!(accumulated, expected); + + // Test that function panics with null values + let counts = UInt64Array::from(vec![Some(1), None, Some(3), Some(4)]); + let sum_x = Float64Array::from(vec![10.0, 20.0, 30.0, 40.0]); + let sum_y = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]); + let sum_xy = Float64Array::from(vec![10.0, 40.0, 90.0, 160.0]); + let sum_xx = Float64Array::from(vec![100.0, 400.0, 900.0, 1600.0]); + let sum_yy = Float64Array::from(vec![1.0, 4.0, 9.0, 16.0]); + + let result = std::panic::catch_unwind(|| { + accumulate_correlation_states( + &group_indices, + (&counts, &sum_x, &sum_y, &sum_xy, &sum_xx, &sum_yy), + |_, _, _| {}, + ) + }); + assert!(result.is_err()); + } +} diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 550df8cb4f7d..b4164c211c35 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -17,7 +17,6 @@ use ahash::RandomState; use datafusion_common::stats::Precision; -use datafusion_doc::DocSection; use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator; use datafusion_macros::user_doc; use datafusion_physical_expr::expressions; @@ -25,7 +24,7 @@ use std::collections::HashSet; use std::fmt::Debug; use std::mem::{size_of, size_of_val}; use std::ops::BitAnd; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::{ array::{ArrayRef, AsArray}, diff --git a/datafusion/functions-aggregate/src/covariance.rs b/datafusion/functions-aggregate/src/covariance.rs index adb546e4d906..ffbf2ceef052 100644 --- a/datafusion/functions-aggregate/src/covariance.rs +++ b/datafusion/functions-aggregate/src/covariance.rs @@ -19,7 +19,6 @@ use std::fmt::Debug; use std::mem::size_of_val; -use std::sync::OnceLock; use arrow::{ array::{ArrayRef, Float64Array, UInt64Array}, @@ -31,7 +30,6 @@ use datafusion_common::{ downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, Result, ScalarValue, }; -use datafusion_doc::DocSection; use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, type_coercion::aggregates::NUMERICS, diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index f3e66edbc009..9ad55d91a68b 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -20,7 +20,7 @@ use std::any::Any; use std::fmt::Debug; use std::mem::size_of_val; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::{ArrayRef, AsArray, BooleanArray}; use arrow::compute::{self, lexsort_to_indices, take_arrays, SortColumn}; @@ -29,7 +29,6 @@ use datafusion_common::utils::{compare_rows, get_row_at_idx}; use datafusion_common::{ arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, }; -use datafusion_doc::DocSection; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ diff --git a/datafusion/functions-aggregate/src/grouping.rs b/datafusion/functions-aggregate/src/grouping.rs index 36bdf68c1b0e..445774ff11e7 100644 --- a/datafusion/functions-aggregate/src/grouping.rs +++ b/datafusion/functions-aggregate/src/grouping.rs @@ -19,12 +19,10 @@ use std::any::Any; use std::fmt; -use std::sync::OnceLock; use arrow::datatypes::DataType; use arrow::datatypes::Field; use datafusion_common::{not_impl_err, Result}; -use datafusion_doc::DocSection; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::utils::format_state_name; diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index db5fbf00165f..70f192c32ae1 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -18,7 +18,7 @@ use std::cmp::Ordering; use std::fmt::{Debug, Formatter}; use std::mem::{size_of, size_of_val}; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::{downcast_integer, ArrowNumericType}; use arrow::{ @@ -34,7 +34,6 @@ use arrow::array::ArrowNativeTypeOp; use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType}; use datafusion_common::{DataFusionError, HashSet, Result, ScalarValue}; -use datafusion_doc::DocSection; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::{ function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index acbeebaad68b..a0f7634c5fa8 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -55,7 +55,6 @@ use arrow::datatypes::{ use crate::min_max::min_max_bytes::MinMaxBytesAccumulator; use datafusion_common::ScalarValue; -use datafusion_doc::DocSection; use datafusion_expr::{ function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, @@ -65,7 +64,6 @@ use datafusion_macros::user_doc; use half::f16; use std::mem::size_of_val; use std::ops::Deref; -use std::sync::OnceLock; fn get_min_max_result_type(input_types: &[DataType]) -> Result> { // make sure that the input types only has one element. diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index 15b9e97516ca..8252fd6baaa3 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -21,14 +21,13 @@ use std::any::Any; use std::collections::VecDeque; use std::mem::{size_of, size_of_val}; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::{new_empty_array, ArrayRef, AsArray, StructArray}; use arrow_schema::{DataType, Field, Fields}; use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder}; use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue}; -use datafusion_doc::DocSection; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 09a39e342cce..adf86a128cfb 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -20,14 +20,13 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; use std::mem::align_of_val; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::Float64Array; use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; use datafusion_common::{internal_err, not_impl_err, Result}; use datafusion_common::{plan_err, ScalarValue}; -use datafusion_doc::DocSection; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index 5a52bec55f15..7643b44e11d5 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -22,7 +22,6 @@ use arrow_schema::DataType; use datafusion_common::cast::as_generic_string_array; use datafusion_common::Result; use datafusion_common::{not_impl_err, ScalarValue}; -use datafusion_doc::DocSection; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility, @@ -31,7 +30,6 @@ use datafusion_macros::user_doc; use datafusion_physical_expr::expressions::Literal; use std::any::Any; use std::mem::size_of_val; -use std::sync::OnceLock; make_udaf_expr_and_func!( StringAgg, diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index ccc6ee3cf925..6c2854f6bc24 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -22,7 +22,6 @@ use datafusion_expr::utils::AggregateOrderSensitivity; use std::any::Any; use std::collections::HashSet; use std::mem::{size_of, size_of_val}; -use std::sync::OnceLock; use arrow::array::Array; use arrow::array::ArrowNativeTypeOp; @@ -35,7 +34,6 @@ use arrow::datatypes::{ }; use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; -use datafusion_doc::DocSection; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::utils::format_state_name; diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 70b10734088f..8aa7a40ce320 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -25,13 +25,11 @@ use arrow::{ datatypes::{DataType, Field}, }; use std::mem::{size_of, size_of_val}; -use std::sync::OnceLock; use std::{fmt::Debug, sync::Arc}; use datafusion_common::{ downcast_value, not_impl_err, plan_err, DataFusionError, Result, ScalarValue, }; -use datafusion_doc::DocSection; use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, utils::format_state_name, diff --git a/datafusion/functions-nested/Cargo.toml b/datafusion/functions-nested/Cargo.toml index bdfb07031b8c..5310493b4e45 100644 --- a/datafusion/functions-nested/Cargo.toml +++ b/datafusion/functions-nested/Cargo.toml @@ -54,10 +54,10 @@ datafusion-physical-expr-common = { workspace = true } itertools = { workspace = true, features = ["use_std"] } log = { workspace = true } paste = "1.0.14" -rand = "0.8.5" [dev-dependencies] criterion = { version = "0.5", features = ["async_tokio"] } +rand = "0.8.5" [[bench]] harness = false diff --git a/datafusion/functions-table/Cargo.toml b/datafusion/functions-table/Cargo.toml index f667bdde5835..f722d698f3d3 100644 --- a/datafusion/functions-table/Cargo.toml +++ b/datafusion/functions-table/Cargo.toml @@ -38,25 +38,14 @@ path = "src/lib.rs" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -ahash = { workspace = true } arrow = { workspace = true } -arrow-schema = { workspace = true } async-trait = { workspace = true } datafusion-catalog = { workspace = true } datafusion-common = { workspace = true } -datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } -datafusion-functions-aggregate-common = { workspace = true } -datafusion-physical-expr = { workspace = true } -datafusion-physical-expr-common = { workspace = true } datafusion-physical-plan = { workspace = true } -half = { workspace = true } -indexmap = { workspace = true } -log = { workspace = true } parking_lot = { workspace = true } paste = "1.0.14" [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } -criterion = "0.5" -rand = { workspace = true } diff --git a/datafusion/functions-window/Cargo.toml b/datafusion/functions-window/Cargo.toml index 262c21fcec65..fc1bc51bcc66 100644 --- a/datafusion/functions-window/Cargo.toml +++ b/datafusion/functions-window/Cargo.toml @@ -39,8 +39,10 @@ path = "src/lib.rs" [dependencies] datafusion-common = { workspace = true } +datafusion-doc = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions-window-common = { workspace = true } +datafusion-macros = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } log = { workspace = true } diff --git a/datafusion/functions-window/src/cume_dist.rs b/datafusion/functions-window/src/cume_dist.rs index 2523fd1cfe57..d777f7932b0e 100644 --- a/datafusion/functions-window/src/cume_dist.rs +++ b/datafusion/functions-window/src/cume_dist.rs @@ -21,18 +21,18 @@ use datafusion_common::arrow::array::{ArrayRef, Float64Array}; use datafusion_common::arrow::datatypes::DataType; use datafusion_common::arrow::datatypes::Field; use datafusion_common::Result; -use datafusion_expr::window_doc_sections::DOC_SECTION_RANKING; use datafusion_expr::{ Documentation, PartitionEvaluator, Signature, Volatility, WindowUDFImpl, }; use datafusion_functions_window_common::field; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use datafusion_macros::user_doc; use field::WindowUDFFieldArgs; use std::any::Any; use std::fmt::Debug; use std::iter; use std::ops::Range; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; define_udwf_and_expr!( CumeDist, @@ -41,6 +41,11 @@ define_udwf_and_expr!( ); /// CumeDist calculates the cume_dist in the window function with order by +#[user_doc( + doc_section(label = "Ranking Functions"), + description = "Relative rank of the current row: (number of rows preceding or peer with current row) / (total rows).", + syntax_example = "cume_dist()" +)] #[derive(Debug)] pub struct CumeDist { signature: Signature, @@ -86,19 +91,10 @@ impl WindowUDFImpl for CumeDist { } fn documentation(&self) -> Option<&Documentation> { - Some(get_cume_dist_doc()) + self.doc() } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_cume_dist_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder(DOC_SECTION_RANKING, "Relative rank of the current row: (number of rows preceding or peer with current row) / (total rows).", "cume_dist()") - .build() - }) -} - #[derive(Debug, Default)] pub(crate) struct CumeDistEvaluator; diff --git a/datafusion/functions-window/src/ntile.rs b/datafusion/functions-window/src/ntile.rs index 06bf32f9859f..180f7ab02c03 100644 --- a/datafusion/functions-window/src/ntile.rs +++ b/datafusion/functions-window/src/ntile.rs @@ -19,7 +19,7 @@ use std::any::Any; use std::fmt::Debug; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use crate::utils::{ get_scalar_value_from_args, get_signed_integer, get_unsigned_integer, @@ -27,12 +27,12 @@ use crate::utils::{ use datafusion_common::arrow::array::{ArrayRef, UInt64Array}; use datafusion_common::arrow::datatypes::{DataType, Field}; use datafusion_common::{exec_err, DataFusionError, Result}; -use datafusion_expr::window_doc_sections::DOC_SECTION_RANKING; use datafusion_expr::{ Documentation, Expr, PartitionEvaluator, Signature, Volatility, WindowUDFImpl, }; use datafusion_functions_window_common::field; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use datafusion_macros::user_doc; use field::WindowUDFFieldArgs; get_or_init_udwf!( @@ -45,6 +45,15 @@ pub fn ntile(arg: Expr) -> Expr { ntile_udwf().call(vec![arg]) } +#[user_doc( + doc_section(label = "Ranking Functions"), + description = "Integer ranging from 1 to the argument value, dividing the partition as equally as possible", + syntax_example = "ntile(expression)", + argument( + name = "expression", + description = "An integer describing the number groups the partition should be split into" + ) +)] #[derive(Debug)] pub struct Ntile { signature: Signature, @@ -78,16 +87,6 @@ impl Default for Ntile { } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_ntile_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder(DOC_SECTION_RANKING, "Integer ranging from 1 to the argument value, dividing the partition as equally as possible", "ntile(expression)") - .with_argument("expression","An integer describing the number groups the partition should be split into") - .build() - }) -} - impl WindowUDFImpl for Ntile { fn as_any(&self) -> &dyn Any { self @@ -135,7 +134,7 @@ impl WindowUDFImpl for Ntile { } fn documentation(&self) -> Option<&Documentation> { - Some(get_ntile_doc()) + self.doc() } } diff --git a/datafusion/functions-window/src/row_number.rs b/datafusion/functions-window/src/row_number.rs index 72d4e0232365..8f462528dbed 100644 --- a/datafusion/functions-window/src/row_number.rs +++ b/datafusion/functions-window/src/row_number.rs @@ -23,17 +23,16 @@ use datafusion_common::arrow::compute::SortOptions; use datafusion_common::arrow::datatypes::DataType; use datafusion_common::arrow::datatypes::Field; use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::window_doc_sections::DOC_SECTION_RANKING; use datafusion_expr::{ Documentation, PartitionEvaluator, Signature, Volatility, WindowUDFImpl, }; use datafusion_functions_window_common::field; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use datafusion_macros::user_doc; use field::WindowUDFFieldArgs; use std::any::Any; use std::fmt::Debug; use std::ops::Range; -use std::sync::OnceLock; define_udwf_and_expr!( RowNumber, @@ -42,6 +41,11 @@ define_udwf_and_expr!( ); /// row_number expression +#[user_doc( + doc_section(label = "Ranking Functions"), + description = "Number of the current row within its partition, counting from 1.", + syntax_example = "row_number()" +)] #[derive(Debug)] pub struct RowNumber { signature: Signature, @@ -62,19 +66,6 @@ impl Default for RowNumber { } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_row_number_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder( - DOC_SECTION_RANKING, - "Number of the current row within its partition, counting from 1.", - "row_number()", - ) - .build() - }) -} - impl WindowUDFImpl for RowNumber { fn as_any(&self) -> &dyn Any { self @@ -107,7 +98,7 @@ impl WindowUDFImpl for RowNumber { } fn documentation(&self) -> Option<&Documentation> { - Some(get_row_number_doc()) + self.doc() } } diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 36d4af9ab55b..575e8484a92f 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -207,3 +207,8 @@ required-features = ["unicode_expressions"] harness = false name = "trunc" required-features = ["math_expressions"] + +[[bench]] +harness = false +name = "initcap" +required-features = ["string_expressions"] diff --git a/datafusion/functions/benches/initcap.rs b/datafusion/functions/benches/initcap.rs new file mode 100644 index 000000000000..c88b6b513980 --- /dev/null +++ b/datafusion/functions/benches/initcap.rs @@ -0,0 +1,93 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::array::OffsetSizeTrait; +use arrow::datatypes::DataType; +use arrow::util::bench_util::{ + create_string_array_with_len, create_string_view_array_with_len, +}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::string; +use std::sync::Arc; + +fn create_args( + size: usize, + str_len: usize, + force_view_types: bool, +) -> Vec { + if force_view_types { + let string_array = + Arc::new(create_string_view_array_with_len(size, 0.2, str_len, false)); + + vec![ColumnarValue::Array(string_array)] + } else { + let string_array = + Arc::new(create_string_array_with_len::(size, 0.2, str_len)); + + vec![ColumnarValue::Array(string_array)] + } +} + +fn criterion_benchmark(c: &mut Criterion) { + let initcap = string::initcap(); + for size in [1024, 4096] { + let args = create_args::(size, 8, true); + c.bench_function( + format!("initcap string view shorter than 12 [size={}]", size).as_str(), + |b| { + b.iter(|| { + black_box(initcap.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + number_rows: size, + return_type: &DataType::Utf8View, + })) + }) + }, + ); + + let args = create_args::(size, 16, true); + c.bench_function( + format!("initcap string view longer than 12 [size={}]", size).as_str(), + |b| { + b.iter(|| { + black_box(initcap.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + number_rows: size, + return_type: &DataType::Utf8View, + })) + }) + }, + ); + + let args = create_args::(size, 16, false); + c.bench_function(format!("initcap string [size={}]", size).as_str(), |b| { + b.iter(|| { + black_box(initcap.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + number_rows: size, + return_type: &DataType::Utf8, + })) + }) + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index 24d26c539539..bd8305cd56d8 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -35,17 +35,17 @@ pub mod r#struct; pub mod version; // create UDFs -make_udf_function!(arrow_cast::ArrowCastFunc, ARROW_CAST, arrow_cast); -make_udf_function!(nullif::NullIfFunc, NULLIF, nullif); -make_udf_function!(nvl::NVLFunc, NVL, nvl); -make_udf_function!(nvl2::NVL2Func, NVL2, nvl2); -make_udf_function!(arrowtypeof::ArrowTypeOfFunc, ARROWTYPEOF, arrow_typeof); -make_udf_function!(r#struct::StructFunc, STRUCT, r#struct); -make_udf_function!(named_struct::NamedStructFunc, NAMED_STRUCT, named_struct); -make_udf_function!(getfield::GetFieldFunc, GET_FIELD, get_field); -make_udf_function!(coalesce::CoalesceFunc, COALESCE, coalesce); -make_udf_function!(greatest::GreatestFunc, GREATEST, greatest); -make_udf_function!(version::VersionFunc, VERSION, version); +make_udf_function!(arrow_cast::ArrowCastFunc, arrow_cast); +make_udf_function!(nullif::NullIfFunc, nullif); +make_udf_function!(nvl::NVLFunc, nvl); +make_udf_function!(nvl2::NVL2Func, nvl2); +make_udf_function!(arrowtypeof::ArrowTypeOfFunc, arrow_typeof); +make_udf_function!(r#struct::StructFunc, r#struct); +make_udf_function!(named_struct::NamedStructFunc, named_struct); +make_udf_function!(getfield::GetFieldFunc, get_field); +make_udf_function!(coalesce::CoalesceFunc, coalesce); +make_udf_function!(greatest::GreatestFunc, greatest); +make_udf_function!(version::VersionFunc, version); pub mod expr_fn { use datafusion_expr::{Expr, Literal}; diff --git a/datafusion/functions/src/crypto/mod.rs b/datafusion/functions/src/crypto/mod.rs index 46177fc22b60..62ea3c2e2737 100644 --- a/datafusion/functions/src/crypto/mod.rs +++ b/datafusion/functions/src/crypto/mod.rs @@ -27,12 +27,12 @@ pub mod sha224; pub mod sha256; pub mod sha384; pub mod sha512; -make_udf_function!(digest::DigestFunc, DIGEST, digest); -make_udf_function!(md5::Md5Func, MD5, md5); -make_udf_function!(sha224::SHA224Func, SHA224, sha224); -make_udf_function!(sha256::SHA256Func, SHA256, sha256); -make_udf_function!(sha384::SHA384Func, SHA384, sha384); -make_udf_function!(sha512::SHA512Func, SHA512, sha512); +make_udf_function!(digest::DigestFunc, digest); +make_udf_function!(md5::Md5Func, md5); +make_udf_function!(sha224::SHA224Func, sha224); +make_udf_function!(sha256::SHA256Func, sha256); +make_udf_function!(sha384::SHA384Func, sha384); +make_udf_function!(sha512::SHA512Func, sha512); pub mod expr_fn { export_functions!(( diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index b8c58a11d999..6d6adf2a344d 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -35,7 +35,9 @@ use datafusion_common::cast::{ as_timestamp_microsecond_array, as_timestamp_millisecond_array, as_timestamp_nanosecond_array, as_timestamp_second_array, }; -use datafusion_common::{exec_err, internal_err, ExprSchema, Result, ScalarValue}; +use datafusion_common::{ + exec_err, internal_err, not_impl_err, ExprSchema, Result, ScalarValue, +}; use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ @@ -290,9 +292,10 @@ fn get_date_part_doc() -> &'static Documentation { /// result to a total number of seconds, milliseconds, microseconds or /// nanoseconds fn seconds_as_i32(array: &dyn Array, unit: TimeUnit) -> Result { - // Nanosecond is neither supported in Postgres nor DuckDB, to avoid to deal with overflow and precision issue we don't support nanosecond + // Nanosecond is neither supported in Postgres nor DuckDB, to avoid dealing + // with overflow and precision issue we don't support nanosecond if unit == Nanosecond { - return internal_err!("unit {unit:?} not supported"); + return not_impl_err!("Date part {unit:?} not supported"); } let conversion_factor = match unit { diff --git a/datafusion/functions/src/datetime/mod.rs b/datafusion/functions/src/datetime/mod.rs index db4e365267dd..96ca63010ee4 100644 --- a/datafusion/functions/src/datetime/mod.rs +++ b/datafusion/functions/src/datetime/mod.rs @@ -37,43 +37,23 @@ pub mod to_timestamp; pub mod to_unixtime; // create UDFs -make_udf_function!(current_date::CurrentDateFunc, CURRENT_DATE, current_date); -make_udf_function!(current_time::CurrentTimeFunc, CURRENT_TIME, current_time); -make_udf_function!(date_bin::DateBinFunc, DATE_BIN, date_bin); -make_udf_function!(date_part::DatePartFunc, DATE_PART, date_part); -make_udf_function!(date_trunc::DateTruncFunc, DATE_TRUNC, date_trunc); -make_udf_function!(make_date::MakeDateFunc, MAKE_DATE, make_date); -make_udf_function!( - from_unixtime::FromUnixtimeFunc, - FROM_UNIXTIME, - from_unixtime -); -make_udf_function!(now::NowFunc, NOW, now); -make_udf_function!(to_char::ToCharFunc, TO_CHAR, to_char); -make_udf_function!(to_date::ToDateFunc, TO_DATE, to_date); -make_udf_function!(to_local_time::ToLocalTimeFunc, TO_LOCAL_TIME, to_local_time); -make_udf_function!(to_unixtime::ToUnixtimeFunc, TO_UNIXTIME, to_unixtime); -make_udf_function!(to_timestamp::ToTimestampFunc, TO_TIMESTAMP, to_timestamp); -make_udf_function!( - to_timestamp::ToTimestampSecondsFunc, - TO_TIMESTAMP_SECONDS, - to_timestamp_seconds -); -make_udf_function!( - to_timestamp::ToTimestampMillisFunc, - TO_TIMESTAMP_MILLIS, - to_timestamp_millis -); -make_udf_function!( - to_timestamp::ToTimestampMicrosFunc, - TO_TIMESTAMP_MICROS, - to_timestamp_micros -); -make_udf_function!( - to_timestamp::ToTimestampNanosFunc, - TO_TIMESTAMP_NANOS, - to_timestamp_nanos -); +make_udf_function!(current_date::CurrentDateFunc, current_date); +make_udf_function!(current_time::CurrentTimeFunc, current_time); +make_udf_function!(date_bin::DateBinFunc, date_bin); +make_udf_function!(date_part::DatePartFunc, date_part); +make_udf_function!(date_trunc::DateTruncFunc, date_trunc); +make_udf_function!(make_date::MakeDateFunc, make_date); +make_udf_function!(from_unixtime::FromUnixtimeFunc, from_unixtime); +make_udf_function!(now::NowFunc, now); +make_udf_function!(to_char::ToCharFunc, to_char); +make_udf_function!(to_date::ToDateFunc, to_date); +make_udf_function!(to_local_time::ToLocalTimeFunc, to_local_time); +make_udf_function!(to_unixtime::ToUnixtimeFunc, to_unixtime); +make_udf_function!(to_timestamp::ToTimestampFunc, to_timestamp); +make_udf_function!(to_timestamp::ToTimestampSecondsFunc, to_timestamp_seconds); +make_udf_function!(to_timestamp::ToTimestampMillisFunc, to_timestamp_millis); +make_udf_function!(to_timestamp::ToTimestampMicrosFunc, to_timestamp_micros); +make_udf_function!(to_timestamp::ToTimestampNanosFunc, to_timestamp_nanos); // we cannot currently use the export_functions macro since it doesn't handle // functions with varargs currently diff --git a/datafusion/functions/src/datetime/to_date.rs b/datafusion/functions/src/datetime/to_date.rs index e2edea843e98..091d0ba37644 100644 --- a/datafusion/functions/src/datetime/to_date.rs +++ b/datafusion/functions/src/datetime/to_date.rs @@ -22,13 +22,11 @@ use arrow::error::ArrowError::ParseError; use arrow::{array::types::Date32Type, compute::kernels::cast_utils::Parser}; use datafusion_common::error::DataFusionError; use datafusion_common::{arrow_err, exec_err, internal_datafusion_err, Result}; -use datafusion_doc::DocSection; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; use std::any::Any; -use std::sync::OnceLock; #[user_doc( doc_section(label = "Time and Date Functions"), diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index eaa91d1140ba..9f95b780ea4f 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -562,7 +562,7 @@ mod tests { fn test_to_local_time_helper(input: ScalarValue, expected: ScalarValue) { let res = ToLocalTimeFunc::new() .invoke_with_args(ScalarFunctionArgs { - args: &[ColumnarValue::Scalar(input)], + args: vec![ColumnarValue::Scalar(input)], number_rows: 1, return_type: &expected.data_type(), }) diff --git a/datafusion/functions/src/encoding/mod.rs b/datafusion/functions/src/encoding/mod.rs index 48171370ad58..b0ddbd368a6b 100644 --- a/datafusion/functions/src/encoding/mod.rs +++ b/datafusion/functions/src/encoding/mod.rs @@ -21,8 +21,8 @@ use std::sync::Arc; pub mod inner; // create `encode` and `decode` UDFs -make_udf_function!(inner::EncodeFunc, ENCODE, encode); -make_udf_function!(inner::DecodeFunc, DECODE, decode); +make_udf_function!(inner::EncodeFunc, encode); +make_udf_function!(inner::DecodeFunc, decode); // Export the functions out of this package, both as expr_fn as well as a list of functions pub mod expr_fn { diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index bedec9bb2e6f..82308601490c 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -65,24 +65,23 @@ macro_rules! export_functions { }; } -/// Creates a singleton `ScalarUDF` of the `$UDF` function named `$GNAME` and a -/// function named `$NAME` which returns that singleton. +/// Creates a singleton `ScalarUDF` of the `$UDF` function and a function +/// named `$NAME` which returns that singleton. /// /// This is used to ensure creating the list of `ScalarUDF` only happens once. macro_rules! make_udf_function { - ($UDF:ty, $GNAME:ident, $NAME:ident) => { - #[doc = "Return a [`ScalarUDF`](datafusion_expr::ScalarUDF) implementation "] - #[doc = stringify!($UDF)] + ($UDF:ty, $NAME:ident) => { + #[doc = concat!("Return a [`ScalarUDF`](datafusion_expr::ScalarUDF) implementation of ", stringify!($NAME))] pub fn $NAME() -> std::sync::Arc { // Singleton instance of the function - static $GNAME: std::sync::LazyLock< + static INSTANCE: std::sync::LazyLock< std::sync::Arc, > = std::sync::LazyLock::new(|| { std::sync::Arc::new(datafusion_expr::ScalarUDF::new_from_impl( <$UDF>::new(), )) }); - std::sync::Arc::clone(&$GNAME) + std::sync::Arc::clone(&INSTANCE) } }; } @@ -134,13 +133,13 @@ macro_rules! downcast_arg { /// applies a unary floating function to the argument, and returns a value of the same type. /// /// $UDF: the name of the UDF struct that implements `ScalarUDFImpl` -/// $GNAME: a singleton instance of the UDF /// $NAME: the name of the function /// $UNARY_FUNC: the unary function to apply to the argument /// $OUTPUT_ORDERING: the output ordering calculation method of the function +/// $GET_DOC: the function to get the documentation of the UDF macro_rules! make_math_unary_udf { - ($UDF:ident, $GNAME:ident, $NAME:ident, $UNARY_FUNC:ident, $OUTPUT_ORDERING:expr, $EVALUATE_BOUNDS:expr, $GET_DOC:expr) => { - make_udf_function!($NAME::$UDF, $GNAME, $NAME); + ($UDF:ident, $NAME:ident, $UNARY_FUNC:ident, $OUTPUT_ORDERING:expr, $EVALUATE_BOUNDS:expr, $GET_DOC:expr) => { + make_udf_function!($NAME::$UDF, $NAME); mod $NAME { use std::any::Any; @@ -248,13 +247,13 @@ macro_rules! make_math_unary_udf { /// applies a binary floating function to the argument, and returns a value of the same type. /// /// $UDF: the name of the UDF struct that implements `ScalarUDFImpl` -/// $GNAME: a singleton instance of the UDF /// $NAME: the name of the function /// $BINARY_FUNC: the binary function to apply to the argument /// $OUTPUT_ORDERING: the output ordering calculation method of the function +/// $GET_DOC: the function to get the documentation of the UDF macro_rules! make_math_binary_udf { - ($UDF:ident, $GNAME:ident, $NAME:ident, $BINARY_FUNC:ident, $OUTPUT_ORDERING:expr, $GET_DOC:expr) => { - make_udf_function!($NAME::$UDF, $GNAME, $NAME); + ($UDF:ident, $NAME:ident, $BINARY_FUNC:ident, $OUTPUT_ORDERING:expr, $GET_DOC:expr) => { + make_udf_function!($NAME::$UDF, $NAME); mod $NAME { use std::any::Any; diff --git a/datafusion/functions/src/math/abs.rs b/datafusion/functions/src/math/abs.rs index c0c7c6f0f6b6..e3d448083e26 100644 --- a/datafusion/functions/src/math/abs.rs +++ b/datafusion/functions/src/math/abs.rs @@ -18,7 +18,7 @@ //! math expressions use std::any::Any; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::{ ArrayRef, Decimal128Array, Decimal256Array, Float32Array, Float64Array, Int16Array, @@ -27,7 +27,6 @@ use arrow::array::{ use arrow::datatypes::DataType; use arrow::error::ArrowError; use datafusion_common::{exec_err, not_impl_err, DataFusionError, Result}; -use datafusion_doc::DocSection; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index 1452bfdee5a0..4eb337a30110 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -40,10 +40,9 @@ pub mod signum; pub mod trunc; // Create UDFs -make_udf_function!(abs::AbsFunc, ABS, abs); +make_udf_function!(abs::AbsFunc, abs); make_math_unary_udf!( AcosFunc, - ACOS, acos, acos, super::acos_order, @@ -52,7 +51,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( AcoshFunc, - ACOSH, acosh, acosh, super::acosh_order, @@ -61,7 +59,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( AsinFunc, - ASIN, asin, asin, super::asin_order, @@ -70,7 +67,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( AsinhFunc, - ASINH, asinh, asinh, super::asinh_order, @@ -79,7 +75,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( AtanFunc, - ATAN, atan, atan, super::atan_order, @@ -88,7 +83,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( AtanhFunc, - ATANH, atanh, atanh, super::atanh_order, @@ -97,7 +91,6 @@ make_math_unary_udf!( ); make_math_binary_udf!( Atan2, - ATAN2, atan2, atan2, super::atan2_order, @@ -105,7 +98,6 @@ make_math_binary_udf!( ); make_math_unary_udf!( CbrtFunc, - CBRT, cbrt, cbrt, super::cbrt_order, @@ -114,7 +106,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( CeilFunc, - CEIL, ceil, ceil, super::ceil_order, @@ -123,7 +114,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( CosFunc, - COS, cos, cos, super::cos_order, @@ -132,17 +122,15 @@ make_math_unary_udf!( ); make_math_unary_udf!( CoshFunc, - COSH, cosh, cosh, super::cosh_order, super::bounds::cosh_bounds, super::get_cosh_doc ); -make_udf_function!(cot::CotFunc, COT, cot); +make_udf_function!(cot::CotFunc, cot); make_math_unary_udf!( DegreesFunc, - DEGREES, degrees, to_degrees, super::degrees_order, @@ -151,31 +139,28 @@ make_math_unary_udf!( ); make_math_unary_udf!( ExpFunc, - EXP, exp, exp, super::exp_order, super::bounds::exp_bounds, super::get_exp_doc ); -make_udf_function!(factorial::FactorialFunc, FACTORIAL, factorial); +make_udf_function!(factorial::FactorialFunc, factorial); make_math_unary_udf!( FloorFunc, - FLOOR, floor, floor, super::floor_order, super::bounds::unbounded_bounds, super::get_floor_doc ); -make_udf_function!(log::LogFunc, LOG, log); -make_udf_function!(gcd::GcdFunc, GCD, gcd); -make_udf_function!(nans::IsNanFunc, ISNAN, isnan); -make_udf_function!(iszero::IsZeroFunc, ISZERO, iszero); -make_udf_function!(lcm::LcmFunc, LCM, lcm); +make_udf_function!(log::LogFunc, log); +make_udf_function!(gcd::GcdFunc, gcd); +make_udf_function!(nans::IsNanFunc, isnan); +make_udf_function!(iszero::IsZeroFunc, iszero); +make_udf_function!(lcm::LcmFunc, lcm); make_math_unary_udf!( LnFunc, - LN, ln, ln, super::ln_order, @@ -184,7 +169,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( Log2Func, - LOG2, log2, log2, super::log2_order, @@ -193,31 +177,28 @@ make_math_unary_udf!( ); make_math_unary_udf!( Log10Func, - LOG10, log10, log10, super::log10_order, super::bounds::unbounded_bounds, super::get_log10_doc ); -make_udf_function!(nanvl::NanvlFunc, NANVL, nanvl); -make_udf_function!(pi::PiFunc, PI, pi); -make_udf_function!(power::PowerFunc, POWER, power); +make_udf_function!(nanvl::NanvlFunc, nanvl); +make_udf_function!(pi::PiFunc, pi); +make_udf_function!(power::PowerFunc, power); make_math_unary_udf!( RadiansFunc, - RADIANS, radians, to_radians, super::radians_order, super::bounds::radians_bounds, super::get_radians_doc ); -make_udf_function!(random::RandomFunc, RANDOM, random); -make_udf_function!(round::RoundFunc, ROUND, round); -make_udf_function!(signum::SignumFunc, SIGNUM, signum); +make_udf_function!(random::RandomFunc, random); +make_udf_function!(round::RoundFunc, round); +make_udf_function!(signum::SignumFunc, signum); make_math_unary_udf!( SinFunc, - SIN, sin, sin, super::sin_order, @@ -226,7 +207,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( SinhFunc, - SINH, sinh, sinh, super::sinh_order, @@ -235,7 +215,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( SqrtFunc, - SQRT, sqrt, sqrt, super::sqrt_order, @@ -244,7 +223,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( TanFunc, - TAN, tan, tan, super::tan_order, @@ -253,14 +231,13 @@ make_math_unary_udf!( ); make_math_unary_udf!( TanhFunc, - TANH, tanh, tanh, super::tanh_order, super::bounds::tanh_bounds, super::get_tanh_doc ); -make_udf_function!(trunc::TruncFunc, TRUNC, trunc); +make_udf_function!(trunc::TruncFunc, trunc); pub mod expr_fn { export_functions!( diff --git a/datafusion/functions/src/regex/mod.rs b/datafusion/functions/src/regex/mod.rs index 803f51e915a9..13fbc049af58 100644 --- a/datafusion/functions/src/regex/mod.rs +++ b/datafusion/functions/src/regex/mod.rs @@ -25,14 +25,10 @@ pub mod regexpmatch; pub mod regexpreplace; // create UDFs -make_udf_function!(regexpcount::RegexpCountFunc, REGEXP_COUNT, regexp_count); -make_udf_function!(regexpmatch::RegexpMatchFunc, REGEXP_MATCH, regexp_match); -make_udf_function!(regexplike::RegexpLikeFunc, REGEXP_LIKE, regexp_like); -make_udf_function!( - regexpreplace::RegexpReplaceFunc, - REGEXP_REPLACE, - regexp_replace -); +make_udf_function!(regexpcount::RegexpCountFunc, regexp_count); +make_udf_function!(regexpmatch::RegexpMatchFunc, regexp_match); +make_udf_function!(regexplike::RegexpLikeFunc, regexp_like); +make_udf_function!(regexpreplace::RegexpReplaceFunc, regexp_replace); pub mod expr_fn { use datafusion_expr::Expr; diff --git a/datafusion/functions/src/regex/regexplike.rs b/datafusion/functions/src/regex/regexplike.rs index 49e57776c7b8..1c826b12ef8f 100644 --- a/datafusion/functions/src/regex/regexplike.rs +++ b/datafusion/functions/src/regex/regexplike.rs @@ -81,26 +81,7 @@ impl RegexpLikeFunc { pub fn new() -> Self { Self { signature: Signature::one_of( - vec![ - TypeSignature::Exact(vec![Utf8View, Utf8]), - TypeSignature::Exact(vec![Utf8View, Utf8View]), - TypeSignature::Exact(vec![Utf8View, LargeUtf8]), - TypeSignature::Exact(vec![Utf8, Utf8]), - TypeSignature::Exact(vec![Utf8, Utf8View]), - TypeSignature::Exact(vec![Utf8, LargeUtf8]), - TypeSignature::Exact(vec![LargeUtf8, Utf8]), - TypeSignature::Exact(vec![LargeUtf8, Utf8View]), - TypeSignature::Exact(vec![LargeUtf8, LargeUtf8]), - TypeSignature::Exact(vec![Utf8View, Utf8, Utf8]), - TypeSignature::Exact(vec![Utf8View, Utf8View, Utf8]), - TypeSignature::Exact(vec![Utf8View, LargeUtf8, Utf8]), - TypeSignature::Exact(vec![Utf8, Utf8, Utf8]), - TypeSignature::Exact(vec![Utf8, Utf8View, Utf8]), - TypeSignature::Exact(vec![Utf8, LargeUtf8, Utf8]), - TypeSignature::Exact(vec![LargeUtf8, Utf8, Utf8]), - TypeSignature::Exact(vec![LargeUtf8, Utf8View, Utf8]), - TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Utf8]), - ], + vec![TypeSignature::String(2), TypeSignature::String(3)], Volatility::Immutable, ), } @@ -211,7 +192,34 @@ pub fn regexp_like(args: &[ArrayRef]) -> Result { match args.len() { 2 => handle_regexp_like(&args[0], &args[1], None), 3 => { - let flags = args[2].as_string::(); + let flags = match args[2].data_type() { + Utf8 => args[2].as_string::(), + LargeUtf8 => { + let large_string_array = args[2].as_string::(); + let string_vec: Vec> = (0..large_string_array.len()).map(|i| { + if large_string_array.is_null(i) { + None + } else { + Some(large_string_array.value(i)) + } + }) + .collect(); + + &GenericStringArray::::from(string_vec) + }, + _ => { + let string_view_array = args[2].as_string_view(); + let string_vec: Vec> = (0..string_view_array.len()).map(|i| { + if string_view_array.is_null(i) { + None + } else { + Some(string_view_array.value(i).to_string()) + } + }) + .collect(); + &GenericStringArray::::from(string_vec) + }, + }; if flags.iter().any(|s| s == Some("g")) { return plan_err!("regexp_like() does not support the \"global\" option"); diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index 4f615b5b2c58..f366329b4f86 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -157,7 +157,7 @@ mod tests { ($INPUT:expr, $EXPECTED:expr) => { test_function!( AsciiFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], $EXPECTED, i32, Int32, @@ -166,7 +166,7 @@ mod tests { test_function!( AsciiFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], + vec![ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], $EXPECTED, i32, Int32, @@ -175,7 +175,7 @@ mod tests { test_function!( AsciiFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], $EXPECTED, i32, Int32, diff --git a/datafusion/functions/src/string/btrim.rs b/datafusion/functions/src/string/btrim.rs index ae79bb59f9c7..298d64f04ae9 100644 --- a/datafusion/functions/src/string/btrim.rs +++ b/datafusion/functions/src/string/btrim.rs @@ -152,9 +152,9 @@ mod tests { // String view cases for checking normal logic test_function!( BTrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( String::from("alphabet ") - ))),], + )))], Ok(Some("alphabet")), &str, Utf8View, @@ -162,7 +162,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( String::from(" alphabet ") ))),], Ok(Some("alphabet")), @@ -172,7 +172,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -185,7 +185,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -200,7 +200,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -214,7 +214,7 @@ mod tests { // Special string view case for checking unlined output(len > 12) test_function!( BTrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "xxxalphabetalphabetxxx" )))), @@ -228,7 +228,7 @@ mod tests { // String cases test_function!( BTrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( String::from("alphabet ") ))),], Ok(Some("alphabet")), @@ -238,7 +238,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( String::from("alphabet ") ))),], Ok(Some("alphabet")), @@ -248,7 +248,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("t")))), ], @@ -259,7 +259,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabe")))), ], @@ -270,7 +270,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::Scalar(ScalarValue::Utf8(None)), ], diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 576c891ce467..895a7cdbf308 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -388,7 +388,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( ConcatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("aa")), ColumnarValue::Scalar(ScalarValue::from("bb")), ColumnarValue::Scalar(ScalarValue::from("cc")), @@ -400,7 +400,7 @@ mod tests { ); test_function!( ConcatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("aa")), ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::from("cc")), @@ -412,7 +412,7 @@ mod tests { ); test_function!( ConcatFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8(None))], Ok(Some("")), &str, Utf8, @@ -420,7 +420,7 @@ mod tests { ); test_function!( ConcatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("aa")), ColumnarValue::Scalar(ScalarValue::Utf8View(None)), ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)), @@ -433,7 +433,7 @@ mod tests { ); test_function!( ConcatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("aa")), ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)), ColumnarValue::Scalar(ScalarValue::from("cc")), @@ -445,7 +445,7 @@ mod tests { ); test_function!( ConcatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some("aa".to_string()))), ColumnarValue::Scalar(ScalarValue::Utf8(Some("cc".to_string()))), ], diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 610c4f0be697..7db8dbec4a71 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -404,7 +404,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( ConcatWsFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("|")), ColumnarValue::Scalar(ScalarValue::from("aa")), ColumnarValue::Scalar(ScalarValue::from("bb")), @@ -417,7 +417,7 @@ mod tests { ); test_function!( ConcatWsFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("|")), ColumnarValue::Scalar(ScalarValue::Utf8(None)), ], @@ -428,7 +428,7 @@ mod tests { ); test_function!( ConcatWsFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::from("aa")), ColumnarValue::Scalar(ScalarValue::from("bb")), @@ -441,7 +441,7 @@ mod tests { ); test_function!( ConcatWsFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("|")), ColumnarValue::Scalar(ScalarValue::from("aa")), ColumnarValue::Scalar(ScalarValue::Utf8(None)), diff --git a/datafusion/functions/src/string/ends_with.rs b/datafusion/functions/src/string/ends_with.rs index fc7fc04f4363..1632fdd9943e 100644 --- a/datafusion/functions/src/string/ends_with.rs +++ b/datafusion/functions/src/string/ends_with.rs @@ -138,7 +138,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( EndsWithFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from("alph")), ], @@ -149,7 +149,7 @@ mod tests { ); test_function!( EndsWithFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from("bet")), ], @@ -160,7 +160,7 @@ mod tests { ); test_function!( EndsWithFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::from("alph")), ], @@ -171,7 +171,7 @@ mod tests { ); test_function!( EndsWithFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::Utf8(None)), ], diff --git a/datafusion/functions/src/string/initcap.rs b/datafusion/functions/src/string/initcap.rs index a9090b0a6f43..2780dcaeeb83 100644 --- a/datafusion/functions/src/string/initcap.rs +++ b/datafusion/functions/src/string/initcap.rs @@ -90,25 +90,36 @@ fn get_initcap_doc() -> &'static Documentation { DOCUMENTATION.get_or_init(|| { Documentation::builder( DOC_SECTION_STRING, - "Capitalizes the first character in each word in the input string. Words are delimited by non-alphanumeric characters.", - "initcap(str)") - .with_sql_example(r#"```sql + "Capitalizes the first character in each word in the ASCII input string. \ + Words are delimited by non-alphanumeric characters.\n\n\ + Note this function does not support UTF-8 characters.", + "initcap(str)", + ) + .with_sql_example( + r#"```sql > select initcap('apache datafusion'); +------------------------------------+ | initcap(Utf8("apache datafusion")) | +------------------------------------+ | Apache Datafusion | +------------------------------------+ -```"#) - .with_standard_argument("str", Some("String")) - .with_related_udf("lower") - .with_related_udf("upper") - .build() +```"#, + ) + .with_standard_argument("str", Some("String")) + .with_related_udf("lower") + .with_related_udf("upper") + .build() }) } -/// Converts the first letter of each word to upper case and the rest to lower case. Words are sequences of alphanumeric characters separated by non-alphanumeric characters. +/// Converts the first letter of each word to upper case and the rest to lower +/// case. Words are sequences of alphanumeric characters separated by +/// non-alphanumeric characters. +/// +/// Example: +/// ```sql /// initcap('hi THOMAS') = 'Hi Thomas' +/// ``` fn initcap(args: &[ArrayRef]) -> Result { let string_array = as_generic_string_array::(&args[0])?; @@ -132,21 +143,22 @@ fn initcap_utf8view(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } -fn initcap_string(string: Option<&str>) -> Option { - let mut char_vector = Vec::::new(); - string.map(|string: &str| { - char_vector.clear(); - let mut previous_character_letter_or_number = false; - for c in string.chars() { - if previous_character_letter_or_number { - char_vector.push(c.to_ascii_lowercase()); +fn initcap_string(input: Option<&str>) -> Option { + input.map(|s| { + let mut result = String::with_capacity(s.len()); + let mut prev_is_alphanumeric = false; + + for c in s.chars() { + let transformed = if prev_is_alphanumeric { + c.to_ascii_lowercase() } else { - char_vector.push(c.to_ascii_uppercase()); - } - previous_character_letter_or_number = - c.is_ascii_uppercase() || c.is_ascii_lowercase() || c.is_ascii_digit(); + c.to_ascii_uppercase() + }; + result.push(transformed); + prev_is_alphanumeric = c.is_ascii_alphanumeric(); } - char_vector.iter().collect::() + + result }) } @@ -163,7 +175,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::from("hi THOMAS"))], + vec![ColumnarValue::Scalar(ScalarValue::from("hi THOMAS"))], Ok(Some("Hi Thomas")), &str, Utf8, @@ -171,7 +183,7 @@ mod tests { ); test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::from(""))], + vec![ColumnarValue::Scalar(ScalarValue::from(""))], Ok(Some("")), &str, Utf8, @@ -179,7 +191,7 @@ mod tests { ); test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::from(""))], + vec![ColumnarValue::Scalar(ScalarValue::from(""))], Ok(Some("")), &str, Utf8, @@ -187,7 +199,7 @@ mod tests { ); test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8(None))], Ok(None), &str, Utf8, @@ -195,7 +207,7 @@ mod tests { ); test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( "hi THOMAS".to_string() )))], Ok(Some("Hi Thomas")), @@ -205,7 +217,7 @@ mod tests { ); test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( "hi THOMAS wIth M0re ThAN 12 ChaRs".to_string() )))], Ok(Some("Hi Thomas With M0re Than 12 Chars")), @@ -215,7 +227,7 @@ mod tests { ); test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( "".to_string() )))], Ok(Some("")), @@ -225,7 +237,7 @@ mod tests { ); test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(None))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(None))], Ok(None), &str, Utf8, diff --git a/datafusion/functions/src/string/ltrim.rs b/datafusion/functions/src/string/ltrim.rs index e0e83d1b01e3..0bc62ee5000d 100644 --- a/datafusion/functions/src/string/ltrim.rs +++ b/datafusion/functions/src/string/ltrim.rs @@ -22,12 +22,10 @@ use std::any::Any; use crate::string::common::*; use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::{exec_err, Result}; -use datafusion_doc::DocSection; use datafusion_expr::function::Hint; use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; -use std::sync::OnceLock; /// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed. /// ltrim('zzzytest', 'xyz') = 'test' @@ -148,7 +146,7 @@ mod tests { // String view cases for checking normal logic test_function!( LtrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( String::from("alphabet ") ))),], Ok(Some("alphabet ")), @@ -158,7 +156,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( String::from(" alphabet ") ))),], Ok(Some("alphabet ")), @@ -168,7 +166,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -181,7 +179,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -196,7 +194,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -210,7 +208,7 @@ mod tests { // Special string view case for checking unlined output(len > 12) test_function!( LtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "xxxalphabetalphabet" )))), @@ -224,7 +222,7 @@ mod tests { // String cases test_function!( LtrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( String::from("alphabet ") ))),], Ok(Some("alphabet ")), @@ -234,7 +232,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( String::from("alphabet ") ))),], Ok(Some("alphabet ")), @@ -244,7 +242,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("t")))), ], @@ -255,7 +253,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabe")))), ], @@ -266,7 +264,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::Scalar(ScalarValue::Utf8(None)), ], diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs index 622802f0142b..f156f070d960 100644 --- a/datafusion/functions/src/string/mod.rs +++ b/datafusion/functions/src/string/mod.rs @@ -45,28 +45,28 @@ pub mod to_hex; pub mod upper; pub mod uuid; // create UDFs -make_udf_function!(ascii::AsciiFunc, ASCII, ascii); -make_udf_function!(bit_length::BitLengthFunc, BIT_LENGTH, bit_length); -make_udf_function!(btrim::BTrimFunc, BTRIM, btrim); -make_udf_function!(chr::ChrFunc, CHR, chr); -make_udf_function!(concat::ConcatFunc, CONCAT, concat); -make_udf_function!(concat_ws::ConcatWsFunc, CONCAT_WS, concat_ws); -make_udf_function!(ends_with::EndsWithFunc, ENDS_WITH, ends_with); -make_udf_function!(initcap::InitcapFunc, INITCAP, initcap); -make_udf_function!(levenshtein::LevenshteinFunc, LEVENSHTEIN, levenshtein); -make_udf_function!(ltrim::LtrimFunc, LTRIM, ltrim); -make_udf_function!(lower::LowerFunc, LOWER, lower); -make_udf_function!(octet_length::OctetLengthFunc, OCTET_LENGTH, octet_length); -make_udf_function!(overlay::OverlayFunc, OVERLAY, overlay); -make_udf_function!(repeat::RepeatFunc, REPEAT, repeat); -make_udf_function!(replace::ReplaceFunc, REPLACE, replace); -make_udf_function!(rtrim::RtrimFunc, RTRIM, rtrim); -make_udf_function!(starts_with::StartsWithFunc, STARTS_WITH, starts_with); -make_udf_function!(split_part::SplitPartFunc, SPLIT_PART, split_part); -make_udf_function!(to_hex::ToHexFunc, TO_HEX, to_hex); -make_udf_function!(upper::UpperFunc, UPPER, upper); -make_udf_function!(uuid::UuidFunc, UUID, uuid); -make_udf_function!(contains::ContainsFunc, CONTAINS, contains); +make_udf_function!(ascii::AsciiFunc, ascii); +make_udf_function!(bit_length::BitLengthFunc, bit_length); +make_udf_function!(btrim::BTrimFunc, btrim); +make_udf_function!(chr::ChrFunc, chr); +make_udf_function!(concat::ConcatFunc, concat); +make_udf_function!(concat_ws::ConcatWsFunc, concat_ws); +make_udf_function!(ends_with::EndsWithFunc, ends_with); +make_udf_function!(initcap::InitcapFunc, initcap); +make_udf_function!(levenshtein::LevenshteinFunc, levenshtein); +make_udf_function!(ltrim::LtrimFunc, ltrim); +make_udf_function!(lower::LowerFunc, lower); +make_udf_function!(octet_length::OctetLengthFunc, octet_length); +make_udf_function!(overlay::OverlayFunc, overlay); +make_udf_function!(repeat::RepeatFunc, repeat); +make_udf_function!(replace::ReplaceFunc, replace); +make_udf_function!(rtrim::RtrimFunc, rtrim); +make_udf_function!(starts_with::StartsWithFunc, starts_with); +make_udf_function!(split_part::SplitPartFunc, split_part); +make_udf_function!(to_hex::ToHexFunc, to_hex); +make_udf_function!(upper::UpperFunc, upper); +make_udf_function!(uuid::UuidFunc, uuid); +make_udf_function!(contains::ContainsFunc, contains); pub mod expr_fn { use datafusion_expr::Expr; diff --git a/datafusion/functions/src/string/octet_length.rs b/datafusion/functions/src/string/octet_length.rs index 2dbfa6746d61..26355556ff07 100644 --- a/datafusion/functions/src/string/octet_length.rs +++ b/datafusion/functions/src/string/octet_length.rs @@ -140,7 +140,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Int32(Some(12)))], + vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(12)))], exec_err!( "The OCTET_LENGTH function can only accept strings, but got Int32." ), @@ -150,7 +150,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Array(Arc::new(StringArray::from(vec![ + vec![ColumnarValue::Array(Arc::new(StringArray::from(vec![ String::from("chars"), String::from("chars2"), ])))], @@ -161,7 +161,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("chars")))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("chars")))) ], @@ -172,7 +172,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( String::from("chars") )))], Ok(Some(5)), @@ -182,7 +182,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( String::from("josé") )))], Ok(Some(5)), @@ -192,7 +192,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( String::from("") )))], Ok(Some(0)), @@ -202,7 +202,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8(None))], Ok(None), i32, Int32, @@ -210,7 +210,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( String::from("joséjoséjoséjosé") )))], Ok(Some(20)), @@ -220,7 +220,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( String::from("josé") )))], Ok(Some(5)), @@ -230,7 +230,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( String::from("") )))], Ok(Some(0)), diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index 4140a9b913ff..d16508c6af5a 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -171,7 +171,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( RepeatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))), ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), ], @@ -182,7 +182,7 @@ mod tests { ); test_function!( RepeatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), ], @@ -193,7 +193,7 @@ mod tests { ); test_function!( RepeatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))), ColumnarValue::Scalar(ScalarValue::Int64(None)), ], @@ -205,7 +205,7 @@ mod tests { test_function!( RepeatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))), ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), ], @@ -216,7 +216,7 @@ mod tests { ); test_function!( RepeatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(None)), ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), ], @@ -227,7 +227,7 @@ mod tests { ); test_function!( RepeatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))), ColumnarValue::Scalar(ScalarValue::Int64(None)), ], diff --git a/datafusion/functions/src/string/replace.rs b/datafusion/functions/src/string/replace.rs index 2439799f96d7..9b71d3871ea8 100644 --- a/datafusion/functions/src/string/replace.rs +++ b/datafusion/functions/src/string/replace.rs @@ -157,7 +157,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( ReplaceFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("aabbdqcbb")))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("bb")))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ccc")))), @@ -170,7 +170,7 @@ mod tests { test_function!( ReplaceFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from( "aabbb" )))), @@ -185,7 +185,7 @@ mod tests { test_function!( ReplaceFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "aabbbcw" )))), diff --git a/datafusion/functions/src/string/rtrim.rs b/datafusion/functions/src/string/rtrim.rs index b4fe8d432432..ff8430f1530e 100644 --- a/datafusion/functions/src/string/rtrim.rs +++ b/datafusion/functions/src/string/rtrim.rs @@ -151,7 +151,7 @@ mod tests { // String view cases for checking normal logic test_function!( RtrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( String::from("alphabet ") ))),], Ok(Some("alphabet")), @@ -161,7 +161,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( String::from(" alphabet ") ))),], Ok(Some(" alphabet")), @@ -171,7 +171,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -184,7 +184,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -199,7 +199,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -213,7 +213,7 @@ mod tests { // Special string view case for checking unlined output(len > 12) test_function!( RtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabetalphabetxxx" )))), @@ -227,7 +227,7 @@ mod tests { // String cases test_function!( RtrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( String::from("alphabet ") ))),], Ok(Some("alphabet")), @@ -237,7 +237,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( String::from(" alphabet ") ))),], Ok(Some(" alphabet")), @@ -247,7 +247,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("t ")))), ], @@ -258,7 +258,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabe")))), ], @@ -269,7 +269,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::Scalar(ScalarValue::Utf8(None)), ], diff --git a/datafusion/functions/src/string/split_part.rs b/datafusion/functions/src/string/split_part.rs index e55325db756d..40bdd3ad01b2 100644 --- a/datafusion/functions/src/string/split_part.rs +++ b/datafusion/functions/src/string/split_part.rs @@ -270,7 +270,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( SplitPartFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( "abc~@~def~@~ghi" )))), @@ -284,7 +284,7 @@ mod tests { ); test_function!( SplitPartFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( "abc~@~def~@~ghi" )))), @@ -298,7 +298,7 @@ mod tests { ); test_function!( SplitPartFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( "abc~@~def~@~ghi" )))), @@ -312,7 +312,7 @@ mod tests { ); test_function!( SplitPartFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( "abc~@~def~@~ghi" )))), diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index 36dbd8167b4e..7354fda09584 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -159,7 +159,7 @@ mod tests { for (args, expected) in test_cases { test_function!( StartsWithFunc::new(), - &args, + args, Ok(expected), bool, Boolean, diff --git a/datafusion/functions/src/unicode/character_length.rs b/datafusion/functions/src/unicode/character_length.rs index 726822a8f887..ad51a8ef72fb 100644 --- a/datafusion/functions/src/unicode/character_length.rs +++ b/datafusion/functions/src/unicode/character_length.rs @@ -18,7 +18,7 @@ use crate::strings::StringArrayType; use crate::utils::{make_scalar_function, utf8_to_int_type}; use arrow::array::{ - Array, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, PrimitiveArray, + Array, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, PrimitiveBuilder, }; use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; use datafusion_common::Result; @@ -136,31 +136,52 @@ fn character_length(args: &[ArrayRef]) -> Result { } } -fn character_length_general<'a, T: ArrowPrimitiveType, V: StringArrayType<'a>>( - array: V, -) -> Result +fn character_length_general<'a, T, V>(array: V) -> Result where + T: ArrowPrimitiveType, T::Native: OffsetSizeTrait, + V: StringArrayType<'a>, { + let mut builder = PrimitiveBuilder::::with_capacity(array.len()); + // String characters are variable length encoded in UTF-8, counting the // number of chars requires expensive decoding, however checking if the // string is ASCII only is relatively cheap. // If strings are ASCII only, count bytes instead. let is_array_ascii_only = array.is_ascii(); - let iter = array.iter(); - let result = iter - .map(|string| { - string.map(|string: &str| { - if is_array_ascii_only { - T::Native::usize_as(string.len()) - } else { - T::Native::usize_as(string.chars().count()) - } - }) - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) + if array.null_count() == 0 { + if is_array_ascii_only { + for i in 0..array.len() { + let value = array.value(i); + builder.append_value(T::Native::usize_as(value.len())); + } + } else { + for i in 0..array.len() { + let value = array.value(i); + builder.append_value(T::Native::usize_as(value.chars().count())); + } + } + } else if is_array_ascii_only { + for i in 0..array.len() { + if array.is_null(i) { + builder.append_null(); + } else { + let value = array.value(i); + builder.append_value(T::Native::usize_as(value.len())); + } + } + } else { + for i in 0..array.len() { + if array.is_null(i) { + builder.append_null(); + } else { + let value = array.value(i); + builder.append_value(T::Native::usize_as(value.chars().count())); + } + } + } + + Ok(Arc::new(builder.finish()) as ArrayRef) } #[cfg(test)] @@ -176,7 +197,7 @@ mod tests { ($INPUT:expr, $EXPECTED:expr) => { test_function!( CharacterLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], $EXPECTED, i32, Int32, @@ -185,7 +206,7 @@ mod tests { test_function!( CharacterLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], + vec![ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], $EXPECTED, i64, Int64, @@ -194,7 +215,7 @@ mod tests { test_function!( CharacterLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], $EXPECTED, i32, Int32, diff --git a/datafusion/functions/src/unicode/left.rs b/datafusion/functions/src/unicode/left.rs index ef2802340b14..e583523d84a0 100644 --- a/datafusion/functions/src/unicode/left.rs +++ b/datafusion/functions/src/unicode/left.rs @@ -188,7 +188,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(2i64)), ], @@ -199,7 +199,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(200i64)), ], @@ -210,7 +210,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(-2i64)), ], @@ -221,7 +221,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(-200i64)), ], @@ -232,7 +232,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(0i64)), ], @@ -243,7 +243,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::from(2i64)), ], @@ -254,7 +254,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::Int64(None)), ], @@ -265,7 +265,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("joséésoj")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ], @@ -276,7 +276,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("joséésoj")), ColumnarValue::Scalar(ScalarValue::from(-3i64)), ], diff --git a/datafusion/functions/src/unicode/lpad.rs b/datafusion/functions/src/unicode/lpad.rs index 6c8a4ec97bb0..f1750d2277ca 100644 --- a/datafusion/functions/src/unicode/lpad.rs +++ b/datafusion/functions/src/unicode/lpad.rs @@ -298,7 +298,7 @@ mod tests { ($INPUT:expr, $LENGTH:expr, $EXPECTED:expr) => { test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), ColumnarValue::Scalar($LENGTH) ], @@ -310,7 +310,7 @@ mod tests { test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), ColumnarValue::Scalar($LENGTH) ], @@ -322,7 +322,7 @@ mod tests { test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), ColumnarValue::Scalar($LENGTH) ], @@ -337,7 +337,7 @@ mod tests { // utf8, utf8 test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), ColumnarValue::Scalar($LENGTH), ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE)) @@ -350,7 +350,7 @@ mod tests { // utf8, largeutf8 test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), ColumnarValue::Scalar($LENGTH), ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE)) @@ -363,7 +363,7 @@ mod tests { // utf8, utf8view test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), ColumnarValue::Scalar($LENGTH), ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE)) @@ -377,7 +377,7 @@ mod tests { // largeutf8, utf8 test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), ColumnarValue::Scalar($LENGTH), ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE)) @@ -390,7 +390,7 @@ mod tests { // largeutf8, largeutf8 test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), ColumnarValue::Scalar($LENGTH), ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE)) @@ -403,7 +403,7 @@ mod tests { // largeutf8, utf8view test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), ColumnarValue::Scalar($LENGTH), ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE)) @@ -417,7 +417,7 @@ mod tests { // utf8view, utf8 test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), ColumnarValue::Scalar($LENGTH), ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE)) @@ -430,7 +430,7 @@ mod tests { // utf8view, largeutf8 test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), ColumnarValue::Scalar($LENGTH), ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE)) @@ -443,7 +443,7 @@ mod tests { // utf8view, utf8view test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), ColumnarValue::Scalar($LENGTH), ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE)) diff --git a/datafusion/functions/src/unicode/mod.rs b/datafusion/functions/src/unicode/mod.rs index 40915bc9efde..f31ece9196d8 100644 --- a/datafusion/functions/src/unicode/mod.rs +++ b/datafusion/functions/src/unicode/mod.rs @@ -34,22 +34,18 @@ pub mod substrindex; pub mod translate; // create UDFs -make_udf_function!( - character_length::CharacterLengthFunc, - CHARACTER_LENGTH, - character_length -); -make_udf_function!(find_in_set::FindInSetFunc, FIND_IN_SET, find_in_set); -make_udf_function!(left::LeftFunc, LEFT, left); -make_udf_function!(lpad::LPadFunc, LPAD, lpad); -make_udf_function!(right::RightFunc, RIGHT, right); -make_udf_function!(reverse::ReverseFunc, REVERSE, reverse); -make_udf_function!(rpad::RPadFunc, RPAD, rpad); -make_udf_function!(strpos::StrposFunc, STRPOS, strpos); -make_udf_function!(substr::SubstrFunc, SUBSTR, substr); -make_udf_function!(substr::SubstrFunc, SUBSTRING, substring); -make_udf_function!(substrindex::SubstrIndexFunc, SUBSTR_INDEX, substr_index); -make_udf_function!(translate::TranslateFunc, TRANSLATE, translate); +make_udf_function!(character_length::CharacterLengthFunc, character_length); +make_udf_function!(find_in_set::FindInSetFunc, find_in_set); +make_udf_function!(left::LeftFunc, left); +make_udf_function!(lpad::LPadFunc, lpad); +make_udf_function!(right::RightFunc, right); +make_udf_function!(reverse::ReverseFunc, reverse); +make_udf_function!(rpad::RPadFunc, rpad); +make_udf_function!(strpos::StrposFunc, strpos); +make_udf_function!(substr::SubstrFunc, substr); +make_udf_function!(substr::SubstrFunc, substring); +make_udf_function!(substrindex::SubstrIndexFunc, substr_index); +make_udf_function!(translate::TranslateFunc, translate); pub mod expr_fn { use datafusion_expr::Expr; diff --git a/datafusion/functions/src/unicode/reverse.rs b/datafusion/functions/src/unicode/reverse.rs index 38c1f23cbd5a..8e3cf8845f98 100644 --- a/datafusion/functions/src/unicode/reverse.rs +++ b/datafusion/functions/src/unicode/reverse.rs @@ -151,7 +151,7 @@ mod tests { ($INPUT:expr, $EXPECTED:expr) => { test_function!( ReverseFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], $EXPECTED, &str, Utf8, @@ -160,7 +160,7 @@ mod tests { test_function!( ReverseFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], + vec![ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], $EXPECTED, &str, LargeUtf8, @@ -169,7 +169,7 @@ mod tests { test_function!( ReverseFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], $EXPECTED, &str, Utf8, diff --git a/datafusion/functions/src/unicode/right.rs b/datafusion/functions/src/unicode/right.rs index 1586e23eb8aa..4e414fbae5cb 100644 --- a/datafusion/functions/src/unicode/right.rs +++ b/datafusion/functions/src/unicode/right.rs @@ -192,7 +192,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(2i64)), ], @@ -203,7 +203,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(200i64)), ], @@ -214,7 +214,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(-2i64)), ], @@ -225,7 +225,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(-200i64)), ], @@ -236,7 +236,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(0i64)), ], @@ -247,7 +247,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::from(2i64)), ], @@ -258,7 +258,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::Int64(None)), ], @@ -269,7 +269,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("joséésoj")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ], @@ -280,7 +280,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("joséésoj")), ColumnarValue::Scalar(ScalarValue::from(-3i64)), ], diff --git a/datafusion/functions/src/unicode/rpad.rs b/datafusion/functions/src/unicode/rpad.rs index 6e6bde3e177c..d5a0079c72aa 100644 --- a/datafusion/functions/src/unicode/rpad.rs +++ b/datafusion/functions/src/unicode/rpad.rs @@ -319,7 +319,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("josé")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ], @@ -330,7 +330,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("hi")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ], @@ -341,7 +341,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("hi")), ColumnarValue::Scalar(ScalarValue::from(0i64)), ], @@ -352,7 +352,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("hi")), ColumnarValue::Scalar(ScalarValue::Int64(None)), ], @@ -363,7 +363,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::from(5i64)), ], @@ -374,7 +374,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("hi")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ColumnarValue::Scalar(ScalarValue::from("xy")), @@ -386,7 +386,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("hi")), ColumnarValue::Scalar(ScalarValue::from(21i64)), ColumnarValue::Scalar(ScalarValue::from("abcdef")), @@ -398,7 +398,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("hi")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ColumnarValue::Scalar(ScalarValue::from(" ")), @@ -410,7 +410,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("hi")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ColumnarValue::Scalar(ScalarValue::from("")), @@ -422,7 +422,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::from(5i64)), ColumnarValue::Scalar(ScalarValue::from("xy")), @@ -434,7 +434,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("hi")), ColumnarValue::Scalar(ScalarValue::Int64(None)), ColumnarValue::Scalar(ScalarValue::from("xy")), @@ -446,7 +446,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("hi")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ColumnarValue::Scalar(ScalarValue::Utf8(None)), @@ -458,7 +458,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("josé")), ColumnarValue::Scalar(ScalarValue::from(10i64)), ColumnarValue::Scalar(ScalarValue::from("xy")), @@ -470,7 +470,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("josé")), ColumnarValue::Scalar(ScalarValue::from(10i64)), ColumnarValue::Scalar(ScalarValue::from("éñ")), diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index 5d1986e44c92..569af87a4b50 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -218,7 +218,7 @@ mod tests { ($lhs:literal, $rhs:literal -> $result:literal; $t1:ident $t2:ident $t3:ident $t4:ident $t5:ident) => { test_function!( StrposFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::$t1(Some($lhs.to_owned()))), ColumnarValue::Scalar(ScalarValue::$t2(Some($rhs.to_owned()))), ], diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index 0ac050c707bf..687f77dbef5b 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -21,8 +21,8 @@ use std::sync::{Arc, OnceLock}; use crate::strings::{make_and_append_view, StringArrayType}; use crate::utils::{make_scalar_function, utf8_to_str_type}; use arrow::array::{ - Array, ArrayIter, ArrayRef, AsArray, GenericStringArray, Int64Array, OffsetSizeTrait, - StringViewArray, + Array, ArrayIter, ArrayRef, AsArray, GenericStringBuilder, Int64Array, + OffsetSizeTrait, StringViewArray, }; use arrow::datatypes::DataType; use arrow_buffer::{NullBufferBuilder, ScalarBuffer}; @@ -448,10 +448,9 @@ where match args.len() { 1 => { let iter = ArrayIter::new(string_array); - - let result = iter - .zip(start_array.iter()) - .map(|(string, start)| match (string, start) { + let mut result_builder = GenericStringBuilder::::new(); + for (string, start) in iter.zip(start_array.iter()) { + match (string, start) { (Some(string), Some(start)) => { let (start, end) = get_true_start_end( string, @@ -460,47 +459,51 @@ where enable_ascii_fast_path, ); // start, end is byte-based let substr = &string[start..end]; - Some(substr.to_string()) + result_builder.append_value(substr); } - _ => None, - }) - .collect::>(); - Ok(Arc::new(result) as ArrayRef) + _ => { + result_builder.append_null(); + } + } + } + Ok(Arc::new(result_builder.finish()) as ArrayRef) } 2 => { let iter = ArrayIter::new(string_array); let count_array = count_array_opt.unwrap(); + let mut result_builder = GenericStringBuilder::::new(); - let result = iter - .zip(start_array.iter()) - .zip(count_array.iter()) - .map(|((string, start), count)| { - match (string, start, count) { - (Some(string), Some(start), Some(count)) => { - if count < 0 { - exec_err!( + for ((string, start), count) in + iter.zip(start_array.iter()).zip(count_array.iter()) + { + match (string, start, count) { + (Some(string), Some(start), Some(count)) => { + if count < 0 { + return exec_err!( "negative substring length not allowed: substr(, {start}, {count})" - ) - } else { - if start == i64::MIN { - return exec_err!("negative overflow when calculating skip value"); - } - let (start, end) = get_true_start_end( - string, - start, - Some(count as u64), - enable_ascii_fast_path, - ); // start, end is byte-based - let substr = &string[start..end]; - Ok(Some(substr.to_string())) + ); + } else { + if start == i64::MIN { + return exec_err!( + "negative overflow when calculating skip value" + ); } + let (start, end) = get_true_start_end( + string, + start, + Some(count as u64), + enable_ascii_fast_path, + ); // start, end is byte-based + let substr = &string[start..end]; + result_builder.append_value(substr); } - _ => Ok(None), } - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) + _ => { + result_builder.append_null(); + } + } + } + Ok(Arc::new(result_builder.finish()) as ArrayRef) } other => { exec_err!("substr was called with {other} arguments. It requires 2 or 3.") @@ -523,7 +526,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(None)), ColumnarValue::Scalar(ScalarValue::from(1i64)), ], @@ -534,7 +537,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -547,7 +550,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "this és longer than 12B" )))), @@ -561,7 +564,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "this is longer than 12B" )))), @@ -574,7 +577,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "joséésoj" )))), @@ -587,7 +590,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -601,7 +604,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -615,7 +618,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(0i64)), ], @@ -626,7 +629,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("joséésoj")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ], @@ -637,7 +640,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("joséésoj")), ColumnarValue::Scalar(ScalarValue::from(-5i64)), ], @@ -648,7 +651,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(1i64)), ], @@ -659,7 +662,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(2i64)), ], @@ -670,7 +673,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(3i64)), ], @@ -681,7 +684,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(-3i64)), ], @@ -692,7 +695,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(30i64)), ], @@ -703,7 +706,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::Int64(None)), ], @@ -714,7 +717,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(3i64)), ColumnarValue::Scalar(ScalarValue::from(2i64)), @@ -726,7 +729,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(3i64)), ColumnarValue::Scalar(ScalarValue::from(20i64)), @@ -738,7 +741,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(0i64)), ColumnarValue::Scalar(ScalarValue::from(5i64)), @@ -751,7 +754,7 @@ mod tests { // starting from 5 (10 + -5) test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(-5i64)), ColumnarValue::Scalar(ScalarValue::from(10i64)), @@ -764,7 +767,7 @@ mod tests { // starting from -1 (4 + -5) test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(-5i64)), ColumnarValue::Scalar(ScalarValue::from(4i64)), @@ -777,7 +780,7 @@ mod tests { // starting from 0 (5 + -5) test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(-5i64)), ColumnarValue::Scalar(ScalarValue::from(5i64)), @@ -789,7 +792,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::Int64(None)), ColumnarValue::Scalar(ScalarValue::from(20i64)), @@ -801,7 +804,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(3i64)), ColumnarValue::Scalar(ScalarValue::Int64(None)), @@ -813,7 +816,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(1i64)), ColumnarValue::Scalar(ScalarValue::from(-1i64)), @@ -825,7 +828,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("joséésoj")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ColumnarValue::Scalar(ScalarValue::from(2i64)), @@ -851,7 +854,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abc")), ColumnarValue::Scalar(ScalarValue::from(-9223372036854775808i64)), ], @@ -862,7 +865,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("overflow")), ColumnarValue::Scalar(ScalarValue::from(-9223372036854775808i64)), ColumnarValue::Scalar(ScalarValue::from(1i64)), diff --git a/datafusion/functions/src/unicode/substrindex.rs b/datafusion/functions/src/unicode/substrindex.rs index 825666b0455e..61cd989bb964 100644 --- a/datafusion/functions/src/unicode/substrindex.rs +++ b/datafusion/functions/src/unicode/substrindex.rs @@ -253,7 +253,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( SubstrIndexFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), ColumnarValue::Scalar(ScalarValue::from(".")), ColumnarValue::Scalar(ScalarValue::from(1i64)), @@ -265,7 +265,7 @@ mod tests { ); test_function!( SubstrIndexFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), ColumnarValue::Scalar(ScalarValue::from(".")), ColumnarValue::Scalar(ScalarValue::from(2i64)), @@ -277,7 +277,7 @@ mod tests { ); test_function!( SubstrIndexFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), ColumnarValue::Scalar(ScalarValue::from(".")), ColumnarValue::Scalar(ScalarValue::from(-2i64)), @@ -289,7 +289,7 @@ mod tests { ); test_function!( SubstrIndexFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), ColumnarValue::Scalar(ScalarValue::from(".")), ColumnarValue::Scalar(ScalarValue::from(-1i64)), @@ -301,7 +301,7 @@ mod tests { ); test_function!( SubstrIndexFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), ColumnarValue::Scalar(ScalarValue::from(".")), ColumnarValue::Scalar(ScalarValue::from(0i64)), @@ -313,7 +313,7 @@ mod tests { ); test_function!( SubstrIndexFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("")), ColumnarValue::Scalar(ScalarValue::from(".")), ColumnarValue::Scalar(ScalarValue::from(1i64)), @@ -325,7 +325,7 @@ mod tests { ); test_function!( SubstrIndexFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), ColumnarValue::Scalar(ScalarValue::from("")), ColumnarValue::Scalar(ScalarValue::from(1i64)), diff --git a/datafusion/functions/src/unicode/translate.rs b/datafusion/functions/src/unicode/translate.rs index 780603777133..9257b0b04e61 100644 --- a/datafusion/functions/src/unicode/translate.rs +++ b/datafusion/functions/src/unicode/translate.rs @@ -201,7 +201,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( TranslateFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("12345")), ColumnarValue::Scalar(ScalarValue::from("143")), ColumnarValue::Scalar(ScalarValue::from("ax")) @@ -213,7 +213,7 @@ mod tests { ); test_function!( TranslateFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::from("143")), ColumnarValue::Scalar(ScalarValue::from("ax")) @@ -225,7 +225,7 @@ mod tests { ); test_function!( TranslateFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("12345")), ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::from("ax")) @@ -237,7 +237,7 @@ mod tests { ); test_function!( TranslateFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("12345")), ColumnarValue::Scalar(ScalarValue::from("143")), ColumnarValue::Scalar(ScalarValue::Utf8(None)) @@ -249,7 +249,7 @@ mod tests { ); test_function!( TranslateFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("é2íñ5")), ColumnarValue::Scalar(ScalarValue::from("éñí")), ColumnarValue::Scalar(ScalarValue::from("óü")), @@ -262,7 +262,7 @@ mod tests { #[cfg(not(feature = "unicode_expressions"))] test_function!( TranslateFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("12345")), ColumnarValue::Scalar(ScalarValue::from("143")), ColumnarValue::Scalar(ScalarValue::from("ax")), diff --git a/datafusion/macros/Cargo.toml b/datafusion/macros/Cargo.toml index c5ac9d08dffa..0c4fdb10a33c 100644 --- a/datafusion/macros/Cargo.toml +++ b/datafusion/macros/Cargo.toml @@ -32,11 +32,10 @@ workspace = true [lib] name = "datafusion_macros" -path = "src/lib.rs" +# lib.rs to be re-added in the future +path = "src/user_doc.rs" proc-macro = true [dependencies] -datafusion-doc = { workspace = true } -proc-macro2 = "1.0" quote = "1.0.37" syn = { version = "2.0.79", features = ["full"] } diff --git a/datafusion/macros/src/lib.rs b/datafusion/macros/src/user_doc.rs similarity index 75% rename from datafusion/macros/src/lib.rs rename to datafusion/macros/src/user_doc.rs index 54b688ac2a49..441b3db2a133 100644 --- a/datafusion/macros/src/lib.rs +++ b/datafusion/macros/src/user_doc.rs @@ -26,16 +26,19 @@ use syn::{parse_macro_input, DeriveInput, LitStr}; /// declared on `AggregateUDF`, `WindowUDFImpl`, `ScalarUDFImpl` traits. /// /// Example: +/// ```ignore /// #[user_doc( /// doc_section(include = "true", label = "Time and Date Functions"), -/// description = r"Converts a value to a date (`YYYY-MM-DD`)." -/// sql_example = "```sql\n\ -/// \> select to_date('2023-01-31');\n\ -/// +-----------------------------+\n\ -/// | to_date(Utf8(\"2023-01-31\")) |\n\ -/// +-----------------------------+\n\ -/// | 2023-01-31 |\n\ -/// +-----------------------------+\n\"), +/// description = r"Converts a value to a date (`YYYY-MM-DD`).", +/// syntax_example = "to_date('2017-05-31', '%Y-%m-%d')", +/// sql_example = r#"```sql +/// > select to_date('2023-01-31'); +/// +-----------------------------+ +/// | to_date(Utf8(\"2023-01-31\")) | +/// +-----------------------------+ +/// | 2023-01-31 | +/// +-----------------------------+ +/// ```"#, /// standard_argument(name = "expression", prefix = "String"), /// argument( /// name = "format_n", @@ -48,40 +51,50 @@ use syn::{parse_macro_input, DeriveInput, LitStr}; /// pub struct ToDateFunc { /// signature: Signature, /// } -/// +/// ``` /// will generate the following code /// -/// #[derive(Debug)] pub struct ToDateFunc { signature : Signature, } -/// use datafusion_doc :: DocSection; -/// use datafusion_doc :: DocumentationBuilder; -/// static DOCUMENTATION : OnceLock < Documentation > = OnceLock :: new(); -/// impl ToDateFunc -/// { -/// fn doc(& self) -> Option < & Documentation > -/// { -/// Some(DOCUMENTATION.get_or_init(|| -/// { -/// Documentation :: -/// builder(DocSection -/// { -/// include : true, label : "Time and Date Functions", description -/// : None -/// }, r"Converts a value to a date (`YYYY-MM-DD`).") -/// .with_syntax_example("to_date('2017-05-31', '%Y-%m-%d')".to_string(),"```sql\n\ -/// \> select to_date('2023-01-31');\n\ -/// +-----------------------------+\n\ -/// | to_date(Utf8(\"2023-01-31\")) |\n\ -/// +-----------------------------+\n\ -/// | 2023-01-31 |\n\ -/// +-----------------------------+\n\) -/// .with_standard_argument("expression", "String".into()) -/// .with_argument("format_n", -/// r"Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order -/// they appear with the first successful one being returned. If none of the formats successfully parse the expression -/// an error will be returned.").build() -/// })) +/// ```ignore +/// pub struct ToDateFunc { +/// signature: Signature, +/// } +/// impl ToDateFunc { +/// fn doc(&self) -> Option<&datafusion_doc::Documentation> { +/// static DOCUMENTATION: std::sync::LazyLock< +/// datafusion_doc::Documentation, +/// > = std::sync::LazyLock::new(|| { +/// datafusion_doc::Documentation::builder( +/// datafusion_doc::DocSection { +/// include: true, +/// label: "Time and Date Functions", +/// description: None, +/// }, +/// r"Converts a value to a date (`YYYY-MM-DD`).".to_string(), +/// "to_date('2017-05-31', '%Y-%m-%d')".to_string(), +/// ) +/// .with_sql_example( +/// r#"```sql +/// > select to_date('2023-01-31'); +/// +-----------------------------+ +/// | to_date(Utf8(\"2023-01-31\")) | +/// +-----------------------------+ +/// | 2023-01-31 | +/// +-----------------------------+ +/// ```"#, +/// ) +/// .with_standard_argument("expression", "String".into()) +/// .with_argument( +/// "format_n", +/// r"Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order +/// they appear with the first successful one being returned. If none of the formats successfully parse the expression +/// an error will be returned.", +/// ) +/// .build() +/// }); +/// Some(&DOCUMENTATION) /// } /// } +/// ``` #[proc_macro_attribute] pub fn user_doc(args: TokenStream, input: TokenStream) -> TokenStream { let mut doc_section_include: Option = None; @@ -235,19 +248,14 @@ pub fn user_doc(args: TokenStream, input: TokenStream) -> TokenStream { } }); - let lock_name: proc_macro2::TokenStream = - format!("{name}_DOCUMENTATION").parse().unwrap(); - let generated = quote! { #input - static #lock_name: OnceLock = OnceLock::new(); - impl #name { - - fn doc(&self) -> Option<&Documentation> { - Some(#lock_name.get_or_init(|| { - Documentation::builder(DocSection { include: #doc_section_include, label: #doc_section_lbl, description: #doc_section_description }, + fn doc(&self) -> Option<&datafusion_doc::Documentation> { + static DOCUMENTATION: std::sync::LazyLock = + std::sync::LazyLock::new(|| { + datafusion_doc::Documentation::builder(datafusion_doc::DocSection { include: #doc_section_include, label: #doc_section_lbl, description: #doc_section_description }, #description.to_string(), #syntax_example.to_string()) #sql_example #alt_syntax_example @@ -255,7 +263,8 @@ pub fn user_doc(args: TokenStream, input: TokenStream) -> TokenStream { #(#udf_args)* #(#related_udfs)* .build() - })) + }); + Some(&DOCUMENTATION) } } }; diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index c0f17de6c5c5..9979df689b0a 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -37,7 +37,6 @@ path = "src/lib.rs" [dependencies] arrow = { workspace = true } -async-trait = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } @@ -50,6 +49,7 @@ regex = { workspace = true } regex-syntax = "0.8.0" [dev-dependencies] +async-trait = { workspace = true } ctor = { workspace = true } datafusion-functions-aggregate = { workspace = true } datafusion-functions-window-common = { workspace = true } diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 93bdcdef8ea0..c2e892d63da0 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -214,6 +214,21 @@ pub fn with_new_children_if_necessary( } } +#[deprecated(since = "44.0.0")] +pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any { + if any.is::>() { + any.downcast_ref::>() + .unwrap() + .as_any() + } else if any.is::>() { + any.downcast_ref::>() + .unwrap() + .as_any() + } else { + any + } +} + /// Returns [`Display`] able a list of [`PhysicalExpr`] /// /// Example output: `[a + 1, b]` diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index d06a495d970a..cc26d12fb029 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -17,8 +17,8 @@ use super::{add_offset_to_expr, collapse_lex_req, ProjectionMapping}; use crate::{ - expressions::Column, physical_exprs_contains, LexOrdering, LexRequirement, - PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, PhysicalSortRequirement, + expressions::Column, LexOrdering, LexRequirement, PhysicalExpr, PhysicalExprRef, + PhysicalSortExpr, PhysicalSortRequirement, }; use std::fmt::Display; use std::sync::Arc; @@ -27,7 +27,7 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::JoinType; use datafusion_physical_expr_common::physical_expr::format_physical_expr_list; -use indexmap::IndexSet; +use indexmap::{IndexMap, IndexSet}; /// A structure representing a expression known to be constant in a physical execution plan. /// @@ -546,28 +546,20 @@ impl EquivalenceGroup { .collect::>(); (new_class.len() > 1).then_some(EquivalenceClass::new(new_class)) }); - // TODO: Convert the algorithm below to a version that uses `HashMap`. - // once `Arc` can be stored in `HashMap`. - // See issue: https://github.com/apache/datafusion/issues/8027 - let mut new_classes = vec![]; - for (source, target) in mapping.iter() { - if new_classes.is_empty() { - new_classes.push((source, vec![Arc::clone(target)])); - } - if let Some((_, values)) = - new_classes.iter_mut().find(|(key, _)| *key == source) - { - if !physical_exprs_contains(values, target) { - values.push(Arc::clone(target)); - } - } - } + // the key is the source expression and the value is the EquivalenceClass that contains the target expression of the source expression. + let mut new_classes: IndexMap, EquivalenceClass> = + IndexMap::new(); + mapping.iter().for_each(|(source, target)| { + new_classes + .entry(Arc::clone(source)) + .or_insert_with(EquivalenceClass::new_empty) + .push(Arc::clone(target)); + }); // Only add equivalence classes with at least two members as singleton // equivalence classes are meaningless. let new_classes = new_classes .into_iter() - .filter_map(|(_, values)| (values.len() > 1).then_some(values)) - .map(EquivalenceClass::new); + .filter_map(|(_, cls)| (cls.len() > 1).then_some(cls)); let classes = projected_classes.chain(new_classes).collect(); Self::new(classes) diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 45f77325eea3..e312d5de59fb 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -134,20 +134,20 @@ impl PhysicalExpr for ScalarFunctionExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { - let inputs = self + let args = self .args .iter() .map(|e| e.evaluate(batch)) .collect::>>()?; - let input_empty = inputs.is_empty(); - let input_all_scalar = inputs + let input_empty = args.is_empty(); + let input_all_scalar = args .iter() .all(|arg| matches!(arg, ColumnarValue::Scalar(_))); // evaluate the function let output = self.fun.invoke_with_args(ScalarFunctionArgs { - args: inputs.as_slice(), + args, number_rows: batch.num_rows(), return_type: &self.return_type, })?; diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index bb0e21fdfd15..83dc9549531d 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -60,16 +60,16 @@ hashbrown = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true, features = ["use_std"] } log = { workspace = true } -once_cell = "1.18.0" parking_lot = { workspace = true } pin-project-lite = "^0.2.7" -rand = { workspace = true } tokio = { workspace = true } [dev-dependencies] criterion = { version = "0.5", features = ["async_futures"] } datafusion-functions-aggregate = { workspace = true } datafusion-functions-window = { workspace = true } +once_cell = "1.18.0" +rand = { workspace = true } rstest = { workspace = true } rstest_reuse = "0.7.0" tokio = { workspace = true, features = [ diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 3080ac0757b5..b70eeb313b2a 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -188,18 +188,21 @@ async fn load_left_input( // Load all batches and count the rows let (batches, _metrics, reservation) = stream - .try_fold((Vec::new(), metrics, reservation), |mut acc, batch| async { - let batch_size = batch.get_array_memory_size(); - // Reserve memory for incoming batch - acc.2.try_grow(batch_size)?; - // Update metrics - acc.1.build_mem_used.add(batch_size); - acc.1.build_input_batches.add(1); - acc.1.build_input_rows.add(batch.num_rows()); - // Push batch to output - acc.0.push(batch); - Ok(acc) - }) + .try_fold( + (Vec::new(), metrics, reservation), + |(mut batches, metrics, mut reservation), batch| async { + let batch_size = batch.get_array_memory_size(); + // Reserve memory for incoming batch + reservation.try_grow(batch_size)?; + // Update metrics + metrics.build_mem_used.add(batch_size); + metrics.build_input_batches.add(1); + metrics.build_input_rows.add(batch.num_rows()); + // Push batch to output + batches.push(batch); + Ok((batches, metrics, reservation)) + }, + ) .await?; let merged_batch = concat_batches(&left_schema, &batches)?; diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index bd54c74d3f11..7aadf5d198a0 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -90,9 +90,6 @@ struct JoinLeftData { /// Counter of running probe-threads, potentially /// able to update `visited_indices_bitmap` probe_threads_counter: AtomicUsize, - /// Memory reservation that tracks memory used by `hash_map` hash table - /// `batch`. Cleared on drop. - _reservation: MemoryReservation, } impl JoinLeftData { @@ -102,14 +99,12 @@ impl JoinLeftData { batch: RecordBatch, visited_indices_bitmap: SharedBitmapBuilder, probe_threads_counter: AtomicUsize, - reservation: MemoryReservation, ) -> Self { Self { hash_map, batch, visited_indices_bitmap, probe_threads_counter, - _reservation: reservation, } } @@ -902,7 +897,6 @@ async fn collect_left_input( single_batch, Mutex::new(visited_indices_bitmap), AtomicUsize::new(probe_threads_count), - reservation, ); Ok(data) @@ -1019,6 +1013,7 @@ impl BuildSide { /// └─ ProcessProbeBatch /// /// ``` +#[derive(Debug, Clone)] enum HashJoinStreamState { /// Initial state for HashJoinStream indicating that build-side data not collected yet WaitBuildSide, @@ -1044,6 +1039,7 @@ impl HashJoinStreamState { } /// Container for HashJoinStreamState::ProcessProbeBatch related data +#[derive(Debug, Clone)] struct ProcessProbeBatchState { /// Current probe-side batch batch: RecordBatch, diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 66d225014d8f..a32f3ea2d9b5 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -46,7 +46,6 @@ use arrow::array::{BooleanBufferBuilder, UInt32Array, UInt64Array}; use arrow::compute::concat_batches; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use arrow::util::bit_util; use datafusion_common::{ exec_datafusion_err, internal_err, JoinSide, Result, Statistics, }; @@ -445,17 +444,17 @@ async fn collect_left_input( let (batches, metrics, mut reservation) = stream .try_fold( (Vec::new(), join_metrics, reservation), - |mut acc, batch| async { + |(mut batches, metrics, mut reservation), batch| async { let batch_size = batch.get_array_memory_size(); // Reserve memory for incoming batch - acc.2.try_grow(batch_size)?; + reservation.try_grow(batch_size)?; // Update metrics - acc.1.build_mem_used.add(batch_size); - acc.1.build_input_batches.add(1); - acc.1.build_input_rows.add(batch.num_rows()); + metrics.build_mem_used.add(batch_size); + metrics.build_input_batches.add(1); + metrics.build_input_rows.add(batch.num_rows()); // Push batch to output - acc.0.push(batch); - Ok(acc) + batches.push(batch); + Ok((batches, metrics, reservation)) }, ) .await?; @@ -464,14 +463,13 @@ async fn collect_left_input( // Reserve memory for visited_left_side bitmap if required by join type let visited_left_side = if with_visited_left_side { - // TODO: Replace `ceil` wrapper with stable `div_cell` after - // https://github.com/rust-lang/rust/issues/88581 - let buffer_size = bit_util::ceil(merged_batch.num_rows(), 8); + let n_rows = merged_batch.num_rows(); + let buffer_size = n_rows.div_ceil(8); reservation.try_grow(buffer_size)?; metrics.build_mem_used.add(buffer_size); - let mut buffer = BooleanBufferBuilder::new(merged_batch.num_rows()); - buffer.append_n(merged_batch.num_rows(), false); + let mut buffer = BooleanBufferBuilder::new(n_rows); + buffer.append_n(n_rows, false); buffer } else { BooleanBufferBuilder::new(0) diff --git a/datafusion/physical-plan/src/stream.rs b/datafusion/physical-plan/src/stream.rs index b3054299b7f7..a05b46d22840 100644 --- a/datafusion/physical-plan/src/stream.rs +++ b/datafusion/physical-plan/src/stream.rs @@ -337,7 +337,9 @@ impl RecordBatchReceiverStream { pin_project! { /// Combines a [`Stream`] with a [`SchemaRef`] implementing - /// [`RecordBatchStream`] for the combination + /// [`SendableRecordBatchStream`] for the combination + /// + /// See [`Self::new`] for an example pub struct RecordBatchStreamAdapter { schema: SchemaRef, @@ -347,7 +349,28 @@ pin_project! { } impl RecordBatchStreamAdapter { - /// Creates a new [`RecordBatchStreamAdapter`] from the provided schema and stream + /// Creates a new [`RecordBatchStreamAdapter`] from the provided schema and stream. + /// + /// Note to create a [`SendableRecordBatchStream`] you pin the result + /// + /// # Example + /// ``` + /// # use arrow::array::record_batch; + /// # use datafusion_execution::SendableRecordBatchStream; + /// # use datafusion_physical_plan::stream::RecordBatchStreamAdapter; + /// // Create stream of Result + /// let batch = record_batch!( + /// ("a", Int32, [1, 2, 3]), + /// ("b", Float64, [Some(4.0), None, Some(5.0)]) + /// ).expect("created batch"); + /// let schema = batch.schema(); + /// let stream = futures::stream::iter(vec![Ok(batch)]); + /// // Convert the stream to a SendableRecordBatchStream + /// let adapter = RecordBatchStreamAdapter::new(schema, stream); + /// // Now you can use the adapter as a SendableRecordBatchStream + /// let batch_stream: SendableRecordBatchStream = Box::pin(adapter); + /// // ... + /// ``` pub fn new(schema: SchemaRef, stream: S) -> Self { Self { schema, stream } } diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index d640a9b4a54b..dd751e500ad2 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -65,7 +65,7 @@ use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; use futures::stream::Stream; use futures::{ready, StreamExt}; -use hashbrown::raw::RawTable; +use hashbrown::hash_table::HashTable; use indexmap::IndexMap; use log::debug; @@ -443,16 +443,16 @@ pub struct LinearSearch { /// is ordered by a, b and the window expression contains a PARTITION BY b, a /// clause, this attribute stores [1, 0]. ordered_partition_by_indices: Vec, - /// We use this [`RawTable`] to calculate unique partitions for each new + /// We use this [`HashTable`] to calculate unique partitions for each new /// RecordBatch. First entry in the tuple is the hash value, the second /// entry is the unique ID for each partition (increments from 0 to n). - row_map_batch: RawTable<(u64, usize)>, - /// We use this [`RawTable`] to calculate the output columns that we can + row_map_batch: HashTable<(u64, usize)>, + /// We use this [`HashTable`] to calculate the output columns that we can /// produce at each cycle. First entry in the tuple is the hash value, the /// second entry is the unique ID for each partition (increments from 0 to n). /// The third entry stores how many new outputs are calculated for the /// corresponding partition. - row_map_out: RawTable<(u64, usize, usize)>, + row_map_out: HashTable<(u64, usize, usize)>, input_schema: SchemaRef, } @@ -611,8 +611,8 @@ impl LinearSearch { input_buffer_hashes: VecDeque::new(), random_state: Default::default(), ordered_partition_by_indices, - row_map_batch: RawTable::with_capacity(256), - row_map_out: RawTable::with_capacity(256), + row_map_batch: HashTable::with_capacity(256), + row_map_out: HashTable::with_capacity(256), input_schema, } } @@ -632,7 +632,7 @@ impl LinearSearch { // res stores PartitionKey and row indices (indices where these partition occurs in the `batch`) for each partition. let mut result: Vec<(PartitionKey, Vec)> = vec![]; for (hash, row_idx) in batch_hashes.into_iter().zip(0u32..) { - let entry = self.row_map_batch.get_mut(hash, |(_, group_idx)| { + let entry = self.row_map_batch.find_mut(hash, |(_, group_idx)| { // We can safely get the first index of the partition indices // since partition indices has one element during initialization. let row = get_row_at_idx(columns, row_idx as usize).unwrap(); @@ -642,8 +642,11 @@ impl LinearSearch { if let Some((_, group_idx)) = entry { result[*group_idx].1.push(row_idx) } else { - self.row_map_batch - .insert(hash, (hash, result.len()), |(hash, _)| *hash); + self.row_map_batch.insert_unique( + hash, + (hash, result.len()), + |(hash, _)| *hash, + ); let row = get_row_at_idx(columns, row_idx as usize)?; // This is a new partition its only index is row_idx for now. result.push((row, vec![row_idx])); @@ -668,7 +671,7 @@ impl LinearSearch { self.row_map_out.clear(); let mut partition_indices: Vec<(PartitionKey, Vec)> = vec![]; for (hash, row_idx) in self.input_buffer_hashes.iter().zip(0u32..) { - let entry = self.row_map_out.get_mut(*hash, |(_, group_idx, _)| { + let entry = self.row_map_out.find_mut(*hash, |(_, group_idx, _)| { let row = get_row_at_idx(&partition_by_columns, row_idx as usize).unwrap(); row == partition_indices[*group_idx].0 @@ -694,7 +697,7 @@ impl LinearSearch { if min_out == 0 { break; } - self.row_map_out.insert( + self.row_map_out.insert_unique( *hash, (*hash, partition_indices.len(), min_out), |(hash, _, _)| *hash, diff --git a/datafusion/proto-common/Cargo.toml b/datafusion/proto-common/Cargo.toml index 102940716c12..ba99f8639d42 100644 --- a/datafusion/proto-common/Cargo.toml +++ b/datafusion/proto-common/Cargo.toml @@ -41,9 +41,7 @@ json = ["serde", "serde_json", "pbjson"] [dependencies] arrow = { workspace = true } -chrono = { workspace = true } datafusion-common = { workspace = true } -object_store = { workspace = true } pbjson = { workspace = true, optional = true } prost = { workspace = true } serde = { version = "1.0", optional = true } @@ -51,4 +49,3 @@ serde_json = { workspace = true, optional = true } [dev-dependencies] doc-comment = { workspace = true } -tokio = { workspace = true } diff --git a/datafusion/proto-common/gen/Cargo.toml b/datafusion/proto-common/gen/Cargo.toml index da5bc6029ff9..21fc9eccb40c 100644 --- a/datafusion/proto-common/gen/Cargo.toml +++ b/datafusion/proto-common/gen/Cargo.toml @@ -35,4 +35,4 @@ workspace = true [dependencies] # Pin these dependencies so that the generated output is deterministic pbjson-build = "=0.7.0" -prost-build = "=0.13.3" +prost-build = "=0.13.4" diff --git a/datafusion/proto/gen/Cargo.toml b/datafusion/proto/gen/Cargo.toml index 297406becada..dda72d20a159 100644 --- a/datafusion/proto/gen/Cargo.toml +++ b/datafusion/proto/gen/Cargo.toml @@ -35,4 +35,4 @@ workspace = true [dependencies] # Pin these dependencies so that the generated output is deterministic pbjson-build = "=0.7.0" -prost-build = "=0.13.3" +prost-build = "=0.13.4" diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 50636048ebc9..addafeb7629d 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -737,23 +737,18 @@ impl AsLogicalPlan for LogicalPlanNode { builder.build() } LogicalPlanType::Union(union) => { - let mut input_plans: Vec = union - .inputs - .iter() - .map(|i| i.try_into_logical_plan(ctx, extension_codec)) - .collect::>()?; - - if input_plans.len() < 2 { + if union.inputs.len() < 2 { return Err( DataFusionError::Internal(String::from( "Protobuf deserialization error, Union was require at least two input.", ))); } + let (first, rest) = union.inputs.split_first().unwrap(); + let mut builder = LogicalPlanBuilder::from( + first.try_into_logical_plan(ctx, extension_codec)?, + ); - let first = input_plans.pop().ok_or_else(|| DataFusionError::Internal(String::from( - "Protobuf deserialization error, Union was require at least two input.", - )))?; - let mut builder = LogicalPlanBuilder::from(first); - for plan in input_plans { + for i in rest { + let plan = i.try_into_logical_plan(ctx, extension_codec)?; builder = builder.union(plan)?; } builder.build() diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 8c150b20dd00..f793e96f612b 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -25,6 +25,8 @@ use arrow::datatypes::{ }; use arrow::util::pretty::pretty_format_batches; use datafusion::datasource::file_format::json::JsonFormatFactory; +use datafusion::optimizer::eliminate_nested_union::EliminateNestedUnion; +use datafusion::optimizer::Optimizer; use datafusion_common::parsers::CompressionTypeVariant; use prost::Message; use std::any::Any; @@ -2555,3 +2557,35 @@ async fn roundtrip_recursive_query() { format!("{}", pretty_format_batches(&output_round_trip).unwrap()) ); } + +#[tokio::test] +async fn roundtrip_union_query() -> Result<()> { + let query = "SELECT a FROM t1 + UNION (SELECT a from t1 UNION SELECT a from t2)"; + + let ctx = SessionContext::new(); + ctx.register_csv("t1", "tests/testdata/test.csv", CsvReadOptions::default()) + .await?; + ctx.register_csv("t2", "tests/testdata/test.csv", CsvReadOptions::default()) + .await?; + let dataframe = ctx.sql(query).await?; + let plan = dataframe.into_optimized_plan()?; + + let bytes = logical_plan_to_bytes(&plan)?; + + let ctx = SessionContext::new(); + ctx.register_csv("t1", "tests/testdata/test.csv", CsvReadOptions::default()) + .await?; + ctx.register_csv("t2", "tests/testdata/test.csv", CsvReadOptions::default()) + .await?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + // proto deserialisation only supports 2-way union, hence this plan has nested unions + // apply the flatten unions optimizer rule to be able to compare + let optimizer = Optimizer::with_rules(vec![Arc::new(EliminateNestedUnion::new())]); + let unnested = optimizer.optimize(logical_round_trip, &(ctx.state()), |_x, _y| {})?; + assert_eq!( + format!("{}", plan.display_indent_schema()), + format!("{}", unnested.display_indent_schema()), + ); + Ok(()) +} diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 57ac96951f1f..e8ec8d7b7d1c 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -23,7 +23,8 @@ use datafusion_expr::planner::{ use recursive::recursive; use sqlparser::ast::{ BinaryOperator, CastFormat, CastKind, DataType as SQLDataType, DictionaryField, - Expr as SQLExpr, MapEntry, StructField, Subscript, TrimWhereField, Value, + Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias, MapEntry, StructField, Subscript, + TrimWhereField, Value, }; use datafusion_common::{ @@ -50,6 +51,19 @@ mod unary_op; mod value; impl SqlToRel<'_, S> { + pub(crate) fn sql_expr_to_logical_expr_with_alias( + &self, + sql: SQLExprWithAlias, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let mut expr = + self.sql_expr_to_logical_expr(sql.expr, schema, planner_context)?; + if let Some(alias) = sql.alias { + expr = expr.alias(alias.value); + } + Ok(expr) + } pub(crate) fn sql_expr_to_logical_expr( &self, sql: SQLExpr, @@ -131,6 +145,20 @@ impl SqlToRel<'_, S> { ))) } + pub fn sql_to_expr_with_alias( + &self, + sql: SQLExprWithAlias, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let mut expr = + self.sql_expr_to_logical_expr_with_alias(sql, schema, planner_context)?; + expr = self.rewrite_partial_qualifier(expr, schema); + self.validate_schema_satisfies_exprs(schema, &[expr.clone()])?; + let (expr, _) = expr.infer_placeholder_types(schema)?; + Ok(expr) + } + /// Generate a relational expression from a SQL expression pub fn sql_to_expr( &self, @@ -1091,8 +1119,11 @@ mod tests { None } - fn get_aggregate_meta(&self, _name: &str) -> Option> { - None + fn get_aggregate_meta(&self, name: &str) -> Option> { + match name { + "sum" => Some(datafusion_functions_aggregate::sum::sum_udaf()), + _ => None, + } } fn get_variable_type(&self, _variable_names: &[String]) -> Option { @@ -1112,7 +1143,7 @@ mod tests { } fn udaf_names(&self) -> Vec { - Vec::new() + vec!["sum".to_string()] } fn udwf_names(&self) -> Vec { @@ -1167,4 +1198,25 @@ mod tests { test_stack_overflow!(2048); test_stack_overflow!(4096); test_stack_overflow!(8192); + #[test] + fn test_sql_to_expr_with_alias() { + let schema = DFSchema::empty(); + let mut planner_context = PlannerContext::default(); + + let expr_str = "SUM(int_col) as sum_int_col"; + + let dialect = GenericDialect {}; + let mut parser = Parser::new(&dialect).try_with_sql(expr_str).unwrap(); + // from sqlparser + let sql_expr = parser.parse_expr_with_alias().unwrap(); + + let context_provider = TestContextProvider::new(); + let sql_to_rel = SqlToRel::new(&context_provider); + + let expr = sql_to_rel + .sql_expr_to_logical_expr_with_alias(sql_expr, &schema, &mut planner_context) + .unwrap(); + + assert!(matches!(expr, Expr::Alias(_))); + } } diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index bd1ed3145ef5..efec6020641c 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -20,9 +20,10 @@ use std::collections::VecDeque; use std::fmt; +use sqlparser::ast::ExprWithAlias; use sqlparser::{ ast::{ - ColumnDef, ColumnOptionDef, Expr, ObjectName, OrderByExpr, Query, + ColumnDef, ColumnOptionDef, ObjectName, OrderByExpr, Query, Statement as SQLStatement, TableConstraint, Value, }, dialect::{keywords::Keyword, Dialect, GenericDialect}, @@ -328,7 +329,7 @@ impl<'a> DFParser<'a> { pub fn parse_sql_into_expr_with_dialect( sql: &str, dialect: &dyn Dialect, - ) -> Result { + ) -> Result { let mut parser = DFParser::new_with_dialect(sql, dialect)?; parser.parse_expr() } @@ -377,7 +378,7 @@ impl<'a> DFParser<'a> { } } - pub fn parse_expr(&mut self) -> Result { + pub fn parse_expr(&mut self) -> Result { if let Token::Word(w) = self.parser.peek_token().token { match w.keyword { Keyword::CREATE | Keyword::COPY | Keyword::EXPLAIN => { @@ -387,7 +388,7 @@ impl<'a> DFParser<'a> { } } - self.parser.parse_expr() + self.parser.parse_expr_with_alias() } /// Parse a SQL `COPY TO` statement diff --git a/datafusion/sql/src/unparser/ast.rs b/datafusion/sql/src/unparser/ast.rs index cc0812cd71e1..ad0b5f16b283 100644 --- a/datafusion/sql/src/unparser/ast.rs +++ b/datafusion/sql/src/unparser/ast.rs @@ -353,6 +353,7 @@ pub(super) struct RelationBuilder { enum TableFactorBuilder { Table(TableRelationBuilder), Derived(DerivedRelationBuilder), + Unnest(UnnestRelationBuilder), Empty, } @@ -369,6 +370,12 @@ impl RelationBuilder { self.relation = Some(TableFactorBuilder::Derived(value)); self } + + pub fn unnest(&mut self, value: UnnestRelationBuilder) -> &mut Self { + self.relation = Some(TableFactorBuilder::Unnest(value)); + self + } + pub fn empty(&mut self) -> &mut Self { self.relation = Some(TableFactorBuilder::Empty); self @@ -382,6 +389,9 @@ impl RelationBuilder { Some(TableFactorBuilder::Derived(ref mut rel_builder)) => { rel_builder.alias = value; } + Some(TableFactorBuilder::Unnest(ref mut rel_builder)) => { + rel_builder.alias = value; + } Some(TableFactorBuilder::Empty) => (), None => (), } @@ -391,6 +401,7 @@ impl RelationBuilder { Ok(match self.relation { Some(TableFactorBuilder::Table(ref value)) => Some(value.build()?), Some(TableFactorBuilder::Derived(ref value)) => Some(value.build()?), + Some(TableFactorBuilder::Unnest(ref value)) => Some(value.build()?), Some(TableFactorBuilder::Empty) => None, None => return Err(Into::into(UninitializedFieldError::from("relation"))), }) @@ -526,6 +537,68 @@ impl Default for DerivedRelationBuilder { } } +#[derive(Clone)] +pub(super) struct UnnestRelationBuilder { + pub alias: Option, + pub array_exprs: Vec, + with_offset: bool, + with_offset_alias: Option, + with_ordinality: bool, +} + +#[allow(dead_code)] +impl UnnestRelationBuilder { + pub fn alias(&mut self, value: Option) -> &mut Self { + self.alias = value; + self + } + pub fn array_exprs(&mut self, value: Vec) -> &mut Self { + self.array_exprs = value; + self + } + + pub fn with_offset(&mut self, value: bool) -> &mut Self { + self.with_offset = value; + self + } + + pub fn with_offset_alias(&mut self, value: Option) -> &mut Self { + self.with_offset_alias = value; + self + } + + pub fn with_ordinality(&mut self, value: bool) -> &mut Self { + self.with_ordinality = value; + self + } + + pub fn build(&self) -> Result { + Ok(ast::TableFactor::UNNEST { + alias: self.alias.clone(), + array_exprs: self.array_exprs.clone(), + with_offset: self.with_offset, + with_offset_alias: self.with_offset_alias.clone(), + with_ordinality: self.with_ordinality, + }) + } + + fn create_empty() -> Self { + Self { + alias: Default::default(), + array_exprs: Default::default(), + with_offset: Default::default(), + with_offset_alias: Default::default(), + with_ordinality: Default::default(), + } + } +} + +impl Default for UnnestRelationBuilder { + fn default() -> Self { + Self::create_empty() + } +} + /// Runtime error when a `build()` method is called and one or more required fields /// do not have a value. #[derive(Debug, Clone)] diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index e979d8fd4ebd..ae387d441fa2 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -157,6 +157,15 @@ pub trait Dialect: Send + Sync { fn full_qualified_col(&self) -> bool { false } + + /// Allow to unparse the unnest plan as [ast::TableFactor::UNNEST]. + /// + /// Some dialects like BigQuery require UNNEST to be used in the FROM clause but + /// the LogicalPlan planner always puts UNNEST in the SELECT clause. This flag allows + /// to unparse the UNNEST plan as [ast::TableFactor::UNNEST] instead of a subquery. + fn unnest_as_table_factor(&self) -> bool { + false + } } /// `IntervalStyle` to use for unparsing @@ -448,6 +457,7 @@ pub struct CustomDialect { requires_derived_table_alias: bool, division_operator: BinaryOperator, full_qualified_col: bool, + unnest_as_table_factor: bool, } impl Default for CustomDialect { @@ -474,6 +484,7 @@ impl Default for CustomDialect { requires_derived_table_alias: false, division_operator: BinaryOperator::Divide, full_qualified_col: false, + unnest_as_table_factor: false, } } } @@ -582,6 +593,10 @@ impl Dialect for CustomDialect { fn full_qualified_col(&self) -> bool { self.full_qualified_col } + + fn unnest_as_table_factor(&self) -> bool { + self.unnest_as_table_factor + } } /// `CustomDialectBuilder` to build `CustomDialect` using builder pattern @@ -617,6 +632,7 @@ pub struct CustomDialectBuilder { requires_derived_table_alias: bool, division_operator: BinaryOperator, full_qualified_col: bool, + unnest_as_table_factor: bool, } impl Default for CustomDialectBuilder { @@ -649,6 +665,7 @@ impl CustomDialectBuilder { requires_derived_table_alias: false, division_operator: BinaryOperator::Divide, full_qualified_col: false, + unnest_as_table_factor: false, } } @@ -673,6 +690,7 @@ impl CustomDialectBuilder { requires_derived_table_alias: self.requires_derived_table_alias, division_operator: self.division_operator, full_qualified_col: self.full_qualified_col, + unnest_as_table_factor: self.unnest_as_table_factor, } } @@ -800,4 +818,9 @@ impl CustomDialectBuilder { self.full_qualified_col = full_qualified_col; self } + + pub fn with_unnest_as_table_factor(mut self, _unnest_as_table_factor: bool) -> Self { + self.unnest_as_table_factor = _unnest_as_table_factor; + self + } } diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index eaae4fe73d8c..e9f9f486ea9a 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -32,7 +32,9 @@ use super::{ }, Unparser, }; +use crate::unparser::ast::UnnestRelationBuilder; use crate::unparser::utils::unproject_agg_exprs; +use crate::utils::UNNEST_PLACEHOLDER; use datafusion_common::{ internal_err, not_impl_err, tree_node::{TransformedResult, TreeNode}, @@ -40,7 +42,7 @@ use datafusion_common::{ }; use datafusion_expr::{ expr::Alias, BinaryExpr, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan, - LogicalPlanBuilder, Operator, Projection, SortExpr, TableScan, + LogicalPlanBuilder, Operator, Projection, SortExpr, TableScan, Unnest, }; use sqlparser::ast::{self, Ident, SetExpr}; use std::sync::Arc; @@ -312,6 +314,19 @@ impl Unparser<'_> { .select_to_sql_recursively(&new_plan, query, select, relation); } + // Projection can be top-level plan for unnest relation + // The projection generated by the `RecursiveUnnestRewriter` from a UNNEST relation will have + // only one expression, which is the placeholder column generated by the rewriter. + if self.dialect.unnest_as_table_factor() + && p.expr.len() == 1 + && Self::is_unnest_placeholder(&p.expr[0]) + { + if let LogicalPlan::Unnest(unnest) = &p.input.as_ref() { + return self + .unnest_to_table_factor_sql(unnest, query, select, relation); + } + } + // Projection can be top-level plan for derived table if select.already_projected() { return self.derive_with_dialect_alias( @@ -678,7 +693,11 @@ impl Unparser<'_> { ) } LogicalPlan::EmptyRelation(_) => { - relation.empty(); + // An EmptyRelation could be behind an UNNEST node. If the dialect supports UNNEST as a table factor, + // a TableRelationBuilder will be created for the UNNEST node first. + if !relation.has_relation() { + relation.empty(); + } Ok(()) } LogicalPlan::Extension(_) => not_impl_err!("Unsupported operator: {plan:?}"), @@ -708,6 +727,38 @@ impl Unparser<'_> { } } + /// Try to find the placeholder column name generated by `RecursiveUnnestRewriter` + /// Only match the pattern `Expr::Alias(Expr::Column("__unnest_placeholder(...)"))` + fn is_unnest_placeholder(expr: &Expr) -> bool { + if let Expr::Alias(Alias { expr, .. }) = expr { + if let Expr::Column(Column { name, .. }) = expr.as_ref() { + return name.starts_with(UNNEST_PLACEHOLDER); + } + } + false + } + + fn unnest_to_table_factor_sql( + &self, + unnest: &Unnest, + query: &mut Option, + select: &mut SelectBuilder, + relation: &mut RelationBuilder, + ) -> Result<()> { + let mut unnest_relation = UnnestRelationBuilder::default(); + let LogicalPlan::Projection(p) = unnest.input.as_ref() else { + return internal_err!("Unnest input is not a Projection: {unnest:?}"); + }; + let exprs = p + .expr + .iter() + .map(|e| self.expr_to_sql(e)) + .collect::>>()?; + unnest_relation.array_exprs(exprs); + relation.unnest(unnest_relation); + self.select_to_sql_recursively(p.input.as_ref(), query, select, relation) + } + fn is_scan_with_pushdown(scan: &TableScan) -> bool { scan.projection.is_some() || !scan.filters.is_empty() || scan.fetch.is_some() } diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 518781106c3b..354a68f60964 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -133,7 +133,7 @@ pub(crate) fn find_window_nodes_within_select<'a>( /// Recursively identify Column expressions and transform them into the appropriate unnest expression /// -/// For example, if expr contains the column expr "unnest_placeholder(make_array(Int64(1),Int64(2),Int64(2),Int64(5),NULL),depth=1)" +/// For example, if expr contains the column expr "__unnest_placeholder(make_array(Int64(1),Int64(2),Int64(2),Int64(5),NULL),depth=1)" /// it will be transformed into an actual unnest expression UNNEST([1, 2, 2, 5, NULL]) pub(crate) fn unproject_unnest_expr(expr: Expr, unnest: &Unnest) -> Result { expr.transform(|sub_expr| { diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 69e3953341ef..1c2a3ea91a2b 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -315,6 +315,8 @@ pub(crate) fn rewrite_recursive_unnests_bottom_up( .collect::>()) } +pub const UNNEST_PLACEHOLDER: &str = "__unnest_placeholder"; + /* This is only usedful when used with transform down up A full example of how the transformation works: @@ -360,9 +362,9 @@ impl RecursiveUnnestRewriter<'_> { // Full context, we are trying to plan the execution as InnerProjection->Unnest->OuterProjection // inside unnest execution, each column inside the inner projection // will be transformed into new columns. Thus we need to keep track of these placeholding column names - let placeholder_name = format!("unnest_placeholder({})", inner_expr_name); + let placeholder_name = format!("{UNNEST_PLACEHOLDER}({})", inner_expr_name); let post_unnest_name = - format!("unnest_placeholder({},depth={})", inner_expr_name, level); + format!("{UNNEST_PLACEHOLDER}({},depth={})", inner_expr_name, level); // This is due to the fact that unnest transformation should keep the original // column name as is, to comply with group by and order by let placeholder_column = Column::from_name(placeholder_name.clone()); @@ -693,17 +695,17 @@ mod tests { // Only the bottom most unnest exprs are transformed assert_eq!( transformed_exprs, - vec![col("unnest_placeholder(3d_col,depth=2)") + vec![col("__unnest_placeholder(3d_col,depth=2)") .alias("UNNEST(UNNEST(3d_col))") .add( - col("unnest_placeholder(3d_col,depth=2)") + col("__unnest_placeholder(3d_col,depth=2)") .alias("UNNEST(UNNEST(3d_col))") ) .add(col("i64_col"))] ); column_unnests_eq( vec![ - "unnest_placeholder(3d_col)=>[unnest_placeholder(3d_col,depth=2)|depth=2]", + "__unnest_placeholder(3d_col)=>[__unnest_placeholder(3d_col,depth=2)|depth=2]", ], &unnest_placeholder_columns, ); @@ -713,7 +715,7 @@ mod tests { assert_eq!( inner_projection_exprs, vec![ - col("3d_col").alias("unnest_placeholder(3d_col)"), + col("3d_col").alias("__unnest_placeholder(3d_col)"), col("i64_col") ] ); @@ -730,12 +732,12 @@ mod tests { assert_eq!( transformed_exprs, vec![ - (col("unnest_placeholder(3d_col,depth=1)").alias("UNNEST(3d_col)")) + (col("__unnest_placeholder(3d_col,depth=1)").alias("UNNEST(3d_col)")) .alias("2d_col") ] ); column_unnests_eq( - vec!["unnest_placeholder(3d_col)=>[unnest_placeholder(3d_col,depth=2)|depth=2, unnest_placeholder(3d_col,depth=1)|depth=1]"], + vec!["__unnest_placeholder(3d_col)=>[__unnest_placeholder(3d_col,depth=2)|depth=2, __unnest_placeholder(3d_col,depth=1)|depth=1]"], &unnest_placeholder_columns, ); // Still reference struct_col in original schema but with alias, @@ -743,7 +745,7 @@ mod tests { assert_eq!( inner_projection_exprs, vec![ - col("3d_col").alias("unnest_placeholder(3d_col)"), + col("3d_col").alias("__unnest_placeholder(3d_col)"), col("i64_col") ] ); @@ -794,19 +796,19 @@ mod tests { assert_eq!( transformed_exprs, vec![ - col("unnest_placeholder(struct_col).field1"), - col("unnest_placeholder(struct_col).field2"), + col("__unnest_placeholder(struct_col).field1"), + col("__unnest_placeholder(struct_col).field2"), ] ); column_unnests_eq( - vec!["unnest_placeholder(struct_col)"], + vec!["__unnest_placeholder(struct_col)"], &unnest_placeholder_columns, ); // Still reference struct_col in original schema but with alias, // to avoid colliding with the projection on the column itself if any assert_eq!( inner_projection_exprs, - vec![col("struct_col").alias("unnest_placeholder(struct_col)"),] + vec![col("struct_col").alias("__unnest_placeholder(struct_col)"),] ); // unnest(array_col) + 1 @@ -819,15 +821,15 @@ mod tests { )?; column_unnests_eq( vec![ - "unnest_placeholder(struct_col)", - "unnest_placeholder(array_col)=>[unnest_placeholder(array_col,depth=1)|depth=1]", + "__unnest_placeholder(struct_col)", + "__unnest_placeholder(array_col)=>[__unnest_placeholder(array_col,depth=1)|depth=1]", ], &unnest_placeholder_columns, ); // Only transform the unnest children assert_eq!( transformed_exprs, - vec![col("unnest_placeholder(array_col,depth=1)") + vec![col("__unnest_placeholder(array_col,depth=1)") .alias("UNNEST(array_col)") .add(lit(1i64))] ); @@ -838,8 +840,8 @@ mod tests { assert_eq!( inner_projection_exprs, vec![ - col("struct_col").alias("unnest_placeholder(struct_col)"), - col("array_col").alias("unnest_placeholder(array_col)") + col("struct_col").alias("__unnest_placeholder(struct_col)"), + col("array_col").alias("__unnest_placeholder(array_col)") ] ); @@ -907,7 +909,7 @@ mod tests { assert_eq!( transformed_exprs, vec![unnest( - col("unnest_placeholder(struct_list,depth=1)") + col("__unnest_placeholder(struct_list,depth=1)") .alias("UNNEST(struct_list)") .field("subfield1") )] @@ -915,14 +917,14 @@ mod tests { column_unnests_eq( vec![ - "unnest_placeholder(struct_list)=>[unnest_placeholder(struct_list,depth=1)|depth=1]", + "__unnest_placeholder(struct_list)=>[__unnest_placeholder(struct_list,depth=1)|depth=1]", ], &unnest_placeholder_columns, ); assert_eq!( inner_projection_exprs, - vec![col("struct_list").alias("unnest_placeholder(struct_list)")] + vec![col("struct_list").alias("__unnest_placeholder(struct_list)")] ); // continue rewrite another expr in select @@ -937,7 +939,7 @@ mod tests { assert_eq!( transformed_exprs, vec![unnest( - col("unnest_placeholder(struct_list,depth=1)") + col("__unnest_placeholder(struct_list,depth=1)") .alias("UNNEST(struct_list)") .field("subfield2") )] @@ -947,14 +949,14 @@ mod tests { // because expr1 and expr2 derive from the same unnest result column_unnests_eq( vec![ - "unnest_placeholder(struct_list)=>[unnest_placeholder(struct_list,depth=1)|depth=1]", + "__unnest_placeholder(struct_list)=>[__unnest_placeholder(struct_list,depth=1)|depth=1]", ], &unnest_placeholder_columns, ); assert_eq!( inner_projection_exprs, - vec![col("struct_list").alias("unnest_placeholder(struct_list)")] + vec![col("struct_list").alias("__unnest_placeholder(struct_list)")] ); Ok(()) diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index fcfee29f6ac9..236b59432a5f 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -525,6 +525,96 @@ fn roundtrip_statement_with_dialect() -> Result<()> { parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(SqliteDialect {}), }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3])", + expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))")"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)", + expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS t1 (c1)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)", + expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS t1 (c1)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]), j1", + expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") CROSS JOIN j1"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) JOIN j1 ON u.c1 = j1.j1_id", + expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS u (c1) JOIN j1 ON (u.c1 = j1.j1_id)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) UNION ALL SELECT * FROM UNNEST([4,5,6]) u(c1)", + expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS u (c1) UNION ALL SELECT * FROM (SELECT UNNEST([4, 5, 6]) AS "UNNEST(make_array(Int64(4),Int64(5),Int64(6)))") AS u (c1)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3])", + expected: r#"SELECT * FROM UNNEST([1, 2, 3])"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)", + expected: r#"SELECT * FROM UNNEST([1, 2, 3]) AS t1 (c1)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)", + expected: r#"SELECT * FROM UNNEST([1, 2, 3]) AS t1 (c1)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]), j1", + expected: r#"SELECT * FROM UNNEST([1, 2, 3]) CROSS JOIN j1"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) JOIN j1 ON u.c1 = j1.j1_id", + expected: r#"SELECT * FROM UNNEST([1, 2, 3]) AS u (c1) JOIN j1 ON (u.c1 = j1.j1_id)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) UNION ALL SELECT * FROM UNNEST([4,5,6]) u(c1)", + expected: r#"SELECT * FROM UNNEST([1, 2, 3]) AS u (c1) UNION ALL SELECT * FROM UNNEST([4, 5, 6]) AS u (c1)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + }, + TestStatementWithDialect { + sql: "SELECT UNNEST([1,2,3])", + expected: r#"SELECT * FROM UNNEST([1, 2, 3])"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + }, + TestStatementWithDialect { + sql: "SELECT UNNEST([1,2,3]) as c1", + expected: r#"SELECT UNNEST([1, 2, 3]) AS c1"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + }, + TestStatementWithDialect { + sql: "SELECT UNNEST([1,2,3]), 1", + expected: r#"SELECT UNNEST([1, 2, 3]) AS UNNEST(make_array(Int64(1),Int64(2),Int64(3))), Int64(1)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + }, ]; for query in tests { @@ -535,7 +625,8 @@ fn roundtrip_statement_with_dialect() -> Result<()> { let state = MockSessionState::default() .with_aggregate_function(max_udaf()) .with_aggregate_function(min_udaf()) - .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())) + .with_expr_planner(Arc::new(NestedFunctionPlanner)); let context = MockContextProvider { state }; let sql_to_rel = SqlToRel::new(&context); @@ -571,9 +662,9 @@ fn test_unnest_logical_plan() -> Result<()> { let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); let expected = r#" -Projection: unnest_placeholder(unnest_table.struct_col).field1, unnest_placeholder(unnest_table.struct_col).field2, unnest_placeholder(unnest_table.array_col,depth=1) AS UNNEST(unnest_table.array_col), unnest_table.struct_col, unnest_table.array_col - Unnest: lists[unnest_placeholder(unnest_table.array_col)|depth=1] structs[unnest_placeholder(unnest_table.struct_col)] - Projection: unnest_table.struct_col AS unnest_placeholder(unnest_table.struct_col), unnest_table.array_col AS unnest_placeholder(unnest_table.array_col), unnest_table.struct_col, unnest_table.array_col +Projection: __unnest_placeholder(unnest_table.struct_col).field1, __unnest_placeholder(unnest_table.struct_col).field2, __unnest_placeholder(unnest_table.array_col,depth=1) AS UNNEST(unnest_table.array_col), unnest_table.struct_col, unnest_table.array_col + Unnest: lists[__unnest_placeholder(unnest_table.array_col)|depth=1] structs[__unnest_placeholder(unnest_table.struct_col)] + Projection: unnest_table.struct_col AS __unnest_placeholder(unnest_table.struct_col), unnest_table.array_col AS __unnest_placeholder(unnest_table.array_col), unnest_table.struct_col, unnest_table.array_col TableScan: unnest_table"#.trim_start(); assert_eq!(plan.to_string(), expected); diff --git a/datafusion/sqllogictest/Cargo.toml b/datafusion/sqllogictest/Cargo.toml index 849003f8eeac..f254e0db41e6 100644 --- a/datafusion/sqllogictest/Cargo.toml +++ b/datafusion/sqllogictest/Cargo.toml @@ -48,15 +48,15 @@ half = { workspace = true, default-features = true } itertools = { workspace = true } log = { workspace = true } object_store = { workspace = true } -postgres-protocol = { version = "0.6.4", optional = true } -postgres-types = { version = "0.2.4", optional = true } -rust_decimal = { version = "1.27.0" } +postgres-protocol = { version = "0.6.7", optional = true } +postgres-types = { version = "0.2.8", features = ["derive", "with-chrono-0_4"], optional = true } +rust_decimal = { version = "1.36.0", features = ["tokio-pg"] } sqllogictest = "0.23.0" sqlparser = { workspace = true } tempfile = { workspace = true } thiserror = "2.0.0" tokio = { workspace = true } -tokio-postgres = { version = "0.7.7", optional = true } +tokio-postgres = { version = "0.7.12", optional = true } [features] avro = ["datafusion/avro"] diff --git a/datafusion/sqllogictest/bin/sqllogictests.rs b/datafusion/sqllogictest/bin/sqllogictests.rs index 12c0e27ea911..be3f1cb251b6 100644 --- a/datafusion/sqllogictest/bin/sqllogictests.rs +++ b/datafusion/sqllogictest/bin/sqllogictests.rs @@ -15,10 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::ffi::OsStr; -use std::fs; -use std::path::{Path, PathBuf}; - use clap::Parser; use datafusion_common::utils::get_available_parallelism; use datafusion_sqllogictest::{DataFusion, TestContext}; @@ -26,6 +22,9 @@ use futures::stream::StreamExt; use itertools::Itertools; use log::info; use sqllogictest::strict_column_validator; +use std::ffi::OsStr; +use std::fs; +use std::path::{Path, PathBuf}; use datafusion_common::{exec_datafusion_err, exec_err, DataFusionError, Result}; use datafusion_common_runtime::SpawnedTask; @@ -100,14 +99,15 @@ async fn run_tests() -> Result<()> { let errors: Vec<_> = futures::stream::iter(read_test_files(&options)?) .map(|test_file| { SpawnedTask::spawn(async move { - println!("Running {:?}", test_file.relative_path); - if options.complete { - run_complete_file(test_file).await?; - } else if options.postgres_runner { - run_test_file_with_postgres(test_file).await?; - } else { - run_test_file(test_file).await?; + let file_path = test_file.relative_path.clone(); + let start = datafusion::common::instant::Instant::now(); + match (options.postgres_runner, options.complete) { + (false, false) => run_test_file(test_file).await?, + (false, true) => run_complete_file(test_file).await?, + (true, false) => run_test_file_with_postgres(test_file).await?, + (true, true) => run_complete_file_with_postgres(test_file).await?, } + println!("Executed {:?}. Took {:?}", file_path, start.elapsed()); Ok(()) as Result<()> }) .join() @@ -226,6 +226,41 @@ async fn run_complete_file(test_file: TestFile) -> Result<()> { }) } +#[cfg(feature = "postgres")] +async fn run_complete_file_with_postgres(test_file: TestFile) -> Result<()> { + use datafusion_sqllogictest::Postgres; + let TestFile { + path, + relative_path, + } = test_file; + info!( + "Using complete mode to complete with Postgres runner: {}", + path.display() + ); + setup_scratch_dir(&relative_path)?; + let mut runner = + sqllogictest::Runner::new(|| Postgres::connect(relative_path.clone())); + let col_separator = " "; + runner + .update_test_file( + path, + col_separator, + value_validator, + strict_column_validator, + ) + .await + // Can't use e directly because it isn't marked Send, so turn it into a string. + .map_err(|e| { + DataFusionError::Execution(format!("Error completing {relative_path:?}: {e}")) + }) +} + +#[cfg(not(feature = "postgres"))] +async fn run_complete_file_with_postgres(_test_file: TestFile) -> Result<()> { + use datafusion_common::plan_err; + plan_err!("Can not run with postgres as postgres feature is not enabled") +} + /// Represents a parsed test file #[derive(Debug)] struct TestFile { diff --git a/datafusion/sqllogictest/src/engines/conversion.rs b/datafusion/sqllogictest/src/engines/conversion.rs index 8d2fd1e6d0f2..516ec69e0b07 100644 --- a/datafusion/sqllogictest/src/engines/conversion.rs +++ b/datafusion/sqllogictest/src/engines/conversion.rs @@ -81,16 +81,18 @@ pub(crate) fn f64_to_str(value: f64) -> String { } } -pub(crate) fn i128_to_str(value: i128, precision: &u8, scale: &i8) -> String { +pub(crate) fn decimal_128_to_str(value: i128, scale: i8) -> String { + let precision = u8::MAX; // does not matter big_decimal_to_str( - BigDecimal::from_str(&Decimal128Type::format_decimal(value, *precision, *scale)) + BigDecimal::from_str(&Decimal128Type::format_decimal(value, precision, scale)) .unwrap(), ) } -pub(crate) fn i256_to_str(value: i256, precision: &u8, scale: &i8) -> String { +pub(crate) fn decimal_256_to_str(value: i256, scale: i8) -> String { + let precision = u8::MAX; // does not matter big_decimal_to_str( - BigDecimal::from_str(&Decimal256Type::format_decimal(value, *precision, *scale)) + BigDecimal::from_str(&Decimal256Type::format_decimal(value, precision, scale)) .unwrap(), ) } @@ -104,30 +106,7 @@ pub(crate) fn big_decimal_to_str(value: BigDecimal) -> String { // Round the value to limit the number of decimal places let value = value.round(12).normalized(); // Format the value to a string - format_big_decimal(value) -} - -fn format_big_decimal(value: BigDecimal) -> String { - let (integer, scale) = value.into_bigint_and_exponent(); - let mut str = integer.to_str_radix(10); - if scale <= 0 { - // Append zeros to the right of the integer part - str.extend(std::iter::repeat('0').take(scale.unsigned_abs() as usize)); - str - } else { - let (sign, unsigned_len, unsigned_str) = if integer.is_negative() { - ("-", str.len() - 1, &str[1..]) - } else { - ("", str.len(), &str[..]) - }; - let scale = scale as usize; - if unsigned_len <= scale { - format!("{}0.{:0>scale$}", sign, unsigned_str) - } else { - str.insert(str.len() - scale, '.'); - str - } - } + value.to_plain_string() } #[cfg(test)] @@ -149,19 +128,41 @@ mod tests { #[test] fn test_big_decimal_to_str() { + assert_decimal_str_eq!(110, 3, "0.11"); assert_decimal_str_eq!(11, 3, "0.011"); assert_decimal_str_eq!(11, 2, "0.11"); assert_decimal_str_eq!(11, 1, "1.1"); assert_decimal_str_eq!(11, 0, "11"); assert_decimal_str_eq!(11, -1, "110"); assert_decimal_str_eq!(0, 0, "0"); + assert_decimal_str_eq!( + 12345678901234567890123456789012345678_i128, + 0, + "12345678901234567890123456789012345678" + ); + assert_decimal_str_eq!( + 12345678901234567890123456789012345678_i128, + 38, + "0.123456789012" + ); // Negative cases + assert_decimal_str_eq!(-110, 3, "-0.11"); assert_decimal_str_eq!(-11, 3, "-0.011"); assert_decimal_str_eq!(-11, 2, "-0.11"); assert_decimal_str_eq!(-11, 1, "-1.1"); assert_decimal_str_eq!(-11, 0, "-11"); assert_decimal_str_eq!(-11, -1, "-110"); + assert_decimal_str_eq!( + -12345678901234567890123456789012345678_i128, + 0, + "-12345678901234567890123456789012345678" + ); + assert_decimal_str_eq!( + -12345678901234567890123456789012345678_i128, + 38, + "-0.123456789012" + ); // Round to 12 decimal places // 1.0000000000011 -> 1.000000000001 diff --git a/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs index 4146c7cf8010..b80f0ef075ff 100644 --- a/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs +++ b/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs @@ -213,13 +213,13 @@ pub fn cell_to_string(col: &ArrayRef, row: usize) -> Result { DataType::Float64 => { Ok(f64_to_str(get_row_value!(array::Float64Array, col, row))) } - DataType::Decimal128(precision, scale) => { + DataType::Decimal128(_, scale) => { let value = get_row_value!(array::Decimal128Array, col, row); - Ok(i128_to_str(value, precision, scale)) + Ok(decimal_128_to_str(value, *scale)) } - DataType::Decimal256(precision, scale) => { + DataType::Decimal256(_, scale) => { let value = get_row_value!(array::Decimal256Array, col, row); - Ok(i256_to_str(value, precision, scale)) + Ok(decimal_256_to_str(value, *scale)) } DataType::LargeUtf8 => Ok(varchar_to_str(get_row_value!( array::LargeStringArray, diff --git a/datafusion/sqllogictest/test_files/csv_files.slt b/datafusion/sqllogictest/test_files/csv_files.slt index 01d0f4ac39bd..5906c6a19bb8 100644 --- a/datafusion/sqllogictest/test_files/csv_files.slt +++ b/datafusion/sqllogictest/test_files/csv_files.slt @@ -350,15 +350,33 @@ col2 TEXT LOCATION '../core/tests/data/cr_terminator.csv' OPTIONS ('format.terminator' E'\r', 'format.has_header' 'true'); -# TODO: It should be passed but got the error: External error: query failed: DataFusion error: Object Store error: Generic LocalFileSystem error: Requested range was invalid -# See the issue: https://github.com/apache/datafusion/issues/12328 -# query TT -# select * from stored_table_with_cr_terminator; -# ---- -# id0 value0 -# id1 value1 -# id2 value2 -# id3 value3 +# Check single-thread reading of CSV with custom line terminator +statement ok +SET datafusion.optimizer.repartition_file_min_size = 10485760; + +query TT +select * from stored_table_with_cr_terminator; +---- +id0 value0 +id1 value1 +id2 value2 +id3 value3 + +# Check repartitioned reading of CSV with custom line terminator +statement ok +SET datafusion.optimizer.repartition_file_min_size = 1; + +query TT +select * from stored_table_with_cr_terminator order by col1; +---- +id0 value0 +id1 value1 +id2 value2 +id3 value3 + +# Reset repartition_file_min_size to default value +statement ok +SET datafusion.optimizer.repartition_file_min_size = 10485760; statement ok drop table stored_table_with_cr_terminator; diff --git a/datafusion/sqllogictest/test_files/encoding.slt b/datafusion/sqllogictest/test_files/encoding.slt index fc22cc8bf7a7..24efb33f7896 100644 --- a/datafusion/sqllogictest/test_files/encoding.slt +++ b/datafusion/sqllogictest/test_files/encoding.slt @@ -101,4 +101,4 @@ FROM test_utf8view; Andrew QW5kcmV3 416e64726577 X WA 58 Xiangpeng WGlhbmdwZW5n 5869616e6770656e67 Xiangpeng WGlhbmdwZW5n 5869616e6770656e67 Raphael UmFwaGFlbA 5261706861656c R Ug 52 -NULL NULL NULL R Ug 52 \ No newline at end of file +NULL NULL NULL R Ug 52 diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index 9b8dfc2186be..2306eda77d35 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -832,867 +832,6 @@ SELECT ---- 0 NULL 0 NULL -# test_extract_date_part - -query error -SELECT EXTRACT("'''year'''" FROM timestamp '2020-09-08T12:00:00+00:00') - -query error -SELECT EXTRACT("'year'" FROM timestamp '2020-09-08T12:00:00+00:00') - -query I -SELECT date_part('YEAR', CAST('2000-01-01' AS DATE)) ----- -2000 - -query I -SELECT EXTRACT(year FROM timestamp '2020-09-08T12:00:00+00:00') ----- -2020 - -query I -SELECT EXTRACT("year" FROM timestamp '2020-09-08T12:00:00+00:00') ----- -2020 - -query I -SELECT EXTRACT('year' FROM timestamp '2020-09-08T12:00:00+00:00') ----- -2020 - -query I -SELECT date_part('QUARTER', CAST('2000-01-01' AS DATE)) ----- -1 - -query I -SELECT EXTRACT(quarter FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -3 - -query I -SELECT EXTRACT("quarter" FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -3 - -query I -SELECT EXTRACT('quarter' FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -3 - -query I -SELECT date_part('MONTH', CAST('2000-01-01' AS DATE)) ----- -1 - -query I -SELECT EXTRACT(month FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -9 - -query I -SELECT EXTRACT("month" FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -9 - -query I -SELECT EXTRACT('month' FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -9 - -query I -SELECT date_part('WEEK', CAST('2003-01-01' AS DATE)) ----- -1 - -query I -SELECT EXTRACT(WEEK FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -37 - -query I -SELECT EXTRACT("WEEK" FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -37 - -query I -SELECT EXTRACT('WEEK' FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -37 - -query I -SELECT date_part('DAY', CAST('2000-01-01' AS DATE)) ----- -1 - -query I -SELECT EXTRACT(day FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -8 - -query I -SELECT EXTRACT("day" FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -8 - -query I -SELECT EXTRACT('day' FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -8 - -query I -SELECT date_part('DOY', CAST('2000-01-01' AS DATE)) ----- -1 - -query I -SELECT EXTRACT(doy FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -252 - -query I -SELECT EXTRACT("doy" FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -252 - -query I -SELECT EXTRACT('doy' FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -252 - -query I -SELECT date_part('DOW', CAST('2000-01-01' AS DATE)) ----- -6 - -query I -SELECT EXTRACT(dow FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -2 - -query I -SELECT EXTRACT("dow" FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -2 - -query I -SELECT EXTRACT('dow' FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -2 - -query I -SELECT date_part('HOUR', CAST('2000-01-01' AS DATE)) ----- -0 - -query I -SELECT EXTRACT(hour FROM to_timestamp('2020-09-08T12:03:03+00:00')) ----- -12 - -query I -SELECT EXTRACT("hour" FROM to_timestamp('2020-09-08T12:03:03+00:00')) ----- -12 - -query I -SELECT EXTRACT('hour' FROM to_timestamp('2020-09-08T12:03:03+00:00')) ----- -12 - -query I -SELECT EXTRACT(minute FROM to_timestamp('2020-09-08T12:12:00+00:00')) ----- -12 - -query I -SELECT EXTRACT("minute" FROM to_timestamp('2020-09-08T12:12:00+00:00')) ----- -12 - -query I -SELECT EXTRACT('minute' FROM to_timestamp('2020-09-08T12:12:00+00:00')) ----- -12 - -query I -SELECT date_part('minute', to_timestamp('2020-09-08T12:12:00+00:00')) ----- -12 - -# make sure the return type is integer -query T -SELECT arrow_typeof(date_part('minute', to_timestamp('2020-09-08T12:12:00+00:00'))) ----- -Int32 - -query I -SELECT EXTRACT(second FROM timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12 - -query I -SELECT EXTRACT(millisecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12123 - -query I -SELECT EXTRACT(microsecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12123456 - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT EXTRACT(nanosecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00') - -query I -SELECT EXTRACT("second" FROM timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12 - -query I -SELECT EXTRACT("millisecond" FROM timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12123 - -query I -SELECT EXTRACT("microsecond" FROM timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12123456 - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT EXTRACT("nanosecond" FROM timestamp '2020-09-08T12:00:12.12345678+00:00') - -query I -SELECT EXTRACT('second' FROM timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12 - -query I -SELECT EXTRACT('millisecond' FROM timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12123 - -query I -SELECT EXTRACT('microsecond' FROM timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12123456 - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT EXTRACT('nanosecond' FROM timestamp '2020-09-08T12:00:12.12345678+00:00') - - -# Keep precision when coercing Utf8 to Timestamp -query I -SELECT date_part('second', timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12 - -query I -SELECT date_part('millisecond', timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12123 - -query I -SELECT date_part('microsecond', timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12123456 - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT date_part('nanosecond', timestamp '2020-09-08T12:00:12.12345678+00:00') - - -query I -SELECT date_part('second', '2020-09-08T12:00:12.12345678+00:00') ----- -12 - -query I -SELECT date_part('millisecond', '2020-09-08T12:00:12.12345678+00:00') ----- -12123 - -query I -SELECT date_part('microsecond', '2020-09-08T12:00:12.12345678+00:00') ----- -12123456 - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT date_part('nanosecond', '2020-09-08T12:00:12.12345678+00:00') - -# test_date_part_time - -## time32 seconds -query I -SELECT date_part('hour', arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -23 - -query I -SELECT extract(hour from arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -23 - -query I -SELECT date_part('minute', arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -32 - -query I -SELECT extract(minute from arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -32 - -query I -SELECT date_part('second', arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -50 - -query I -SELECT extract(second from arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -50 - -query I -SELECT date_part('millisecond', arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -50000 - -query I -SELECT extract(millisecond from arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -50000 - -query I -SELECT date_part('microsecond', arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -50000000 - -query I -SELECT extract(microsecond from arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -50000000 - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT extract(nanosecond from arrow_cast('23:32:50'::time, 'Time32(Second)')) - -query R -SELECT date_part('epoch', arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -84770 - -query R -SELECT extract(epoch from arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -84770 - -## time32 milliseconds -query I -SELECT date_part('hour', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -23 - -query I -SELECT extract(hour from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -23 - -query I -SELECT date_part('minute', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -32 - -query I -SELECT extract(minute from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -32 - -query I -SELECT date_part('second', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -50 - -query I -SELECT extract(second from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -50 - -query I -SELECT date_part('millisecond', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -50123 - -query I -SELECT extract(millisecond from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -50123 - -query I -SELECT date_part('microsecond', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -50123000 - -query I -SELECT extract(microsecond from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -50123000 - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT extract(nanosecond from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) - -query R -SELECT date_part('epoch', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -84770.123 - -query R -SELECT extract(epoch from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -84770.123 - -## time64 microseconds -query I -SELECT date_part('hour', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -23 - -query I -SELECT extract(hour from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -23 - -query I -SELECT date_part('minute', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -32 - -query I -SELECT extract(minute from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -32 - -query I -SELECT date_part('second', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -50 - -query I -SELECT extract(second from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -50 - -query I -SELECT date_part('millisecond', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -50123 - -query I -SELECT extract(millisecond from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -50123 - -query I -SELECT date_part('microsecond', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -50123456 - -query I -SELECT extract(microsecond from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -50123456 - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT extract(nanosecond from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) - -query R -SELECT date_part('epoch', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -84770.123456 - -query R -SELECT extract(epoch from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -84770.123456 - -## time64 nanoseconds -query I -SELECT date_part('hour', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -23 - -query I -SELECT extract(hour from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -23 - -query I -SELECT date_part('minute', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -32 - -query I -SELECT extract(minute from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -32 - -query I -SELECT date_part('second', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -50 - -query I -select extract(second from '2024-08-09T12:13:14') ----- -14 - -query I -select extract(seconds from '2024-08-09T12:13:14') ----- -14 - -query I -SELECT extract(second from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -50 - -query I -SELECT date_part('millisecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -50123 - -query I -SELECT extract(millisecond from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -50123 - -# just some floating point stuff happening in the result here -query I -SELECT date_part('microsecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -50123456 - -query I -SELECT extract(microsecond from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -50123456 - -query I -SELECT extract(us from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -50123456 - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT date_part('nanosecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) - -query R -SELECT date_part('epoch', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -84770.123456789 - -query R -SELECT extract(epoch from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -84770.123456789 - -# test_extract_epoch - -query R -SELECT extract(epoch from '1870-01-01T07:29:10.256'::timestamp) ----- --3155646649.744 - -query R -SELECT extract(epoch from '2000-01-01T00:00:00.000'::timestamp) ----- -946684800 - -query R -SELECT extract(epoch from to_timestamp('2000-01-01T00:00:00+00:00')) ----- -946684800 - -query R -SELECT extract(epoch from NULL::timestamp) ----- -NULL - -query R -SELECT extract(epoch from arrow_cast('1970-01-01', 'Date32')) ----- -0 - -query R -SELECT extract(epoch from arrow_cast('1970-01-02', 'Date32')) ----- -86400 - -query R -SELECT extract(epoch from arrow_cast('1970-01-11', 'Date32')) ----- -864000 - -query R -SELECT extract(epoch from arrow_cast('1969-12-31', 'Date32')) ----- --86400 - -query R -SELECT extract(epoch from arrow_cast('1970-01-01', 'Date64')) ----- -0 - -query R -SELECT extract(epoch from arrow_cast('1970-01-02', 'Date64')) ----- -86400 - -query R -SELECT extract(epoch from arrow_cast('1970-01-11', 'Date64')) ----- -864000 - -query R -SELECT extract(epoch from arrow_cast('1969-12-31', 'Date64')) ----- --86400 - -# test_extract_interval - -query I -SELECT extract(year from arrow_cast('10 years', 'Interval(YearMonth)')) ----- -10 - -query I -SELECT extract(month from arrow_cast('10 years', 'Interval(YearMonth)')) ----- -0 - -query I -SELECT extract(year from arrow_cast('10 months', 'Interval(YearMonth)')) ----- -0 - -query I -SELECT extract(month from arrow_cast('10 months', 'Interval(YearMonth)')) ----- -10 - -query I -SELECT extract(year from arrow_cast('20 months', 'Interval(YearMonth)')) ----- -1 - -query I -SELECT extract(month from arrow_cast('20 months', 'Interval(YearMonth)')) ----- -8 - -query error DataFusion error: Arrow error: Compute error: Year does not support: Interval\(DayTime\) -SELECT extract(year from arrow_cast('10 days', 'Interval(DayTime)')) - -query error DataFusion error: Arrow error: Compute error: Month does not support: Interval\(DayTime\) -SELECT extract(month from arrow_cast('10 days', 'Interval(DayTime)')) - -query I -SELECT extract(day from arrow_cast('10 days', 'Interval(DayTime)')) ----- -10 - -query I -SELECT extract(day from arrow_cast('14400 minutes', 'Interval(DayTime)')) ----- -0 - -query I -SELECT extract(minute from arrow_cast('14400 minutes', 'Interval(DayTime)')) ----- -14400 - -query I -SELECT extract(second from arrow_cast('5.1 seconds', 'Interval(DayTime)')) ----- -5 - -query I -SELECT extract(second from arrow_cast('14400 minutes', 'Interval(DayTime)')) ----- -864000 - -query I -SELECT extract(second from arrow_cast('2 months', 'Interval(MonthDayNano)')) ----- -0 - -query I -SELECT extract(second from arrow_cast('2 days', 'Interval(MonthDayNano)')) ----- -0 - -query I -SELECT extract(second from arrow_cast('2 seconds', 'Interval(MonthDayNano)')) ----- -2 - -query I -SELECT extract(seconds from arrow_cast('2 seconds', 'Interval(MonthDayNano)')) ----- -2 - -query R -SELECT extract(epoch from arrow_cast('2 seconds', 'Interval(MonthDayNano)')) ----- -2 - -query I -SELECT extract(milliseconds from arrow_cast('2 seconds', 'Interval(MonthDayNano)')) ----- -2000 - -query I -SELECT extract(second from arrow_cast('2030 milliseconds', 'Interval(MonthDayNano)')) ----- -2 - -query I -SELECT extract(second from arrow_cast(NULL, 'Interval(MonthDayNano)')) ----- -NULL - -statement ok -create table t (id int, i interval) as values - (0, interval '5 months 1 day 10 nanoseconds'), - (1, interval '1 year 3 months'), - (2, interval '3 days 2 milliseconds'), - (3, interval '2 seconds'), - (4, interval '8 months'), - (5, NULL); - -query III -select - id, - extract(second from i), - extract(month from i) -from t -order by id; ----- -0 0 5 -1 0 15 -2 0 0 -3 2 0 -4 0 8 -5 NULL NULL - -statement ok -drop table t; - -# test_extract_duration - -query I -SELECT extract(second from arrow_cast(2, 'Duration(Second)')) ----- -2 - -query I -SELECT extract(seconds from arrow_cast(2, 'Duration(Second)')) ----- -2 - -query R -SELECT extract(epoch from arrow_cast(2, 'Duration(Second)')) ----- -2 - -query I -SELECT extract(millisecond from arrow_cast(2, 'Duration(Second)')) ----- -2000 - -query I -SELECT extract(second from arrow_cast(2, 'Duration(Millisecond)')) ----- -0 - -query I -SELECT extract(second from arrow_cast(2002, 'Duration(Millisecond)')) ----- -2 - -query I -SELECT extract(millisecond from arrow_cast(2002, 'Duration(Millisecond)')) ----- -2002 - -query I -SELECT extract(day from arrow_cast(864000, 'Duration(Second)')) ----- -10 - -query error DataFusion error: Arrow error: Compute error: Month does not support: Duration\(Second\) -SELECT extract(month from arrow_cast(864000, 'Duration(Second)')) - -query error DataFusion error: Arrow error: Compute error: Year does not support: Duration\(Second\) -SELECT extract(year from arrow_cast(864000, 'Duration(Second)')) - -query I -SELECT extract(day from arrow_cast(NULL, 'Duration(Second)')) ----- -NULL - -# test_extract_date_part_func - -query B -SELECT (date_part('year', now()) = EXTRACT(year FROM now())) ----- -true - -query B -SELECT (date_part('quarter', now()) = EXTRACT(quarter FROM now())) ----- -true - -query B -SELECT (date_part('month', now()) = EXTRACT(month FROM now())) ----- -true - -query B -SELECT (date_part('week', now()) = EXTRACT(week FROM now())) ----- -true - -query B -SELECT (date_part('day', now()) = EXTRACT(day FROM now())) ----- -true - -query B -SELECT (date_part('hour', now()) = EXTRACT(hour FROM now())) ----- -true - -query B -SELECT (date_part('minute', now()) = EXTRACT(minute FROM now())) ----- -true - -query B -SELECT (date_part('second', now()) = EXTRACT(second FROM now())) ----- -true - -query B -SELECT (date_part('millisecond', now()) = EXTRACT(millisecond FROM now())) ----- -true - -query B -SELECT (date_part('microsecond', now()) = EXTRACT(microsecond FROM now())) ----- -true - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT (date_part('nanosecond', now()) = EXTRACT(nanosecond FROM now())) - query B SELECT 'a' IN ('a','b') ---- diff --git a/datafusion/sqllogictest/test_files/expr/date_part.slt b/datafusion/sqllogictest/test_files/expr/date_part.slt new file mode 100644 index 000000000000..dec796aa59cb --- /dev/null +++ b/datafusion/sqllogictest/test_files/expr/date_part.slt @@ -0,0 +1,1072 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tests for `date_part` and `EXTRACT` (which is a different syntax +# for the same function). + + +## Begin tests fo rdate_part with columns and timestamp's with timezones + +# Source data table has +# timestamps with millisecond (very common timestamp precision) and nanosecond (maximum precision) timestamps +statement count 0 +CREATE TABLE source_ts AS +with t as (values + ('2020-01-01T00:00:00+00:00'), + ('2021-01-01T00:00:00+00:00'), -- year + ('2020-09-01T00:00:00+00:00'), -- month + ('2020-01-25T00:00:00+00:00'), -- day + ('2020-01-24T00:00:00+00:00'), -- day + ('2020-01-01T12:00:00+00:00'), -- hour + ('2020-01-01T00:30:00+00:00'), -- minute + ('2020-01-01T00:00:30+00:00'), -- second + ('2020-01-01T00:00:00.123+00:00'), -- ms + ('2020-01-01T00:00:00.123456+00:00'), -- us + ('2020-01-01T00:00:00.123456789+00:00') -- ns +) +SELECT + -- nanoseconds, with no, utc, and local timezone + arrow_cast(column1, 'Timestamp(Nanosecond, None)') as ts_nano_no_tz, + arrow_cast(column1, 'Timestamp(Nanosecond, Some("UTC"))') as ts_nano_utc, + arrow_cast(column1, 'Timestamp(Nanosecond, Some("America/New_York"))') as ts_nano_eastern, + -- milliseconds, with no, utc, and local timezone + arrow_cast(column1, 'Timestamp(Millisecond, None)') as ts_milli_no_tz, + arrow_cast(column1, 'Timestamp(Millisecond, Some("UTC"))') as ts_milli_utc, + arrow_cast(column1, 'Timestamp(Millisecond, Some("America/New_York"))') as ts_milli_eastern +FROM t; + + +query PPPPPP +SELECT * FROM source_ts; +---- +2020-01-01T00:00:00 2020-01-01T00:00:00Z 2019-12-31T19:00:00-05:00 2020-01-01T00:00:00 2020-01-01T00:00:00Z 2019-12-31T19:00:00-05:00 +2021-01-01T00:00:00 2021-01-01T00:00:00Z 2020-12-31T19:00:00-05:00 2021-01-01T00:00:00 2021-01-01T00:00:00Z 2020-12-31T19:00:00-05:00 +2020-09-01T00:00:00 2020-09-01T00:00:00Z 2020-08-31T20:00:00-04:00 2020-09-01T00:00:00 2020-09-01T00:00:00Z 2020-08-31T20:00:00-04:00 +2020-01-25T00:00:00 2020-01-25T00:00:00Z 2020-01-24T19:00:00-05:00 2020-01-25T00:00:00 2020-01-25T00:00:00Z 2020-01-24T19:00:00-05:00 +2020-01-24T00:00:00 2020-01-24T00:00:00Z 2020-01-23T19:00:00-05:00 2020-01-24T00:00:00 2020-01-24T00:00:00Z 2020-01-23T19:00:00-05:00 +2020-01-01T12:00:00 2020-01-01T12:00:00Z 2020-01-01T07:00:00-05:00 2020-01-01T12:00:00 2020-01-01T12:00:00Z 2020-01-01T07:00:00-05:00 +2020-01-01T00:30:00 2020-01-01T00:30:00Z 2019-12-31T19:30:00-05:00 2020-01-01T00:30:00 2020-01-01T00:30:00Z 2019-12-31T19:30:00-05:00 +2020-01-01T00:00:30 2020-01-01T00:00:30Z 2019-12-31T19:00:30-05:00 2020-01-01T00:00:30 2020-01-01T00:00:30Z 2019-12-31T19:00:30-05:00 +2020-01-01T00:00:00.123 2020-01-01T00:00:00.123Z 2019-12-31T19:00:00.123-05:00 2020-01-01T00:00:00.123 2020-01-01T00:00:00.123Z 2019-12-31T19:00:00.123-05:00 +2020-01-01T00:00:00.123456 2020-01-01T00:00:00.123456Z 2019-12-31T19:00:00.123456-05:00 2020-01-01T00:00:00.123 2020-01-01T00:00:00.123Z 2019-12-31T19:00:00.123-05:00 +2020-01-01T00:00:00.123456789 2020-01-01T00:00:00.123456789Z 2019-12-31T19:00:00.123456789-05:00 2020-01-01T00:00:00.123 2020-01-01T00:00:00.123Z 2019-12-31T19:00:00.123-05:00 + +# date_part (year) with columns and explicit timestamp +query IIIIII +SELECT date_part('year', ts_nano_no_tz), date_part('year', ts_nano_utc), date_part('year', ts_nano_eastern), date_part('year', ts_milli_no_tz), date_part('year', ts_milli_utc), date_part('year', ts_milli_eastern) FROM source_ts; +---- +2020 2020 2019 2020 2020 2019 +2021 2021 2020 2021 2021 2020 +2020 2020 2020 2020 2020 2020 +2020 2020 2020 2020 2020 2020 +2020 2020 2020 2020 2020 2020 +2020 2020 2020 2020 2020 2020 +2020 2020 2019 2020 2020 2019 +2020 2020 2019 2020 2020 2019 +2020 2020 2019 2020 2020 2019 +2020 2020 2019 2020 2020 2019 +2020 2020 2019 2020 2020 2019 + +# date_part (month) +query IIIIII +SELECT date_part('month', ts_nano_no_tz), date_part('month', ts_nano_utc), date_part('month', ts_nano_eastern), date_part('month', ts_milli_no_tz), date_part('month', ts_milli_utc), date_part('month', ts_milli_eastern) FROM source_ts; +---- +1 1 12 1 1 12 +1 1 12 1 1 12 +9 9 8 9 9 8 +1 1 1 1 1 1 +1 1 1 1 1 1 +1 1 1 1 1 1 +1 1 12 1 1 12 +1 1 12 1 1 12 +1 1 12 1 1 12 +1 1 12 1 1 12 +1 1 12 1 1 12 + +# date_part (day) +query IIIIII +SELECT date_part('day', ts_nano_no_tz), date_part('day', ts_nano_utc), date_part('day', ts_nano_eastern), date_part('day', ts_milli_no_tz), date_part('day', ts_milli_utc), date_part('day', ts_milli_eastern) FROM source_ts; +---- +1 1 31 1 1 31 +1 1 31 1 1 31 +1 1 31 1 1 31 +25 25 24 25 25 24 +24 24 23 24 24 23 +1 1 1 1 1 1 +1 1 31 1 1 31 +1 1 31 1 1 31 +1 1 31 1 1 31 +1 1 31 1 1 31 +1 1 31 1 1 31 + +# date_part (hour) +query IIIIII +SELECT date_part('hour', ts_nano_no_tz), date_part('hour', ts_nano_utc), date_part('hour', ts_nano_eastern), date_part('hour', ts_milli_no_tz), date_part('hour', ts_milli_utc), date_part('hour', ts_milli_eastern) FROM source_ts; +---- +0 0 19 0 0 19 +0 0 19 0 0 19 +0 0 20 0 0 20 +0 0 19 0 0 19 +0 0 19 0 0 19 +12 12 7 12 12 7 +0 0 19 0 0 19 +0 0 19 0 0 19 +0 0 19 0 0 19 +0 0 19 0 0 19 +0 0 19 0 0 19 + +# date_part (minute) +query IIIIII +SELECT date_part('minute', ts_nano_no_tz), date_part('minute', ts_nano_utc), date_part('minute', ts_nano_eastern), date_part('minute', ts_milli_no_tz), date_part('minute', ts_milli_utc), date_part('minute', ts_milli_eastern) FROM source_ts; +---- +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 +30 30 30 30 30 30 +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 + +# date_part (second) +query IIIIII +SELECT date_part('second', ts_nano_no_tz), date_part('second', ts_nano_utc), date_part('second', ts_nano_eastern), date_part('second', ts_milli_no_tz), date_part('second', ts_milli_utc), date_part('second', ts_milli_eastern) FROM source_ts; +---- +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 +30 30 30 30 30 30 +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 + +# date_part (millisecond) +query IIIIII +SELECT date_part('millisecond', ts_nano_no_tz), date_part('millisecond', ts_nano_utc), date_part('millisecond', ts_nano_eastern), date_part('millisecond', ts_milli_no_tz), date_part('millisecond', ts_milli_utc), date_part('millisecond', ts_milli_eastern) FROM source_ts; +---- +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 +30000 30000 30000 30000 30000 30000 +123 123 123 123 123 123 +123 123 123 123 123 123 +123 123 123 123 123 123 + +# date_part (microsecond) +query IIIIII +SELECT date_part('microsecond', ts_nano_no_tz), date_part('microsecond', ts_nano_utc), date_part('microsecond', ts_nano_eastern), date_part('microsecond', ts_milli_no_tz), date_part('microsecond', ts_milli_utc), date_part('microsecond', ts_milli_eastern) FROM source_ts; +---- +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 +30000000 30000000 30000000 30000000 30000000 30000000 +123000 123000 123000 123000 123000 123000 +123456 123456 123456 123000 123000 123000 +123456 123456 123456 123000 123000 123000 + +### Cleanup +statement ok +drop table source_ts; + + + +## "Unit style" tests for types and units on scalar values + + +query error +SELECT EXTRACT("'''year'''" FROM timestamp '2020-09-08T12:00:00+00:00') + +query error +SELECT EXTRACT("'year'" FROM timestamp '2020-09-08T12:00:00+00:00') + +query I +SELECT date_part('YEAR', CAST('2000-01-01' AS DATE)) +---- +2000 + +query I +SELECT EXTRACT(year FROM timestamp '2020-09-08T12:00:00+00:00') +---- +2020 + +query I +SELECT EXTRACT("year" FROM timestamp '2020-09-08T12:00:00+00:00') +---- +2020 + +query I +SELECT EXTRACT('year' FROM timestamp '2020-09-08T12:00:00+00:00') +---- +2020 + +query I +SELECT date_part('QUARTER', CAST('2000-01-01' AS DATE)) +---- +1 + +query I +SELECT EXTRACT(quarter FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +3 + +query I +SELECT EXTRACT("quarter" FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +3 + +query I +SELECT EXTRACT('quarter' FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +3 + +query I +SELECT date_part('MONTH', CAST('2000-01-01' AS DATE)) +---- +1 + +query I +SELECT EXTRACT(month FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +9 + +query I +SELECT EXTRACT("month" FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +9 + +query I +SELECT EXTRACT('month' FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +9 + +query I +SELECT date_part('WEEK', CAST('2003-01-01' AS DATE)) +---- +1 + +query I +SELECT EXTRACT(WEEK FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +37 + +query I +SELECT EXTRACT("WEEK" FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +37 + +query I +SELECT EXTRACT('WEEK' FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +37 + +query I +SELECT date_part('DAY', CAST('2000-01-01' AS DATE)) +---- +1 + +query I +SELECT EXTRACT(day FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +8 + +query I +SELECT EXTRACT("day" FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +8 + +query I +SELECT EXTRACT('day' FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +8 + +query I +SELECT date_part('DOY', CAST('2000-01-01' AS DATE)) +---- +1 + +query I +SELECT EXTRACT(doy FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +252 + +query I +SELECT EXTRACT("doy" FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +252 + +query I +SELECT EXTRACT('doy' FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +252 + +query I +SELECT date_part('DOW', CAST('2000-01-01' AS DATE)) +---- +6 + +query I +SELECT EXTRACT(dow FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +2 + +query I +SELECT EXTRACT("dow" FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +2 + +query I +SELECT EXTRACT('dow' FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +2 + +query I +SELECT date_part('HOUR', CAST('2000-01-01' AS DATE)) +---- +0 + +query I +SELECT EXTRACT(hour FROM to_timestamp('2020-09-08T12:03:03+00:00')) +---- +12 + +query I +SELECT EXTRACT("hour" FROM to_timestamp('2020-09-08T12:03:03+00:00')) +---- +12 + +query I +SELECT EXTRACT('hour' FROM to_timestamp('2020-09-08T12:03:03+00:00')) +---- +12 + +query I +SELECT EXTRACT(minute FROM to_timestamp('2020-09-08T12:12:00+00:00')) +---- +12 + +query I +SELECT EXTRACT("minute" FROM to_timestamp('2020-09-08T12:12:00+00:00')) +---- +12 + +query I +SELECT EXTRACT('minute' FROM to_timestamp('2020-09-08T12:12:00+00:00')) +---- +12 + +query I +SELECT date_part('minute', to_timestamp('2020-09-08T12:12:00+00:00')) +---- +12 + +# make sure the return type is integer +query T +SELECT arrow_typeof(date_part('minute', to_timestamp('2020-09-08T12:12:00+00:00'))) +---- +Int32 + +query I +SELECT EXTRACT(second FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12 + +query I +SELECT EXTRACT(millisecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123 + +query I +SELECT EXTRACT(microsecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123456 + +query error DataFusion error: This feature is not implemented: Date part Nanosecond not supported +SELECT EXTRACT(nanosecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00') + +query I +SELECT EXTRACT("second" FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12 + +query I +SELECT EXTRACT("millisecond" FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123 + +query I +SELECT EXTRACT("microsecond" FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123456 + +query error DataFusion error: This feature is not implemented: Date part Nanosecond not supported +SELECT EXTRACT("nanosecond" FROM timestamp '2020-09-08T12:00:12.12345678+00:00') + +query I +SELECT EXTRACT('second' FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12 + +query I +SELECT EXTRACT('millisecond' FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123 + +query I +SELECT EXTRACT('microsecond' FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123456 + +query error DataFusion error: This feature is not implemented: Date part Nanosecond not supported +SELECT EXTRACT('nanosecond' FROM timestamp '2020-09-08T12:00:12.12345678+00:00') + + +# Keep precision when coercing Utf8 to Timestamp +query I +SELECT date_part('second', timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12 + +query I +SELECT date_part('millisecond', timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123 + +query I +SELECT date_part('microsecond', timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123456 + +query error DataFusion error: This feature is not implemented: Date part Nanosecond not supported +SELECT date_part('nanosecond', timestamp '2020-09-08T12:00:12.12345678+00:00') + + +query I +SELECT date_part('second', '2020-09-08T12:00:12.12345678+00:00') +---- +12 + +query I +SELECT date_part('millisecond', '2020-09-08T12:00:12.12345678+00:00') +---- +12123 + +query I +SELECT date_part('microsecond', '2020-09-08T12:00:12.12345678+00:00') +---- +12123456 + +query error DataFusion error: This feature is not implemented: Date part Nanosecond not supported +SELECT date_part('nanosecond', '2020-09-08T12:00:12.12345678+00:00') + +# test_date_part_time + +## time32 seconds +query I +SELECT date_part('hour', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +23 + +query I +SELECT extract(hour from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +23 + +query I +SELECT date_part('minute', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +32 + +query I +SELECT extract(minute from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +32 + +query I +SELECT date_part('second', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50 + +query I +SELECT extract(second from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50 + +query I +SELECT date_part('millisecond', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000 + +query I +SELECT extract(millisecond from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000 + +query I +SELECT date_part('microsecond', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000000 + +query I +SELECT extract(microsecond from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000000 + +query error DataFusion error: This feature is not implemented: Date part Nanosecond not supported +SELECT extract(nanosecond from arrow_cast('23:32:50'::time, 'Time32(Second)')) + +query R +SELECT date_part('epoch', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +84770 + +query R +SELECT extract(epoch from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +84770 + +## time32 milliseconds +query I +SELECT date_part('hour', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +23 + +query I +SELECT extract(hour from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +23 + +query I +SELECT date_part('minute', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +32 + +query I +SELECT extract(minute from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +32 + +query I +SELECT date_part('second', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50 + +query I +SELECT extract(second from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50 + +query I +SELECT date_part('millisecond', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123 + +query I +SELECT extract(millisecond from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123 + +query I +SELECT date_part('microsecond', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123000 + +query I +SELECT extract(microsecond from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123000 + +query error DataFusion error: This feature is not implemented: Date part Nanosecond not supported +SELECT extract(nanosecond from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) + +query R +SELECT date_part('epoch', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +84770.123 + +query R +SELECT extract(epoch from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +84770.123 + +## time64 microseconds +query I +SELECT date_part('hour', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +23 + +query I +SELECT extract(hour from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +23 + +query I +SELECT date_part('minute', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +32 + +query I +SELECT extract(minute from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +32 + +query I +SELECT date_part('second', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50 + +query I +SELECT extract(second from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50 + +query I +SELECT date_part('millisecond', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123 + +query I +SELECT extract(millisecond from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123 + +query I +SELECT date_part('microsecond', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123456 + +query I +SELECT extract(microsecond from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123456 + +query error DataFusion error: This feature is not implemented: Date part Nanosecond not supported +SELECT extract(nanosecond from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) + +query R +SELECT date_part('epoch', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +84770.123456 + +query R +SELECT extract(epoch from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +84770.123456 + +## time64 nanoseconds +query I +SELECT date_part('hour', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +23 + +query I +SELECT extract(hour from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +23 + +query I +SELECT date_part('minute', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +32 + +query I +SELECT extract(minute from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +32 + +query I +SELECT date_part('second', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50 + +query I +select extract(second from '2024-08-09T12:13:14') +---- +14 + +query I +select extract(second from timestamp '2024-08-09T12:13:14') +---- +14 + +query I +select extract(seconds from '2024-08-09T12:13:14') +---- +14 + +query I +select extract(seconds from timestamp '2024-08-09T12:13:14') +---- +14 + +query I +SELECT extract(second from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50 + +query I +SELECT date_part('millisecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123 + +query I +SELECT extract(millisecond from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123 + +# just some floating point stuff happening in the result here +query I +SELECT date_part('microsecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123456 + +query I +SELECT extract(microsecond from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123456 + +query I +SELECT extract(us from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123456 + +query error DataFusion error: This feature is not implemented: Date part Nanosecond not supported +SELECT date_part('nanosecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) + +query R +SELECT date_part('epoch', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +84770.123456789 + +query R +SELECT extract(epoch from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +84770.123456789 + +# test_extract_epoch + +query R +SELECT extract(epoch from '1870-01-01T07:29:10.256'::timestamp) +---- +-3155646649.744 + +query R +SELECT extract(epoch from '2000-01-01T00:00:00.000'::timestamp) +---- +946684800 + +query R +SELECT extract(epoch from to_timestamp('2000-01-01T00:00:00+00:00')) +---- +946684800 + +query R +SELECT extract(epoch from NULL::timestamp) +---- +NULL + +query R +SELECT extract(epoch from arrow_cast('1970-01-01', 'Date32')) +---- +0 + +query R +SELECT extract(epoch from arrow_cast('1970-01-02', 'Date32')) +---- +86400 + +query R +SELECT extract(epoch from arrow_cast('1970-01-11', 'Date32')) +---- +864000 + +query R +SELECT extract(epoch from arrow_cast('1969-12-31', 'Date32')) +---- +-86400 + +query R +SELECT extract(epoch from arrow_cast('1970-01-01', 'Date64')) +---- +0 + +query R +SELECT extract(epoch from arrow_cast('1970-01-02', 'Date64')) +---- +86400 + +query R +SELECT extract(epoch from arrow_cast('1970-01-11', 'Date64')) +---- +864000 + +query R +SELECT extract(epoch from arrow_cast('1969-12-31', 'Date64')) +---- +-86400 + +# test_extract_interval + +query I +SELECT extract(year from arrow_cast('10 years', 'Interval(YearMonth)')) +---- +10 + +query I +SELECT extract(month from arrow_cast('10 years', 'Interval(YearMonth)')) +---- +0 + +query I +SELECT extract(year from arrow_cast('10 months', 'Interval(YearMonth)')) +---- +0 + +query I +SELECT extract(month from arrow_cast('10 months', 'Interval(YearMonth)')) +---- +10 + +query I +SELECT extract(year from arrow_cast('20 months', 'Interval(YearMonth)')) +---- +1 + +query I +SELECT extract(month from arrow_cast('20 months', 'Interval(YearMonth)')) +---- +8 + +query error DataFusion error: Arrow error: Compute error: Year does not support: Interval\(DayTime\) +SELECT extract(year from arrow_cast('10 days', 'Interval(DayTime)')) + +query error DataFusion error: Arrow error: Compute error: Month does not support: Interval\(DayTime\) +SELECT extract(month from arrow_cast('10 days', 'Interval(DayTime)')) + +query I +SELECT extract(day from arrow_cast('10 days', 'Interval(DayTime)')) +---- +10 + +query I +SELECT extract(day from arrow_cast('14400 minutes', 'Interval(DayTime)')) +---- +0 + +query I +SELECT extract(minute from arrow_cast('14400 minutes', 'Interval(DayTime)')) +---- +14400 + +query I +SELECT extract(second from arrow_cast('5.1 seconds', 'Interval(DayTime)')) +---- +5 + +query I +SELECT extract(second from arrow_cast('14400 minutes', 'Interval(DayTime)')) +---- +864000 + +query I +SELECT extract(second from arrow_cast('2 months', 'Interval(MonthDayNano)')) +---- +0 + +query I +SELECT extract(second from arrow_cast('2 days', 'Interval(MonthDayNano)')) +---- +0 + +query I +SELECT extract(second from arrow_cast('2 seconds', 'Interval(MonthDayNano)')) +---- +2 + +query I +SELECT extract(seconds from arrow_cast('2 seconds', 'Interval(MonthDayNano)')) +---- +2 + +query R +SELECT extract(epoch from arrow_cast('2 seconds', 'Interval(MonthDayNano)')) +---- +2 + +query I +SELECT extract(milliseconds from arrow_cast('2 seconds', 'Interval(MonthDayNano)')) +---- +2000 + +query I +SELECT extract(second from arrow_cast('2030 milliseconds', 'Interval(MonthDayNano)')) +---- +2 + +query I +SELECT extract(second from arrow_cast(NULL, 'Interval(MonthDayNano)')) +---- +NULL + +statement ok +create table t (id int, i interval) as values + (0, interval '5 months 1 day 10 nanoseconds'), + (1, interval '1 year 3 months'), + (2, interval '3 days 2 milliseconds'), + (3, interval '2 seconds'), + (4, interval '8 months'), + (5, NULL); + +query III +select + id, + extract(second from i), + extract(month from i) +from t +order by id; +---- +0 0 5 +1 0 15 +2 0 0 +3 2 0 +4 0 8 +5 NULL NULL + +statement ok +drop table t; + +# test_extract_duration + +query I +SELECT extract(second from arrow_cast(2, 'Duration(Second)')) +---- +2 + +query I +SELECT extract(seconds from arrow_cast(2, 'Duration(Second)')) +---- +2 + +query R +SELECT extract(epoch from arrow_cast(2, 'Duration(Second)')) +---- +2 + +query I +SELECT extract(millisecond from arrow_cast(2, 'Duration(Second)')) +---- +2000 + +query I +SELECT extract(second from arrow_cast(2, 'Duration(Millisecond)')) +---- +0 + +query I +SELECT extract(second from arrow_cast(2002, 'Duration(Millisecond)')) +---- +2 + +query I +SELECT extract(millisecond from arrow_cast(2002, 'Duration(Millisecond)')) +---- +2002 + +query I +SELECT extract(day from arrow_cast(864000, 'Duration(Second)')) +---- +10 + +query error DataFusion error: Arrow error: Compute error: Month does not support: Duration\(Second\) +SELECT extract(month from arrow_cast(864000, 'Duration(Second)')) + +query error DataFusion error: Arrow error: Compute error: Year does not support: Duration\(Second\) +SELECT extract(year from arrow_cast(864000, 'Duration(Second)')) + +query I +SELECT extract(day from arrow_cast(NULL, 'Duration(Second)')) +---- +NULL + +# test_extract_date_part_func + +query B +SELECT (date_part('year', now()) = EXTRACT(year FROM now())) +---- +true + +query B +SELECT (date_part('quarter', now()) = EXTRACT(quarter FROM now())) +---- +true + +query B +SELECT (date_part('month', now()) = EXTRACT(month FROM now())) +---- +true + +query B +SELECT (date_part('week', now()) = EXTRACT(week FROM now())) +---- +true + +query B +SELECT (date_part('day', now()) = EXTRACT(day FROM now())) +---- +true + +query B +SELECT (date_part('hour', now()) = EXTRACT(hour FROM now())) +---- +true + +query B +SELECT (date_part('minute', now()) = EXTRACT(minute FROM now())) +---- +true + +query B +SELECT (date_part('second', now()) = EXTRACT(second FROM now())) +---- +true + +query B +SELECT (date_part('millisecond', now()) = EXTRACT(millisecond FROM now())) +---- +true + +query B +SELECT (date_part('microsecond', now()) = EXTRACT(microsecond FROM now())) +---- +true + +query error DataFusion error: This feature is not implemented: Date part Nanosecond not supported +SELECT (date_part('nanosecond', now()) = EXTRACT(nanosecond FROM now())) diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index e636e93007a4..49aaa877caa6 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -2864,13 +2864,13 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 WHERE t1_id I ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] -05)--------CoalesceBatchesExec: target_batch_size=2 -06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 -07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -08)--------------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] +04)------CoalesceBatchesExec: target_batch_size=2 +05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +07)------------MemoryExec: partitions=1, partition_sizes=[1] +08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 09)--------CoalesceBatchesExec: target_batch_size=2 10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -2905,13 +2905,13 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 LEFT SEMI JOI ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] -05)--------CoalesceBatchesExec: target_batch_size=2 -06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 -07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -08)--------------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] +04)------CoalesceBatchesExec: target_batch_size=2 +05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +07)------------MemoryExec: partitions=1, partition_sizes=[1] +08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 09)--------CoalesceBatchesExec: target_batch_size=2 10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -2967,10 +2967,10 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 WHERE t1_id I ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] -05)--------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] +04)------MemoryExec: partitions=1, partition_sizes=[1] +05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 07)----------MemoryExec: partitions=1, partition_sizes=[1] @@ -3003,10 +3003,10 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 LEFT SEMI JOI ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] -05)--------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] +04)------MemoryExec: partitions=1, partition_sizes=[1] +05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 07)----------MemoryExec: partitions=1, partition_sizes=[1] @@ -3061,13 +3061,13 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t1 t1 WHER ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0 -05)--------CoalesceBatchesExec: target_batch_size=2 -06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 -07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -08)--------------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0 +04)------CoalesceBatchesExec: target_batch_size=2 +05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +07)------------MemoryExec: partitions=1, partition_sizes=[1] +08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 09)--------CoalesceBatchesExec: target_batch_size=2 10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -3083,13 +3083,13 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t2 t2 RIGH ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1 -05)--------CoalesceBatchesExec: target_batch_size=2 -06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 -07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -08)--------------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1 +04)------CoalesceBatchesExec: target_batch_size=2 +05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +07)------------MemoryExec: partitions=1, partition_sizes=[1] +08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 09)--------CoalesceBatchesExec: target_batch_size=2 10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -3143,10 +3143,10 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t1 t1 WHER ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0 -05)--------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0 +04)------MemoryExec: partitions=1, partition_sizes=[1] +05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 07)----------MemoryExec: partitions=1, partition_sizes=[1] @@ -3160,10 +3160,10 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t2 t2 RIGH ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1 -05)--------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1 +04)------MemoryExec: partitions=1, partition_sizes=[1] +05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 07)----------MemoryExec: partitions=1, partition_sizes=[1] @@ -4058,9 +4058,9 @@ logical_plan 03)----TableScan: join_t1 projection=[t1_id, t1_name] 04)--SubqueryAlias: series 05)----Subquery: -06)------Projection: unnest_placeholder(generate_series(Int64(1),outer_ref(t1.t1_int)),depth=1) AS i -07)--------Unnest: lists[unnest_placeholder(generate_series(Int64(1),outer_ref(t1.t1_int)))|depth=1] structs[] -08)----------Projection: generate_series(Int64(1), CAST(outer_ref(t1.t1_int) AS Int64)) AS unnest_placeholder(generate_series(Int64(1),outer_ref(t1.t1_int))) +06)------Projection: __unnest_placeholder(generate_series(Int64(1),outer_ref(t1.t1_int)),depth=1) AS i +07)--------Unnest: lists[__unnest_placeholder(generate_series(Int64(1),outer_ref(t1.t1_int)))|depth=1] structs[] +08)----------Projection: generate_series(Int64(1), CAST(outer_ref(t1.t1_int) AS Int64)) AS __unnest_placeholder(generate_series(Int64(1),outer_ref(t1.t1_int))) 09)------------EmptyRelation physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(UInt32, Column { relation: Some(Bare { table: "t1" }), name: "t1_int" }) @@ -4081,9 +4081,9 @@ logical_plan 03)----TableScan: join_t1 projection=[t1_id, t1_name] 04)--SubqueryAlias: series 05)----Subquery: -06)------Projection: unnest_placeholder(generate_series(Int64(1),outer_ref(t2.t1_int)),depth=1) AS i -07)--------Unnest: lists[unnest_placeholder(generate_series(Int64(1),outer_ref(t2.t1_int)))|depth=1] structs[] -08)----------Projection: generate_series(Int64(1), CAST(outer_ref(t2.t1_int) AS Int64)) AS unnest_placeholder(generate_series(Int64(1),outer_ref(t2.t1_int))) +06)------Projection: __unnest_placeholder(generate_series(Int64(1),outer_ref(t2.t1_int)),depth=1) AS i +07)--------Unnest: lists[__unnest_placeholder(generate_series(Int64(1),outer_ref(t2.t1_int)))|depth=1] structs[] +08)----------Projection: generate_series(Int64(1), CAST(outer_ref(t2.t1_int) AS Int64)) AS __unnest_placeholder(generate_series(Int64(1),outer_ref(t2.t1_int))) 09)------------EmptyRelation physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(UInt32, Column { relation: Some(Bare { table: "t2" }), name: "t1_int" }) @@ -4313,3 +4313,86 @@ physical_plan 04)------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(binary_col@0, binary_col@0)] 05)--------MemoryExec: partitions=1, partition_sizes=[1] 06)--------MemoryExec: partitions=1, partition_sizes=[1] + +# Test hash join sort push down +# Issue: https://github.com/apache/datafusion/issues/13559 +statement ok +CREATE TABLE test(a INT, b INT, c INT) + +statement ok +insert into test values (1,2,3), (4,5,6), (null, 7, 8), (8, null, 9), (9, 10, null) + +statement ok +set datafusion.execution.target_partitions = 2; + +query TT +explain select * from test where a in (select a from test where b > 3) order by c desc nulls first; +---- +logical_plan +01)Sort: test.c DESC NULLS FIRST +02)--LeftSemi Join: test.a = __correlated_sq_1.a +03)----TableScan: test projection=[a, b, c] +04)----SubqueryAlias: __correlated_sq_1 +05)------Projection: test.a +06)--------Filter: test.b > Int32(3) +07)----------TableScan: test projection=[a, b] +physical_plan +01)SortPreservingMergeExec: [c@2 DESC] +02)--CoalesceBatchesExec: target_batch_size=3 +03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(a@0, a@0)] +04)------CoalesceBatchesExec: target_batch_size=3 +05)--------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 +06)----------CoalesceBatchesExec: target_batch_size=3 +07)------------FilterExec: b@1 > 3, projection=[a@0] +08)--------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +09)----------------MemoryExec: partitions=1, partition_sizes=[1] +10)------SortExec: expr=[c@2 DESC], preserve_partitioning=[true] +11)--------CoalesceBatchesExec: target_batch_size=3 +12)----------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 +13)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +14)--------------MemoryExec: partitions=1, partition_sizes=[1] + +query TT +explain select * from test where a in (select a from test where b > 3) order by c desc nulls last; +---- +logical_plan +01)Sort: test.c DESC NULLS LAST +02)--LeftSemi Join: test.a = __correlated_sq_1.a +03)----TableScan: test projection=[a, b, c] +04)----SubqueryAlias: __correlated_sq_1 +05)------Projection: test.a +06)--------Filter: test.b > Int32(3) +07)----------TableScan: test projection=[a, b] +physical_plan +01)SortPreservingMergeExec: [c@2 DESC NULLS LAST] +02)--CoalesceBatchesExec: target_batch_size=3 +03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(a@0, a@0)] +04)------CoalesceBatchesExec: target_batch_size=3 +05)--------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 +06)----------CoalesceBatchesExec: target_batch_size=3 +07)------------FilterExec: b@1 > 3, projection=[a@0] +08)--------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +09)----------------MemoryExec: partitions=1, partition_sizes=[1] +10)------SortExec: expr=[c@2 DESC NULLS LAST], preserve_partitioning=[true] +11)--------CoalesceBatchesExec: target_batch_size=3 +12)----------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 +13)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +14)--------------MemoryExec: partitions=1, partition_sizes=[1] + +query III +select * from test where a in (select a from test where b > 3) order by c desc nulls first; +---- +9 10 NULL +4 5 6 + +query III +select * from test where a in (select a from test where b > 3) order by c desc nulls last; +---- +4 5 6 +9 10 NULL + +statement ok +DROP TABLE test + +statement ok +set datafusion.execution.target_partitions = 1; diff --git a/datafusion/sqllogictest/test_files/push_down_filter.slt b/datafusion/sqllogictest/test_files/push_down_filter.slt index 86aa07b04ce1..64cc51b3c4ff 100644 --- a/datafusion/sqllogictest/test_files/push_down_filter.slt +++ b/datafusion/sqllogictest/test_files/push_down_filter.slt @@ -36,9 +36,9 @@ query TT explain select uc2 from (select unnest(column2) as uc2, column1 from v) where column1 = 2; ---- logical_plan -01)Projection: unnest_placeholder(v.column2,depth=1) AS uc2 -02)--Unnest: lists[unnest_placeholder(v.column2)|depth=1] structs[] -03)----Projection: v.column2 AS unnest_placeholder(v.column2), v.column1 +01)Projection: __unnest_placeholder(v.column2,depth=1) AS uc2 +02)--Unnest: lists[__unnest_placeholder(v.column2)|depth=1] structs[] +03)----Projection: v.column2 AS __unnest_placeholder(v.column2), v.column1 04)------Filter: v.column1 = Int64(2) 05)--------TableScan: v projection=[column1, column2] @@ -53,11 +53,11 @@ query TT explain select uc2 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3; ---- logical_plan -01)Projection: unnest_placeholder(v.column2,depth=1) AS uc2 -02)--Filter: unnest_placeholder(v.column2,depth=1) > Int64(3) -03)----Projection: unnest_placeholder(v.column2,depth=1) -04)------Unnest: lists[unnest_placeholder(v.column2)|depth=1] structs[] -05)--------Projection: v.column2 AS unnest_placeholder(v.column2), v.column1 +01)Projection: __unnest_placeholder(v.column2,depth=1) AS uc2 +02)--Filter: __unnest_placeholder(v.column2,depth=1) > Int64(3) +03)----Projection: __unnest_placeholder(v.column2,depth=1) +04)------Unnest: lists[__unnest_placeholder(v.column2)|depth=1] structs[] +05)--------Projection: v.column2 AS __unnest_placeholder(v.column2), v.column1 06)----------TableScan: v projection=[column1, column2] query II @@ -71,10 +71,10 @@ query TT explain select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 AND column1 = 2; ---- logical_plan -01)Projection: unnest_placeholder(v.column2,depth=1) AS uc2, v.column1 -02)--Filter: unnest_placeholder(v.column2,depth=1) > Int64(3) -03)----Unnest: lists[unnest_placeholder(v.column2)|depth=1] structs[] -04)------Projection: v.column2 AS unnest_placeholder(v.column2), v.column1 +01)Projection: __unnest_placeholder(v.column2,depth=1) AS uc2, v.column1 +02)--Filter: __unnest_placeholder(v.column2,depth=1) > Int64(3) +03)----Unnest: lists[__unnest_placeholder(v.column2)|depth=1] structs[] +04)------Projection: v.column2 AS __unnest_placeholder(v.column2), v.column1 05)--------Filter: v.column1 = Int64(2) 06)----------TableScan: v projection=[column1, column2] @@ -90,10 +90,10 @@ query TT explain select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 OR column1 = 2; ---- logical_plan -01)Projection: unnest_placeholder(v.column2,depth=1) AS uc2, v.column1 -02)--Filter: unnest_placeholder(v.column2,depth=1) > Int64(3) OR v.column1 = Int64(2) -03)----Unnest: lists[unnest_placeholder(v.column2)|depth=1] structs[] -04)------Projection: v.column2 AS unnest_placeholder(v.column2), v.column1 +01)Projection: __unnest_placeholder(v.column2,depth=1) AS uc2, v.column1 +02)--Filter: __unnest_placeholder(v.column2,depth=1) > Int64(3) OR v.column1 = Int64(2) +03)----Unnest: lists[__unnest_placeholder(v.column2)|depth=1] structs[] +04)------Projection: v.column2 AS __unnest_placeholder(v.column2), v.column1 05)--------TableScan: v projection=[column1, column2] statement ok @@ -112,10 +112,10 @@ query TT explain select * from (select column1, unnest(column2) as o from d) where o['a'] = 1; ---- logical_plan -01)Projection: d.column1, unnest_placeholder(d.column2,depth=1) AS o -02)--Filter: get_field(unnest_placeholder(d.column2,depth=1), Utf8("a")) = Int64(1) -03)----Unnest: lists[unnest_placeholder(d.column2)|depth=1] structs[] -04)------Projection: d.column1, d.column2 AS unnest_placeholder(d.column2) +01)Projection: d.column1, __unnest_placeholder(d.column2,depth=1) AS o +02)--Filter: get_field(__unnest_placeholder(d.column2,depth=1), Utf8("a")) = Int64(1) +03)----Unnest: lists[__unnest_placeholder(d.column2)|depth=1] structs[] +04)------Projection: d.column1, d.column2 AS __unnest_placeholder(d.column2) 05)--------TableScan: d projection=[column1, column2] diff --git a/datafusion/sqllogictest/test_files/string/string_view.slt b/datafusion/sqllogictest/test_files/string/string_view.slt index ebabaf7655ff..c37dd1ed3b4f 100644 --- a/datafusion/sqllogictest/test_files/string/string_view.slt +++ b/datafusion/sqllogictest/test_files/string/string_view.slt @@ -731,7 +731,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: regexp_like(test.column1_utf8view, Utf8("^https?://(?:www\.)?([^/]+)/.*$")) AS k +01)Projection: regexp_like(test.column1_utf8view, Utf8View("^https?://(?:www\.)?([^/]+)/.*$")) AS k 02)--TableScan: test projection=[column1_utf8view] ## Ensure no casts for REGEXP_MATCH diff --git a/datafusion/sqllogictest/test_files/table_functions.slt b/datafusion/sqllogictest/test_files/table_functions.slt index 12402e0d70c5..79294993dded 100644 --- a/datafusion/sqllogictest/test_files/table_functions.slt +++ b/datafusion/sqllogictest/test_files/table_functions.slt @@ -139,4 +139,4 @@ SELECT generate_series(1, t1.end) FROM generate_series(3, 5) as t1(end) ---- [1, 2, 3, 4, 5] [1, 2, 3, 4] -[1, 2, 3] \ No newline at end of file +[1, 2, 3] diff --git a/datafusion/sqllogictest/test_files/unnest.slt b/datafusion/sqllogictest/test_files/unnest.slt index d409e0902f7e..1c54006bd2a0 100644 --- a/datafusion/sqllogictest/test_files/unnest.slt +++ b/datafusion/sqllogictest/test_files/unnest.slt @@ -594,17 +594,17 @@ query TT explain select unnest(unnest(column3)), column3 from recursive_unnest_table; ---- logical_plan -01)Unnest: lists[] structs[unnest_placeholder(UNNEST(recursive_unnest_table.column3))] -02)--Projection: unnest_placeholder(recursive_unnest_table.column3,depth=1) AS UNNEST(recursive_unnest_table.column3) AS unnest_placeholder(UNNEST(recursive_unnest_table.column3)), recursive_unnest_table.column3 -03)----Unnest: lists[unnest_placeholder(recursive_unnest_table.column3)|depth=1] structs[] -04)------Projection: recursive_unnest_table.column3 AS unnest_placeholder(recursive_unnest_table.column3), recursive_unnest_table.column3 +01)Unnest: lists[] structs[__unnest_placeholder(UNNEST(recursive_unnest_table.column3))] +02)--Projection: __unnest_placeholder(recursive_unnest_table.column3,depth=1) AS UNNEST(recursive_unnest_table.column3) AS __unnest_placeholder(UNNEST(recursive_unnest_table.column3)), recursive_unnest_table.column3 +03)----Unnest: lists[__unnest_placeholder(recursive_unnest_table.column3)|depth=1] structs[] +04)------Projection: recursive_unnest_table.column3 AS __unnest_placeholder(recursive_unnest_table.column3), recursive_unnest_table.column3 05)--------TableScan: recursive_unnest_table projection=[column3] physical_plan 01)UnnestExec 02)--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -03)----ProjectionExec: expr=[unnest_placeholder(recursive_unnest_table.column3,depth=1)@0 as unnest_placeholder(UNNEST(recursive_unnest_table.column3)), column3@1 as column3] +03)----ProjectionExec: expr=[__unnest_placeholder(recursive_unnest_table.column3,depth=1)@0 as __unnest_placeholder(UNNEST(recursive_unnest_table.column3)), column3@1 as column3] 04)------UnnestExec -05)--------ProjectionExec: expr=[column3@0 as unnest_placeholder(recursive_unnest_table.column3), column3@0 as column3] +05)--------ProjectionExec: expr=[column3@0 as __unnest_placeholder(recursive_unnest_table.column3), column3@0 as column3] 06)----------MemoryExec: partitions=1, partition_sizes=[1] ## unnest->field_access->unnest->unnest @@ -650,19 +650,19 @@ query TT explain select unnest(unnest(unnest(column3)['c1'])), column3 from recursive_unnest_table; ---- logical_plan -01)Projection: unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1],depth=2) AS UNNEST(UNNEST(UNNEST(recursive_unnest_table.column3)[c1])), recursive_unnest_table.column3 -02)--Unnest: lists[unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1])|depth=2] structs[] -03)----Projection: get_field(unnest_placeholder(recursive_unnest_table.column3,depth=1) AS UNNEST(recursive_unnest_table.column3), Utf8("c1")) AS unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1]), recursive_unnest_table.column3 -04)------Unnest: lists[unnest_placeholder(recursive_unnest_table.column3)|depth=1] structs[] -05)--------Projection: recursive_unnest_table.column3 AS unnest_placeholder(recursive_unnest_table.column3), recursive_unnest_table.column3 +01)Projection: __unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1],depth=2) AS UNNEST(UNNEST(UNNEST(recursive_unnest_table.column3)[c1])), recursive_unnest_table.column3 +02)--Unnest: lists[__unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1])|depth=2] structs[] +03)----Projection: get_field(__unnest_placeholder(recursive_unnest_table.column3,depth=1) AS UNNEST(recursive_unnest_table.column3), Utf8("c1")) AS __unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1]), recursive_unnest_table.column3 +04)------Unnest: lists[__unnest_placeholder(recursive_unnest_table.column3)|depth=1] structs[] +05)--------Projection: recursive_unnest_table.column3 AS __unnest_placeholder(recursive_unnest_table.column3), recursive_unnest_table.column3 06)----------TableScan: recursive_unnest_table projection=[column3] physical_plan -01)ProjectionExec: expr=[unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1],depth=2)@0 as UNNEST(UNNEST(UNNEST(recursive_unnest_table.column3)[c1])), column3@1 as column3] +01)ProjectionExec: expr=[__unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1],depth=2)@0 as UNNEST(UNNEST(UNNEST(recursive_unnest_table.column3)[c1])), column3@1 as column3] 02)--UnnestExec -03)----ProjectionExec: expr=[get_field(unnest_placeholder(recursive_unnest_table.column3,depth=1)@0, c1) as unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1]), column3@1 as column3] +03)----ProjectionExec: expr=[get_field(__unnest_placeholder(recursive_unnest_table.column3,depth=1)@0, c1) as __unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1]), column3@1 as column3] 04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 05)--------UnnestExec -06)----------ProjectionExec: expr=[column3@0 as unnest_placeholder(recursive_unnest_table.column3), column3@0 as column3] +06)----------ProjectionExec: expr=[column3@0 as __unnest_placeholder(recursive_unnest_table.column3), column3@0 as column3] 07)------------MemoryExec: partitions=1, partition_sizes=[1] diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 6c48ac68ab6b..188e2ae0915f 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -5127,4 +5127,3 @@ order by id; statement ok drop table t1; - diff --git a/datafusion/wasmtest/Cargo.toml b/datafusion/wasmtest/Cargo.toml index 2440244d08c3..69b9bd61a341 100644 --- a/datafusion/wasmtest/Cargo.toml +++ b/datafusion/wasmtest/Cargo.toml @@ -43,7 +43,6 @@ chrono = { version = "0.4", features = ["wasmbind"] } # code size when deploying. console_error_panic_hook = { version = "0.1.1", optional = true } datafusion = { workspace = true } - datafusion-common = { workspace = true, default-features = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 208d18f0e5ab..4e74cfc54ae5 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1046,7 +1046,9 @@ find_in_set(str, strlist) ### `initcap` -Capitalizes the first character in each word in the input string. Words are delimited by non-alphanumeric characters. +Capitalizes the first character in each word in the ASCII input string. Words are delimited by non-alphanumeric characters. + +Note this function does not support UTF-8 characters. ``` initcap(str) diff --git a/test-utils/Cargo.toml b/test-utils/Cargo.toml index d2763f507ffa..4ad6e213cda3 100644 --- a/test-utils/Cargo.toml +++ b/test-utils/Cargo.toml @@ -30,5 +30,4 @@ arrow = { workspace = true } chrono-tz = { version = "0.10.0", default-features = false } datafusion-common = { workspace = true, default-features = true } env_logger = { workspace = true } -paste = "1.0.15" rand = { workspace = true }